Compare commits
4 Commits
refactorin
...
v0.9.5
| Author | SHA1 | Date | |
|---|---|---|---|
| 889976c7ee | |||
| d3d2cb0915 | |||
| bf82e38b6e | |||
| e98fd1de93 |
@@ -1 +1 @@
|
|||||||
v0.9.4
|
v0.9.5
|
||||||
|
|||||||
@@ -18,8 +18,8 @@ logpath = /dev/null
|
|||||||
backend = auto
|
backend = auto
|
||||||
maxretry = 1
|
maxretry = 1
|
||||||
findtime = 1d
|
findtime = 1d
|
||||||
# Block imported IPs for one week.
|
# Block imported IPs for 24 hours.
|
||||||
bantime = 1w
|
bantime = 86400
|
||||||
|
|
||||||
# Never ban the Docker bridge network or localhost.
|
# Never ban the Docker bridge network or localhost.
|
||||||
ignoreip = 127.0.0.0/8 ::1 172.16.0.0/12
|
ignoreip = 127.0.0.0/8 ::1 172.16.0.0/12
|
||||||
|
|||||||
@@ -56,11 +56,8 @@ echo " Registry : ${REGISTRY}"
|
|||||||
echo " Tag : ${TAG}"
|
echo " Tag : ${TAG}"
|
||||||
echo "============================================"
|
echo "============================================"
|
||||||
|
|
||||||
if [[ "${ENGINE}" == "podman" ]]; then
|
log "Logging in to ${REGISTRY}"
|
||||||
if ! podman login --get-login "${REGISTRY}" &>/dev/null; then
|
"${ENGINE}" login "${REGISTRY}"
|
||||||
err "Not logged in. Run:\n podman login ${REGISTRY}"
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Build
|
# Build
|
||||||
|
|||||||
@@ -82,12 +82,10 @@ The backend follows a **layered architecture** with strict separation of concern
|
|||||||
backend/
|
backend/
|
||||||
├── app/
|
├── app/
|
||||||
│ ├── __init__.py
|
│ ├── __init__.py
|
||||||
│ ├── `main.py` # FastAPI app factory, lifespan, exception handlers
|
│ ├── main.py # FastAPI app factory, lifespan, exception handlers
|
||||||
│ ├── `config.py` # Pydantic settings (env vars, .env loading)
|
│ ├── config.py # Pydantic settings (env vars, .env loading)
|
||||||
│ ├── `db.py` # Database connection and initialization
|
│ ├── dependencies.py # FastAPI Depends() providers (DB, services, auth)
|
||||||
│ ├── `exceptions.py` # Shared domain exception classes
|
│ ├── models/ # Pydantic schemas
|
||||||
│ ├── `dependencies.py` # FastAPI Depends() providers (DB, services, auth)
|
|
||||||
│ ├── `models/` # Pydantic schemas
|
|
||||||
│ │ ├── auth.py # Login request/response, session models
|
│ │ ├── auth.py # Login request/response, session models
|
||||||
│ │ ├── ban.py # Ban request/response/domain models
|
│ │ ├── ban.py # Ban request/response/domain models
|
||||||
│ │ ├── jail.py # Jail request/response/domain models
|
│ │ ├── jail.py # Jail request/response/domain models
|
||||||
@@ -113,12 +111,6 @@ backend/
|
|||||||
│ │ ├── jail_service.py # Jail listing, start/stop/reload, status aggregation
|
│ │ ├── jail_service.py # Jail listing, start/stop/reload, status aggregation
|
||||||
│ │ ├── ban_service.py # Ban/unban execution, currently-banned queries
|
│ │ ├── ban_service.py # Ban/unban execution, currently-banned queries
|
||||||
│ │ ├── config_service.py # Read/write fail2ban config, regex validation
|
│ │ ├── config_service.py # Read/write fail2ban config, regex validation
|
||||||
│ │ ├── config_file_service.py # Shared config parsing and file-level operations
|
|
||||||
│ │ ├── raw_config_io_service.py # Raw config file I/O wrapper
|
|
||||||
│ │ ├── jail_config_service.py # jail config activation/deactivation logic
|
|
||||||
│ │ ├── filter_config_service.py # filter config lifecycle management
|
|
||||||
│ │ ├── action_config_service.py # action config lifecycle management
|
|
||||||
│ │ ├── log_service.py # Log preview and regex test operations
|
|
||||||
│ │ ├── history_service.py # Historical ban queries, per-IP timeline
|
│ │ ├── history_service.py # Historical ban queries, per-IP timeline
|
||||||
│ │ ├── blocklist_service.py # Download, validate, apply blocklists
|
│ │ ├── blocklist_service.py # Download, validate, apply blocklists
|
||||||
│ │ ├── geo_service.py # IP-to-country resolution, ASN/RIR lookup
|
│ │ ├── geo_service.py # IP-to-country resolution, ASN/RIR lookup
|
||||||
@@ -127,18 +119,17 @@ backend/
|
|||||||
│ ├── repositories/ # Data access layer (raw queries only)
|
│ ├── repositories/ # Data access layer (raw queries only)
|
||||||
│ │ ├── settings_repo.py # App configuration CRUD in SQLite
|
│ │ ├── settings_repo.py # App configuration CRUD in SQLite
|
||||||
│ │ ├── session_repo.py # Session storage and lookup
|
│ │ ├── session_repo.py # Session storage and lookup
|
||||||
│ │ ├── blocklist_repo.py # Blocklist sources and import log persistence│ │ ├── fail2ban_db_repo.py # fail2ban SQLite ban history read operations
|
│ │ ├── blocklist_repo.py # Blocklist sources and import log persistence
|
||||||
│ │ ├── geo_cache_repo.py # IP geolocation cache persistence│ │ └── import_log_repo.py # Import run history records
|
│ │ └── import_log_repo.py # Import run history records
|
||||||
│ ├── tasks/ # APScheduler background jobs
|
│ ├── tasks/ # APScheduler background jobs
|
||||||
│ │ ├── blocklist_import.py# Scheduled blocklist download and application
|
│ │ ├── blocklist_import.py# Scheduled blocklist download and application
|
||||||
│ │ ├── geo_cache_flush.py # Periodic geo cache persistence (dirty-set flush to SQLite)│ │ ├── geo_re_resolve.py # Periodic re-resolution of stale geo cache records│ │ └── health_check.py # Periodic fail2ban connectivity probe
|
│ │ ├── geo_cache_flush.py # Periodic geo cache persistence (dirty-set flush to SQLite)
|
||||||
|
│ │ └── health_check.py # Periodic fail2ban connectivity probe
|
||||||
│ └── utils/ # Helpers, constants, shared types
|
│ └── utils/ # Helpers, constants, shared types
|
||||||
│ ├── fail2ban_client.py # Async wrapper around the fail2ban socket protocol
|
│ ├── fail2ban_client.py # Async wrapper around the fail2ban socket protocol
|
||||||
│ ├── ip_utils.py # IP/CIDR validation and normalisation
|
│ ├── ip_utils.py # IP/CIDR validation and normalisation
|
||||||
│ ├── time_utils.py # Timezone-aware datetime helpers│ ├── jail_config.py # Jail config parser/serializer helper
|
│ ├── time_utils.py # Timezone-aware datetime helpers
|
||||||
│ ├── conffile_parser.py # Fail2ban config file parser/serializer
|
│ └── constants.py # Shared constants (default paths, limits, etc.)
|
||||||
│ ├── config_parser.py # Structured config object parser
|
|
||||||
│ ├── config_writer.py # Atomic config file write operations│ └── constants.py # Shared constants (default paths, limits, etc.)
|
|
||||||
├── tests/
|
├── tests/
|
||||||
│ ├── conftest.py # Shared fixtures (test app, client, mock DB)
|
│ ├── conftest.py # Shared fixtures (test app, client, mock DB)
|
||||||
│ ├── test_routers/ # One test file per router
|
│ ├── test_routers/ # One test file per router
|
||||||
@@ -167,9 +158,8 @@ The HTTP interface layer. Each router maps URL paths to handler functions. Route
|
|||||||
| `blocklist.py` | `/api/blocklists` | CRUD blocklist sources, trigger import, view import logs |
|
| `blocklist.py` | `/api/blocklists` | CRUD blocklist sources, trigger import, view import logs |
|
||||||
| `geo.py` | `/api/geo` | IP geolocation lookup, ASN and RIR data |
|
| `geo.py` | `/api/geo` | IP geolocation lookup, ASN and RIR data |
|
||||||
| `server.py` | `/api/server` | Log level, log target, DB path, purge age, flush logs |
|
| `server.py` | `/api/server` | Log level, log target, DB path, purge age, flush logs |
|
||||||
| `health.py` | `/api/health` | fail2ban connectivity health check and status |
|
|
||||||
|
|
||||||
#### Services (`app/services`)
|
#### Services (`app/services/`)
|
||||||
|
|
||||||
The business logic layer. Services orchestrate operations, enforce rules, and coordinate between repositories, the fail2ban client, and external APIs. Each service covers a single domain.
|
The business logic layer. Services orchestrate operations, enforce rules, and coordinate between repositories, the fail2ban client, and external APIs. Each service covers a single domain.
|
||||||
|
|
||||||
@@ -181,12 +171,8 @@ The business logic layer. Services orchestrate operations, enforce rules, and co
|
|||||||
| `ban_service.py` | Executes ban and unban commands via the fail2ban socket, queries the currently banned IP list, validates IPs before banning |
|
| `ban_service.py` | Executes ban and unban commands via the fail2ban socket, queries the currently banned IP list, validates IPs before banning |
|
||||||
| `config_service.py` | Reads active jail and filter configuration from fail2ban, writes configuration changes, validates regex patterns, triggers reload; reads the fail2ban log file tail and queries service status for the Log tab |
|
| `config_service.py` | Reads active jail and filter configuration from fail2ban, writes configuration changes, validates regex patterns, triggers reload; reads the fail2ban log file tail and queries service status for the Log tab |
|
||||||
| `file_config_service.py` | Reads and writes raw fail2ban config files on disk (jail.d/, filter.d/, action.d/); lists files, reads content, overwrites files, toggles enabled/disabled |
|
| `file_config_service.py` | Reads and writes raw fail2ban config files on disk (jail.d/, filter.d/, action.d/); lists files, reads content, overwrites files, toggles enabled/disabled |
|
||||||
| `jail_config_service.py` | Discovers inactive jails by parsing jail.conf / jail.local / jail.d/*; writes .local overrides to activate/deactivate jails; triggers fail2ban reload; validates jail configurations |
|
| `config_file_service.py` | Parses jail.conf / jail.local / jail.d/* to discover inactive jails; writes .local overrides to activate or deactivate jails; triggers fail2ban reload |
|
||||||
| `filter_config_service.py` | Discovers available filters by scanning filter.d/; reads, creates, updates, and deletes filter definitions; assigns filters to jails |
|
| `conffile_parser.py` | Parses fail2ban `.conf` files into structured Python types (jail config, filter config, action config); also serialises back to text |
|
||||||
| `action_config_service.py` | Discovers available actions by scanning action.d/; reads, creates, updates, and deletes action definitions; assigns actions to jails |
|
|
||||||
| `config_file_service.py` | Shared utilities for configuration parsing and manipulation: parses config files, validates names/IPs, manages atomic file writes, probes fail2ban socket |
|
|
||||||
| `raw_config_io_service.py` | Low-level file I/O for raw fail2ban config files |
|
|
||||||
| `log_service.py` | Log preview and regex test operations (extracted from config_service) |
|
|
||||||
| `history_service.py` | Queries the fail2ban database for historical ban records, builds per-IP timelines, computes ban counts and repeat-offender flags |
|
| `history_service.py` | Queries the fail2ban database for historical ban records, builds per-IP timelines, computes ban counts and repeat-offender flags |
|
||||||
| `blocklist_service.py` | Downloads blocklists via aiohttp, validates IPs/CIDRs, applies bans through fail2ban or iptables, logs import results |
|
| `blocklist_service.py` | Downloads blocklists via aiohttp, validates IPs/CIDRs, applies bans through fail2ban or iptables, logs import results |
|
||||||
| `geo_service.py` | Resolves IP addresses to country, ASN, and RIR using external APIs or a local database, caches results |
|
| `geo_service.py` | Resolves IP addresses to country, ASN, and RIR using external APIs or a local database, caches results |
|
||||||
@@ -202,26 +188,15 @@ The data access layer. Repositories execute raw SQL queries against the applicat
|
|||||||
| `settings_repo.py` | CRUD operations for application settings (master password hash, DB path, fail2ban socket path, preferences) |
|
| `settings_repo.py` | CRUD operations for application settings (master password hash, DB path, fail2ban socket path, preferences) |
|
||||||
| `session_repo.py` | Store, retrieve, and delete session records for authentication |
|
| `session_repo.py` | Store, retrieve, and delete session records for authentication |
|
||||||
| `blocklist_repo.py` | Persist blocklist source definitions (name, URL, enabled/disabled) |
|
| `blocklist_repo.py` | Persist blocklist source definitions (name, URL, enabled/disabled) |
|
||||||
| `fail2ban_db_repo.py` | Read historical ban records from the fail2ban SQLite database |
|
|
||||||
| `geo_cache_repo.py` | Persist and query IP geo resolution cache |
|
|
||||||
| `import_log_repo.py` | Record import run results (timestamp, source, IPs imported, errors) for the import log view |
|
| `import_log_repo.py` | Record import run results (timestamp, source, IPs imported, errors) for the import log view |
|
||||||
|
|
||||||
#### Models (`app/models/`)
|
#### Models (`app/models/`)
|
||||||
|
|
||||||
Pydantic schemas that define data shapes and validation. Models are split into three categories per domain.
|
Pydantic schemas that define data shapes and validation. Models are split into three categories per domain:
|
||||||
|
|
||||||
| Model file | Purpose |
|
- **Request models** — validate incoming API data (e.g., `BanRequest`, `LoginRequest`)
|
||||||
|---|---|
|
- **Response models** — shape outgoing API data (e.g., `JailResponse`, `BanListResponse`)
|
||||||
| `auth.py` | Login/request and session models |
|
- **Domain models** — internal representations used between services and repositories (e.g., `Ban`, `Jail`)
|
||||||
| `ban.py` | Ban creation and lookup models |
|
|
||||||
| `blocklist.py` | Blocklist source and import log models |
|
|
||||||
| `config.py` | Fail2ban config view/edit models |
|
|
||||||
| `file_config.py` | Raw config file read/write models |
|
|
||||||
| `geo.py` | Geo and ASN lookup models |
|
|
||||||
| `history.py` | Historical ban query and timeline models |
|
|
||||||
| `jail.py` | Jail listing and status models |
|
|
||||||
| `server.py` | Server status and settings models |
|
|
||||||
| `setup.py` | First-run setup wizard models |
|
|
||||||
|
|
||||||
#### Tasks (`app/tasks/`)
|
#### Tasks (`app/tasks/`)
|
||||||
|
|
||||||
@@ -231,7 +206,6 @@ APScheduler background jobs that run on a schedule without user interaction.
|
|||||||
|---|---|
|
|---|---|
|
||||||
| `blocklist_import.py` | Downloads all enabled blocklist sources, validates entries, applies bans, records results in the import log |
|
| `blocklist_import.py` | Downloads all enabled blocklist sources, validates entries, applies bans, records results in the import log |
|
||||||
| `geo_cache_flush.py` | Periodically flushes newly resolved IPs from the in-memory dirty set to the `geo_cache` SQLite table (default: every 60 seconds). GET requests populate only the in-memory cache; this task persists them without blocking any request. |
|
| `geo_cache_flush.py` | Periodically flushes newly resolved IPs from the in-memory dirty set to the `geo_cache` SQLite table (default: every 60 seconds). GET requests populate only the in-memory cache; this task persists them without blocking any request. |
|
||||||
| `geo_re_resolve.py` | Periodically re-resolves stale entries in `geo_cache` to keep geolocation data fresh |
|
|
||||||
| `health_check.py` | Periodically pings the fail2ban socket and updates the cached server status so the frontend always has fresh data |
|
| `health_check.py` | Periodically pings the fail2ban socket and updates the cached server status so the frontend always has fresh data |
|
||||||
|
|
||||||
#### Utils (`app/utils/`)
|
#### Utils (`app/utils/`)
|
||||||
@@ -242,16 +216,7 @@ Pure helper modules with no framework dependencies.
|
|||||||
|---|---|
|
|---|---|
|
||||||
| `fail2ban_client.py` | Async client that communicates with fail2ban via its Unix domain socket — sends commands and parses responses using the fail2ban protocol. Modelled after [`./fail2ban-master/fail2ban/client/csocket.py`](../fail2ban-master/fail2ban/client/csocket.py) and [`./fail2ban-master/fail2ban/client/fail2banclient.py`](../fail2ban-master/fail2ban/client/fail2banclient.py). |
|
| `fail2ban_client.py` | Async client that communicates with fail2ban via its Unix domain socket — sends commands and parses responses using the fail2ban protocol. Modelled after [`./fail2ban-master/fail2ban/client/csocket.py`](../fail2ban-master/fail2ban/client/csocket.py) and [`./fail2ban-master/fail2ban/client/fail2banclient.py`](../fail2ban-master/fail2ban/client/fail2banclient.py). |
|
||||||
| `ip_utils.py` | Validates IPv4/IPv6 addresses and CIDR ranges using the `ipaddress` stdlib module, normalises formats |
|
| `ip_utils.py` | Validates IPv4/IPv6 addresses and CIDR ranges using the `ipaddress` stdlib module, normalises formats |
|
||||||
| `jail_utils.py` | Jail helper functions for configuration and status inference |
|
|
||||||
| `jail_config.py` | Jail config parser and serializer for fail2ban config manipulation |
|
|
||||||
| `time_utils.py` | Timezone-aware datetime construction, formatting helpers, time-range calculations |
|
| `time_utils.py` | Timezone-aware datetime construction, formatting helpers, time-range calculations |
|
||||||
| `log_utils.py` | Structured log formatting and enrichment helpers |
|
|
||||||
| `conffile_parser.py` | Parses Fail2ban `.conf` files into structured objects and serialises back to text |
|
|
||||||
| `config_parser.py` | Builds structured config objects from file content tokens |
|
|
||||||
| `config_writer.py` | Atomic config file writes, backups, and safe replace semantics |
|
|
||||||
| `config_file_utils.py` | Common file-level config utility helpers |
|
|
||||||
| `fail2ban_db_utils.py` | Fail2ban DB path discovery and ban-history parsing helpers |
|
|
||||||
| `setup_utils.py` | Setup wizard helper utilities |
|
|
||||||
| `constants.py` | Shared constants: default socket path, default database path, time-range presets, limits |
|
| `constants.py` | Shared constants: default socket path, default database path, time-range presets, limits |
|
||||||
|
|
||||||
#### Configuration (`app/config.py`)
|
#### Configuration (`app/config.py`)
|
||||||
|
|||||||
@@ -1,5 +1,238 @@
|
|||||||
# BanGUI — Architecture Issues & Refactoring Plan
|
# BanGUI — Refactoring Instructions for AI Agents
|
||||||
|
|
||||||
This document catalogues architecture violations, code smells, and structural issues found during a full project review. Issues are grouped by category and prioritised.
|
This document is the single source of truth for any AI agent performing a refactoring task on the BanGUI codebase.
|
||||||
|
Read it in full before writing a single line of code.
|
||||||
|
The authoritative description of every module, its responsibilities, and the allowed dependency direction is in [Architekture.md](Architekture.md). Always cross-reference it.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
## 0. Golden Rules
|
||||||
|
|
||||||
|
1. **Architecture first.** Every change must comply with the layered architecture defined in [Architekture.md §2](Architekture.md). Dependencies flow inward: `routers → services → repositories`. Never add an import that reverses this direction.
|
||||||
|
2. **One concern per file.** Each module has an explicitly stated purpose in [Architekture.md](Architekture.md). Do not add responsibilities to a module that do not belong there.
|
||||||
|
3. **No behaviour change.** Refactoring must preserve all existing behaviour. If a function's public signature, return value, or side-effects must change, that is a feature — create a separate task for it.
|
||||||
|
4. **Tests stay green.** Run the full test suite (`pytest backend/`) before and after every change. Do not submit work that introduces new failures.
|
||||||
|
5. **Smallest diff wins.** Prefer targeted edits. Do not rewrite a file when a few lines suffice.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 1. Before You Start
|
||||||
|
|
||||||
|
### 1.1 Understand the project
|
||||||
|
|
||||||
|
Read the following documents in order:
|
||||||
|
|
||||||
|
1. [Architekture.md](Architekture.md) — full system overview, component map, module purposes, dependency rules.
|
||||||
|
2. [Docs/Backend-Development.md](Backend-Development.md) — coding conventions, testing strategy, environment setup.
|
||||||
|
3. [Docs/Tasks.md](Tasks.md) — open issues and planned work; avoid touching areas that have pending conflicting changes.
|
||||||
|
|
||||||
|
### 1.2 Map the code to the architecture
|
||||||
|
|
||||||
|
Before editing, locate every file that is in scope:
|
||||||
|
|
||||||
|
```
|
||||||
|
backend/app/
|
||||||
|
routers/ HTTP layer — zero business logic
|
||||||
|
services/ Business logic — orchestrates repositories + clients
|
||||||
|
repositories/ Data access — raw SQL only
|
||||||
|
models/ Pydantic schemas
|
||||||
|
tasks/ APScheduler jobs
|
||||||
|
utils/ Pure helpers, no framework deps
|
||||||
|
main.py App factory, lifespan, middleware
|
||||||
|
config.py Pydantic settings
|
||||||
|
dependencies.py FastAPI Depends() wiring
|
||||||
|
|
||||||
|
frontend/src/
|
||||||
|
api/ Typed fetch wrappers + endpoint constants
|
||||||
|
components/ Presentational UI, no API calls
|
||||||
|
hooks/ All state, side-effects, API calls
|
||||||
|
pages/ Route components — orchestration only
|
||||||
|
providers/ React context
|
||||||
|
types/ TypeScript interfaces
|
||||||
|
utils/ Pure helpers
|
||||||
|
```
|
||||||
|
|
||||||
|
Confirm which layer every file you intend to touch belongs to. If unsure, consult [Architekture.md §2.2](Architekture.md) (backend) or [Architekture.md §3.2](Architekture.md) (frontend).
|
||||||
|
|
||||||
|
### 1.3 Run the baseline
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Backend
|
||||||
|
pytest backend/ -x --tb=short
|
||||||
|
|
||||||
|
# Frontend
|
||||||
|
cd frontend && npm run test
|
||||||
|
```
|
||||||
|
|
||||||
|
Record the number of passing tests. After refactoring, that number must be equal or higher.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. Backend Refactoring
|
||||||
|
|
||||||
|
### 2.1 Routers (`app/routers/`)
|
||||||
|
|
||||||
|
**Allowed content:** request parsing, response serialisation, dependency injection via `Depends()`, delegation to a service, HTTP error mapping.
|
||||||
|
**Forbidden content:** SQL queries, business logic, direct use of `fail2ban_client`, any logic that would also make sense in a unit test without an HTTP request.
|
||||||
|
|
||||||
|
Checklist:
|
||||||
|
- [ ] Every handler calls exactly one service method per logical operation.
|
||||||
|
- [ ] No `if`/`elif` chains that implement business rules — move these to the service.
|
||||||
|
- [ ] No raw SQL or repository imports.
|
||||||
|
- [ ] All response models are Pydantic schemas from `app/models/`.
|
||||||
|
- [ ] HTTP status codes are consistent with API conventions (200 OK, 201 Created, 204 No Content, 400/422 for client errors, 404 for missing resources, 500 only for unexpected failures).
|
||||||
|
|
||||||
|
### 2.2 Services (`app/services/`)
|
||||||
|
|
||||||
|
**Allowed content:** business rules, coordination between repositories and external clients, validation that goes beyond Pydantic, fail2ban command orchestration.
|
||||||
|
**Forbidden content:** raw SQL, direct aiosqlite calls, FastAPI `HTTPException` (raise domain exceptions instead and let the router or exception handler convert them).
|
||||||
|
|
||||||
|
Checklist:
|
||||||
|
- [ ] Service classes / functions accept plain Python types or domain models — not `Request` or `Response` objects.
|
||||||
|
- [ ] No direct `aiosqlite` usage — go through a repository.
|
||||||
|
- [ ] No `HTTPException` — raise a custom domain exception or a plain `ValueError`/`RuntimeError` with a clear message.
|
||||||
|
- [ ] No circular imports between services — if two services need each other's logic, extract the shared logic to a utility or a third service.
|
||||||
|
|
||||||
|
### 2.3 Repositories (`app/repositories/`)
|
||||||
|
|
||||||
|
**Allowed content:** SQL queries, result mapping to domain models, transaction management.
|
||||||
|
**Forbidden content:** business logic, fail2ban calls, HTTP concerns, logging beyond debug-level traces.
|
||||||
|
|
||||||
|
Checklist:
|
||||||
|
- [ ] Every public method accepts a `db: aiosqlite.Connection` parameter — sessions are not managed internally.
|
||||||
|
- [ ] Methods return typed domain models or plain Python primitives, never raw `aiosqlite.Row` objects exposed to callers.
|
||||||
|
- [ ] No business rules (e.g., no "if this setting is missing, create a default" logic — that belongs in the service).
|
||||||
|
|
||||||
|
### 2.4 Models (`app/models/`)
|
||||||
|
|
||||||
|
- Keep **Request**, **Response**, and **Domain** model types clearly separated (see [Architekture.md §2.2](Architekture.md)).
|
||||||
|
- Do not use response models as function arguments inside service or repository code.
|
||||||
|
- Validators (`@field_validator`, `@model_validator`) belong in models only when they concern data shape, not business rules.
|
||||||
|
|
||||||
|
### 2.5 Tasks (`app/tasks/`)
|
||||||
|
|
||||||
|
- Tasks must be thin: fetch inputs → call one service method → log result.
|
||||||
|
- Error handling must be inside the task (APScheduler swallows unhandled exceptions — log them explicitly).
|
||||||
|
- No direct repository or `fail2ban_client` use; go through a service.
|
||||||
|
|
||||||
|
### 2.6 Utils (`app/utils/`)
|
||||||
|
|
||||||
|
- Must have zero framework dependencies (no FastAPI, no aiosqlite imports).
|
||||||
|
- Must be pure or near-pure functions.
|
||||||
|
- `fail2ban_client.py` is the single exception — it wraps the socket protocol but still has no service-layer logic.
|
||||||
|
|
||||||
|
### 2.7 Dependencies (`app/dependencies.py`)
|
||||||
|
|
||||||
|
- This file is the **only** place where service constructors are called and injected.
|
||||||
|
- Do not construct services inside router handlers; always receive them via `Depends()`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3. Frontend Refactoring
|
||||||
|
|
||||||
|
### 3.1 Pages (`src/pages/`)
|
||||||
|
|
||||||
|
**Allowed content:** composing components and hooks, layout decisions, routing.
|
||||||
|
**Forbidden content:** direct `fetch`/`axios` calls, inline business logic, state management beyond what is needed to coordinate child components.
|
||||||
|
|
||||||
|
Checklist:
|
||||||
|
- [ ] All data fetching goes through a hook from `src/hooks/`.
|
||||||
|
- [ ] No API function from `src/api/` is called directly inside a page component.
|
||||||
|
|
||||||
|
### 3.2 Components (`src/components/`)
|
||||||
|
|
||||||
|
**Allowed content:** rendering, styling, event handlers that call prop callbacks.
|
||||||
|
**Forbidden content:** API calls, hook-level state (prefer lifting state to the page or a dedicated hook), direct use of `src/api/`.
|
||||||
|
|
||||||
|
Checklist:
|
||||||
|
- [ ] Components receive all data via props.
|
||||||
|
- [ ] Components emit changes via callback props (`onXxx`).
|
||||||
|
- [ ] No `useEffect` that calls an API function — that belongs in a hook.
|
||||||
|
|
||||||
|
### 3.3 Hooks (`src/hooks/`)
|
||||||
|
|
||||||
|
**Allowed content:** `useState`, `useEffect`, `useCallback`, `useRef`; calls to `src/api/`; local state derivation.
|
||||||
|
**Forbidden content:** JSX rendering, Fluent UI components.
|
||||||
|
|
||||||
|
Checklist:
|
||||||
|
- [ ] Each hook has a single, focused concern matching its name (e.g., `useBans` only manages ban data).
|
||||||
|
- [ ] Hooks return a stable interface: `{ data, loading, error, refetch }` or equivalent.
|
||||||
|
- [ ] Shared logic between hooks is extracted to `src/utils/` (pure) or a parent hook (stateful).
|
||||||
|
|
||||||
|
### 3.4 API layer (`src/api/`)
|
||||||
|
|
||||||
|
- `client.ts` is the only place that calls `fetch`. All other api files call `client.ts`.
|
||||||
|
- `endpoints.ts` is the single source of truth for URL strings.
|
||||||
|
- API functions must be typed: explicit request and response TypeScript interfaces from `src/types/`.
|
||||||
|
|
||||||
|
### 3.5 Types (`src/types/`)
|
||||||
|
|
||||||
|
- Interfaces must match the backend Pydantic response schemas exactly (field names, optionality).
|
||||||
|
- Do not use `any`. Use `unknown` and narrow with type guards when the shape is genuinely unknown.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4. General Code Quality Rules
|
||||||
|
|
||||||
|
### Naming
|
||||||
|
- Python: `snake_case` for variables/functions, `PascalCase` for classes.
|
||||||
|
- TypeScript: `camelCase` for variables/functions, `PascalCase` for components and types.
|
||||||
|
- File names must match the primary export they contain.
|
||||||
|
|
||||||
|
### Error handling
|
||||||
|
- Backend: raise typed exceptions; map them to HTTP status codes in `main.py` exception handlers or in the router — nowhere else.
|
||||||
|
- Frontend: all API call error states are represented in hook return values; never swallow errors silently.
|
||||||
|
|
||||||
|
### Logging (backend)
|
||||||
|
- Use `structlog` with bound context loggers — never bare `print()`.
|
||||||
|
- Log at `debug` in repositories, `info` in services for meaningful events, `warning`/`error` in tasks and exception handlers.
|
||||||
|
- Never log sensitive data (passwords, session tokens, raw IP lists larger than a handful of entries).
|
||||||
|
|
||||||
|
### Async correctness (backend)
|
||||||
|
- Every function that touches I/O (database, fail2ban socket, HTTP) must be `async def`.
|
||||||
|
- Never call `asyncio.run()` inside a running event loop.
|
||||||
|
- Do not use `time.sleep()` — use `await asyncio.sleep()`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 5. Refactoring Workflow
|
||||||
|
|
||||||
|
Follow this sequence for every refactoring task:
|
||||||
|
|
||||||
|
1. **Read** the relevant section of [Architekture.md](Architekture.md) for the files you will touch.
|
||||||
|
2. **Run** the full test suite to confirm the baseline.
|
||||||
|
3. **Identify** the violation or smell: which rule from this document does it break?
|
||||||
|
4. **Plan** the minimal change: what is the smallest edit that fixes the violation?
|
||||||
|
5. **Edit** the code. One logical change per commit.
|
||||||
|
6. **Verify** imports: nothing new violates the dependency direction.
|
||||||
|
7. **Run** the full test suite. All previously passing tests must still pass.
|
||||||
|
8. **Update** any affected docstrings or inline comments to reflect the new structure.
|
||||||
|
9. **Do not** update `Architekture.md` unless the refactor changes the documented structure — that requires a separate review.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 6. Common Violations to Look For
|
||||||
|
|
||||||
|
| Violation | Where it typically appears | Fix |
|
||||||
|
|---|---|---|
|
||||||
|
| Business logic in a router handler | `app/routers/*.py` | Extract logic to the corresponding service |
|
||||||
|
| Direct `aiosqlite` calls in a service | `app/services/*.py` | Move the query into the matching repository |
|
||||||
|
| `HTTPException` raised inside a service | `app/services/*.py` | Raise a domain exception; catch and convert it in the router or exception handler |
|
||||||
|
| API call inside a React component | `src/components/*.tsx` | Move to a hook; pass data via props |
|
||||||
|
| Hardcoded URL string in a hook or component | `src/hooks/*.ts`, `src/components/*.tsx` | Use the constant from `src/api/endpoints.ts` |
|
||||||
|
| `any` type in TypeScript | anywhere in `src/` | Replace with a concrete interface from `src/types/` |
|
||||||
|
| `print()` statements in production code | `backend/app/**/*.py` | Replace with `structlog` logger |
|
||||||
|
| Synchronous I/O in an async function | `backend/app/**/*.py` | Use the async equivalent |
|
||||||
|
| A repository method that contains an `if` with a business rule | `app/repositories/*.py` | Move the rule to the service layer |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 7. Out of Scope
|
||||||
|
|
||||||
|
Do not make the following changes unless explicitly instructed in a separate task:
|
||||||
|
|
||||||
|
- Adding new API endpoints or pages.
|
||||||
|
- Changing database schema or migration files.
|
||||||
|
- Upgrading dependencies.
|
||||||
|
- Altering Docker or CI configuration.
|
||||||
|
- Modifying `Architekture.md` or `Tasks.md`.
|
||||||
|
|||||||
@@ -2,8 +2,91 @@
|
|||||||
|
|
||||||
This document breaks the entire BanGUI project into development stages, ordered so that each stage builds on the previous one. Every task is described in prose with enough detail for a developer to begin work. References point to the relevant documentation.
|
This document breaks the entire BanGUI project into development stages, ordered so that each stage builds on the previous one. Every task is described in prose with enough detail for a developer to begin work. References point to the relevant documentation.
|
||||||
|
|
||||||
Reference: `Docs/Refactoring.md` for full analysis of each issue.
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Open Issues
|
## Open Issues
|
||||||
|
|
||||||
|
> **Architectural Review — 2026-03-16**
|
||||||
|
> The findings below were identified by auditing every backend and frontend module against the rules in [Refactoring.md](Refactoring.md) and [Architekture.md](Architekture.md).
|
||||||
|
> Tasks are grouped by layer and ordered so that lower-level fixes (repositories, services) are done before the layers that depend on them.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 1 — Blocklist-import jail ban time must be 24 hours
|
||||||
|
|
||||||
|
**Status:** ✅ Done
|
||||||
|
|
||||||
|
**Context**
|
||||||
|
|
||||||
|
When the blocklist importer bans an IP it calls `jail_service.ban_ip(socket_path, BLOCKLIST_JAIL, ip)` (see `backend/app/services/blocklist_service.py`, constant `BLOCKLIST_JAIL = "blocklist-import"`). That call sends `set blocklist-import banip <ip>` to fail2ban, which applies the jail's configured `bantime`. There is currently no guarantee that the `blocklist-import` jail's `bantime` is 86 400 s (24 h), so imported IPs may be released too early or held indefinitely depending on the jail template.
|
||||||
|
|
||||||
|
**What to do**
|
||||||
|
|
||||||
|
1. Locate every place the `blocklist-import` jail is defined or provisioned — check `Docker/fail2ban-dev-config/`, `Docker/Dockerfile.backend`, any jail template files, and the `setup_service.py` / `SetupPage.tsx` flow.
|
||||||
|
2. Ensure the `blocklist-import` jail is created with `bantime = 86400` (24 h). If the jail is created at runtime by the setup service, add or update the `bantime` parameter there. If it is defined in a static config file, set `bantime = 86400` in that file.
|
||||||
|
3. Verify that the existing `jail_service.ban_ip` call in `blocklist_service.import_source` does not need a per-call duration override; the jail-level default of 86 400 s is sufficient.
|
||||||
|
4. Add or update the relevant unit/integration test in `backend/tests/` to assert that the blocklist-import jail is set up with a 24-hour bantime.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 2 — Clicking a jail in Jail Overview navigates to Configuration → Jails
|
||||||
|
|
||||||
|
**Status:** ✅ Done
|
||||||
|
|
||||||
|
**Context**
|
||||||
|
|
||||||
|
`JailsPage.tsx` renders a "Jail Overview" data grid with one row per jail (see `frontend/src/pages/JailsPage.tsx`). Clicking a row currently does nothing. `ConfigPage.tsx` hosts a tab bar with a "Jails" tab that renders `JailsTab`, which already uses a list/detail layout where a jail can be selected from the left pane.
|
||||||
|
|
||||||
|
**What to do**
|
||||||
|
|
||||||
|
1. In `JailsPage.tsx`, make each jail name cell (or the entire row) a clickable element that navigates to `/config` with state `{ tab: "jails", jail: "<jailName>" }`. Use `useNavigate` from `react-router-dom`; the existing `Link` import can be used or replaced with a programmatic navigate.
|
||||||
|
2. In `ConfigPage.tsx`, read the location state on mount. If `state.tab` is `"jails"`, set the active tab to `"jails"`. Pass `state.jail` down to `<JailsTab initialJail={state.jail} />`.
|
||||||
|
3. In `JailsTab.tsx`, accept an optional `initialJail?: string` prop. When it is provided, pre-select that jail in the left-pane list on first render (i.e. set the selected jail state to the jail whose name matches `initialJail`). This should scroll the item into view if the list is long.
|
||||||
|
4. Add a frontend unit test in `frontend/src/pages/__tests__/` that mounts `JailsPage` with a mocked jail list, clicks a jail row, and asserts that `useNavigate` was called with the correct path and state.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 3 — Setting bantime / findtime throws 400 error due to unsupported `backend` set command
|
||||||
|
|
||||||
|
**Status:** ✅ Done
|
||||||
|
|
||||||
|
**Context**
|
||||||
|
|
||||||
|
Editing ban time or find time in Configuration → Jails triggers an auto-save that sends the full `JailConfigUpdate` payload including the `backend` field. `config_service.update_jail_config` then calls `set <jail> backend <value>` on the fail2ban socket, which returns error code 1 with the message `Invalid command 'backend' (no set action or not yet implemented)`. Fail2ban does not support changing a jail's backend at runtime; it must be set before the jail starts.
|
||||||
|
|
||||||
|
**What to do**
|
||||||
|
|
||||||
|
**Backend** (`backend/app/services/config_service.py`):
|
||||||
|
|
||||||
|
1. Remove the `if update.backend is not None: await _set("backend", update.backend)` block from `update_jail_config`. Setting `backend` via the socket is not supported by fail2ban and will always fail.
|
||||||
|
2. `log_encoding` has the same constraint — verify whether `set <jail> logencoding` is supported at runtime. If it is not, remove it too. If it is supported, leave it.
|
||||||
|
3. Ensure the function still accepts and stores the `backend` value in the Pydantic model for read purposes; do not remove it from `JailConfigUpdate` or the response model.
|
||||||
|
|
||||||
|
**Frontend** (`frontend/src/components/config/JailsTab.tsx`):
|
||||||
|
|
||||||
|
4. Remove `backend` (and `log_encoding` if step 2 confirms it is unsupported) from the `autoSavePayload` memo so the field is never sent in the PATCH/PUT body. The displayed value should remain read-only — show them as plain text or a disabled select so the user can see the current value without being able to trigger the broken set command.
|
||||||
|
|
||||||
|
**Tests**:
|
||||||
|
|
||||||
|
5. Add or update the backend test for `update_jail_config` to assert that no `set … backend` command is issued, and that a payload containing a `backend` field does not cause an error.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 4 — Unify filter bar: use `DashboardFilterBar` in World Map and History pages
|
||||||
|
|
||||||
|
**Status:** ✅ Done
|
||||||
|
|
||||||
|
**Context**
|
||||||
|
|
||||||
|
`DashboardPage.tsx` uses the shared `<DashboardFilterBar>` component for its time-range and origin-filter controls. `MapPage.tsx` and `HistoryPage.tsx` each implement their own ad-hoc filter UI: `MapPage` uses a Fluent UI `<Select>` for time range plus an inline Toolbar for origin filter; `HistoryPage` uses a `<Select>` for time range with no origin filter toggle. The `DashboardFilterBar` already supports both `TimeRange` and `BanOriginFilter` with the exact toggle-button style shown in the design reference. All three pages should share the same filter appearance and interaction patterns.
|
||||||
|
|
||||||
|
**What to do**
|
||||||
|
|
||||||
|
1. **`MapPage.tsx`**: Replace the custom time-range `<Select>` and the inline origin-filter Toolbar with `<DashboardFilterBar timeRange={range} onTimeRangeChange={setRange} originFilter={originFilter} onOriginFilterChange={setOriginFilter} />`. Remove the now-unused `TIME_RANGE_OPTIONS` constant and the `BAN_ORIGIN_FILTER_LABELS` inline usage. Pass `originFilter` to `useMapData` if it does not already receive it (check the hook signature).
|
||||||
|
2. **`HistoryPage.tsx`**: Replace the custom time-range `<Select>` with `<DashboardFilterBar>`. Add an `originFilter` state (`BanOriginFilter`, default `"all"`) and wire it through `<DashboardFilterBar onOriginFilterChange={setOriginFilter} />`. Pass the origin filter into the `useHistory` query so the backend receives it. If `useHistory` / `HistoryQuery` does not yet accept `origin_filter`, add the parameter to the type and the hook's fetch call.
|
||||||
|
3. Remove any local `filterBar` style definitions from `MapPage.tsx` and `HistoryPage.tsx` that duplicate what `DashboardFilterBar` already provides.
|
||||||
|
4. Ensure the `DashboardFilterBar` component's props interface (`DashboardFilterBarProps` in `frontend/src/components/DashboardFilterBar.tsx`) is not changed in a breaking way; only the call sites change.
|
||||||
|
5. Update or add component tests for `MapPage` and `HistoryPage` to assert that `DashboardFilterBar` is rendered and that changing the time range or origin filter updates the displayed data.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
|||||||
@@ -1,224 +0,0 @@
|
|||||||
# Config File Service Extraction Summary
|
|
||||||
|
|
||||||
## ✓ Extraction Complete
|
|
||||||
|
|
||||||
Three new service modules have been created by extracting functions from `config_file_service.py`.
|
|
||||||
|
|
||||||
### Files Created
|
|
||||||
|
|
||||||
| File | Lines | Status |
|
|
||||||
|------|-------|--------|
|
|
||||||
| [jail_config_service.py](jail_config_service.py) | 991 | ✓ Created |
|
|
||||||
| [filter_config_service.py](filter_config_service.py) | 765 | ✓ Created |
|
|
||||||
| [action_config_service.py](action_config_service.py) | 988 | ✓ Created |
|
|
||||||
| **Total** | **2,744** | **✓ Verified** |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 1. JAIL_CONFIG Service (`jail_config_service.py`)
|
|
||||||
|
|
||||||
### Public Functions (7)
|
|
||||||
- `list_inactive_jails(config_dir, socket_path)` → InactiveJailListResponse
|
|
||||||
- `activate_jail(config_dir, socket_path, name, req)` → JailActivationResponse
|
|
||||||
- `deactivate_jail(config_dir, socket_path, name)` → JailActivationResponse
|
|
||||||
- `delete_jail_local_override(config_dir, socket_path, name)` → None
|
|
||||||
- `validate_jail_config(config_dir, name)` → JailValidationResult
|
|
||||||
- `rollback_jail(config_dir, socket_path, name, start_cmd_parts)` → RollbackResponse
|
|
||||||
- `_rollback_activation_async(config_dir, name, socket_path, original_content)` → bool
|
|
||||||
|
|
||||||
### Helper Functions (5)
|
|
||||||
- `_write_local_override_sync()` - Atomic write of jail.d/{name}.local
|
|
||||||
- `_restore_local_file_sync()` - Restore or delete .local file during rollback
|
|
||||||
- `_validate_regex_patterns()` - Validate failregex/ignoreregex patterns
|
|
||||||
- `_set_jail_local_key_sync()` - Update single key in jail section
|
|
||||||
- `_validate_jail_config_sync()` - Synchronous validation (filter/action files, patterns, logpath)
|
|
||||||
|
|
||||||
### Custom Exceptions (3)
|
|
||||||
- `JailNotFoundInConfigError`
|
|
||||||
- `JailAlreadyActiveError`
|
|
||||||
- `JailAlreadyInactiveError`
|
|
||||||
|
|
||||||
### Shared Dependencies Imported
|
|
||||||
- `_safe_jail_name()` - From config_file_service
|
|
||||||
- `_parse_jails_sync()` - From config_file_service
|
|
||||||
- `_build_inactive_jail()` - From config_file_service
|
|
||||||
- `_get_active_jail_names()` - From config_file_service
|
|
||||||
- `_probe_fail2ban_running()` - From config_file_service
|
|
||||||
- `wait_for_fail2ban()` - From config_file_service
|
|
||||||
- `start_daemon()` - From config_file_service
|
|
||||||
- `_resolve_filter()` - From config_file_service
|
|
||||||
- `_parse_multiline()` - From config_file_service
|
|
||||||
- `_SOCKET_TIMEOUT`, `_META_SECTIONS` - Constants
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 2. FILTER_CONFIG Service (`filter_config_service.py`)
|
|
||||||
|
|
||||||
### Public Functions (6)
|
|
||||||
- `list_filters(config_dir, socket_path)` → FilterListResponse
|
|
||||||
- `get_filter(config_dir, socket_path, name)` → FilterConfig
|
|
||||||
- `update_filter(config_dir, socket_path, name, req, do_reload=False)` → FilterConfig
|
|
||||||
- `create_filter(config_dir, socket_path, req, do_reload=False)` → FilterConfig
|
|
||||||
- `delete_filter(config_dir, name)` → None
|
|
||||||
- `assign_filter_to_jail(config_dir, socket_path, jail_name, req, do_reload=False)` → None
|
|
||||||
|
|
||||||
### Helper Functions (4)
|
|
||||||
- `_extract_filter_base_name(filter_raw)` - Extract base name from filter string
|
|
||||||
- `_build_filter_to_jails_map()` - Map filters to jails using them
|
|
||||||
- `_parse_filters_sync()` - Scan filter.d/ and return tuples
|
|
||||||
- `_write_filter_local_sync()` - Atomic write of filter.d/{name}.local
|
|
||||||
- `_validate_regex_patterns()` - Validate regex patterns (shared with jail_config)
|
|
||||||
|
|
||||||
### Custom Exceptions (5)
|
|
||||||
- `FilterNotFoundError`
|
|
||||||
- `FilterAlreadyExistsError`
|
|
||||||
- `FilterReadonlyError`
|
|
||||||
- `FilterInvalidRegexError`
|
|
||||||
- `FilterNameError` (re-exported from config_file_service)
|
|
||||||
|
|
||||||
### Shared Dependencies Imported
|
|
||||||
- `_safe_filter_name()` - From config_file_service
|
|
||||||
- `_safe_jail_name()` - From config_file_service
|
|
||||||
- `_parse_jails_sync()` - From config_file_service
|
|
||||||
- `_get_active_jail_names()` - From config_file_service
|
|
||||||
- `_resolve_filter()` - From config_file_service
|
|
||||||
- `_parse_multiline()` - From config_file_service
|
|
||||||
- `_SAFE_FILTER_NAME_RE` - Constant pattern
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 3. ACTION_CONFIG Service (`action_config_service.py`)
|
|
||||||
|
|
||||||
### Public Functions (7)
|
|
||||||
- `list_actions(config_dir, socket_path)` → ActionListResponse
|
|
||||||
- `get_action(config_dir, socket_path, name)` → ActionConfig
|
|
||||||
- `update_action(config_dir, socket_path, name, req, do_reload=False)` → ActionConfig
|
|
||||||
- `create_action(config_dir, socket_path, req, do_reload=False)` → ActionConfig
|
|
||||||
- `delete_action(config_dir, name)` → None
|
|
||||||
- `assign_action_to_jail(config_dir, socket_path, jail_name, req, do_reload=False)` → None
|
|
||||||
- `remove_action_from_jail(config_dir, socket_path, jail_name, action_name, do_reload=False)` → None
|
|
||||||
|
|
||||||
### Helper Functions (5)
|
|
||||||
- `_safe_action_name(name)` - Validate action name
|
|
||||||
- `_extract_action_base_name()` - Extract base name from action string
|
|
||||||
- `_build_action_to_jails_map()` - Map actions to jails using them
|
|
||||||
- `_parse_actions_sync()` - Scan action.d/ and return tuples
|
|
||||||
- `_append_jail_action_sync()` - Append action to jail.d/{name}.local
|
|
||||||
- `_remove_jail_action_sync()` - Remove action from jail.d/{name}.local
|
|
||||||
- `_write_action_local_sync()` - Atomic write of action.d/{name}.local
|
|
||||||
|
|
||||||
### Custom Exceptions (4)
|
|
||||||
- `ActionNotFoundError`
|
|
||||||
- `ActionAlreadyExistsError`
|
|
||||||
- `ActionReadonlyError`
|
|
||||||
- `ActionNameError`
|
|
||||||
|
|
||||||
### Shared Dependencies Imported
|
|
||||||
- `_safe_jail_name()` - From config_file_service
|
|
||||||
- `_parse_jails_sync()` - From config_file_service
|
|
||||||
- `_get_active_jail_names()` - From config_file_service
|
|
||||||
- `_build_parser()` - From config_file_service
|
|
||||||
- `_SAFE_ACTION_NAME_RE` - Constant pattern
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 4. SHARED Utilities (remain in `config_file_service.py`)
|
|
||||||
|
|
||||||
### Utility Functions (14)
|
|
||||||
- `_safe_jail_name(name)` → str
|
|
||||||
- `_safe_filter_name(name)` → str
|
|
||||||
- `_ordered_config_files(config_dir)` → list[Path]
|
|
||||||
- `_build_parser()` → configparser.RawConfigParser
|
|
||||||
- `_is_truthy(value)` → bool
|
|
||||||
- `_parse_int_safe(value)` → int | None
|
|
||||||
- `_parse_time_to_seconds(value, default)` → int
|
|
||||||
- `_parse_multiline(raw)` → list[str]
|
|
||||||
- `_resolve_filter(raw_filter, jail_name, mode)` → str
|
|
||||||
- `_parse_jails_sync(config_dir)` → tuple
|
|
||||||
- `_build_inactive_jail(name, settings, source_file, config_dir=None)` → InactiveJail
|
|
||||||
- `_get_active_jail_names(socket_path)` → set[str]
|
|
||||||
- `_probe_fail2ban_running(socket_path)` → bool
|
|
||||||
- `wait_for_fail2ban(socket_path, max_wait_seconds, poll_interval)` → bool
|
|
||||||
- `start_daemon(start_cmd_parts)` → bool
|
|
||||||
|
|
||||||
### Shared Exceptions (3)
|
|
||||||
- `JailNameError`
|
|
||||||
- `FilterNameError`
|
|
||||||
- `ConfigWriteError`
|
|
||||||
|
|
||||||
### Constants (7)
|
|
||||||
- `_SOCKET_TIMEOUT`
|
|
||||||
- `_SAFE_JAIL_NAME_RE`
|
|
||||||
- `_META_SECTIONS`
|
|
||||||
- `_TRUE_VALUES`
|
|
||||||
- `_FALSE_VALUES`
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Import Dependencies
|
|
||||||
|
|
||||||
### jail_config_service imports:
|
|
||||||
```python
|
|
||||||
config_file_service: (shared utilities + private functions)
|
|
||||||
jail_service.reload_all()
|
|
||||||
Fail2BanConnectionError
|
|
||||||
```
|
|
||||||
|
|
||||||
### filter_config_service imports:
|
|
||||||
```python
|
|
||||||
config_file_service: (shared utilities + _set_jail_local_key_sync)
|
|
||||||
jail_service.reload_all()
|
|
||||||
conffile_parser: (parse/merge/serialize filter functions)
|
|
||||||
jail_config_service: (JailNotFoundInConfigError - lazy import)
|
|
||||||
```
|
|
||||||
|
|
||||||
### action_config_service imports:
|
|
||||||
```python
|
|
||||||
config_file_service: (shared utilities + _build_parser)
|
|
||||||
jail_service.reload_all()
|
|
||||||
conffile_parser: (parse/merge/serialize action functions)
|
|
||||||
jail_config_service: (JailNotFoundInConfigError - lazy import)
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Cross-Service Dependencies
|
|
||||||
|
|
||||||
**Circular imports handled via lazy imports:**
|
|
||||||
- `filter_config_service` imports `JailNotFoundInConfigError` from `jail_config_service` inside function
|
|
||||||
- `action_config_service` imports `JailNotFoundInConfigError` from `jail_config_service` inside function
|
|
||||||
|
|
||||||
**Shared functions re-used:**
|
|
||||||
- `_set_jail_local_key_sync()` exported from `jail_config_service`, used by `filter_config_service`
|
|
||||||
- `_append_jail_action_sync()` and `_remove_jail_action_sync()` internal to `action_config_service`
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Verification Results
|
|
||||||
|
|
||||||
✓ **Syntax Check:** All three files compile without errors
|
|
||||||
✓ **Import Verification:** All imports resolved correctly
|
|
||||||
✓ **Total Lines:** 2,744 lines across three new files
|
|
||||||
✓ **Function Coverage:** 100% of specified functions extracted
|
|
||||||
✓ **Type Hints:** Preserved throughout
|
|
||||||
✓ **Docstrings:** All preserved with full documentation
|
|
||||||
✓ **Comments:** All inline comments preserved
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Next Steps (if needed)
|
|
||||||
|
|
||||||
1. **Update router imports** - Point from config_file_service to specific service modules:
|
|
||||||
- `jail_config_service` for jail operations
|
|
||||||
- `filter_config_service` for filter operations
|
|
||||||
- `action_config_service` for action operations
|
|
||||||
|
|
||||||
2. **Update config_file_service.py** - Remove all extracted functions (optional cleanup)
|
|
||||||
- Optionally keep it as a facade/aggregator
|
|
||||||
- Or reduce it to only the shared utilities module
|
|
||||||
|
|
||||||
3. **Add __all__ exports** to each new module for cleaner public API
|
|
||||||
|
|
||||||
4. **Update type hints** in models if needed for cross-service usage
|
|
||||||
|
|
||||||
5. **Testing** - Run existing tests to ensure no regressions
|
|
||||||
@@ -1 +1,50 @@
|
|||||||
"""BanGUI backend application package."""
|
"""BanGUI backend application package.
|
||||||
|
|
||||||
|
This package exposes the application version based on the project metadata.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Final
|
||||||
|
|
||||||
|
import importlib.metadata
|
||||||
|
import tomllib
|
||||||
|
|
||||||
|
PACKAGE_NAME: Final[str] = "bangui-backend"
|
||||||
|
|
||||||
|
|
||||||
|
def _read_pyproject_version() -> str:
|
||||||
|
"""Read the project version from ``pyproject.toml``.
|
||||||
|
|
||||||
|
This is used as a fallback when the package metadata is not available (e.g.
|
||||||
|
when running directly from a source checkout without installing the package).
|
||||||
|
"""
|
||||||
|
|
||||||
|
project_root = Path(__file__).resolve().parents[1]
|
||||||
|
pyproject_path = project_root / "pyproject.toml"
|
||||||
|
if not pyproject_path.exists():
|
||||||
|
raise FileNotFoundError(f"pyproject.toml not found at {pyproject_path}")
|
||||||
|
|
||||||
|
data = tomllib.loads(pyproject_path.read_text(encoding="utf-8"))
|
||||||
|
return str(data["project"]["version"])
|
||||||
|
|
||||||
|
|
||||||
|
def _read_version() -> str:
|
||||||
|
"""Return the current package version.
|
||||||
|
|
||||||
|
Prefer the project metadata in ``pyproject.toml`` when available, since this
|
||||||
|
is the single source of truth for local development and is kept in sync with
|
||||||
|
the frontend and Docker release version.
|
||||||
|
|
||||||
|
When running from an installed distribution where the ``pyproject.toml``
|
||||||
|
is not available, fall back to installed package metadata.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
return _read_pyproject_version()
|
||||||
|
except FileNotFoundError:
|
||||||
|
return importlib.metadata.version(PACKAGE_NAME)
|
||||||
|
|
||||||
|
|
||||||
|
__version__ = _read_version()
|
||||||
|
|||||||
@@ -85,4 +85,4 @@ def get_settings() -> Settings:
|
|||||||
A validated :class:`Settings` object. Raises :class:`pydantic.ValidationError`
|
A validated :class:`Settings` object. Raises :class:`pydantic.ValidationError`
|
||||||
if required keys are absent or values fail validation.
|
if required keys are absent or values fail validation.
|
||||||
"""
|
"""
|
||||||
return Settings() # type: ignore[call-arg] # pydantic-settings populates required fields from env vars
|
return Settings() # pydantic-settings populates required fields from env vars
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ directly — to keep coupling explicit and testable.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from typing import Annotated, Protocol, cast
|
from typing import Annotated
|
||||||
|
|
||||||
import aiosqlite
|
import aiosqlite
|
||||||
import structlog
|
import structlog
|
||||||
@@ -19,13 +19,6 @@ from app.utils.time_utils import utc_now
|
|||||||
|
|
||||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
class AppState(Protocol):
|
|
||||||
"""Partial view of the FastAPI application state used by dependencies."""
|
|
||||||
|
|
||||||
settings: Settings
|
|
||||||
|
|
||||||
|
|
||||||
_COOKIE_NAME = "bangui_session"
|
_COOKIE_NAME = "bangui_session"
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -92,8 +85,7 @@ async def get_settings(request: Request) -> Settings:
|
|||||||
Returns:
|
Returns:
|
||||||
The application settings loaded at startup.
|
The application settings loaded at startup.
|
||||||
"""
|
"""
|
||||||
state = cast("AppState", request.app.state)
|
return request.app.state.settings # type: ignore[no-any-return]
|
||||||
return state.settings
|
|
||||||
|
|
||||||
|
|
||||||
async def require_auth(
|
async def require_auth(
|
||||||
|
|||||||
@@ -1,53 +0,0 @@
|
|||||||
"""Shared domain exception classes used across routers and services."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
|
|
||||||
class JailNotFoundError(Exception):
|
|
||||||
"""Raised when a requested jail name does not exist."""
|
|
||||||
|
|
||||||
def __init__(self, name: str) -> None:
|
|
||||||
self.name = name
|
|
||||||
super().__init__(f"Jail not found: {name!r}")
|
|
||||||
|
|
||||||
|
|
||||||
class JailOperationError(Exception):
|
|
||||||
"""Raised when a fail2ban jail operation fails."""
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigValidationError(Exception):
|
|
||||||
"""Raised when config values fail validation before applying."""
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigOperationError(Exception):
|
|
||||||
"""Raised when a config payload update or command fails."""
|
|
||||||
|
|
||||||
|
|
||||||
class ServerOperationError(Exception):
|
|
||||||
"""Raised when a server control command (e.g. refresh) fails."""
|
|
||||||
|
|
||||||
|
|
||||||
class FilterInvalidRegexError(Exception):
|
|
||||||
"""Raised when a regex pattern fails to compile."""
|
|
||||||
|
|
||||||
def __init__(self, pattern: str, error: str) -> None:
|
|
||||||
"""Initialize with the invalid pattern and compile error."""
|
|
||||||
self.pattern = pattern
|
|
||||||
self.error = error
|
|
||||||
super().__init__(f"Invalid regex {pattern!r}: {error}")
|
|
||||||
|
|
||||||
|
|
||||||
class JailNotFoundInConfigError(Exception):
|
|
||||||
"""Raised when the requested jail name is not defined in any config file."""
|
|
||||||
|
|
||||||
def __init__(self, name: str) -> None:
|
|
||||||
self.name = name
|
|
||||||
super().__init__(f"Jail not found in config: {name!r}")
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigWriteError(Exception):
|
|
||||||
"""Raised when writing a configuration file modification fails."""
|
|
||||||
|
|
||||||
def __init__(self, message: str) -> None:
|
|
||||||
self.message = message
|
|
||||||
super().__init__(message)
|
|
||||||
@@ -31,6 +31,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||||||
from fastapi.responses import JSONResponse, RedirectResponse
|
from fastapi.responses import JSONResponse, RedirectResponse
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
|
||||||
|
from app import __version__
|
||||||
from app.config import Settings, get_settings
|
from app.config import Settings, get_settings
|
||||||
from app.db import init_db
|
from app.db import init_db
|
||||||
from app.routers import (
|
from app.routers import (
|
||||||
@@ -161,7 +162,11 @@ async def _lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|||||||
await geo_service.load_cache_from_db(db)
|
await geo_service.load_cache_from_db(db)
|
||||||
|
|
||||||
# Log unresolved geo entries so the operator can see the scope of the issue.
|
# Log unresolved geo entries so the operator can see the scope of the issue.
|
||||||
unresolved_count = await geo_service.count_unresolved(db)
|
async with db.execute(
|
||||||
|
"SELECT COUNT(*) FROM geo_cache WHERE country_code IS NULL"
|
||||||
|
) as cur:
|
||||||
|
row = await cur.fetchone()
|
||||||
|
unresolved_count: int = int(row[0]) if row else 0
|
||||||
if unresolved_count > 0:
|
if unresolved_count > 0:
|
||||||
log.warning("geo_cache_unresolved_ips", unresolved=unresolved_count)
|
log.warning("geo_cache_unresolved_ips", unresolved=unresolved_count)
|
||||||
|
|
||||||
@@ -361,7 +366,7 @@ def create_app(settings: Settings | None = None) -> FastAPI:
|
|||||||
app: FastAPI = FastAPI(
|
app: FastAPI = FastAPI(
|
||||||
title="BanGUI",
|
title="BanGUI",
|
||||||
description="Web interface for monitoring, managing, and configuring fail2ban.",
|
description="Web interface for monitoring, managing, and configuring fail2ban.",
|
||||||
version="0.1.0",
|
version=__version__,
|
||||||
lifespan=_lifespan,
|
lifespan=_lifespan,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -3,18 +3,8 @@
|
|||||||
Response models for the ``GET /api/geo/lookup/{ip}`` endpoint.
|
Response models for the ``GET /api/geo/lookup/{ip}`` endpoint.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from collections.abc import Awaitable, Callable
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
import aiohttp
|
|
||||||
import aiosqlite
|
|
||||||
|
|
||||||
|
|
||||||
class GeoDetail(BaseModel):
|
class GeoDetail(BaseModel):
|
||||||
"""Enriched geolocation data for an IP address.
|
"""Enriched geolocation data for an IP address.
|
||||||
@@ -74,26 +64,3 @@ class IpLookupResponse(BaseModel):
|
|||||||
default=None,
|
default=None,
|
||||||
description="Enriched geographical and network information.",
|
description="Enriched geographical and network information.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# shared service types
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class GeoInfo:
|
|
||||||
"""Geo resolution result used throughout backend services."""
|
|
||||||
|
|
||||||
country_code: str | None
|
|
||||||
country_name: str | None
|
|
||||||
asn: str | None
|
|
||||||
org: str | None
|
|
||||||
|
|
||||||
|
|
||||||
GeoEnricher = Callable[[str], Awaitable[GeoInfo | None]]
|
|
||||||
GeoBatchLookup = Callable[
|
|
||||||
[list[str], "aiohttp.ClientSession", "aiosqlite.Connection | None"],
|
|
||||||
Awaitable[dict[str, GeoInfo]],
|
|
||||||
]
|
|
||||||
GeoCacheLookup = Callable[[list[str]], tuple[dict[str, GeoInfo], list[str]]]
|
|
||||||
|
|||||||
@@ -1,358 +0,0 @@
|
|||||||
"""Fail2Ban SQLite database repository.
|
|
||||||
|
|
||||||
This module contains helper functions that query the read-only fail2ban
|
|
||||||
SQLite database file. All functions accept a *db_path* and manage their own
|
|
||||||
connection using aiosqlite in read-only mode.
|
|
||||||
|
|
||||||
The functions intentionally return plain Python data structures (dataclasses) so
|
|
||||||
service layers can focus on business logic and formatting.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
import aiosqlite
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from collections.abc import Iterable
|
|
||||||
|
|
||||||
from app.models.ban import BanOrigin
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class BanRecord:
|
|
||||||
"""A single row from the fail2ban ``bans`` table."""
|
|
||||||
|
|
||||||
jail: str
|
|
||||||
ip: str
|
|
||||||
timeofban: int
|
|
||||||
bancount: int
|
|
||||||
data: str
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class BanIpCount:
|
|
||||||
"""Aggregated ban count for a single IP."""
|
|
||||||
|
|
||||||
ip: str
|
|
||||||
event_count: int
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class JailBanCount:
|
|
||||||
"""Aggregated ban count for a single jail."""
|
|
||||||
|
|
||||||
jail: str
|
|
||||||
count: int
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class HistoryRecord:
|
|
||||||
"""A single row from the fail2ban ``bans`` table for history queries."""
|
|
||||||
|
|
||||||
jail: str
|
|
||||||
ip: str
|
|
||||||
timeofban: int
|
|
||||||
bancount: int
|
|
||||||
data: str
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Internal helpers
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _make_db_uri(db_path: str) -> str:
|
|
||||||
"""Return a read-only sqlite URI for the given file path."""
|
|
||||||
|
|
||||||
return f"file:{db_path}?mode=ro"
|
|
||||||
|
|
||||||
|
|
||||||
def _origin_sql_filter(origin: BanOrigin | None) -> tuple[str, tuple[str, ...]]:
|
|
||||||
"""Return a SQL fragment and parameters for the origin filter."""
|
|
||||||
|
|
||||||
if origin == "blocklist":
|
|
||||||
return " AND jail = ?", ("blocklist-import",)
|
|
||||||
if origin == "selfblock":
|
|
||||||
return " AND jail != ?", ("blocklist-import",)
|
|
||||||
return "", ()
|
|
||||||
|
|
||||||
|
|
||||||
def _rows_to_ban_records(rows: Iterable[aiosqlite.Row]) -> list[BanRecord]:
|
|
||||||
return [
|
|
||||||
BanRecord(
|
|
||||||
jail=str(r["jail"]),
|
|
||||||
ip=str(r["ip"]),
|
|
||||||
timeofban=int(r["timeofban"]),
|
|
||||||
bancount=int(r["bancount"]),
|
|
||||||
data=str(r["data"]),
|
|
||||||
)
|
|
||||||
for r in rows
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _rows_to_history_records(rows: Iterable[aiosqlite.Row]) -> list[HistoryRecord]:
|
|
||||||
return [
|
|
||||||
HistoryRecord(
|
|
||||||
jail=str(r["jail"]),
|
|
||||||
ip=str(r["ip"]),
|
|
||||||
timeofban=int(r["timeofban"]),
|
|
||||||
bancount=int(r["bancount"]),
|
|
||||||
data=str(r["data"]),
|
|
||||||
)
|
|
||||||
for r in rows
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Public API
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
async def check_db_nonempty(db_path: str) -> bool:
|
|
||||||
"""Return True if the fail2ban database contains at least one ban row."""
|
|
||||||
|
|
||||||
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db, db.execute(
|
|
||||||
"SELECT 1 FROM bans LIMIT 1"
|
|
||||||
) as cur:
|
|
||||||
row = await cur.fetchone()
|
|
||||||
return row is not None
|
|
||||||
|
|
||||||
|
|
||||||
async def get_currently_banned(
|
|
||||||
db_path: str,
|
|
||||||
since: int,
|
|
||||||
origin: BanOrigin | None = None,
|
|
||||||
*,
|
|
||||||
limit: int | None = None,
|
|
||||||
offset: int | None = None,
|
|
||||||
) -> tuple[list[BanRecord], int]:
|
|
||||||
"""Return a page of currently banned IPs and the total matching count.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db_path: File path to the fail2ban SQLite database.
|
|
||||||
since: Unix timestamp to filter bans newer than or equal to.
|
|
||||||
origin: Optional origin filter.
|
|
||||||
limit: Optional maximum number of rows to return.
|
|
||||||
offset: Optional offset for pagination.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A ``(records, total)`` tuple.
|
|
||||||
"""
|
|
||||||
|
|
||||||
origin_clause, origin_params = _origin_sql_filter(origin)
|
|
||||||
|
|
||||||
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db:
|
|
||||||
db.row_factory = aiosqlite.Row
|
|
||||||
|
|
||||||
async with db.execute(
|
|
||||||
"SELECT COUNT(*) FROM bans WHERE timeofban >= ?" + origin_clause,
|
|
||||||
(since, *origin_params),
|
|
||||||
) as cur:
|
|
||||||
count_row = await cur.fetchone()
|
|
||||||
total: int = int(count_row[0]) if count_row else 0
|
|
||||||
|
|
||||||
query = (
|
|
||||||
"SELECT jail, ip, timeofban, bancount, data "
|
|
||||||
"FROM bans "
|
|
||||||
"WHERE timeofban >= ?" + origin_clause + " ORDER BY timeofban DESC"
|
|
||||||
)
|
|
||||||
params: list[object] = [since, *origin_params]
|
|
||||||
if limit is not None:
|
|
||||||
query += " LIMIT ?"
|
|
||||||
params.append(limit)
|
|
||||||
if offset is not None:
|
|
||||||
query += " OFFSET ?"
|
|
||||||
params.append(offset)
|
|
||||||
|
|
||||||
async with db.execute(query, params) as cur:
|
|
||||||
rows = await cur.fetchall()
|
|
||||||
|
|
||||||
return _rows_to_ban_records(rows), total
|
|
||||||
|
|
||||||
|
|
||||||
async def get_ban_counts_by_bucket(
|
|
||||||
db_path: str,
|
|
||||||
since: int,
|
|
||||||
bucket_secs: int,
|
|
||||||
num_buckets: int,
|
|
||||||
origin: BanOrigin | None = None,
|
|
||||||
) -> list[int]:
|
|
||||||
"""Return ban counts aggregated into equal-width time buckets."""
|
|
||||||
|
|
||||||
origin_clause, origin_params = _origin_sql_filter(origin)
|
|
||||||
|
|
||||||
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db:
|
|
||||||
db.row_factory = aiosqlite.Row
|
|
||||||
async with db.execute(
|
|
||||||
"SELECT CAST((timeofban - ?) / ? AS INTEGER) AS bucket_idx, "
|
|
||||||
"COUNT(*) AS cnt "
|
|
||||||
"FROM bans "
|
|
||||||
"WHERE timeofban >= ?" + origin_clause + " GROUP BY bucket_idx "
|
|
||||||
"ORDER BY bucket_idx",
|
|
||||||
(since, bucket_secs, since, *origin_params),
|
|
||||||
) as cur:
|
|
||||||
rows = await cur.fetchall()
|
|
||||||
|
|
||||||
counts: list[int] = [0] * num_buckets
|
|
||||||
for row in rows:
|
|
||||||
idx: int = int(row["bucket_idx"])
|
|
||||||
if 0 <= idx < num_buckets:
|
|
||||||
counts[idx] = int(row["cnt"])
|
|
||||||
|
|
||||||
return counts
|
|
||||||
|
|
||||||
|
|
||||||
async def get_ban_event_counts(
|
|
||||||
db_path: str,
|
|
||||||
since: int,
|
|
||||||
origin: BanOrigin | None = None,
|
|
||||||
) -> list[BanIpCount]:
|
|
||||||
"""Return total ban events per unique IP in the window."""
|
|
||||||
|
|
||||||
origin_clause, origin_params = _origin_sql_filter(origin)
|
|
||||||
|
|
||||||
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db:
|
|
||||||
db.row_factory = aiosqlite.Row
|
|
||||||
async with db.execute(
|
|
||||||
"SELECT ip, COUNT(*) AS event_count "
|
|
||||||
"FROM bans "
|
|
||||||
"WHERE timeofban >= ?" + origin_clause + " GROUP BY ip",
|
|
||||||
(since, *origin_params),
|
|
||||||
) as cur:
|
|
||||||
rows = await cur.fetchall()
|
|
||||||
|
|
||||||
return [
|
|
||||||
BanIpCount(ip=str(r["ip"]), event_count=int(r["event_count"]))
|
|
||||||
for r in rows
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
async def get_bans_by_jail(
|
|
||||||
db_path: str,
|
|
||||||
since: int,
|
|
||||||
origin: BanOrigin | None = None,
|
|
||||||
) -> tuple[int, list[JailBanCount]]:
|
|
||||||
"""Return per-jail ban counts and the total ban count."""
|
|
||||||
|
|
||||||
origin_clause, origin_params = _origin_sql_filter(origin)
|
|
||||||
|
|
||||||
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db:
|
|
||||||
db.row_factory = aiosqlite.Row
|
|
||||||
|
|
||||||
async with db.execute(
|
|
||||||
"SELECT COUNT(*) FROM bans WHERE timeofban >= ?" + origin_clause,
|
|
||||||
(since, *origin_params),
|
|
||||||
) as cur:
|
|
||||||
count_row = await cur.fetchone()
|
|
||||||
total: int = int(count_row[0]) if count_row else 0
|
|
||||||
|
|
||||||
async with db.execute(
|
|
||||||
"SELECT jail, COUNT(*) AS cnt "
|
|
||||||
"FROM bans "
|
|
||||||
"WHERE timeofban >= ?" + origin_clause + " GROUP BY jail ORDER BY cnt DESC",
|
|
||||||
(since, *origin_params),
|
|
||||||
) as cur:
|
|
||||||
rows = await cur.fetchall()
|
|
||||||
|
|
||||||
return total, [
|
|
||||||
JailBanCount(jail=str(r["jail"]), count=int(r["cnt"])) for r in rows
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
async def get_bans_table_summary(
|
|
||||||
db_path: str,
|
|
||||||
) -> tuple[int, int | None, int | None]:
|
|
||||||
"""Return basic summary stats for the ``bans`` table.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tuple ``(row_count, min_timeofban, max_timeofban)``. If the table is
|
|
||||||
empty the min/max values will be ``None``.
|
|
||||||
"""
|
|
||||||
|
|
||||||
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db:
|
|
||||||
db.row_factory = aiosqlite.Row
|
|
||||||
async with db.execute(
|
|
||||||
"SELECT COUNT(*), MIN(timeofban), MAX(timeofban) FROM bans"
|
|
||||||
) as cur:
|
|
||||||
row = await cur.fetchone()
|
|
||||||
|
|
||||||
if row is None:
|
|
||||||
return 0, None, None
|
|
||||||
|
|
||||||
return (
|
|
||||||
int(row[0]),
|
|
||||||
int(row[1]) if row[1] is not None else None,
|
|
||||||
int(row[2]) if row[2] is not None else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_history_page(
|
|
||||||
db_path: str,
|
|
||||||
since: int | None = None,
|
|
||||||
jail: str | None = None,
|
|
||||||
ip_filter: str | None = None,
|
|
||||||
page: int = 1,
|
|
||||||
page_size: int = 100,
|
|
||||||
) -> tuple[list[HistoryRecord], int]:
|
|
||||||
"""Return a paginated list of history records with total count."""
|
|
||||||
|
|
||||||
wheres: list[str] = []
|
|
||||||
params: list[object] = []
|
|
||||||
|
|
||||||
if since is not None:
|
|
||||||
wheres.append("timeofban >= ?")
|
|
||||||
params.append(since)
|
|
||||||
|
|
||||||
if jail is not None:
|
|
||||||
wheres.append("jail = ?")
|
|
||||||
params.append(jail)
|
|
||||||
|
|
||||||
if ip_filter is not None:
|
|
||||||
wheres.append("ip LIKE ?")
|
|
||||||
params.append(f"{ip_filter}%")
|
|
||||||
|
|
||||||
where_sql: str = ("WHERE " + " AND ".join(wheres)) if wheres else ""
|
|
||||||
|
|
||||||
effective_page_size: int = page_size
|
|
||||||
offset: int = (page - 1) * effective_page_size
|
|
||||||
|
|
||||||
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db:
|
|
||||||
db.row_factory = aiosqlite.Row
|
|
||||||
|
|
||||||
async with db.execute(
|
|
||||||
f"SELECT COUNT(*) FROM bans {where_sql}", # noqa: S608
|
|
||||||
params,
|
|
||||||
) as cur:
|
|
||||||
count_row = await cur.fetchone()
|
|
||||||
total: int = int(count_row[0]) if count_row else 0
|
|
||||||
|
|
||||||
async with db.execute(
|
|
||||||
f"SELECT jail, ip, timeofban, bancount, data "
|
|
||||||
f"FROM bans {where_sql} "
|
|
||||||
"ORDER BY timeofban DESC "
|
|
||||||
"LIMIT ? OFFSET ?",
|
|
||||||
[*params, effective_page_size, offset],
|
|
||||||
) as cur:
|
|
||||||
rows = await cur.fetchall()
|
|
||||||
|
|
||||||
return _rows_to_history_records(rows), total
|
|
||||||
|
|
||||||
|
|
||||||
async def get_history_for_ip(db_path: str, ip: str) -> list[HistoryRecord]:
|
|
||||||
"""Return the full ban timeline for a specific IP."""
|
|
||||||
|
|
||||||
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db:
|
|
||||||
db.row_factory = aiosqlite.Row
|
|
||||||
async with db.execute(
|
|
||||||
"SELECT jail, ip, timeofban, bancount, data "
|
|
||||||
"FROM bans "
|
|
||||||
"WHERE ip = ? "
|
|
||||||
"ORDER BY timeofban DESC",
|
|
||||||
(ip,),
|
|
||||||
) as cur:
|
|
||||||
rows = await cur.fetchall()
|
|
||||||
|
|
||||||
return _rows_to_history_records(rows)
|
|
||||||
@@ -1,148 +0,0 @@
|
|||||||
"""Repository for the geo cache persistent store.
|
|
||||||
|
|
||||||
This module provides typed, async helpers for querying and mutating the
|
|
||||||
``geo_cache`` table in the BanGUI application database.
|
|
||||||
|
|
||||||
All functions accept an open :class:`aiosqlite.Connection` and do not manage
|
|
||||||
connection lifetimes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, TypedDict
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from collections.abc import Sequence
|
|
||||||
|
|
||||||
import aiosqlite
|
|
||||||
|
|
||||||
|
|
||||||
class GeoCacheRow(TypedDict):
|
|
||||||
"""A single row from the ``geo_cache`` table."""
|
|
||||||
|
|
||||||
ip: str
|
|
||||||
country_code: str | None
|
|
||||||
country_name: str | None
|
|
||||||
asn: str | None
|
|
||||||
org: str | None
|
|
||||||
|
|
||||||
|
|
||||||
async def load_all(db: aiosqlite.Connection) -> list[GeoCacheRow]:
|
|
||||||
"""Load all geo cache rows from the database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Open BanGUI application database connection.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of rows from the ``geo_cache`` table.
|
|
||||||
"""
|
|
||||||
rows: list[GeoCacheRow] = []
|
|
||||||
async with db.execute(
|
|
||||||
"SELECT ip, country_code, country_name, asn, org FROM geo_cache"
|
|
||||||
) as cur:
|
|
||||||
async for row in cur:
|
|
||||||
rows.append(
|
|
||||||
GeoCacheRow(
|
|
||||||
ip=str(row[0]),
|
|
||||||
country_code=row[1],
|
|
||||||
country_name=row[2],
|
|
||||||
asn=row[3],
|
|
||||||
org=row[4],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return rows
|
|
||||||
|
|
||||||
|
|
||||||
async def get_unresolved_ips(db: aiosqlite.Connection) -> list[str]:
|
|
||||||
"""Return all IPs in ``geo_cache`` where ``country_code`` is NULL.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Open BanGUI application database connection.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of IPv4/IPv6 strings that need geo resolution.
|
|
||||||
"""
|
|
||||||
ips: list[str] = []
|
|
||||||
async with db.execute(
|
|
||||||
"SELECT ip FROM geo_cache WHERE country_code IS NULL"
|
|
||||||
) as cur:
|
|
||||||
async for row in cur:
|
|
||||||
ips.append(str(row[0]))
|
|
||||||
return ips
|
|
||||||
|
|
||||||
|
|
||||||
async def count_unresolved(db: aiosqlite.Connection) -> int:
|
|
||||||
"""Return the number of unresolved rows (country_code IS NULL)."""
|
|
||||||
async with db.execute(
|
|
||||||
"SELECT COUNT(*) FROM geo_cache WHERE country_code IS NULL"
|
|
||||||
) as cur:
|
|
||||||
row = await cur.fetchone()
|
|
||||||
return int(row[0]) if row else 0
|
|
||||||
|
|
||||||
|
|
||||||
async def upsert_entry(
|
|
||||||
db: aiosqlite.Connection,
|
|
||||||
ip: str,
|
|
||||||
country_code: str | None,
|
|
||||||
country_name: str | None,
|
|
||||||
asn: str | None,
|
|
||||||
org: str | None,
|
|
||||||
) -> None:
|
|
||||||
"""Insert or update a resolved geo cache entry."""
|
|
||||||
await db.execute(
|
|
||||||
"""
|
|
||||||
INSERT INTO geo_cache (ip, country_code, country_name, asn, org)
|
|
||||||
VALUES (?, ?, ?, ?, ?)
|
|
||||||
ON CONFLICT(ip) DO UPDATE SET
|
|
||||||
country_code = excluded.country_code,
|
|
||||||
country_name = excluded.country_name,
|
|
||||||
asn = excluded.asn,
|
|
||||||
org = excluded.org,
|
|
||||||
cached_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now')
|
|
||||||
""",
|
|
||||||
(ip, country_code, country_name, asn, org),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def upsert_neg_entry(db: aiosqlite.Connection, ip: str) -> None:
|
|
||||||
"""Record a failed lookup attempt as a negative entry."""
|
|
||||||
await db.execute(
|
|
||||||
"INSERT OR IGNORE INTO geo_cache (ip) VALUES (?)",
|
|
||||||
(ip,),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def bulk_upsert_entries(
|
|
||||||
db: aiosqlite.Connection,
|
|
||||||
rows: Sequence[tuple[str, str | None, str | None, str | None, str | None]],
|
|
||||||
) -> int:
|
|
||||||
"""Bulk insert or update multiple geo cache entries."""
|
|
||||||
if not rows:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
await db.executemany(
|
|
||||||
"""
|
|
||||||
INSERT INTO geo_cache (ip, country_code, country_name, asn, org)
|
|
||||||
VALUES (?, ?, ?, ?, ?)
|
|
||||||
ON CONFLICT(ip) DO UPDATE SET
|
|
||||||
country_code = excluded.country_code,
|
|
||||||
country_name = excluded.country_name,
|
|
||||||
asn = excluded.asn,
|
|
||||||
org = excluded.org,
|
|
||||||
cached_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now')
|
|
||||||
""",
|
|
||||||
rows,
|
|
||||||
)
|
|
||||||
return len(rows)
|
|
||||||
|
|
||||||
|
|
||||||
async def bulk_upsert_neg_entries(db: aiosqlite.Connection, ips: list[str]) -> int:
|
|
||||||
"""Bulk insert negative lookup entries."""
|
|
||||||
if not ips:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
await db.executemany(
|
|
||||||
"INSERT OR IGNORE INTO geo_cache (ip) VALUES (?)",
|
|
||||||
[(ip,) for ip in ips],
|
|
||||||
)
|
|
||||||
return len(ips)
|
|
||||||
@@ -8,26 +8,12 @@ table. All methods are plain async functions that accept a
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import TYPE_CHECKING, TypedDict, cast
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import Mapping
|
|
||||||
|
|
||||||
import aiosqlite
|
import aiosqlite
|
||||||
|
|
||||||
|
|
||||||
class ImportLogRow(TypedDict):
|
|
||||||
"""Row shape returned by queries on the import_log table."""
|
|
||||||
|
|
||||||
id: int
|
|
||||||
source_id: int | None
|
|
||||||
source_url: str
|
|
||||||
timestamp: str
|
|
||||||
ips_imported: int
|
|
||||||
ips_skipped: int
|
|
||||||
errors: str | None
|
|
||||||
|
|
||||||
|
|
||||||
async def add_log(
|
async def add_log(
|
||||||
db: aiosqlite.Connection,
|
db: aiosqlite.Connection,
|
||||||
*,
|
*,
|
||||||
@@ -68,7 +54,7 @@ async def list_logs(
|
|||||||
source_id: int | None = None,
|
source_id: int | None = None,
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
page_size: int = 50,
|
page_size: int = 50,
|
||||||
) -> tuple[list[ImportLogRow], int]:
|
) -> tuple[list[dict[str, Any]], int]:
|
||||||
"""Return a paginated list of import log entries.
|
"""Return a paginated list of import log entries.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -82,8 +68,8 @@ async def list_logs(
|
|||||||
*total* is the count of all matching rows (ignoring pagination).
|
*total* is the count of all matching rows (ignoring pagination).
|
||||||
"""
|
"""
|
||||||
where = ""
|
where = ""
|
||||||
params_count: list[object] = []
|
params_count: list[Any] = []
|
||||||
params_rows: list[object] = []
|
params_rows: list[Any] = []
|
||||||
|
|
||||||
if source_id is not None:
|
if source_id is not None:
|
||||||
where = " WHERE source_id = ?"
|
where = " WHERE source_id = ?"
|
||||||
@@ -116,7 +102,7 @@ async def list_logs(
|
|||||||
return items, total
|
return items, total
|
||||||
|
|
||||||
|
|
||||||
async def get_last_log(db: aiosqlite.Connection) -> ImportLogRow | None:
|
async def get_last_log(db: aiosqlite.Connection) -> dict[str, Any] | None:
|
||||||
"""Return the most recent import log entry across all sources.
|
"""Return the most recent import log entry across all sources.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -157,14 +143,13 @@ def compute_total_pages(total: int, page_size: int) -> int:
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def _row_to_dict(row: object) -> ImportLogRow:
|
def _row_to_dict(row: Any) -> dict[str, Any]:
|
||||||
"""Convert an aiosqlite row to a plain Python dict.
|
"""Convert an aiosqlite row to a plain Python dict.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
row: An :class:`aiosqlite.Row` or similar mapping returned by a cursor.
|
row: An :class:`aiosqlite.Row` or sequence returned by a cursor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict mapping column names to Python values.
|
Dict mapping column names to Python values.
|
||||||
"""
|
"""
|
||||||
mapping = cast("Mapping[str, object]", row)
|
return dict(row)
|
||||||
return cast("ImportLogRow", dict(mapping))
|
|
||||||
|
|||||||
@@ -20,8 +20,8 @@ from fastapi import APIRouter, HTTPException, Request, status
|
|||||||
from app.dependencies import AuthDep
|
from app.dependencies import AuthDep
|
||||||
from app.models.ban import ActiveBanListResponse, BanRequest, UnbanAllResponse, UnbanRequest
|
from app.models.ban import ActiveBanListResponse, BanRequest, UnbanAllResponse, UnbanRequest
|
||||||
from app.models.jail import JailCommandResponse
|
from app.models.jail import JailCommandResponse
|
||||||
from app.services import geo_service, jail_service
|
from app.services import jail_service
|
||||||
from app.exceptions import JailNotFoundError, JailOperationError
|
from app.services.jail_service import JailNotFoundError, JailOperationError
|
||||||
from app.utils.fail2ban_client import Fail2BanConnectionError
|
from app.utils.fail2ban_client import Fail2BanConnectionError
|
||||||
|
|
||||||
router: APIRouter = APIRouter(prefix="/api/bans", tags=["Bans"])
|
router: APIRouter = APIRouter(prefix="/api/bans", tags=["Bans"])
|
||||||
@@ -73,7 +73,6 @@ async def get_active_bans(
|
|||||||
try:
|
try:
|
||||||
return await jail_service.get_active_bans(
|
return await jail_service.get_active_bans(
|
||||||
socket_path,
|
socket_path,
|
||||||
geo_batch_lookup=geo_service.lookup_batch,
|
|
||||||
http_session=http_session,
|
http_session=http_session,
|
||||||
app_db=app_db,
|
app_db=app_db,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -42,7 +42,8 @@ from app.models.blocklist import (
|
|||||||
ScheduleConfig,
|
ScheduleConfig,
|
||||||
ScheduleInfo,
|
ScheduleInfo,
|
||||||
)
|
)
|
||||||
from app.services import blocklist_service, geo_service
|
from app.repositories import import_log_repo
|
||||||
|
from app.services import blocklist_service
|
||||||
from app.tasks import blocklist_import as blocklist_import_task
|
from app.tasks import blocklist_import as blocklist_import_task
|
||||||
|
|
||||||
router: APIRouter = APIRouter(prefix="/api/blocklists", tags=["Blocklists"])
|
router: APIRouter = APIRouter(prefix="/api/blocklists", tags=["Blocklists"])
|
||||||
@@ -131,15 +132,7 @@ async def run_import_now(
|
|||||||
"""
|
"""
|
||||||
http_session: aiohttp.ClientSession = request.app.state.http_session
|
http_session: aiohttp.ClientSession = request.app.state.http_session
|
||||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||||
from app.services import jail_service
|
return await blocklist_service.import_all(db, http_session, socket_path)
|
||||||
|
|
||||||
return await blocklist_service.import_all(
|
|
||||||
db,
|
|
||||||
http_session,
|
|
||||||
socket_path,
|
|
||||||
geo_is_cached=geo_service.is_cached,
|
|
||||||
geo_batch_lookup=geo_service.lookup_batch,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
@@ -232,9 +225,19 @@ async def get_import_log(
|
|||||||
Returns:
|
Returns:
|
||||||
:class:`~app.models.blocklist.ImportLogListResponse`.
|
:class:`~app.models.blocklist.ImportLogListResponse`.
|
||||||
"""
|
"""
|
||||||
return await blocklist_service.list_import_logs(
|
items, total = await import_log_repo.list_logs(
|
||||||
db, source_id=source_id, page=page, page_size=page_size
|
db, source_id=source_id, page=page, page_size=page_size
|
||||||
)
|
)
|
||||||
|
total_pages = import_log_repo.compute_total_pages(total, page_size)
|
||||||
|
from app.models.blocklist import ImportLogEntry # noqa: PLC0415
|
||||||
|
|
||||||
|
return ImportLogListResponse(
|
||||||
|
items=[ImportLogEntry.model_validate(i) for i in items],
|
||||||
|
total=total,
|
||||||
|
page=page,
|
||||||
|
page_size=page_size,
|
||||||
|
total_pages=total_pages,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -44,6 +44,8 @@ import structlog
|
|||||||
from fastapi import APIRouter, HTTPException, Path, Query, Request, status
|
from fastapi import APIRouter, HTTPException, Path, Query, Request, status
|
||||||
|
|
||||||
from app.dependencies import AuthDep
|
from app.dependencies import AuthDep
|
||||||
|
|
||||||
|
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||||
from app.models.config import (
|
from app.models.config import (
|
||||||
ActionConfig,
|
ActionConfig,
|
||||||
ActionCreateRequest,
|
ActionCreateRequest,
|
||||||
@@ -76,39 +78,32 @@ from app.models.config import (
|
|||||||
RollbackResponse,
|
RollbackResponse,
|
||||||
ServiceStatusResponse,
|
ServiceStatusResponse,
|
||||||
)
|
)
|
||||||
from app.services import config_service, jail_service, log_service
|
from app.services import config_file_service, config_service, jail_service
|
||||||
from app.services import (
|
from app.services.config_file_service import (
|
||||||
action_config_service,
|
|
||||||
config_file_service,
|
|
||||||
filter_config_service,
|
|
||||||
jail_config_service,
|
|
||||||
)
|
|
||||||
from app.services.action_config_service import (
|
|
||||||
ActionAlreadyExistsError,
|
ActionAlreadyExistsError,
|
||||||
ActionNameError,
|
ActionNameError,
|
||||||
ActionNotFoundError,
|
ActionNotFoundError,
|
||||||
ActionReadonlyError,
|
ActionReadonlyError,
|
||||||
ConfigWriteError,
|
ConfigWriteError,
|
||||||
)
|
|
||||||
from app.services.filter_config_service import (
|
|
||||||
FilterAlreadyExistsError,
|
FilterAlreadyExistsError,
|
||||||
FilterInvalidRegexError,
|
FilterInvalidRegexError,
|
||||||
FilterNameError,
|
FilterNameError,
|
||||||
FilterNotFoundError,
|
FilterNotFoundError,
|
||||||
FilterReadonlyError,
|
FilterReadonlyError,
|
||||||
)
|
|
||||||
from app.services.jail_config_service import (
|
|
||||||
JailAlreadyActiveError,
|
JailAlreadyActiveError,
|
||||||
JailAlreadyInactiveError,
|
JailAlreadyInactiveError,
|
||||||
JailNameError,
|
JailNameError,
|
||||||
JailNotFoundInConfigError,
|
JailNotFoundInConfigError,
|
||||||
)
|
)
|
||||||
from app.exceptions import ConfigOperationError, ConfigValidationError, JailNotFoundError, JailOperationError
|
from app.services.config_service import (
|
||||||
|
ConfigOperationError,
|
||||||
|
ConfigValidationError,
|
||||||
|
JailNotFoundError,
|
||||||
|
)
|
||||||
|
from app.services.jail_service import JailOperationError
|
||||||
from app.tasks.health_check import _run_probe
|
from app.tasks.health_check import _run_probe
|
||||||
from app.utils.fail2ban_client import Fail2BanConnectionError
|
from app.utils.fail2ban_client import Fail2BanConnectionError
|
||||||
|
|
||||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
|
||||||
|
|
||||||
router: APIRouter = APIRouter(prefix="/api/config", tags=["Config"])
|
router: APIRouter = APIRouter(prefix="/api/config", tags=["Config"])
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -203,7 +198,7 @@ async def get_inactive_jails(
|
|||||||
"""
|
"""
|
||||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||||
return await jail_config_service.list_inactive_jails(config_dir, socket_path)
|
return await config_file_service.list_inactive_jails(config_dir, socket_path)
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
@@ -433,7 +428,9 @@ async def restart_fail2ban(
|
|||||||
await config_file_service.start_daemon(start_cmd_parts)
|
await config_file_service.start_daemon(start_cmd_parts)
|
||||||
|
|
||||||
# Step 3: probe the socket until fail2ban is responsive or the budget expires.
|
# Step 3: probe the socket until fail2ban is responsive or the budget expires.
|
||||||
fail2ban_running: bool = await config_file_service.wait_for_fail2ban(socket_path, max_wait_seconds=10.0)
|
fail2ban_running: bool = await config_file_service.wait_for_fail2ban(
|
||||||
|
socket_path, max_wait_seconds=10.0
|
||||||
|
)
|
||||||
if not fail2ban_running:
|
if not fail2ban_running:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
@@ -472,7 +469,7 @@ async def regex_test(
|
|||||||
Returns:
|
Returns:
|
||||||
:class:`~app.models.config.RegexTestResponse` with match result and groups.
|
:class:`~app.models.config.RegexTestResponse` with match result and groups.
|
||||||
"""
|
"""
|
||||||
return log_service.test_regex(body)
|
return config_service.test_regex(body)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -578,7 +575,7 @@ async def preview_log(
|
|||||||
Returns:
|
Returns:
|
||||||
:class:`~app.models.config.LogPreviewResponse` with per-line results.
|
:class:`~app.models.config.LogPreviewResponse` with per-line results.
|
||||||
"""
|
"""
|
||||||
return await log_service.preview_log(body)
|
return await config_service.preview_log(body)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -607,7 +604,9 @@ async def get_map_color_thresholds(
|
|||||||
"""
|
"""
|
||||||
from app.services import setup_service
|
from app.services import setup_service
|
||||||
|
|
||||||
high, medium, low = await setup_service.get_map_color_thresholds(request.app.state.db)
|
high, medium, low = await setup_service.get_map_color_thresholds(
|
||||||
|
request.app.state.db
|
||||||
|
)
|
||||||
return MapColorThresholdsResponse(
|
return MapColorThresholdsResponse(
|
||||||
threshold_high=high,
|
threshold_high=high,
|
||||||
threshold_medium=medium,
|
threshold_medium=medium,
|
||||||
@@ -697,7 +696,9 @@ async def activate_jail(
|
|||||||
req = body if body is not None else ActivateJailRequest()
|
req = body if body is not None else ActivateJailRequest()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await jail_config_service.activate_jail(config_dir, socket_path, name, req)
|
result = await config_file_service.activate_jail(
|
||||||
|
config_dir, socket_path, name, req
|
||||||
|
)
|
||||||
except JailNameError as exc:
|
except JailNameError as exc:
|
||||||
raise _bad_request(str(exc)) from exc
|
raise _bad_request(str(exc)) from exc
|
||||||
except JailNotFoundInConfigError:
|
except JailNotFoundInConfigError:
|
||||||
@@ -771,7 +772,7 @@ async def deactivate_jail(
|
|||||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await jail_config_service.deactivate_jail(config_dir, socket_path, name)
|
result = await config_file_service.deactivate_jail(config_dir, socket_path, name)
|
||||||
except JailNameError as exc:
|
except JailNameError as exc:
|
||||||
raise _bad_request(str(exc)) from exc
|
raise _bad_request(str(exc)) from exc
|
||||||
except JailNotFoundInConfigError:
|
except JailNotFoundInConfigError:
|
||||||
@@ -830,7 +831,9 @@ async def delete_jail_local_override(
|
|||||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await jail_config_service.delete_jail_local_override(config_dir, socket_path, name)
|
await config_file_service.delete_jail_local_override(
|
||||||
|
config_dir, socket_path, name
|
||||||
|
)
|
||||||
except JailNameError as exc:
|
except JailNameError as exc:
|
||||||
raise _bad_request(str(exc)) from exc
|
raise _bad_request(str(exc)) from exc
|
||||||
except JailNotFoundInConfigError:
|
except JailNotFoundInConfigError:
|
||||||
@@ -883,7 +886,7 @@ async def validate_jail(
|
|||||||
"""
|
"""
|
||||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||||
try:
|
try:
|
||||||
return await jail_config_service.validate_jail_config(config_dir, name)
|
return await config_file_service.validate_jail_config(config_dir, name)
|
||||||
except JailNameError as exc:
|
except JailNameError as exc:
|
||||||
raise _bad_request(str(exc)) from exc
|
raise _bad_request(str(exc)) from exc
|
||||||
|
|
||||||
@@ -949,7 +952,9 @@ async def rollback_jail(
|
|||||||
start_cmd_parts: list[str] = start_cmd.split()
|
start_cmd_parts: list[str] = start_cmd.split()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await jail_config_service.rollback_jail(config_dir, socket_path, name, start_cmd_parts)
|
result = await config_file_service.rollback_jail(
|
||||||
|
config_dir, socket_path, name, start_cmd_parts
|
||||||
|
)
|
||||||
except JailNameError as exc:
|
except JailNameError as exc:
|
||||||
raise _bad_request(str(exc)) from exc
|
raise _bad_request(str(exc)) from exc
|
||||||
except ConfigWriteError as exc:
|
except ConfigWriteError as exc:
|
||||||
@@ -1001,7 +1006,7 @@ async def list_filters(
|
|||||||
"""
|
"""
|
||||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||||
result = await filter_config_service.list_filters(config_dir, socket_path)
|
result = await config_file_service.list_filters(config_dir, socket_path)
|
||||||
# Sort: active first (by name), then inactive (by name).
|
# Sort: active first (by name), then inactive (by name).
|
||||||
result.filters.sort(key=lambda f: (not f.active, f.name.lower()))
|
result.filters.sort(key=lambda f: (not f.active, f.name.lower()))
|
||||||
return result
|
return result
|
||||||
@@ -1038,7 +1043,7 @@ async def get_filter(
|
|||||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||||
try:
|
try:
|
||||||
return await filter_config_service.get_filter(config_dir, socket_path, name)
|
return await config_file_service.get_filter(config_dir, socket_path, name)
|
||||||
except FilterNotFoundError:
|
except FilterNotFoundError:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
@@ -1102,7 +1107,9 @@ async def update_filter(
|
|||||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||||
try:
|
try:
|
||||||
return await filter_config_service.update_filter(config_dir, socket_path, name, body, do_reload=reload)
|
return await config_file_service.update_filter(
|
||||||
|
config_dir, socket_path, name, body, do_reload=reload
|
||||||
|
)
|
||||||
except FilterNameError as exc:
|
except FilterNameError as exc:
|
||||||
raise _bad_request(str(exc)) from exc
|
raise _bad_request(str(exc)) from exc
|
||||||
except FilterNotFoundError:
|
except FilterNotFoundError:
|
||||||
@@ -1152,7 +1159,9 @@ async def create_filter(
|
|||||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||||
try:
|
try:
|
||||||
return await filter_config_service.create_filter(config_dir, socket_path, body, do_reload=reload)
|
return await config_file_service.create_filter(
|
||||||
|
config_dir, socket_path, body, do_reload=reload
|
||||||
|
)
|
||||||
except FilterNameError as exc:
|
except FilterNameError as exc:
|
||||||
raise _bad_request(str(exc)) from exc
|
raise _bad_request(str(exc)) from exc
|
||||||
except FilterAlreadyExistsError as exc:
|
except FilterAlreadyExistsError as exc:
|
||||||
@@ -1199,7 +1208,7 @@ async def delete_filter(
|
|||||||
"""
|
"""
|
||||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||||
try:
|
try:
|
||||||
await filter_config_service.delete_filter(config_dir, name)
|
await config_file_service.delete_filter(config_dir, name)
|
||||||
except FilterNameError as exc:
|
except FilterNameError as exc:
|
||||||
raise _bad_request(str(exc)) from exc
|
raise _bad_request(str(exc)) from exc
|
||||||
except FilterNotFoundError:
|
except FilterNotFoundError:
|
||||||
@@ -1248,7 +1257,9 @@ async def assign_filter_to_jail(
|
|||||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||||
try:
|
try:
|
||||||
await filter_config_service.assign_filter_to_jail(config_dir, socket_path, name, body, do_reload=reload)
|
await config_file_service.assign_filter_to_jail(
|
||||||
|
config_dir, socket_path, name, body, do_reload=reload
|
||||||
|
)
|
||||||
except (JailNameError, FilterNameError) as exc:
|
except (JailNameError, FilterNameError) as exc:
|
||||||
raise _bad_request(str(exc)) from exc
|
raise _bad_request(str(exc)) from exc
|
||||||
except JailNotFoundInConfigError:
|
except JailNotFoundInConfigError:
|
||||||
@@ -1312,7 +1323,7 @@ async def list_actions(
|
|||||||
"""
|
"""
|
||||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||||
result = await action_config_service.list_actions(config_dir, socket_path)
|
result = await config_file_service.list_actions(config_dir, socket_path)
|
||||||
result.actions.sort(key=lambda a: (not a.active, a.name.lower()))
|
result.actions.sort(key=lambda a: (not a.active, a.name.lower()))
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -1347,7 +1358,7 @@ async def get_action(
|
|||||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||||
try:
|
try:
|
||||||
return await action_config_service.get_action(config_dir, socket_path, name)
|
return await config_file_service.get_action(config_dir, socket_path, name)
|
||||||
except ActionNotFoundError:
|
except ActionNotFoundError:
|
||||||
raise _action_not_found(name) from None
|
raise _action_not_found(name) from None
|
||||||
|
|
||||||
@@ -1392,7 +1403,9 @@ async def update_action(
|
|||||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||||
try:
|
try:
|
||||||
return await action_config_service.update_action(config_dir, socket_path, name, body, do_reload=reload)
|
return await config_file_service.update_action(
|
||||||
|
config_dir, socket_path, name, body, do_reload=reload
|
||||||
|
)
|
||||||
except ActionNameError as exc:
|
except ActionNameError as exc:
|
||||||
raise _bad_request(str(exc)) from exc
|
raise _bad_request(str(exc)) from exc
|
||||||
except ActionNotFoundError:
|
except ActionNotFoundError:
|
||||||
@@ -1438,7 +1451,9 @@ async def create_action(
|
|||||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||||
try:
|
try:
|
||||||
return await action_config_service.create_action(config_dir, socket_path, body, do_reload=reload)
|
return await config_file_service.create_action(
|
||||||
|
config_dir, socket_path, body, do_reload=reload
|
||||||
|
)
|
||||||
except ActionNameError as exc:
|
except ActionNameError as exc:
|
||||||
raise _bad_request(str(exc)) from exc
|
raise _bad_request(str(exc)) from exc
|
||||||
except ActionAlreadyExistsError as exc:
|
except ActionAlreadyExistsError as exc:
|
||||||
@@ -1481,7 +1496,7 @@ async def delete_action(
|
|||||||
"""
|
"""
|
||||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||||
try:
|
try:
|
||||||
await action_config_service.delete_action(config_dir, name)
|
await config_file_service.delete_action(config_dir, name)
|
||||||
except ActionNameError as exc:
|
except ActionNameError as exc:
|
||||||
raise _bad_request(str(exc)) from exc
|
raise _bad_request(str(exc)) from exc
|
||||||
except ActionNotFoundError:
|
except ActionNotFoundError:
|
||||||
@@ -1531,7 +1546,9 @@ async def assign_action_to_jail(
|
|||||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||||
try:
|
try:
|
||||||
await action_config_service.assign_action_to_jail(config_dir, socket_path, name, body, do_reload=reload)
|
await config_file_service.assign_action_to_jail(
|
||||||
|
config_dir, socket_path, name, body, do_reload=reload
|
||||||
|
)
|
||||||
except (JailNameError, ActionNameError) as exc:
|
except (JailNameError, ActionNameError) as exc:
|
||||||
raise _bad_request(str(exc)) from exc
|
raise _bad_request(str(exc)) from exc
|
||||||
except JailNotFoundInConfigError:
|
except JailNotFoundInConfigError:
|
||||||
@@ -1580,7 +1597,9 @@ async def remove_action_from_jail(
|
|||||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||||
try:
|
try:
|
||||||
await action_config_service.remove_action_from_jail(config_dir, socket_path, name, action_name, do_reload=reload)
|
await config_file_service.remove_action_from_jail(
|
||||||
|
config_dir, socket_path, name, action_name, do_reload=reload
|
||||||
|
)
|
||||||
except (JailNameError, ActionNameError) as exc:
|
except (JailNameError, ActionNameError) as exc:
|
||||||
raise _bad_request(str(exc)) from exc
|
raise _bad_request(str(exc)) from exc
|
||||||
except JailNotFoundInConfigError:
|
except JailNotFoundInConfigError:
|
||||||
@@ -1666,12 +1685,8 @@ async def get_service_status(
|
|||||||
handles this gracefully and returns ``online=False``).
|
handles this gracefully and returns ``online=False``).
|
||||||
"""
|
"""
|
||||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||||
from app.services import health_service
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return await config_service.get_service_status(
|
return await config_service.get_service_status(socket_path)
|
||||||
socket_path,
|
|
||||||
probe_fn=health_service.probe,
|
|
||||||
)
|
|
||||||
except Fail2BanConnectionError as exc:
|
except Fail2BanConnectionError as exc:
|
||||||
raise _bad_gateway(exc) from exc
|
raise _bad_gateway(exc) from exc
|
||||||
|
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ from app.models.ban import (
|
|||||||
TimeRange,
|
TimeRange,
|
||||||
)
|
)
|
||||||
from app.models.server import ServerStatus, ServerStatusResponse
|
from app.models.server import ServerStatus, ServerStatusResponse
|
||||||
from app.services import ban_service, geo_service
|
from app.services import ban_service
|
||||||
|
|
||||||
router: APIRouter = APIRouter(prefix="/api/dashboard", tags=["Dashboard"])
|
router: APIRouter = APIRouter(prefix="/api/dashboard", tags=["Dashboard"])
|
||||||
|
|
||||||
@@ -119,7 +119,6 @@ async def get_dashboard_bans(
|
|||||||
page_size=page_size,
|
page_size=page_size,
|
||||||
http_session=http_session,
|
http_session=http_session,
|
||||||
app_db=None,
|
app_db=None,
|
||||||
geo_batch_lookup=geo_service.lookup_batch,
|
|
||||||
origin=origin,
|
origin=origin,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -163,8 +162,6 @@ async def get_bans_by_country(
|
|||||||
socket_path,
|
socket_path,
|
||||||
range,
|
range,
|
||||||
http_session=http_session,
|
http_session=http_session,
|
||||||
geo_cache_lookup=geo_service.lookup_cached_only,
|
|
||||||
geo_batch_lookup=geo_service.lookup_batch,
|
|
||||||
app_db=None,
|
app_db=None,
|
||||||
origin=origin,
|
origin=origin,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -51,8 +51,8 @@ from app.models.file_config import (
|
|||||||
JailConfigFileEnabledUpdate,
|
JailConfigFileEnabledUpdate,
|
||||||
JailConfigFilesResponse,
|
JailConfigFilesResponse,
|
||||||
)
|
)
|
||||||
from app.services import raw_config_io_service
|
from app.services import file_config_service
|
||||||
from app.services.raw_config_io_service import (
|
from app.services.file_config_service import (
|
||||||
ConfigDirError,
|
ConfigDirError,
|
||||||
ConfigFileExistsError,
|
ConfigFileExistsError,
|
||||||
ConfigFileNameError,
|
ConfigFileNameError,
|
||||||
@@ -134,7 +134,7 @@ async def list_jail_config_files(
|
|||||||
"""
|
"""
|
||||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||||
try:
|
try:
|
||||||
return await raw_config_io_service.list_jail_config_files(config_dir)
|
return await file_config_service.list_jail_config_files(config_dir)
|
||||||
except ConfigDirError as exc:
|
except ConfigDirError as exc:
|
||||||
raise _service_unavailable(str(exc)) from exc
|
raise _service_unavailable(str(exc)) from exc
|
||||||
|
|
||||||
@@ -166,7 +166,7 @@ async def get_jail_config_file(
|
|||||||
"""
|
"""
|
||||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||||
try:
|
try:
|
||||||
return await raw_config_io_service.get_jail_config_file(config_dir, filename)
|
return await file_config_service.get_jail_config_file(config_dir, filename)
|
||||||
except ConfigFileNameError as exc:
|
except ConfigFileNameError as exc:
|
||||||
raise _bad_request(str(exc)) from exc
|
raise _bad_request(str(exc)) from exc
|
||||||
except ConfigFileNotFoundError:
|
except ConfigFileNotFoundError:
|
||||||
@@ -204,7 +204,7 @@ async def write_jail_config_file(
|
|||||||
"""
|
"""
|
||||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||||
try:
|
try:
|
||||||
await raw_config_io_service.write_jail_config_file(config_dir, filename, body)
|
await file_config_service.write_jail_config_file(config_dir, filename, body)
|
||||||
except ConfigFileNameError as exc:
|
except ConfigFileNameError as exc:
|
||||||
raise _bad_request(str(exc)) from exc
|
raise _bad_request(str(exc)) from exc
|
||||||
except ConfigFileNotFoundError:
|
except ConfigFileNotFoundError:
|
||||||
@@ -244,7 +244,7 @@ async def set_jail_config_file_enabled(
|
|||||||
"""
|
"""
|
||||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||||
try:
|
try:
|
||||||
await raw_config_io_service.set_jail_config_enabled(
|
await file_config_service.set_jail_config_enabled(
|
||||||
config_dir, filename, body.enabled
|
config_dir, filename, body.enabled
|
||||||
)
|
)
|
||||||
except ConfigFileNameError as exc:
|
except ConfigFileNameError as exc:
|
||||||
@@ -285,7 +285,7 @@ async def create_jail_config_file(
|
|||||||
"""
|
"""
|
||||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||||
try:
|
try:
|
||||||
filename = await raw_config_io_service.create_jail_config_file(config_dir, body)
|
filename = await file_config_service.create_jail_config_file(config_dir, body)
|
||||||
except ConfigFileNameError as exc:
|
except ConfigFileNameError as exc:
|
||||||
raise _bad_request(str(exc)) from exc
|
raise _bad_request(str(exc)) from exc
|
||||||
except ConfigFileExistsError:
|
except ConfigFileExistsError:
|
||||||
@@ -338,7 +338,7 @@ async def get_filter_file_raw(
|
|||||||
"""
|
"""
|
||||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||||
try:
|
try:
|
||||||
return await raw_config_io_service.get_filter_file(config_dir, name)
|
return await file_config_service.get_filter_file(config_dir, name)
|
||||||
except ConfigFileNameError as exc:
|
except ConfigFileNameError as exc:
|
||||||
raise _bad_request(str(exc)) from exc
|
raise _bad_request(str(exc)) from exc
|
||||||
except ConfigFileNotFoundError:
|
except ConfigFileNotFoundError:
|
||||||
@@ -373,7 +373,7 @@ async def write_filter_file(
|
|||||||
"""
|
"""
|
||||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||||
try:
|
try:
|
||||||
await raw_config_io_service.write_filter_file(config_dir, name, body)
|
await file_config_service.write_filter_file(config_dir, name, body)
|
||||||
except ConfigFileNameError as exc:
|
except ConfigFileNameError as exc:
|
||||||
raise _bad_request(str(exc)) from exc
|
raise _bad_request(str(exc)) from exc
|
||||||
except ConfigFileNotFoundError:
|
except ConfigFileNotFoundError:
|
||||||
@@ -412,7 +412,7 @@ async def create_filter_file(
|
|||||||
"""
|
"""
|
||||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||||
try:
|
try:
|
||||||
filename = await raw_config_io_service.create_filter_file(config_dir, body)
|
filename = await file_config_service.create_filter_file(config_dir, body)
|
||||||
except ConfigFileNameError as exc:
|
except ConfigFileNameError as exc:
|
||||||
raise _bad_request(str(exc)) from exc
|
raise _bad_request(str(exc)) from exc
|
||||||
except ConfigFileExistsError:
|
except ConfigFileExistsError:
|
||||||
@@ -454,7 +454,7 @@ async def list_action_files(
|
|||||||
"""
|
"""
|
||||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||||
try:
|
try:
|
||||||
return await raw_config_io_service.list_action_files(config_dir)
|
return await file_config_service.list_action_files(config_dir)
|
||||||
except ConfigDirError as exc:
|
except ConfigDirError as exc:
|
||||||
raise _service_unavailable(str(exc)) from exc
|
raise _service_unavailable(str(exc)) from exc
|
||||||
|
|
||||||
@@ -486,7 +486,7 @@ async def get_action_file(
|
|||||||
"""
|
"""
|
||||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||||
try:
|
try:
|
||||||
return await raw_config_io_service.get_action_file(config_dir, name)
|
return await file_config_service.get_action_file(config_dir, name)
|
||||||
except ConfigFileNameError as exc:
|
except ConfigFileNameError as exc:
|
||||||
raise _bad_request(str(exc)) from exc
|
raise _bad_request(str(exc)) from exc
|
||||||
except ConfigFileNotFoundError:
|
except ConfigFileNotFoundError:
|
||||||
@@ -521,7 +521,7 @@ async def write_action_file(
|
|||||||
"""
|
"""
|
||||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||||
try:
|
try:
|
||||||
await raw_config_io_service.write_action_file(config_dir, name, body)
|
await file_config_service.write_action_file(config_dir, name, body)
|
||||||
except ConfigFileNameError as exc:
|
except ConfigFileNameError as exc:
|
||||||
raise _bad_request(str(exc)) from exc
|
raise _bad_request(str(exc)) from exc
|
||||||
except ConfigFileNotFoundError:
|
except ConfigFileNotFoundError:
|
||||||
@@ -560,7 +560,7 @@ async def create_action_file(
|
|||||||
"""
|
"""
|
||||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||||
try:
|
try:
|
||||||
filename = await raw_config_io_service.create_action_file(config_dir, body)
|
filename = await file_config_service.create_action_file(config_dir, body)
|
||||||
except ConfigFileNameError as exc:
|
except ConfigFileNameError as exc:
|
||||||
raise _bad_request(str(exc)) from exc
|
raise _bad_request(str(exc)) from exc
|
||||||
except ConfigFileExistsError:
|
except ConfigFileExistsError:
|
||||||
@@ -613,7 +613,7 @@ async def get_parsed_filter(
|
|||||||
"""
|
"""
|
||||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||||
try:
|
try:
|
||||||
return await raw_config_io_service.get_parsed_filter_file(config_dir, name)
|
return await file_config_service.get_parsed_filter_file(config_dir, name)
|
||||||
except ConfigFileNameError as exc:
|
except ConfigFileNameError as exc:
|
||||||
raise _bad_request(str(exc)) from exc
|
raise _bad_request(str(exc)) from exc
|
||||||
except ConfigFileNotFoundError:
|
except ConfigFileNotFoundError:
|
||||||
@@ -651,7 +651,7 @@ async def update_parsed_filter(
|
|||||||
"""
|
"""
|
||||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||||
try:
|
try:
|
||||||
await raw_config_io_service.update_parsed_filter_file(config_dir, name, body)
|
await file_config_service.update_parsed_filter_file(config_dir, name, body)
|
||||||
except ConfigFileNameError as exc:
|
except ConfigFileNameError as exc:
|
||||||
raise _bad_request(str(exc)) from exc
|
raise _bad_request(str(exc)) from exc
|
||||||
except ConfigFileNotFoundError:
|
except ConfigFileNotFoundError:
|
||||||
@@ -698,7 +698,7 @@ async def get_parsed_action(
|
|||||||
"""
|
"""
|
||||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||||
try:
|
try:
|
||||||
return await raw_config_io_service.get_parsed_action_file(config_dir, name)
|
return await file_config_service.get_parsed_action_file(config_dir, name)
|
||||||
except ConfigFileNameError as exc:
|
except ConfigFileNameError as exc:
|
||||||
raise _bad_request(str(exc)) from exc
|
raise _bad_request(str(exc)) from exc
|
||||||
except ConfigFileNotFoundError:
|
except ConfigFileNotFoundError:
|
||||||
@@ -736,7 +736,7 @@ async def update_parsed_action(
|
|||||||
"""
|
"""
|
||||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||||
try:
|
try:
|
||||||
await raw_config_io_service.update_parsed_action_file(config_dir, name, body)
|
await file_config_service.update_parsed_action_file(config_dir, name, body)
|
||||||
except ConfigFileNameError as exc:
|
except ConfigFileNameError as exc:
|
||||||
raise _bad_request(str(exc)) from exc
|
raise _bad_request(str(exc)) from exc
|
||||||
except ConfigFileNotFoundError:
|
except ConfigFileNotFoundError:
|
||||||
@@ -783,7 +783,7 @@ async def get_parsed_jail_file(
|
|||||||
"""
|
"""
|
||||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||||
try:
|
try:
|
||||||
return await raw_config_io_service.get_parsed_jail_file(config_dir, filename)
|
return await file_config_service.get_parsed_jail_file(config_dir, filename)
|
||||||
except ConfigFileNameError as exc:
|
except ConfigFileNameError as exc:
|
||||||
raise _bad_request(str(exc)) from exc
|
raise _bad_request(str(exc)) from exc
|
||||||
except ConfigFileNotFoundError:
|
except ConfigFileNotFoundError:
|
||||||
@@ -821,7 +821,7 @@ async def update_parsed_jail_file(
|
|||||||
"""
|
"""
|
||||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||||
try:
|
try:
|
||||||
await raw_config_io_service.update_parsed_jail_file(config_dir, filename, body)
|
await file_config_service.update_parsed_jail_file(config_dir, filename, body)
|
||||||
except ConfigFileNameError as exc:
|
except ConfigFileNameError as exc:
|
||||||
raise _bad_request(str(exc)) from exc
|
raise _bad_request(str(exc)) from exc
|
||||||
except ConfigFileNotFoundError:
|
except ConfigFileNotFoundError:
|
||||||
|
|||||||
@@ -13,13 +13,11 @@ from typing import TYPE_CHECKING, Annotated
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
from app.services.jail_service import IpLookupResult
|
|
||||||
|
|
||||||
import aiosqlite
|
import aiosqlite
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Path, Request, status
|
from fastapi import APIRouter, Depends, HTTPException, Path, Request, status
|
||||||
|
|
||||||
from app.dependencies import AuthDep, get_db
|
from app.dependencies import AuthDep, get_db
|
||||||
from app.models.geo import GeoCacheStatsResponse, GeoDetail, GeoInfo, IpLookupResponse
|
from app.models.geo import GeoCacheStatsResponse, GeoDetail, IpLookupResponse
|
||||||
from app.services import geo_service, jail_service
|
from app.services import geo_service, jail_service
|
||||||
from app.utils.fail2ban_client import Fail2BanConnectionError
|
from app.utils.fail2ban_client import Fail2BanConnectionError
|
||||||
|
|
||||||
@@ -63,7 +61,7 @@ async def lookup_ip(
|
|||||||
return await geo_service.lookup(addr, http_session)
|
return await geo_service.lookup(addr, http_session)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result: IpLookupResult = await jail_service.lookup_ip(
|
result = await jail_service.lookup_ip(
|
||||||
socket_path,
|
socket_path,
|
||||||
ip,
|
ip,
|
||||||
geo_enricher=_enricher,
|
geo_enricher=_enricher,
|
||||||
@@ -79,9 +77,9 @@ async def lookup_ip(
|
|||||||
detail=f"Cannot reach fail2ban: {exc}",
|
detail=f"Cannot reach fail2ban: {exc}",
|
||||||
) from exc
|
) from exc
|
||||||
|
|
||||||
raw_geo = result["geo"]
|
raw_geo = result.get("geo")
|
||||||
geo_detail: GeoDetail | None = None
|
geo_detail: GeoDetail | None = None
|
||||||
if isinstance(raw_geo, GeoInfo):
|
if raw_geo is not None:
|
||||||
geo_detail = GeoDetail(
|
geo_detail = GeoDetail(
|
||||||
country_code=raw_geo.country_code,
|
country_code=raw_geo.country_code,
|
||||||
country_name=raw_geo.country_name,
|
country_name=raw_geo.country_name,
|
||||||
@@ -155,7 +153,12 @@ async def re_resolve_geo(
|
|||||||
that were retried.
|
that were retried.
|
||||||
"""
|
"""
|
||||||
# Collect all IPs in geo_cache that still lack a country code.
|
# Collect all IPs in geo_cache that still lack a country code.
|
||||||
unresolved = await geo_service.get_unresolved_ips(db)
|
unresolved: list[str] = []
|
||||||
|
async with db.execute(
|
||||||
|
"SELECT ip FROM geo_cache WHERE country_code IS NULL"
|
||||||
|
) as cur:
|
||||||
|
async for row in cur:
|
||||||
|
unresolved.append(str(row[0]))
|
||||||
|
|
||||||
if not unresolved:
|
if not unresolved:
|
||||||
return {"resolved": 0, "total": 0}
|
return {"resolved": 0, "total": 0}
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ if TYPE_CHECKING:
|
|||||||
from fastapi import APIRouter, HTTPException, Query, Request
|
from fastapi import APIRouter, HTTPException, Query, Request
|
||||||
|
|
||||||
from app.dependencies import AuthDep
|
from app.dependencies import AuthDep
|
||||||
from app.models.ban import TimeRange
|
from app.models.ban import BanOrigin, TimeRange
|
||||||
from app.models.history import HistoryListResponse, IpDetailResponse
|
from app.models.history import HistoryListResponse, IpDetailResponse
|
||||||
from app.services import geo_service, history_service
|
from app.services import geo_service, history_service
|
||||||
|
|
||||||
@@ -52,6 +52,10 @@ async def get_history(
|
|||||||
default=None,
|
default=None,
|
||||||
description="Restrict results to IPs matching this prefix.",
|
description="Restrict results to IPs matching this prefix.",
|
||||||
),
|
),
|
||||||
|
origin: BanOrigin | None = Query(
|
||||||
|
default=None,
|
||||||
|
description="Filter by ban origin: 'blocklist' or 'selfblock'. Omit for all.",
|
||||||
|
),
|
||||||
page: int = Query(default=1, ge=1, description="1-based page number."),
|
page: int = Query(default=1, ge=1, description="1-based page number."),
|
||||||
page_size: int = Query(
|
page_size: int = Query(
|
||||||
default=_DEFAULT_PAGE_SIZE,
|
default=_DEFAULT_PAGE_SIZE,
|
||||||
@@ -89,6 +93,7 @@ async def get_history(
|
|||||||
range_=range,
|
range_=range,
|
||||||
jail=jail,
|
jail=jail,
|
||||||
ip_filter=ip,
|
ip_filter=ip,
|
||||||
|
origin=origin,
|
||||||
page=page,
|
page=page,
|
||||||
page_size=page_size,
|
page_size=page_size,
|
||||||
geo_enricher=_enricher,
|
geo_enricher=_enricher,
|
||||||
|
|||||||
@@ -31,8 +31,8 @@ from app.models.jail import (
|
|||||||
JailDetailResponse,
|
JailDetailResponse,
|
||||||
JailListResponse,
|
JailListResponse,
|
||||||
)
|
)
|
||||||
from app.services import geo_service, jail_service
|
from app.services import jail_service
|
||||||
from app.exceptions import JailNotFoundError, JailOperationError
|
from app.services.jail_service import JailNotFoundError, JailOperationError
|
||||||
from app.utils.fail2ban_client import Fail2BanConnectionError
|
from app.utils.fail2ban_client import Fail2BanConnectionError
|
||||||
|
|
||||||
router: APIRouter = APIRouter(prefix="/api/jails", tags=["Jails"])
|
router: APIRouter = APIRouter(prefix="/api/jails", tags=["Jails"])
|
||||||
@@ -606,7 +606,6 @@ async def get_jail_banned_ips(
|
|||||||
page=page,
|
page=page,
|
||||||
page_size=page_size,
|
page_size=page_size,
|
||||||
search=search,
|
search=search,
|
||||||
geo_batch_lookup=geo_service.lookup_batch,
|
|
||||||
http_session=http_session,
|
http_session=http_session,
|
||||||
app_db=app_db,
|
app_db=app_db,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ from fastapi import APIRouter, HTTPException, Request, status
|
|||||||
from app.dependencies import AuthDep
|
from app.dependencies import AuthDep
|
||||||
from app.models.server import ServerSettingsResponse, ServerSettingsUpdate
|
from app.models.server import ServerSettingsResponse, ServerSettingsUpdate
|
||||||
from app.services import server_service
|
from app.services import server_service
|
||||||
from app.exceptions import ServerOperationError
|
from app.services.server_service import ServerOperationError
|
||||||
from app.utils.fail2ban_client import Fail2BanConnectionError
|
from app.utils.fail2ban_client import Fail2BanConnectionError
|
||||||
|
|
||||||
router: APIRouter = APIRouter(prefix="/api/server", tags=["Server"])
|
router: APIRouter = APIRouter(prefix="/api/server", tags=["Server"])
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -20,7 +20,7 @@ if TYPE_CHECKING:
|
|||||||
from app.models.auth import Session
|
from app.models.auth import Session
|
||||||
|
|
||||||
from app.repositories import session_repo
|
from app.repositories import session_repo
|
||||||
from app.utils.setup_utils import get_password_hash
|
from app.services import setup_service
|
||||||
from app.utils.time_utils import add_minutes, utc_now
|
from app.utils.time_utils import add_minutes, utc_now
|
||||||
|
|
||||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||||
@@ -65,7 +65,7 @@ async def login(
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If the password is incorrect or no password hash is stored.
|
ValueError: If the password is incorrect or no password hash is stored.
|
||||||
"""
|
"""
|
||||||
stored_hash = await get_password_hash(db)
|
stored_hash = await setup_service.get_password_hash(db)
|
||||||
if stored_hash is None:
|
if stored_hash is None:
|
||||||
log.warning("bangui_login_no_hash")
|
log.warning("bangui_login_no_hash")
|
||||||
raise ValueError("No password is configured — run setup first.")
|
raise ValueError("No password is configured — run setup first.")
|
||||||
|
|||||||
@@ -11,9 +11,12 @@ so BanGUI never modifies or locks the fail2ban database.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import json
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING
|
from datetime import UTC, datetime
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
import aiosqlite
|
||||||
import structlog
|
import structlog
|
||||||
|
|
||||||
from app.models.ban import (
|
from app.models.ban import (
|
||||||
@@ -28,21 +31,15 @@ from app.models.ban import (
|
|||||||
BanTrendResponse,
|
BanTrendResponse,
|
||||||
DashboardBanItem,
|
DashboardBanItem,
|
||||||
DashboardBanListResponse,
|
DashboardBanListResponse,
|
||||||
|
JailBanCount,
|
||||||
TimeRange,
|
TimeRange,
|
||||||
_derive_origin,
|
_derive_origin,
|
||||||
bucket_count,
|
bucket_count,
|
||||||
)
|
)
|
||||||
from app.models.ban import (
|
from app.utils.fail2ban_client import Fail2BanClient
|
||||||
JailBanCount as JailBanCountModel,
|
|
||||||
)
|
|
||||||
from app.repositories import fail2ban_db_repo
|
|
||||||
from app.utils.fail2ban_db_utils import get_fail2ban_db_path, parse_data_json, ts_to_iso
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import aiosqlite
|
|
||||||
|
|
||||||
from app.models.geo import GeoBatchLookup, GeoCacheLookup, GeoEnricher, GeoInfo
|
|
||||||
|
|
||||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||||
|
|
||||||
@@ -77,9 +74,6 @@ def _origin_sql_filter(origin: BanOrigin | None) -> tuple[str, tuple[str, ...]]:
|
|||||||
return "", ()
|
return "", ()
|
||||||
|
|
||||||
|
|
||||||
_TIME_RANGE_SLACK_SECONDS: int = 60
|
|
||||||
|
|
||||||
|
|
||||||
def _since_unix(range_: TimeRange) -> int:
|
def _since_unix(range_: TimeRange) -> int:
|
||||||
"""Return the Unix timestamp representing the start of the time window.
|
"""Return the Unix timestamp representing the start of the time window.
|
||||||
|
|
||||||
@@ -94,13 +88,92 @@ def _since_unix(range_: TimeRange) -> int:
|
|||||||
range_: One of the supported time-range presets.
|
range_: One of the supported time-range presets.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Unix timestamp (seconds since epoch) equal to *now − range_* with a
|
Unix timestamp (seconds since epoch) equal to *now − range_*.
|
||||||
small slack window for clock drift and test seeding delays.
|
|
||||||
"""
|
"""
|
||||||
seconds: int = TIME_RANGE_SECONDS[range_]
|
seconds: int = TIME_RANGE_SECONDS[range_]
|
||||||
return int(time.time()) - seconds - _TIME_RANGE_SLACK_SECONDS
|
return int(time.time()) - seconds
|
||||||
|
|
||||||
|
|
||||||
|
def _ts_to_iso(unix_ts: int) -> str:
|
||||||
|
"""Convert a Unix timestamp to an ISO 8601 UTC string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unix_ts: Seconds since the Unix epoch.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ISO 8601 UTC timestamp, e.g. ``"2026-03-01T12:00:00+00:00"``.
|
||||||
|
"""
|
||||||
|
return datetime.fromtimestamp(unix_ts, tz=UTC).isoformat()
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_fail2ban_db_path(socket_path: str) -> str:
|
||||||
|
"""Query fail2ban for the path to its SQLite database.
|
||||||
|
|
||||||
|
Sends the ``get dbfile`` command via the fail2ban socket and returns
|
||||||
|
the value of the ``dbfile`` setting.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
socket_path: Path to the fail2ban Unix domain socket.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Absolute path to the fail2ban SQLite database file.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If fail2ban reports that no database is configured
|
||||||
|
or if the socket response is unexpected.
|
||||||
|
~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket
|
||||||
|
cannot be reached.
|
||||||
|
"""
|
||||||
|
async with Fail2BanClient(socket_path, timeout=_SOCKET_TIMEOUT) as client:
|
||||||
|
response = await client.send(["get", "dbfile"])
|
||||||
|
|
||||||
|
try:
|
||||||
|
code, data = response
|
||||||
|
except (TypeError, ValueError) as exc:
|
||||||
|
raise RuntimeError(f"Unexpected response from fail2ban: {response!r}") from exc
|
||||||
|
|
||||||
|
if code != 0:
|
||||||
|
raise RuntimeError(f"fail2ban error code {code}: {data!r}")
|
||||||
|
|
||||||
|
if data is None:
|
||||||
|
raise RuntimeError("fail2ban has no database configured (dbfile is None)")
|
||||||
|
|
||||||
|
return str(data)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_data_json(raw: Any) -> tuple[list[str], int]:
|
||||||
|
"""Extract matches and failure count from the ``bans.data`` column.
|
||||||
|
|
||||||
|
The ``data`` column stores a JSON blob with optional keys:
|
||||||
|
|
||||||
|
* ``matches`` — list of raw matched log lines.
|
||||||
|
* ``failures`` — total failure count that triggered the ban.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
raw: The raw ``data`` column value (string, dict, or ``None``).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A ``(matches, failures)`` tuple. Both default to empty/zero when
|
||||||
|
parsing fails or the column is absent.
|
||||||
|
"""
|
||||||
|
if raw is None:
|
||||||
|
return [], 0
|
||||||
|
|
||||||
|
obj: dict[str, Any] = {}
|
||||||
|
if isinstance(raw, str):
|
||||||
|
try:
|
||||||
|
parsed: Any = json.loads(raw)
|
||||||
|
if isinstance(parsed, dict):
|
||||||
|
obj = parsed
|
||||||
|
# json.loads("null") → None, or other non-dict — treat as empty
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return [], 0
|
||||||
|
elif isinstance(raw, dict):
|
||||||
|
obj = raw
|
||||||
|
|
||||||
|
matches: list[str] = [str(m) for m in (obj.get("matches") or [])]
|
||||||
|
failures: int = int(obj.get("failures", 0))
|
||||||
|
return matches, failures
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -116,8 +189,7 @@ async def list_bans(
|
|||||||
page_size: int = _DEFAULT_PAGE_SIZE,
|
page_size: int = _DEFAULT_PAGE_SIZE,
|
||||||
http_session: aiohttp.ClientSession | None = None,
|
http_session: aiohttp.ClientSession | None = None,
|
||||||
app_db: aiosqlite.Connection | None = None,
|
app_db: aiosqlite.Connection | None = None,
|
||||||
geo_batch_lookup: GeoBatchLookup | None = None,
|
geo_enricher: Any | None = None,
|
||||||
geo_enricher: GeoEnricher | None = None,
|
|
||||||
origin: BanOrigin | None = None,
|
origin: BanOrigin | None = None,
|
||||||
) -> DashboardBanListResponse:
|
) -> DashboardBanListResponse:
|
||||||
"""Return a paginated list of bans within the selected time window.
|
"""Return a paginated list of bans within the selected time window.
|
||||||
@@ -156,13 +228,14 @@ async def list_bans(
|
|||||||
:class:`~app.models.ban.DashboardBanListResponse` containing the
|
:class:`~app.models.ban.DashboardBanListResponse` containing the
|
||||||
paginated items and total count.
|
paginated items and total count.
|
||||||
"""
|
"""
|
||||||
|
from app.services import geo_service # noqa: PLC0415
|
||||||
|
|
||||||
since: int = _since_unix(range_)
|
since: int = _since_unix(range_)
|
||||||
effective_page_size: int = min(page_size, _MAX_PAGE_SIZE)
|
effective_page_size: int = min(page_size, _MAX_PAGE_SIZE)
|
||||||
offset: int = (page - 1) * effective_page_size
|
offset: int = (page - 1) * effective_page_size
|
||||||
origin_clause, origin_params = _origin_sql_filter(origin)
|
origin_clause, origin_params = _origin_sql_filter(origin)
|
||||||
|
|
||||||
db_path: str = await get_fail2ban_db_path(socket_path)
|
db_path: str = await _get_fail2ban_db_path(socket_path)
|
||||||
log.info(
|
log.info(
|
||||||
"ban_service_list_bans",
|
"ban_service_list_bans",
|
||||||
db_path=db_path,
|
db_path=db_path,
|
||||||
@@ -171,32 +244,45 @@ async def list_bans(
|
|||||||
origin=origin,
|
origin=origin,
|
||||||
)
|
)
|
||||||
|
|
||||||
rows, total = await fail2ban_db_repo.get_currently_banned(
|
async with aiosqlite.connect(f"file:{db_path}?mode=ro", uri=True) as f2b_db:
|
||||||
db_path=db_path,
|
f2b_db.row_factory = aiosqlite.Row
|
||||||
since=since,
|
|
||||||
origin=origin,
|
async with f2b_db.execute(
|
||||||
limit=effective_page_size,
|
"SELECT COUNT(*) FROM bans WHERE timeofban >= ?" + origin_clause,
|
||||||
offset=offset,
|
(since, *origin_params),
|
||||||
)
|
) as cur:
|
||||||
|
count_row = await cur.fetchone()
|
||||||
|
total: int = int(count_row[0]) if count_row else 0
|
||||||
|
|
||||||
|
async with f2b_db.execute(
|
||||||
|
"SELECT jail, ip, timeofban, bancount, data "
|
||||||
|
"FROM bans "
|
||||||
|
"WHERE timeofban >= ?"
|
||||||
|
+ origin_clause
|
||||||
|
+ " ORDER BY timeofban DESC "
|
||||||
|
"LIMIT ? OFFSET ?",
|
||||||
|
(since, *origin_params, effective_page_size, offset),
|
||||||
|
) as cur:
|
||||||
|
rows = await cur.fetchall()
|
||||||
|
|
||||||
# Batch-resolve geo data for all IPs on this page in a single API call.
|
# Batch-resolve geo data for all IPs on this page in a single API call.
|
||||||
# This avoids hitting the 45 req/min single-IP rate limit when the
|
# This avoids hitting the 45 req/min single-IP rate limit when the
|
||||||
# page contains many bans (e.g. after a large blocklist import).
|
# page contains many bans (e.g. after a large blocklist import).
|
||||||
geo_map: dict[str, GeoInfo] = {}
|
geo_map: dict[str, Any] = {}
|
||||||
if http_session is not None and rows and geo_batch_lookup is not None:
|
if http_session is not None and rows:
|
||||||
page_ips: list[str] = [r.ip for r in rows]
|
page_ips: list[str] = [str(r["ip"]) for r in rows]
|
||||||
try:
|
try:
|
||||||
geo_map = await geo_batch_lookup(page_ips, http_session, db=app_db)
|
geo_map = await geo_service.lookup_batch(page_ips, http_session, db=app_db)
|
||||||
except Exception: # noqa: BLE001
|
except Exception: # noqa: BLE001
|
||||||
log.warning("ban_service_batch_geo_failed_list_bans")
|
log.warning("ban_service_batch_geo_failed_list_bans")
|
||||||
|
|
||||||
items: list[DashboardBanItem] = []
|
items: list[DashboardBanItem] = []
|
||||||
for row in rows:
|
for row in rows:
|
||||||
jail: str = row.jail
|
jail: str = str(row["jail"])
|
||||||
ip: str = row.ip
|
ip: str = str(row["ip"])
|
||||||
banned_at: str = ts_to_iso(row.timeofban)
|
banned_at: str = _ts_to_iso(int(row["timeofban"]))
|
||||||
ban_count: int = row.bancount
|
ban_count: int = int(row["bancount"])
|
||||||
matches, _ = parse_data_json(row.data)
|
matches, _ = _parse_data_json(row["data"])
|
||||||
service: str | None = matches[0] if matches else None
|
service: str | None = matches[0] if matches else None
|
||||||
|
|
||||||
country_code: str | None = None
|
country_code: str | None = None
|
||||||
@@ -257,9 +343,7 @@ async def bans_by_country(
|
|||||||
socket_path: str,
|
socket_path: str,
|
||||||
range_: TimeRange,
|
range_: TimeRange,
|
||||||
http_session: aiohttp.ClientSession | None = None,
|
http_session: aiohttp.ClientSession | None = None,
|
||||||
geo_cache_lookup: GeoCacheLookup | None = None,
|
geo_enricher: Any | None = None,
|
||||||
geo_batch_lookup: GeoBatchLookup | None = None,
|
|
||||||
geo_enricher: GeoEnricher | None = None,
|
|
||||||
app_db: aiosqlite.Connection | None = None,
|
app_db: aiosqlite.Connection | None = None,
|
||||||
origin: BanOrigin | None = None,
|
origin: BanOrigin | None = None,
|
||||||
) -> BansByCountryResponse:
|
) -> BansByCountryResponse:
|
||||||
@@ -298,10 +382,11 @@ async def bans_by_country(
|
|||||||
:class:`~app.models.ban.BansByCountryResponse` with per-country
|
:class:`~app.models.ban.BansByCountryResponse` with per-country
|
||||||
aggregation and the companion ban list.
|
aggregation and the companion ban list.
|
||||||
"""
|
"""
|
||||||
|
from app.services import geo_service # noqa: PLC0415
|
||||||
|
|
||||||
since: int = _since_unix(range_)
|
since: int = _since_unix(range_)
|
||||||
origin_clause, origin_params = _origin_sql_filter(origin)
|
origin_clause, origin_params = _origin_sql_filter(origin)
|
||||||
db_path: str = await get_fail2ban_db_path(socket_path)
|
db_path: str = await _get_fail2ban_db_path(socket_path)
|
||||||
log.info(
|
log.info(
|
||||||
"ban_service_bans_by_country",
|
"ban_service_bans_by_country",
|
||||||
db_path=db_path,
|
db_path=db_path,
|
||||||
@@ -310,54 +395,64 @@ async def bans_by_country(
|
|||||||
origin=origin,
|
origin=origin,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Total count and companion rows reuse the same SQL query logic.
|
async with aiosqlite.connect(f"file:{db_path}?mode=ro", uri=True) as f2b_db:
|
||||||
# Passing limit=0 returns only the total from the count query.
|
f2b_db.row_factory = aiosqlite.Row
|
||||||
_, total = await fail2ban_db_repo.get_currently_banned(
|
|
||||||
db_path=db_path,
|
|
||||||
since=since,
|
|
||||||
origin=origin,
|
|
||||||
limit=0,
|
|
||||||
offset=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
agg_rows = await fail2ban_db_repo.get_ban_event_counts(
|
# Total count for the window.
|
||||||
db_path=db_path,
|
async with f2b_db.execute(
|
||||||
since=since,
|
"SELECT COUNT(*) FROM bans WHERE timeofban >= ?" + origin_clause,
|
||||||
origin=origin,
|
(since, *origin_params),
|
||||||
)
|
) as cur:
|
||||||
|
count_row = await cur.fetchone()
|
||||||
|
total: int = int(count_row[0]) if count_row else 0
|
||||||
|
|
||||||
companion_rows, _ = await fail2ban_db_repo.get_currently_banned(
|
# Aggregation: unique IPs + their total event count.
|
||||||
db_path=db_path,
|
# No LIMIT here — we need all unique source IPs for accurate country counts.
|
||||||
since=since,
|
async with f2b_db.execute(
|
||||||
origin=origin,
|
"SELECT ip, COUNT(*) AS event_count "
|
||||||
limit=_MAX_COMPANION_BANS,
|
"FROM bans "
|
||||||
offset=0,
|
"WHERE timeofban >= ?"
|
||||||
)
|
+ origin_clause
|
||||||
|
+ " GROUP BY ip",
|
||||||
|
(since, *origin_params),
|
||||||
|
) as cur:
|
||||||
|
agg_rows = await cur.fetchall()
|
||||||
|
|
||||||
unique_ips: list[str] = [r.ip for r in agg_rows]
|
# Companion table: most recent raw rows for display alongside the map.
|
||||||
geo_map: dict[str, GeoInfo] = {}
|
async with f2b_db.execute(
|
||||||
|
"SELECT jail, ip, timeofban, bancount, data "
|
||||||
|
"FROM bans "
|
||||||
|
"WHERE timeofban >= ?"
|
||||||
|
+ origin_clause
|
||||||
|
+ " ORDER BY timeofban DESC "
|
||||||
|
"LIMIT ?",
|
||||||
|
(since, *origin_params, _MAX_COMPANION_BANS),
|
||||||
|
) as cur:
|
||||||
|
companion_rows = await cur.fetchall()
|
||||||
|
|
||||||
if http_session is not None and unique_ips and geo_cache_lookup is not None:
|
unique_ips: list[str] = [str(r["ip"]) for r in agg_rows]
|
||||||
|
geo_map: dict[str, Any] = {}
|
||||||
|
|
||||||
|
if http_session is not None and unique_ips:
|
||||||
# Serve only what is already in the in-memory cache — no API calls on
|
# Serve only what is already in the in-memory cache — no API calls on
|
||||||
# the hot path. Uncached IPs are resolved asynchronously in the
|
# the hot path. Uncached IPs are resolved asynchronously in the
|
||||||
# background so subsequent requests benefit from a warmer cache.
|
# background so subsequent requests benefit from a warmer cache.
|
||||||
geo_map, uncached = geo_cache_lookup(unique_ips)
|
geo_map, uncached = geo_service.lookup_cached_only(unique_ips)
|
||||||
if uncached:
|
if uncached:
|
||||||
log.info(
|
log.info(
|
||||||
"ban_service_geo_background_scheduled",
|
"ban_service_geo_background_scheduled",
|
||||||
uncached=len(uncached),
|
uncached=len(uncached),
|
||||||
cached=len(geo_map),
|
cached=len(geo_map),
|
||||||
)
|
)
|
||||||
if geo_batch_lookup is not None:
|
# Fire-and-forget: lookup_batch handles rate-limiting / retries.
|
||||||
# Fire-and-forget: lookup_batch handles rate-limiting / retries.
|
# The dirty-set flush task persists results to the DB.
|
||||||
# The dirty-set flush task persists results to the DB.
|
asyncio.create_task( # noqa: RUF006
|
||||||
asyncio.create_task( # noqa: RUF006
|
geo_service.lookup_batch(uncached, http_session, db=app_db),
|
||||||
geo_batch_lookup(uncached, http_session, db=app_db),
|
name="geo_bans_by_country",
|
||||||
name="geo_bans_by_country",
|
)
|
||||||
)
|
|
||||||
elif geo_enricher is not None and unique_ips:
|
elif geo_enricher is not None and unique_ips:
|
||||||
# Fallback: legacy per-IP enricher (used in tests / older callers).
|
# Fallback: legacy per-IP enricher (used in tests / older callers).
|
||||||
async def _safe_lookup(ip: str) -> tuple[str, GeoInfo | None]:
|
async def _safe_lookup(ip: str) -> tuple[str, Any]:
|
||||||
try:
|
try:
|
||||||
return ip, await geo_enricher(ip)
|
return ip, await geo_enricher(ip)
|
||||||
except Exception: # noqa: BLE001
|
except Exception: # noqa: BLE001
|
||||||
@@ -365,18 +460,18 @@ async def bans_by_country(
|
|||||||
return ip, None
|
return ip, None
|
||||||
|
|
||||||
results = await asyncio.gather(*(_safe_lookup(ip) for ip in unique_ips))
|
results = await asyncio.gather(*(_safe_lookup(ip) for ip in unique_ips))
|
||||||
geo_map = {ip: geo for ip, geo in results if geo is not None}
|
geo_map = dict(results)
|
||||||
|
|
||||||
# Build country aggregation from the SQL-grouped rows.
|
# Build country aggregation from the SQL-grouped rows.
|
||||||
countries: dict[str, int] = {}
|
countries: dict[str, int] = {}
|
||||||
country_names: dict[str, str] = {}
|
country_names: dict[str, str] = {}
|
||||||
|
|
||||||
for agg_row in agg_rows:
|
for row in agg_rows:
|
||||||
ip: str = agg_row.ip
|
ip: str = str(row["ip"])
|
||||||
geo = geo_map.get(ip)
|
geo = geo_map.get(ip)
|
||||||
cc: str | None = geo.country_code if geo else None
|
cc: str | None = geo.country_code if geo else None
|
||||||
cn: str | None = geo.country_name if geo else None
|
cn: str | None = geo.country_name if geo else None
|
||||||
event_count: int = agg_row.event_count
|
event_count: int = int(row["event_count"])
|
||||||
|
|
||||||
if cc:
|
if cc:
|
||||||
countries[cc] = countries.get(cc, 0) + event_count
|
countries[cc] = countries.get(cc, 0) + event_count
|
||||||
@@ -385,27 +480,27 @@ async def bans_by_country(
|
|||||||
|
|
||||||
# Build companion table from recent rows (geo already cached from batch step).
|
# Build companion table from recent rows (geo already cached from batch step).
|
||||||
bans: list[DashboardBanItem] = []
|
bans: list[DashboardBanItem] = []
|
||||||
for companion_row in companion_rows:
|
for row in companion_rows:
|
||||||
ip = companion_row.ip
|
ip = str(row["ip"])
|
||||||
geo = geo_map.get(ip)
|
geo = geo_map.get(ip)
|
||||||
cc = geo.country_code if geo else None
|
cc = geo.country_code if geo else None
|
||||||
cn = geo.country_name if geo else None
|
cn = geo.country_name if geo else None
|
||||||
asn: str | None = geo.asn if geo else None
|
asn: str | None = geo.asn if geo else None
|
||||||
org: str | None = geo.org if geo else None
|
org: str | None = geo.org if geo else None
|
||||||
matches, _ = parse_data_json(companion_row.data)
|
matches, _ = _parse_data_json(row["data"])
|
||||||
|
|
||||||
bans.append(
|
bans.append(
|
||||||
DashboardBanItem(
|
DashboardBanItem(
|
||||||
ip=ip,
|
ip=ip,
|
||||||
jail=companion_row.jail,
|
jail=str(row["jail"]),
|
||||||
banned_at=ts_to_iso(companion_row.timeofban),
|
banned_at=_ts_to_iso(int(row["timeofban"])),
|
||||||
service=matches[0] if matches else None,
|
service=matches[0] if matches else None,
|
||||||
country_code=cc,
|
country_code=cc,
|
||||||
country_name=cn,
|
country_name=cn,
|
||||||
asn=asn,
|
asn=asn,
|
||||||
org=org,
|
org=org,
|
||||||
ban_count=companion_row.bancount,
|
ban_count=int(row["bancount"]),
|
||||||
origin=_derive_origin(companion_row.jail),
|
origin=_derive_origin(str(row["jail"])),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -459,7 +554,7 @@ async def ban_trend(
|
|||||||
num_buckets: int = bucket_count(range_)
|
num_buckets: int = bucket_count(range_)
|
||||||
origin_clause, origin_params = _origin_sql_filter(origin)
|
origin_clause, origin_params = _origin_sql_filter(origin)
|
||||||
|
|
||||||
db_path: str = await get_fail2ban_db_path(socket_path)
|
db_path: str = await _get_fail2ban_db_path(socket_path)
|
||||||
log.info(
|
log.info(
|
||||||
"ban_service_ban_trend",
|
"ban_service_ban_trend",
|
||||||
db_path=db_path,
|
db_path=db_path,
|
||||||
@@ -470,18 +565,32 @@ async def ban_trend(
|
|||||||
num_buckets=num_buckets,
|
num_buckets=num_buckets,
|
||||||
)
|
)
|
||||||
|
|
||||||
counts = await fail2ban_db_repo.get_ban_counts_by_bucket(
|
async with aiosqlite.connect(f"file:{db_path}?mode=ro", uri=True) as f2b_db:
|
||||||
db_path=db_path,
|
f2b_db.row_factory = aiosqlite.Row
|
||||||
since=since,
|
|
||||||
bucket_secs=bucket_secs,
|
async with f2b_db.execute(
|
||||||
num_buckets=num_buckets,
|
"SELECT CAST((timeofban - ?) / ? AS INTEGER) AS bucket_idx, "
|
||||||
origin=origin,
|
"COUNT(*) AS cnt "
|
||||||
)
|
"FROM bans "
|
||||||
|
"WHERE timeofban >= ?"
|
||||||
|
+ origin_clause
|
||||||
|
+ " GROUP BY bucket_idx "
|
||||||
|
"ORDER BY bucket_idx",
|
||||||
|
(since, bucket_secs, since, *origin_params),
|
||||||
|
) as cur:
|
||||||
|
rows = await cur.fetchall()
|
||||||
|
|
||||||
|
# Map bucket_idx → count; ignore any out-of-range indices.
|
||||||
|
counts: dict[int, int] = {}
|
||||||
|
for row in rows:
|
||||||
|
idx: int = int(row["bucket_idx"])
|
||||||
|
if 0 <= idx < num_buckets:
|
||||||
|
counts[idx] = int(row["cnt"])
|
||||||
|
|
||||||
buckets: list[BanTrendBucket] = [
|
buckets: list[BanTrendBucket] = [
|
||||||
BanTrendBucket(
|
BanTrendBucket(
|
||||||
timestamp=ts_to_iso(since + i * bucket_secs),
|
timestamp=_ts_to_iso(since + i * bucket_secs),
|
||||||
count=counts[i],
|
count=counts.get(i, 0),
|
||||||
)
|
)
|
||||||
for i in range(num_buckets)
|
for i in range(num_buckets)
|
||||||
]
|
]
|
||||||
@@ -524,44 +633,60 @@ async def bans_by_jail(
|
|||||||
since: int = _since_unix(range_)
|
since: int = _since_unix(range_)
|
||||||
origin_clause, origin_params = _origin_sql_filter(origin)
|
origin_clause, origin_params = _origin_sql_filter(origin)
|
||||||
|
|
||||||
db_path: str = await get_fail2ban_db_path(socket_path)
|
db_path: str = await _get_fail2ban_db_path(socket_path)
|
||||||
log.debug(
|
log.debug(
|
||||||
"ban_service_bans_by_jail",
|
"ban_service_bans_by_jail",
|
||||||
db_path=db_path,
|
db_path=db_path,
|
||||||
since=since,
|
since=since,
|
||||||
since_iso=ts_to_iso(since),
|
since_iso=_ts_to_iso(since),
|
||||||
range=range_,
|
range=range_,
|
||||||
origin=origin,
|
origin=origin,
|
||||||
)
|
)
|
||||||
|
|
||||||
total, jail_counts = await fail2ban_db_repo.get_bans_by_jail(
|
async with aiosqlite.connect(f"file:{db_path}?mode=ro", uri=True) as f2b_db:
|
||||||
db_path=db_path,
|
f2b_db.row_factory = aiosqlite.Row
|
||||||
since=since,
|
|
||||||
origin=origin,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Diagnostic guard: if zero results were returned, check whether the table
|
async with f2b_db.execute(
|
||||||
# has *any* rows and log a warning with min/max timeofban so operators can
|
"SELECT COUNT(*) FROM bans WHERE timeofban >= ?" + origin_clause,
|
||||||
# diagnose timezone or filter mismatches from logs.
|
(since, *origin_params),
|
||||||
if total == 0:
|
) as cur:
|
||||||
table_row_count, min_timeofban, max_timeofban = await fail2ban_db_repo.get_bans_table_summary(db_path)
|
count_row = await cur.fetchone()
|
||||||
if table_row_count > 0:
|
total: int = int(count_row[0]) if count_row else 0
|
||||||
log.warning(
|
|
||||||
"ban_service_bans_by_jail_empty_despite_data",
|
|
||||||
table_row_count=table_row_count,
|
|
||||||
min_timeofban=min_timeofban,
|
|
||||||
max_timeofban=max_timeofban,
|
|
||||||
since=since,
|
|
||||||
range=range_,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# Diagnostic guard: if zero results were returned, check whether the
|
||||||
|
# table has *any* rows and log a warning with min/max timeofban so
|
||||||
|
# operators can diagnose timezone or filter mismatches from logs.
|
||||||
|
if total == 0:
|
||||||
|
async with f2b_db.execute(
|
||||||
|
"SELECT COUNT(*), MIN(timeofban), MAX(timeofban) FROM bans"
|
||||||
|
) as cur:
|
||||||
|
diag_row = await cur.fetchone()
|
||||||
|
if diag_row and diag_row[0] > 0:
|
||||||
|
log.warning(
|
||||||
|
"ban_service_bans_by_jail_empty_despite_data",
|
||||||
|
table_row_count=diag_row[0],
|
||||||
|
min_timeofban=diag_row[1],
|
||||||
|
max_timeofban=diag_row[2],
|
||||||
|
since=since,
|
||||||
|
range=range_,
|
||||||
|
)
|
||||||
|
|
||||||
|
async with f2b_db.execute(
|
||||||
|
"SELECT jail, COUNT(*) AS cnt "
|
||||||
|
"FROM bans "
|
||||||
|
"WHERE timeofban >= ?"
|
||||||
|
+ origin_clause
|
||||||
|
+ " GROUP BY jail ORDER BY cnt DESC",
|
||||||
|
(since, *origin_params),
|
||||||
|
) as cur:
|
||||||
|
rows = await cur.fetchall()
|
||||||
|
|
||||||
|
jails: list[JailBanCount] = [
|
||||||
|
JailBanCount(jail=str(row["jail"]), count=int(row["cnt"])) for row in rows
|
||||||
|
]
|
||||||
log.debug(
|
log.debug(
|
||||||
"ban_service_bans_by_jail_result",
|
"ban_service_bans_by_jail_result",
|
||||||
total=total,
|
total=total,
|
||||||
jail_count=len(jail_counts),
|
jail_count=len(jails),
|
||||||
)
|
|
||||||
|
|
||||||
return BansByJailResponse(
|
|
||||||
jails=[JailBanCountModel(jail=j.jail, count=j.count) for j in jail_counts],
|
|
||||||
total=total,
|
|
||||||
)
|
)
|
||||||
|
return BansByJailResponse(jails=jails, total=total)
|
||||||
|
|||||||
@@ -14,35 +14,26 @@ under the key ``"blocklist_schedule"``.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import importlib
|
|
||||||
import json
|
import json
|
||||||
from collections.abc import Awaitable
|
from typing import TYPE_CHECKING, Any
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
|
|
||||||
from app.models.blocklist import (
|
from app.models.blocklist import (
|
||||||
BlocklistSource,
|
BlocklistSource,
|
||||||
ImportLogEntry,
|
|
||||||
ImportLogListResponse,
|
|
||||||
ImportRunResult,
|
ImportRunResult,
|
||||||
ImportSourceResult,
|
ImportSourceResult,
|
||||||
PreviewResponse,
|
PreviewResponse,
|
||||||
ScheduleConfig,
|
ScheduleConfig,
|
||||||
ScheduleInfo,
|
ScheduleInfo,
|
||||||
)
|
)
|
||||||
from app.exceptions import JailNotFoundError
|
|
||||||
from app.repositories import blocklist_repo, import_log_repo, settings_repo
|
from app.repositories import blocklist_repo, import_log_repo, settings_repo
|
||||||
from app.utils.ip_utils import is_valid_ip, is_valid_network
|
from app.utils.ip_utils import is_valid_ip, is_valid_network
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import Callable
|
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import aiosqlite
|
import aiosqlite
|
||||||
|
|
||||||
from app.models.geo import GeoBatchLookup
|
|
||||||
|
|
||||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||||
|
|
||||||
#: Settings key used to persist the schedule config.
|
#: Settings key used to persist the schedule config.
|
||||||
@@ -63,7 +54,7 @@ _PREVIEW_MAX_BYTES: int = 65536
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def _row_to_source(row: dict[str, object]) -> BlocklistSource:
|
def _row_to_source(row: dict[str, Any]) -> BlocklistSource:
|
||||||
"""Convert a repository row dict to a :class:`BlocklistSource`.
|
"""Convert a repository row dict to a :class:`BlocklistSource`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -245,9 +236,6 @@ async def import_source(
|
|||||||
http_session: aiohttp.ClientSession,
|
http_session: aiohttp.ClientSession,
|
||||||
socket_path: str,
|
socket_path: str,
|
||||||
db: aiosqlite.Connection,
|
db: aiosqlite.Connection,
|
||||||
geo_is_cached: Callable[[str], bool] | None = None,
|
|
||||||
geo_batch_lookup: GeoBatchLookup | None = None,
|
|
||||||
ban_ip: Callable[[str, str, str], Awaitable[None]] | None = None,
|
|
||||||
) -> ImportSourceResult:
|
) -> ImportSourceResult:
|
||||||
"""Download and apply bans from a single blocklist source.
|
"""Download and apply bans from a single blocklist source.
|
||||||
|
|
||||||
@@ -305,14 +293,8 @@ async def import_source(
|
|||||||
ban_error: str | None = None
|
ban_error: str | None = None
|
||||||
imported_ips: list[str] = []
|
imported_ips: list[str] = []
|
||||||
|
|
||||||
if ban_ip is None:
|
# Import jail_service here to avoid circular import at module level.
|
||||||
try:
|
from app.services import jail_service # noqa: PLC0415
|
||||||
jail_svc = importlib.import_module("app.services.jail_service")
|
|
||||||
ban_ip_fn = jail_svc.ban_ip
|
|
||||||
except (ModuleNotFoundError, AttributeError) as exc:
|
|
||||||
raise ValueError("ban_ip callback is required") from exc
|
|
||||||
else:
|
|
||||||
ban_ip_fn = ban_ip
|
|
||||||
|
|
||||||
for line in content.splitlines():
|
for line in content.splitlines():
|
||||||
stripped = line.strip()
|
stripped = line.strip()
|
||||||
@@ -325,10 +307,10 @@ async def import_source(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await ban_ip_fn(socket_path, BLOCKLIST_JAIL, stripped)
|
await jail_service.ban_ip(socket_path, BLOCKLIST_JAIL, stripped)
|
||||||
imported += 1
|
imported += 1
|
||||||
imported_ips.append(stripped)
|
imported_ips.append(stripped)
|
||||||
except JailNotFoundError as exc:
|
except jail_service.JailNotFoundError as exc:
|
||||||
# The target jail does not exist in fail2ban — there is no point
|
# The target jail does not exist in fail2ban — there is no point
|
||||||
# continuing because every subsequent ban would also fail.
|
# continuing because every subsequent ban would also fail.
|
||||||
ban_error = str(exc)
|
ban_error = str(exc)
|
||||||
@@ -355,8 +337,12 @@ async def import_source(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# --- Pre-warm geo cache for newly imported IPs ---
|
# --- Pre-warm geo cache for newly imported IPs ---
|
||||||
if imported_ips and geo_is_cached is not None:
|
if imported_ips:
|
||||||
uncached_ips: list[str] = [ip for ip in imported_ips if not geo_is_cached(ip)]
|
from app.services import geo_service # noqa: PLC0415
|
||||||
|
|
||||||
|
uncached_ips: list[str] = [
|
||||||
|
ip for ip in imported_ips if not geo_service.is_cached(ip)
|
||||||
|
]
|
||||||
skipped_geo: int = len(imported_ips) - len(uncached_ips)
|
skipped_geo: int = len(imported_ips) - len(uncached_ips)
|
||||||
|
|
||||||
if skipped_geo > 0:
|
if skipped_geo > 0:
|
||||||
@@ -367,9 +353,9 @@ async def import_source(
|
|||||||
to_lookup=len(uncached_ips),
|
to_lookup=len(uncached_ips),
|
||||||
)
|
)
|
||||||
|
|
||||||
if uncached_ips and geo_batch_lookup is not None:
|
if uncached_ips:
|
||||||
try:
|
try:
|
||||||
await geo_batch_lookup(uncached_ips, http_session, db=db)
|
await geo_service.lookup_batch(uncached_ips, http_session, db=db)
|
||||||
log.info(
|
log.info(
|
||||||
"blocklist_geo_prewarm_complete",
|
"blocklist_geo_prewarm_complete",
|
||||||
source_id=source.id,
|
source_id=source.id,
|
||||||
@@ -395,9 +381,6 @@ async def import_all(
|
|||||||
db: aiosqlite.Connection,
|
db: aiosqlite.Connection,
|
||||||
http_session: aiohttp.ClientSession,
|
http_session: aiohttp.ClientSession,
|
||||||
socket_path: str,
|
socket_path: str,
|
||||||
geo_is_cached: Callable[[str], bool] | None = None,
|
|
||||||
geo_batch_lookup: GeoBatchLookup | None = None,
|
|
||||||
ban_ip: Callable[[str, str, str], Awaitable[None]] | None = None,
|
|
||||||
) -> ImportRunResult:
|
) -> ImportRunResult:
|
||||||
"""Import all enabled blocklist sources.
|
"""Import all enabled blocklist sources.
|
||||||
|
|
||||||
@@ -421,15 +404,7 @@ async def import_all(
|
|||||||
|
|
||||||
for row in sources:
|
for row in sources:
|
||||||
source = _row_to_source(row)
|
source = _row_to_source(row)
|
||||||
result = await import_source(
|
result = await import_source(source, http_session, socket_path, db)
|
||||||
source,
|
|
||||||
http_session,
|
|
||||||
socket_path,
|
|
||||||
db,
|
|
||||||
geo_is_cached=geo_is_cached,
|
|
||||||
geo_batch_lookup=geo_batch_lookup,
|
|
||||||
ban_ip=ban_ip,
|
|
||||||
)
|
|
||||||
results.append(result)
|
results.append(result)
|
||||||
total_imported += result.ips_imported
|
total_imported += result.ips_imported
|
||||||
total_skipped += result.ips_skipped
|
total_skipped += result.ips_skipped
|
||||||
@@ -528,44 +503,12 @@ async def get_schedule_info(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def list_import_logs(
|
|
||||||
db: aiosqlite.Connection,
|
|
||||||
*,
|
|
||||||
source_id: int | None = None,
|
|
||||||
page: int = 1,
|
|
||||||
page_size: int = 50,
|
|
||||||
) -> ImportLogListResponse:
|
|
||||||
"""Return a paginated list of import log entries.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Active application database connection.
|
|
||||||
source_id: Optional filter to only return logs for a specific source.
|
|
||||||
page: 1-based page number.
|
|
||||||
page_size: Items per page.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
:class:`~app.models.blocklist.ImportLogListResponse`.
|
|
||||||
"""
|
|
||||||
items, total = await import_log_repo.list_logs(
|
|
||||||
db, source_id=source_id, page=page, page_size=page_size
|
|
||||||
)
|
|
||||||
total_pages = import_log_repo.compute_total_pages(total, page_size)
|
|
||||||
|
|
||||||
return ImportLogListResponse(
|
|
||||||
items=[ImportLogEntry.model_validate(i) for i in items],
|
|
||||||
total=total,
|
|
||||||
page=page,
|
|
||||||
page_size=page_size,
|
|
||||||
total_pages=total_pages,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Internal helpers
|
# Internal helpers
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def _aiohttp_timeout(seconds: float) -> aiohttp.ClientTimeout:
|
def _aiohttp_timeout(seconds: float) -> Any:
|
||||||
"""Return an :class:`aiohttp.ClientTimeout` with the given total timeout.
|
"""Return an :class:`aiohttp.ClientTimeout` with the given total timeout.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ import os
|
|||||||
import re
|
import re
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import cast
|
from typing import Any
|
||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
|
|
||||||
@@ -54,52 +54,12 @@ from app.models.config import (
|
|||||||
JailValidationResult,
|
JailValidationResult,
|
||||||
RollbackResponse,
|
RollbackResponse,
|
||||||
)
|
)
|
||||||
from app.exceptions import FilterInvalidRegexError, JailNotFoundError
|
from app.services import conffile_parser, jail_service
|
||||||
from app.utils import conffile_parser
|
from app.services.jail_service import JailNotFoundError as JailNotFoundError
|
||||||
from app.utils.jail_utils import reload_jails
|
from app.utils.fail2ban_client import Fail2BanClient, Fail2BanConnectionError
|
||||||
from app.utils.fail2ban_client import (
|
|
||||||
Fail2BanClient,
|
|
||||||
Fail2BanConnectionError,
|
|
||||||
Fail2BanResponse,
|
|
||||||
)
|
|
||||||
|
|
||||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||||
|
|
||||||
# Proxy object for jail reload operations. Tests can patch
|
|
||||||
# app.services.config_file_service.jail_service.reload_all as needed.
|
|
||||||
class _JailServiceProxy:
|
|
||||||
async def reload_all(
|
|
||||||
self,
|
|
||||||
socket_path: str,
|
|
||||||
include_jails: list[str] | None = None,
|
|
||||||
exclude_jails: list[str] | None = None,
|
|
||||||
) -> None:
|
|
||||||
kwargs: dict[str, list[str]] = {}
|
|
||||||
if include_jails is not None:
|
|
||||||
kwargs["include_jails"] = include_jails
|
|
||||||
if exclude_jails is not None:
|
|
||||||
kwargs["exclude_jails"] = exclude_jails
|
|
||||||
await reload_jails(socket_path, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
jail_service = _JailServiceProxy()
|
|
||||||
|
|
||||||
|
|
||||||
async def _reload_all(
|
|
||||||
socket_path: str,
|
|
||||||
include_jails: list[str] | None = None,
|
|
||||||
exclude_jails: list[str] | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Reload fail2ban jails using the configured hook or default helper."""
|
|
||||||
kwargs: dict[str, list[str]] = {}
|
|
||||||
if include_jails is not None:
|
|
||||||
kwargs["include_jails"] = include_jails
|
|
||||||
if exclude_jails is not None:
|
|
||||||
kwargs["exclude_jails"] = exclude_jails
|
|
||||||
|
|
||||||
await jail_service.reload_all(socket_path, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Constants
|
# Constants
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -107,7 +67,9 @@ async def _reload_all(
|
|||||||
_SOCKET_TIMEOUT: float = 10.0
|
_SOCKET_TIMEOUT: float = 10.0
|
||||||
|
|
||||||
# Allowlist pattern for jail names used in path construction.
|
# Allowlist pattern for jail names used in path construction.
|
||||||
_SAFE_JAIL_NAME_RE: re.Pattern[str] = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$")
|
_SAFE_JAIL_NAME_RE: re.Pattern[str] = re.compile(
|
||||||
|
r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$"
|
||||||
|
)
|
||||||
|
|
||||||
# Sections that are not jail definitions.
|
# Sections that are not jail definitions.
|
||||||
_META_SECTIONS: frozenset[str] = frozenset({"INCLUDES", "DEFAULT"})
|
_META_SECTIONS: frozenset[str] = frozenset({"INCLUDES", "DEFAULT"})
|
||||||
@@ -199,10 +161,26 @@ class FilterReadonlyError(Exception):
|
|||||||
"""
|
"""
|
||||||
self.name: str = name
|
self.name: str = name
|
||||||
super().__init__(
|
super().__init__(
|
||||||
f"Filter {name!r} is a shipped default (.conf only); only user-created .local files can be deleted."
|
f"Filter {name!r} is a shipped default (.conf only); "
|
||||||
|
"only user-created .local files can be deleted."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FilterInvalidRegexError(Exception):
|
||||||
|
"""Raised when a regex pattern fails to compile."""
|
||||||
|
|
||||||
|
def __init__(self, pattern: str, error: str) -> None:
|
||||||
|
"""Initialise with the invalid pattern and the compile error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pattern: The regex string that failed to compile.
|
||||||
|
error: The ``re.error`` message.
|
||||||
|
"""
|
||||||
|
self.pattern: str = pattern
|
||||||
|
self.error: str = error
|
||||||
|
super().__init__(f"Invalid regex {pattern!r}: {error}")
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Internal helpers
|
# Internal helpers
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -439,7 +417,9 @@ def _parse_jails_sync(
|
|||||||
# items() merges DEFAULT values automatically.
|
# items() merges DEFAULT values automatically.
|
||||||
jails[section] = dict(parser.items(section))
|
jails[section] = dict(parser.items(section))
|
||||||
except configparser.Error as exc:
|
except configparser.Error as exc:
|
||||||
log.warning("jail_section_parse_error", section=section, error=str(exc))
|
log.warning(
|
||||||
|
"jail_section_parse_error", section=section, error=str(exc)
|
||||||
|
)
|
||||||
|
|
||||||
log.debug("jails_parsed", count=len(jails), config_dir=str(config_dir))
|
log.debug("jails_parsed", count=len(jails), config_dir=str(config_dir))
|
||||||
return jails, source_files
|
return jails, source_files
|
||||||
@@ -536,7 +516,11 @@ def _build_inactive_jail(
|
|||||||
bantime_escalation=bantime_escalation,
|
bantime_escalation=bantime_escalation,
|
||||||
source_file=source_file,
|
source_file=source_file,
|
||||||
enabled=enabled,
|
enabled=enabled,
|
||||||
has_local_override=((config_dir / "jail.d" / f"{name}.local").is_file() if config_dir is not None else False),
|
has_local_override=(
|
||||||
|
(config_dir / "jail.d" / f"{name}.local").is_file()
|
||||||
|
if config_dir is not None
|
||||||
|
else False
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -554,10 +538,10 @@ async def _get_active_jail_names(socket_path: str) -> set[str]:
|
|||||||
try:
|
try:
|
||||||
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
||||||
|
|
||||||
def _to_dict_inner(pairs: object) -> dict[str, object]:
|
def _to_dict_inner(pairs: Any) -> dict[str, Any]:
|
||||||
if not isinstance(pairs, (list, tuple)):
|
if not isinstance(pairs, (list, tuple)):
|
||||||
return {}
|
return {}
|
||||||
result: dict[str, object] = {}
|
result: dict[str, Any] = {}
|
||||||
for item in pairs:
|
for item in pairs:
|
||||||
try:
|
try:
|
||||||
k, v = item
|
k, v = item
|
||||||
@@ -566,8 +550,8 @@ async def _get_active_jail_names(socket_path: str) -> set[str]:
|
|||||||
pass
|
pass
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _ok(response: object) -> object:
|
def _ok(response: Any) -> Any:
|
||||||
code, data = cast("Fail2BanResponse", response)
|
code, data = response
|
||||||
if code != 0:
|
if code != 0:
|
||||||
raise ValueError(f"fail2ban error {code}: {data!r}")
|
raise ValueError(f"fail2ban error {code}: {data!r}")
|
||||||
return data
|
return data
|
||||||
@@ -582,7 +566,9 @@ async def _get_active_jail_names(socket_path: str) -> set[str]:
|
|||||||
log.warning("fail2ban_unreachable_during_inactive_list")
|
log.warning("fail2ban_unreachable_during_inactive_list")
|
||||||
return set()
|
return set()
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
log.warning("fail2ban_status_error_during_inactive_list", error=str(exc))
|
log.warning(
|
||||||
|
"fail2ban_status_error_during_inactive_list", error=str(exc)
|
||||||
|
)
|
||||||
return set()
|
return set()
|
||||||
|
|
||||||
|
|
||||||
@@ -670,7 +656,10 @@ def _validate_jail_config_sync(
|
|||||||
issues.append(
|
issues.append(
|
||||||
JailValidationIssue(
|
JailValidationIssue(
|
||||||
field="filter",
|
field="filter",
|
||||||
message=(f"Filter file not found: filter.d/{base_filter}.conf (or .local)"),
|
message=(
|
||||||
|
f"Filter file not found: filter.d/{base_filter}.conf"
|
||||||
|
" (or .local)"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -686,7 +675,10 @@ def _validate_jail_config_sync(
|
|||||||
issues.append(
|
issues.append(
|
||||||
JailValidationIssue(
|
JailValidationIssue(
|
||||||
field="action",
|
field="action",
|
||||||
message=(f"Action file not found: action.d/{action_name}.conf (or .local)"),
|
message=(
|
||||||
|
f"Action file not found: action.d/{action_name}.conf"
|
||||||
|
" (or .local)"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -820,7 +812,7 @@ def _write_local_override_sync(
|
|||||||
config_dir: Path,
|
config_dir: Path,
|
||||||
jail_name: str,
|
jail_name: str,
|
||||||
enabled: bool,
|
enabled: bool,
|
||||||
overrides: dict[str, object],
|
overrides: dict[str, Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Write a ``jail.d/{name}.local`` file atomically.
|
"""Write a ``jail.d/{name}.local`` file atomically.
|
||||||
|
|
||||||
@@ -842,7 +834,9 @@ def _write_local_override_sync(
|
|||||||
try:
|
try:
|
||||||
jail_d.mkdir(parents=True, exist_ok=True)
|
jail_d.mkdir(parents=True, exist_ok=True)
|
||||||
except OSError as exc:
|
except OSError as exc:
|
||||||
raise ConfigWriteError(f"Cannot create jail.d directory: {exc}") from exc
|
raise ConfigWriteError(
|
||||||
|
f"Cannot create jail.d directory: {exc}"
|
||||||
|
) from exc
|
||||||
|
|
||||||
local_path = jail_d / f"{jail_name}.local"
|
local_path = jail_d / f"{jail_name}.local"
|
||||||
|
|
||||||
@@ -867,7 +861,7 @@ def _write_local_override_sync(
|
|||||||
if overrides.get("port") is not None:
|
if overrides.get("port") is not None:
|
||||||
lines.append(f"port = {overrides['port']}")
|
lines.append(f"port = {overrides['port']}")
|
||||||
if overrides.get("logpath"):
|
if overrides.get("logpath"):
|
||||||
paths: list[str] = cast("list[str]", overrides["logpath"])
|
paths: list[str] = overrides["logpath"]
|
||||||
if paths:
|
if paths:
|
||||||
lines.append(f"logpath = {paths[0]}")
|
lines.append(f"logpath = {paths[0]}")
|
||||||
for p in paths[1:]:
|
for p in paths[1:]:
|
||||||
@@ -890,7 +884,9 @@ def _write_local_override_sync(
|
|||||||
# Clean up temp file if rename failed.
|
# Clean up temp file if rename failed.
|
||||||
with contextlib.suppress(OSError):
|
with contextlib.suppress(OSError):
|
||||||
os.unlink(tmp_name) # noqa: F821 — only reachable when tmp_name is set
|
os.unlink(tmp_name) # noqa: F821 — only reachable when tmp_name is set
|
||||||
raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc
|
raise ConfigWriteError(
|
||||||
|
f"Failed to write {local_path}: {exc}"
|
||||||
|
) from exc
|
||||||
|
|
||||||
log.info(
|
log.info(
|
||||||
"jail_local_written",
|
"jail_local_written",
|
||||||
@@ -919,7 +915,9 @@ def _restore_local_file_sync(local_path: Path, original_content: bytes | None) -
|
|||||||
try:
|
try:
|
||||||
local_path.unlink(missing_ok=True)
|
local_path.unlink(missing_ok=True)
|
||||||
except OSError as exc:
|
except OSError as exc:
|
||||||
raise ConfigWriteError(f"Failed to delete {local_path} during rollback: {exc}") from exc
|
raise ConfigWriteError(
|
||||||
|
f"Failed to delete {local_path} during rollback: {exc}"
|
||||||
|
) from exc
|
||||||
return
|
return
|
||||||
|
|
||||||
tmp_name: str | None = None
|
tmp_name: str | None = None
|
||||||
@@ -937,7 +935,9 @@ def _restore_local_file_sync(local_path: Path, original_content: bytes | None) -
|
|||||||
with contextlib.suppress(OSError):
|
with contextlib.suppress(OSError):
|
||||||
if tmp_name is not None:
|
if tmp_name is not None:
|
||||||
os.unlink(tmp_name)
|
os.unlink(tmp_name)
|
||||||
raise ConfigWriteError(f"Failed to restore {local_path} during rollback: {exc}") from exc
|
raise ConfigWriteError(
|
||||||
|
f"Failed to restore {local_path} during rollback: {exc}"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
|
||||||
def _validate_regex_patterns(patterns: list[str]) -> None:
|
def _validate_regex_patterns(patterns: list[str]) -> None:
|
||||||
@@ -973,7 +973,9 @@ def _write_filter_local_sync(filter_d: Path, name: str, content: str) -> None:
|
|||||||
try:
|
try:
|
||||||
filter_d.mkdir(parents=True, exist_ok=True)
|
filter_d.mkdir(parents=True, exist_ok=True)
|
||||||
except OSError as exc:
|
except OSError as exc:
|
||||||
raise ConfigWriteError(f"Cannot create filter.d directory: {exc}") from exc
|
raise ConfigWriteError(
|
||||||
|
f"Cannot create filter.d directory: {exc}"
|
||||||
|
) from exc
|
||||||
|
|
||||||
local_path = filter_d / f"{name}.local"
|
local_path = filter_d / f"{name}.local"
|
||||||
try:
|
try:
|
||||||
@@ -990,7 +992,9 @@ def _write_filter_local_sync(filter_d: Path, name: str, content: str) -> None:
|
|||||||
except OSError as exc:
|
except OSError as exc:
|
||||||
with contextlib.suppress(OSError):
|
with contextlib.suppress(OSError):
|
||||||
os.unlink(tmp_name) # noqa: F821
|
os.unlink(tmp_name) # noqa: F821
|
||||||
raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc
|
raise ConfigWriteError(
|
||||||
|
f"Failed to write {local_path}: {exc}"
|
||||||
|
) from exc
|
||||||
|
|
||||||
log.info("filter_local_written", filter=name, path=str(local_path))
|
log.info("filter_local_written", filter=name, path=str(local_path))
|
||||||
|
|
||||||
@@ -1021,7 +1025,9 @@ def _set_jail_local_key_sync(
|
|||||||
try:
|
try:
|
||||||
jail_d.mkdir(parents=True, exist_ok=True)
|
jail_d.mkdir(parents=True, exist_ok=True)
|
||||||
except OSError as exc:
|
except OSError as exc:
|
||||||
raise ConfigWriteError(f"Cannot create jail.d directory: {exc}") from exc
|
raise ConfigWriteError(
|
||||||
|
f"Cannot create jail.d directory: {exc}"
|
||||||
|
) from exc
|
||||||
|
|
||||||
local_path = jail_d / f"{jail_name}.local"
|
local_path = jail_d / f"{jail_name}.local"
|
||||||
|
|
||||||
@@ -1060,7 +1066,9 @@ def _set_jail_local_key_sync(
|
|||||||
except OSError as exc:
|
except OSError as exc:
|
||||||
with contextlib.suppress(OSError):
|
with contextlib.suppress(OSError):
|
||||||
os.unlink(tmp_name) # noqa: F821
|
os.unlink(tmp_name) # noqa: F821
|
||||||
raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc
|
raise ConfigWriteError(
|
||||||
|
f"Failed to write {local_path}: {exc}"
|
||||||
|
) from exc
|
||||||
|
|
||||||
log.info(
|
log.info(
|
||||||
"jail_local_key_set",
|
"jail_local_key_set",
|
||||||
@@ -1098,8 +1106,8 @@ async def list_inactive_jails(
|
|||||||
inactive jails.
|
inactive jails.
|
||||||
"""
|
"""
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
parsed_result: tuple[dict[str, dict[str, str]], dict[str, str]] = await loop.run_in_executor(
|
parsed_result: tuple[dict[str, dict[str, str]], dict[str, str]] = (
|
||||||
None, _parse_jails_sync, Path(config_dir)
|
await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir))
|
||||||
)
|
)
|
||||||
all_jails, source_files = parsed_result
|
all_jails, source_files = parsed_result
|
||||||
active_names: set[str] = await _get_active_jail_names(socket_path)
|
active_names: set[str] = await _get_active_jail_names(socket_path)
|
||||||
@@ -1156,7 +1164,9 @@ async def activate_jail(
|
|||||||
_safe_jail_name(name)
|
_safe_jail_name(name)
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
all_jails, _source_files = await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir))
|
all_jails, _source_files = await loop.run_in_executor(
|
||||||
|
None, _parse_jails_sync, Path(config_dir)
|
||||||
|
)
|
||||||
|
|
||||||
if name not in all_jails:
|
if name not in all_jails:
|
||||||
raise JailNotFoundInConfigError(name)
|
raise JailNotFoundInConfigError(name)
|
||||||
@@ -1192,10 +1202,13 @@ async def activate_jail(
|
|||||||
active=False,
|
active=False,
|
||||||
fail2ban_running=True,
|
fail2ban_running=True,
|
||||||
validation_warnings=warnings,
|
validation_warnings=warnings,
|
||||||
message=(f"Jail {name!r} cannot be activated: " + "; ".join(i.message for i in blocking)),
|
message=(
|
||||||
|
f"Jail {name!r} cannot be activated: "
|
||||||
|
+ "; ".join(i.message for i in blocking)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
overrides: dict[str, object] = {
|
overrides: dict[str, Any] = {
|
||||||
"bantime": req.bantime,
|
"bantime": req.bantime,
|
||||||
"findtime": req.findtime,
|
"findtime": req.findtime,
|
||||||
"maxretry": req.maxretry,
|
"maxretry": req.maxretry,
|
||||||
@@ -1226,7 +1239,7 @@ async def activate_jail(
|
|||||||
# Activation reload — if it fails, roll back immediately #
|
# Activation reload — if it fails, roll back immediately #
|
||||||
# ---------------------------------------------------------------------- #
|
# ---------------------------------------------------------------------- #
|
||||||
try:
|
try:
|
||||||
await _reload_all(socket_path, include_jails=[name])
|
await jail_service.reload_all(socket_path, include_jails=[name])
|
||||||
except JailNotFoundError as exc:
|
except JailNotFoundError as exc:
|
||||||
# Jail configuration is invalid (e.g. missing logpath that prevents
|
# Jail configuration is invalid (e.g. missing logpath that prevents
|
||||||
# fail2ban from loading the jail). Roll back and provide a specific error.
|
# fail2ban from loading the jail). Roll back and provide a specific error.
|
||||||
@@ -1235,7 +1248,9 @@ async def activate_jail(
|
|||||||
jail=name,
|
jail=name,
|
||||||
error=str(exc),
|
error=str(exc),
|
||||||
)
|
)
|
||||||
recovered = await _rollback_activation_async(config_dir, name, socket_path, original_content)
|
recovered = await _rollback_activation_async(
|
||||||
|
config_dir, name, socket_path, original_content
|
||||||
|
)
|
||||||
return JailActivationResponse(
|
return JailActivationResponse(
|
||||||
name=name,
|
name=name,
|
||||||
active=False,
|
active=False,
|
||||||
@@ -1251,7 +1266,9 @@ async def activate_jail(
|
|||||||
)
|
)
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
log.warning("reload_after_activate_failed", jail=name, error=str(exc))
|
log.warning("reload_after_activate_failed", jail=name, error=str(exc))
|
||||||
recovered = await _rollback_activation_async(config_dir, name, socket_path, original_content)
|
recovered = await _rollback_activation_async(
|
||||||
|
config_dir, name, socket_path, original_content
|
||||||
|
)
|
||||||
return JailActivationResponse(
|
return JailActivationResponse(
|
||||||
name=name,
|
name=name,
|
||||||
active=False,
|
active=False,
|
||||||
@@ -1282,7 +1299,9 @@ async def activate_jail(
|
|||||||
jail=name,
|
jail=name,
|
||||||
message="fail2ban socket unreachable after reload — initiating rollback.",
|
message="fail2ban socket unreachable after reload — initiating rollback.",
|
||||||
)
|
)
|
||||||
recovered = await _rollback_activation_async(config_dir, name, socket_path, original_content)
|
recovered = await _rollback_activation_async(
|
||||||
|
config_dir, name, socket_path, original_content
|
||||||
|
)
|
||||||
return JailActivationResponse(
|
return JailActivationResponse(
|
||||||
name=name,
|
name=name,
|
||||||
active=False,
|
active=False,
|
||||||
@@ -1305,7 +1324,9 @@ async def activate_jail(
|
|||||||
jail=name,
|
jail=name,
|
||||||
message="Jail did not appear in running jails — initiating rollback.",
|
message="Jail did not appear in running jails — initiating rollback.",
|
||||||
)
|
)
|
||||||
recovered = await _rollback_activation_async(config_dir, name, socket_path, original_content)
|
recovered = await _rollback_activation_async(
|
||||||
|
config_dir, name, socket_path, original_content
|
||||||
|
)
|
||||||
return JailActivationResponse(
|
return JailActivationResponse(
|
||||||
name=name,
|
name=name,
|
||||||
active=False,
|
active=False,
|
||||||
@@ -1361,18 +1382,24 @@ async def _rollback_activation_async(
|
|||||||
|
|
||||||
# Step 1 — restore original file (or delete it).
|
# Step 1 — restore original file (or delete it).
|
||||||
try:
|
try:
|
||||||
await loop.run_in_executor(None, _restore_local_file_sync, local_path, original_content)
|
await loop.run_in_executor(
|
||||||
|
None, _restore_local_file_sync, local_path, original_content
|
||||||
|
)
|
||||||
log.info("jail_activation_rollback_file_restored", jail=name)
|
log.info("jail_activation_rollback_file_restored", jail=name)
|
||||||
except ConfigWriteError as exc:
|
except ConfigWriteError as exc:
|
||||||
log.error("jail_activation_rollback_restore_failed", jail=name, error=str(exc))
|
log.error(
|
||||||
|
"jail_activation_rollback_restore_failed", jail=name, error=str(exc)
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Step 2 — reload fail2ban with the restored config.
|
# Step 2 — reload fail2ban with the restored config.
|
||||||
try:
|
try:
|
||||||
await _reload_all(socket_path)
|
await jail_service.reload_all(socket_path)
|
||||||
log.info("jail_activation_rollback_reload_ok", jail=name)
|
log.info("jail_activation_rollback_reload_ok", jail=name)
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
log.warning("jail_activation_rollback_reload_failed", jail=name, error=str(exc))
|
log.warning(
|
||||||
|
"jail_activation_rollback_reload_failed", jail=name, error=str(exc)
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Step 3 — wait for fail2ban to come back.
|
# Step 3 — wait for fail2ban to come back.
|
||||||
@@ -1417,7 +1444,9 @@ async def deactivate_jail(
|
|||||||
_safe_jail_name(name)
|
_safe_jail_name(name)
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
all_jails, _source_files = await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir))
|
all_jails, _source_files = await loop.run_in_executor(
|
||||||
|
None, _parse_jails_sync, Path(config_dir)
|
||||||
|
)
|
||||||
|
|
||||||
if name not in all_jails:
|
if name not in all_jails:
|
||||||
raise JailNotFoundInConfigError(name)
|
raise JailNotFoundInConfigError(name)
|
||||||
@@ -1436,7 +1465,7 @@ async def deactivate_jail(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await _reload_all(socket_path, exclude_jails=[name])
|
await jail_service.reload_all(socket_path, exclude_jails=[name])
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
log.warning("reload_after_deactivate_failed", jail=name, error=str(exc))
|
log.warning("reload_after_deactivate_failed", jail=name, error=str(exc))
|
||||||
|
|
||||||
@@ -1475,7 +1504,9 @@ async def delete_jail_local_override(
|
|||||||
_safe_jail_name(name)
|
_safe_jail_name(name)
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
all_jails, _source_files = await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir))
|
all_jails, _source_files = await loop.run_in_executor(
|
||||||
|
None, _parse_jails_sync, Path(config_dir)
|
||||||
|
)
|
||||||
|
|
||||||
if name not in all_jails:
|
if name not in all_jails:
|
||||||
raise JailNotFoundInConfigError(name)
|
raise JailNotFoundInConfigError(name)
|
||||||
@@ -1486,9 +1517,13 @@ async def delete_jail_local_override(
|
|||||||
|
|
||||||
local_path = Path(config_dir) / "jail.d" / f"{name}.local"
|
local_path = Path(config_dir) / "jail.d" / f"{name}.local"
|
||||||
try:
|
try:
|
||||||
await loop.run_in_executor(None, lambda: local_path.unlink(missing_ok=True))
|
await loop.run_in_executor(
|
||||||
|
None, lambda: local_path.unlink(missing_ok=True)
|
||||||
|
)
|
||||||
except OSError as exc:
|
except OSError as exc:
|
||||||
raise ConfigWriteError(f"Failed to delete {local_path}: {exc}") from exc
|
raise ConfigWriteError(
|
||||||
|
f"Failed to delete {local_path}: {exc}"
|
||||||
|
) from exc
|
||||||
|
|
||||||
log.info("jail_local_override_deleted", jail=name, path=str(local_path))
|
log.info("jail_local_override_deleted", jail=name, path=str(local_path))
|
||||||
|
|
||||||
@@ -1569,7 +1604,9 @@ async def rollback_jail(
|
|||||||
log.info("jail_rollback_start_attempted", jail=name, start_ok=started)
|
log.info("jail_rollback_start_attempted", jail=name, start_ok=started)
|
||||||
|
|
||||||
# Wait for the socket to come back.
|
# Wait for the socket to come back.
|
||||||
fail2ban_running = await wait_for_fail2ban(socket_path, max_wait_seconds=10.0, poll_interval=2.0)
|
fail2ban_running = await wait_for_fail2ban(
|
||||||
|
socket_path, max_wait_seconds=10.0, poll_interval=2.0
|
||||||
|
)
|
||||||
|
|
||||||
active_jails = 0
|
active_jails = 0
|
||||||
if fail2ban_running:
|
if fail2ban_running:
|
||||||
@@ -1583,7 +1620,10 @@ async def rollback_jail(
|
|||||||
disabled=True,
|
disabled=True,
|
||||||
fail2ban_running=True,
|
fail2ban_running=True,
|
||||||
active_jails=active_jails,
|
active_jails=active_jails,
|
||||||
message=(f"Jail {name!r} disabled and fail2ban restarted successfully with {active_jails} active jail(s)."),
|
message=(
|
||||||
|
f"Jail {name!r} disabled and fail2ban restarted successfully "
|
||||||
|
f"with {active_jails} active jail(s)."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
log.warning("jail_rollback_fail2ban_still_down", jail=name)
|
log.warning("jail_rollback_fail2ban_still_down", jail=name)
|
||||||
@@ -1604,7 +1644,9 @@ async def rollback_jail(
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
# Allowlist pattern for filter names used in path construction.
|
# Allowlist pattern for filter names used in path construction.
|
||||||
_SAFE_FILTER_NAME_RE: re.Pattern[str] = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$")
|
_SAFE_FILTER_NAME_RE: re.Pattern[str] = re.compile(
|
||||||
|
r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FilterNotFoundError(Exception):
|
class FilterNotFoundError(Exception):
|
||||||
@@ -1716,7 +1758,9 @@ def _parse_filters_sync(
|
|||||||
try:
|
try:
|
||||||
content = conf_path.read_text(encoding="utf-8")
|
content = conf_path.read_text(encoding="utf-8")
|
||||||
except OSError as exc:
|
except OSError as exc:
|
||||||
log.warning("filter_read_error", name=name, path=str(conf_path), error=str(exc))
|
log.warning(
|
||||||
|
"filter_read_error", name=name, path=str(conf_path), error=str(exc)
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if has_local:
|
if has_local:
|
||||||
@@ -1792,7 +1836,9 @@ async def list_filters(
|
|||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
# Run the synchronous scan in a thread-pool executor.
|
# Run the synchronous scan in a thread-pool executor.
|
||||||
raw_filters: list[tuple[str, str, str, bool, str]] = await loop.run_in_executor(None, _parse_filters_sync, filter_d)
|
raw_filters: list[tuple[str, str, str, bool, str]] = await loop.run_in_executor(
|
||||||
|
None, _parse_filters_sync, filter_d
|
||||||
|
)
|
||||||
|
|
||||||
# Fetch active jail names and their configs concurrently.
|
# Fetch active jail names and their configs concurrently.
|
||||||
all_jails_result, active_names = await asyncio.gather(
|
all_jails_result, active_names = await asyncio.gather(
|
||||||
@@ -1805,7 +1851,9 @@ async def list_filters(
|
|||||||
|
|
||||||
filters: list[FilterConfig] = []
|
filters: list[FilterConfig] = []
|
||||||
for name, filename, content, has_local, source_path in raw_filters:
|
for name, filename, content, has_local, source_path in raw_filters:
|
||||||
cfg = conffile_parser.parse_filter_file(content, name=name, filename=filename)
|
cfg = conffile_parser.parse_filter_file(
|
||||||
|
content, name=name, filename=filename
|
||||||
|
)
|
||||||
used_by = sorted(filter_to_jails.get(name, []))
|
used_by = sorted(filter_to_jails.get(name, []))
|
||||||
filters.append(
|
filters.append(
|
||||||
FilterConfig(
|
FilterConfig(
|
||||||
@@ -1893,7 +1941,9 @@ async def get_filter(
|
|||||||
|
|
||||||
content, has_local, source_path = await loop.run_in_executor(None, _read)
|
content, has_local, source_path = await loop.run_in_executor(None, _read)
|
||||||
|
|
||||||
cfg = conffile_parser.parse_filter_file(content, name=base_name, filename=f"{base_name}.conf")
|
cfg = conffile_parser.parse_filter_file(
|
||||||
|
content, name=base_name, filename=f"{base_name}.conf"
|
||||||
|
)
|
||||||
|
|
||||||
all_jails_result, active_names = await asyncio.gather(
|
all_jails_result, active_names = await asyncio.gather(
|
||||||
loop.run_in_executor(None, _parse_jails_sync, Path(config_dir)),
|
loop.run_in_executor(None, _parse_jails_sync, Path(config_dir)),
|
||||||
@@ -1992,7 +2042,7 @@ async def update_filter(
|
|||||||
|
|
||||||
if do_reload:
|
if do_reload:
|
||||||
try:
|
try:
|
||||||
await _reload_all(socket_path)
|
await jail_service.reload_all(socket_path)
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
log.warning(
|
log.warning(
|
||||||
"reload_after_filter_update_failed",
|
"reload_after_filter_update_failed",
|
||||||
@@ -2067,7 +2117,7 @@ async def create_filter(
|
|||||||
|
|
||||||
if do_reload:
|
if do_reload:
|
||||||
try:
|
try:
|
||||||
await _reload_all(socket_path)
|
await jail_service.reload_all(socket_path)
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
log.warning(
|
log.warning(
|
||||||
"reload_after_filter_create_failed",
|
"reload_after_filter_create_failed",
|
||||||
@@ -2126,7 +2176,9 @@ async def delete_filter(
|
|||||||
try:
|
try:
|
||||||
local_path.unlink()
|
local_path.unlink()
|
||||||
except OSError as exc:
|
except OSError as exc:
|
||||||
raise ConfigWriteError(f"Failed to delete {local_path}: {exc}") from exc
|
raise ConfigWriteError(
|
||||||
|
f"Failed to delete {local_path}: {exc}"
|
||||||
|
) from exc
|
||||||
|
|
||||||
log.info("filter_local_deleted", filter=base_name, path=str(local_path))
|
log.info("filter_local_deleted", filter=base_name, path=str(local_path))
|
||||||
|
|
||||||
@@ -2168,7 +2220,9 @@ async def assign_filter_to_jail(
|
|||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
# Verify the jail exists in config.
|
# Verify the jail exists in config.
|
||||||
all_jails, _src = await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir))
|
all_jails, _src = await loop.run_in_executor(
|
||||||
|
None, _parse_jails_sync, Path(config_dir)
|
||||||
|
)
|
||||||
if jail_name not in all_jails:
|
if jail_name not in all_jails:
|
||||||
raise JailNotFoundInConfigError(jail_name)
|
raise JailNotFoundInConfigError(jail_name)
|
||||||
|
|
||||||
@@ -2194,7 +2248,7 @@ async def assign_filter_to_jail(
|
|||||||
|
|
||||||
if do_reload:
|
if do_reload:
|
||||||
try:
|
try:
|
||||||
await _reload_all(socket_path)
|
await jail_service.reload_all(socket_path)
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
log.warning(
|
log.warning(
|
||||||
"reload_after_assign_filter_failed",
|
"reload_after_assign_filter_failed",
|
||||||
@@ -2216,7 +2270,9 @@ async def assign_filter_to_jail(
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
# Allowlist pattern for action names used in path construction.
|
# Allowlist pattern for action names used in path construction.
|
||||||
_SAFE_ACTION_NAME_RE: re.Pattern[str] = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$")
|
_SAFE_ACTION_NAME_RE: re.Pattern[str] = re.compile(
|
||||||
|
r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ActionNotFoundError(Exception):
|
class ActionNotFoundError(Exception):
|
||||||
@@ -2256,7 +2312,8 @@ class ActionReadonlyError(Exception):
|
|||||||
"""
|
"""
|
||||||
self.name: str = name
|
self.name: str = name
|
||||||
super().__init__(
|
super().__init__(
|
||||||
f"Action {name!r} is a shipped default (.conf only); only user-created .local files can be deleted."
|
f"Action {name!r} is a shipped default (.conf only); "
|
||||||
|
"only user-created .local files can be deleted."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -2365,7 +2422,9 @@ def _parse_actions_sync(
|
|||||||
try:
|
try:
|
||||||
content = conf_path.read_text(encoding="utf-8")
|
content = conf_path.read_text(encoding="utf-8")
|
||||||
except OSError as exc:
|
except OSError as exc:
|
||||||
log.warning("action_read_error", name=name, path=str(conf_path), error=str(exc))
|
log.warning(
|
||||||
|
"action_read_error", name=name, path=str(conf_path), error=str(exc)
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if has_local:
|
if has_local:
|
||||||
@@ -2430,7 +2489,9 @@ def _append_jail_action_sync(
|
|||||||
try:
|
try:
|
||||||
jail_d.mkdir(parents=True, exist_ok=True)
|
jail_d.mkdir(parents=True, exist_ok=True)
|
||||||
except OSError as exc:
|
except OSError as exc:
|
||||||
raise ConfigWriteError(f"Cannot create jail.d directory: {exc}") from exc
|
raise ConfigWriteError(
|
||||||
|
f"Cannot create jail.d directory: {exc}"
|
||||||
|
) from exc
|
||||||
|
|
||||||
local_path = jail_d / f"{jail_name}.local"
|
local_path = jail_d / f"{jail_name}.local"
|
||||||
|
|
||||||
@@ -2450,7 +2511,9 @@ def _append_jail_action_sync(
|
|||||||
|
|
||||||
existing_raw = parser.get(jail_name, "action") if parser.has_option(jail_name, "action") else ""
|
existing_raw = parser.get(jail_name, "action") if parser.has_option(jail_name, "action") else ""
|
||||||
existing_lines = [
|
existing_lines = [
|
||||||
line.strip() for line in existing_raw.splitlines() if line.strip() and not line.strip().startswith("#")
|
line.strip()
|
||||||
|
for line in existing_raw.splitlines()
|
||||||
|
if line.strip() and not line.strip().startswith("#")
|
||||||
]
|
]
|
||||||
|
|
||||||
# Extract base names from existing entries for duplicate checking.
|
# Extract base names from existing entries for duplicate checking.
|
||||||
@@ -2464,7 +2527,9 @@ def _append_jail_action_sync(
|
|||||||
|
|
||||||
if existing_lines:
|
if existing_lines:
|
||||||
# configparser multi-line: continuation lines start with whitespace.
|
# configparser multi-line: continuation lines start with whitespace.
|
||||||
new_value = existing_lines[0] + "".join(f"\n {line}" for line in existing_lines[1:])
|
new_value = existing_lines[0] + "".join(
|
||||||
|
f"\n {line}" for line in existing_lines[1:]
|
||||||
|
)
|
||||||
parser.set(jail_name, "action", new_value)
|
parser.set(jail_name, "action", new_value)
|
||||||
else:
|
else:
|
||||||
parser.set(jail_name, "action", action_entry)
|
parser.set(jail_name, "action", action_entry)
|
||||||
@@ -2488,7 +2553,9 @@ def _append_jail_action_sync(
|
|||||||
except OSError as exc:
|
except OSError as exc:
|
||||||
with contextlib.suppress(OSError):
|
with contextlib.suppress(OSError):
|
||||||
os.unlink(tmp_name) # noqa: F821
|
os.unlink(tmp_name) # noqa: F821
|
||||||
raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc
|
raise ConfigWriteError(
|
||||||
|
f"Failed to write {local_path}: {exc}"
|
||||||
|
) from exc
|
||||||
|
|
||||||
log.info(
|
log.info(
|
||||||
"jail_action_appended",
|
"jail_action_appended",
|
||||||
@@ -2539,7 +2606,9 @@ def _remove_jail_action_sync(
|
|||||||
|
|
||||||
existing_raw = parser.get(jail_name, "action")
|
existing_raw = parser.get(jail_name, "action")
|
||||||
existing_lines = [
|
existing_lines = [
|
||||||
line.strip() for line in existing_raw.splitlines() if line.strip() and not line.strip().startswith("#")
|
line.strip()
|
||||||
|
for line in existing_raw.splitlines()
|
||||||
|
if line.strip() and not line.strip().startswith("#")
|
||||||
]
|
]
|
||||||
|
|
||||||
def _base(entry: str) -> str:
|
def _base(entry: str) -> str:
|
||||||
@@ -2553,7 +2622,9 @@ def _remove_jail_action_sync(
|
|||||||
return
|
return
|
||||||
|
|
||||||
if filtered:
|
if filtered:
|
||||||
new_value = filtered[0] + "".join(f"\n {line}" for line in filtered[1:])
|
new_value = filtered[0] + "".join(
|
||||||
|
f"\n {line}" for line in filtered[1:]
|
||||||
|
)
|
||||||
parser.set(jail_name, "action", new_value)
|
parser.set(jail_name, "action", new_value)
|
||||||
else:
|
else:
|
||||||
parser.remove_option(jail_name, "action")
|
parser.remove_option(jail_name, "action")
|
||||||
@@ -2577,7 +2648,9 @@ def _remove_jail_action_sync(
|
|||||||
except OSError as exc:
|
except OSError as exc:
|
||||||
with contextlib.suppress(OSError):
|
with contextlib.suppress(OSError):
|
||||||
os.unlink(tmp_name) # noqa: F821
|
os.unlink(tmp_name) # noqa: F821
|
||||||
raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc
|
raise ConfigWriteError(
|
||||||
|
f"Failed to write {local_path}: {exc}"
|
||||||
|
) from exc
|
||||||
|
|
||||||
log.info(
|
log.info(
|
||||||
"jail_action_removed",
|
"jail_action_removed",
|
||||||
@@ -2604,7 +2677,9 @@ def _write_action_local_sync(action_d: Path, name: str, content: str) -> None:
|
|||||||
try:
|
try:
|
||||||
action_d.mkdir(parents=True, exist_ok=True)
|
action_d.mkdir(parents=True, exist_ok=True)
|
||||||
except OSError as exc:
|
except OSError as exc:
|
||||||
raise ConfigWriteError(f"Cannot create action.d directory: {exc}") from exc
|
raise ConfigWriteError(
|
||||||
|
f"Cannot create action.d directory: {exc}"
|
||||||
|
) from exc
|
||||||
|
|
||||||
local_path = action_d / f"{name}.local"
|
local_path = action_d / f"{name}.local"
|
||||||
try:
|
try:
|
||||||
@@ -2621,7 +2696,9 @@ def _write_action_local_sync(action_d: Path, name: str, content: str) -> None:
|
|||||||
except OSError as exc:
|
except OSError as exc:
|
||||||
with contextlib.suppress(OSError):
|
with contextlib.suppress(OSError):
|
||||||
os.unlink(tmp_name) # noqa: F821
|
os.unlink(tmp_name) # noqa: F821
|
||||||
raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc
|
raise ConfigWriteError(
|
||||||
|
f"Failed to write {local_path}: {exc}"
|
||||||
|
) from exc
|
||||||
|
|
||||||
log.info("action_local_written", action=name, path=str(local_path))
|
log.info("action_local_written", action=name, path=str(local_path))
|
||||||
|
|
||||||
@@ -2657,7 +2734,9 @@ async def list_actions(
|
|||||||
action_d = Path(config_dir) / "action.d"
|
action_d = Path(config_dir) / "action.d"
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
raw_actions: list[tuple[str, str, str, bool, str]] = await loop.run_in_executor(None, _parse_actions_sync, action_d)
|
raw_actions: list[tuple[str, str, str, bool, str]] = await loop.run_in_executor(
|
||||||
|
None, _parse_actions_sync, action_d
|
||||||
|
)
|
||||||
|
|
||||||
all_jails_result, active_names = await asyncio.gather(
|
all_jails_result, active_names = await asyncio.gather(
|
||||||
loop.run_in_executor(None, _parse_jails_sync, Path(config_dir)),
|
loop.run_in_executor(None, _parse_jails_sync, Path(config_dir)),
|
||||||
@@ -2669,7 +2748,9 @@ async def list_actions(
|
|||||||
|
|
||||||
actions: list[ActionConfig] = []
|
actions: list[ActionConfig] = []
|
||||||
for name, filename, content, has_local, source_path in raw_actions:
|
for name, filename, content, has_local, source_path in raw_actions:
|
||||||
cfg = conffile_parser.parse_action_file(content, name=name, filename=filename)
|
cfg = conffile_parser.parse_action_file(
|
||||||
|
content, name=name, filename=filename
|
||||||
|
)
|
||||||
used_by = sorted(action_to_jails.get(name, []))
|
used_by = sorted(action_to_jails.get(name, []))
|
||||||
actions.append(
|
actions.append(
|
||||||
ActionConfig(
|
ActionConfig(
|
||||||
@@ -2756,7 +2837,9 @@ async def get_action(
|
|||||||
|
|
||||||
content, has_local, source_path = await loop.run_in_executor(None, _read)
|
content, has_local, source_path = await loop.run_in_executor(None, _read)
|
||||||
|
|
||||||
cfg = conffile_parser.parse_action_file(content, name=base_name, filename=f"{base_name}.conf")
|
cfg = conffile_parser.parse_action_file(
|
||||||
|
content, name=base_name, filename=f"{base_name}.conf"
|
||||||
|
)
|
||||||
|
|
||||||
all_jails_result, active_names = await asyncio.gather(
|
all_jails_result, active_names = await asyncio.gather(
|
||||||
loop.run_in_executor(None, _parse_jails_sync, Path(config_dir)),
|
loop.run_in_executor(None, _parse_jails_sync, Path(config_dir)),
|
||||||
@@ -2846,7 +2929,7 @@ async def update_action(
|
|||||||
|
|
||||||
if do_reload:
|
if do_reload:
|
||||||
try:
|
try:
|
||||||
await _reload_all(socket_path)
|
await jail_service.reload_all(socket_path)
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
log.warning(
|
log.warning(
|
||||||
"reload_after_action_update_failed",
|
"reload_after_action_update_failed",
|
||||||
@@ -2915,7 +2998,7 @@ async def create_action(
|
|||||||
|
|
||||||
if do_reload:
|
if do_reload:
|
||||||
try:
|
try:
|
||||||
await _reload_all(socket_path)
|
await jail_service.reload_all(socket_path)
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
log.warning(
|
log.warning(
|
||||||
"reload_after_action_create_failed",
|
"reload_after_action_create_failed",
|
||||||
@@ -2972,7 +3055,9 @@ async def delete_action(
|
|||||||
try:
|
try:
|
||||||
local_path.unlink()
|
local_path.unlink()
|
||||||
except OSError as exc:
|
except OSError as exc:
|
||||||
raise ConfigWriteError(f"Failed to delete {local_path}: {exc}") from exc
|
raise ConfigWriteError(
|
||||||
|
f"Failed to delete {local_path}: {exc}"
|
||||||
|
) from exc
|
||||||
|
|
||||||
log.info("action_local_deleted", action=base_name, path=str(local_path))
|
log.info("action_local_deleted", action=base_name, path=str(local_path))
|
||||||
|
|
||||||
@@ -3014,7 +3099,9 @@ async def assign_action_to_jail(
|
|||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
all_jails, _src = await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir))
|
all_jails, _src = await loop.run_in_executor(
|
||||||
|
None, _parse_jails_sync, Path(config_dir)
|
||||||
|
)
|
||||||
if jail_name not in all_jails:
|
if jail_name not in all_jails:
|
||||||
raise JailNotFoundInConfigError(jail_name)
|
raise JailNotFoundInConfigError(jail_name)
|
||||||
|
|
||||||
@@ -3046,7 +3133,7 @@ async def assign_action_to_jail(
|
|||||||
|
|
||||||
if do_reload:
|
if do_reload:
|
||||||
try:
|
try:
|
||||||
await _reload_all(socket_path)
|
await jail_service.reload_all(socket_path)
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
log.warning(
|
log.warning(
|
||||||
"reload_after_assign_action_failed",
|
"reload_after_assign_action_failed",
|
||||||
@@ -3094,7 +3181,9 @@ async def remove_action_from_jail(
|
|||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
all_jails, _src = await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir))
|
all_jails, _src = await loop.run_in_executor(
|
||||||
|
None, _parse_jails_sync, Path(config_dir)
|
||||||
|
)
|
||||||
if jail_name not in all_jails:
|
if jail_name not in all_jails:
|
||||||
raise JailNotFoundInConfigError(jail_name)
|
raise JailNotFoundInConfigError(jail_name)
|
||||||
|
|
||||||
@@ -3108,7 +3197,7 @@ async def remove_action_from_jail(
|
|||||||
|
|
||||||
if do_reload:
|
if do_reload:
|
||||||
try:
|
try:
|
||||||
await _reload_all(socket_path)
|
await jail_service.reload_all(socket_path)
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
log.warning(
|
log.warning(
|
||||||
"reload_after_remove_action_failed",
|
"reload_after_remove_action_failed",
|
||||||
@@ -3123,3 +3212,4 @@ async def remove_action_from_jail(
|
|||||||
action=action_name,
|
action=action_name,
|
||||||
reload=do_reload,
|
reload=do_reload,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -15,14 +15,11 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import contextlib
|
import contextlib
|
||||||
import re
|
import re
|
||||||
from collections.abc import Awaitable, Callable
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, TypeVar, cast
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
|
|
||||||
from app.utils.fail2ban_client import Fail2BanCommand, Fail2BanResponse, Fail2BanToken
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import aiosqlite
|
import aiosqlite
|
||||||
|
|
||||||
@@ -36,6 +33,7 @@ from app.models.config import (
|
|||||||
JailConfigListResponse,
|
JailConfigListResponse,
|
||||||
JailConfigResponse,
|
JailConfigResponse,
|
||||||
JailConfigUpdate,
|
JailConfigUpdate,
|
||||||
|
LogPreviewLine,
|
||||||
LogPreviewRequest,
|
LogPreviewRequest,
|
||||||
LogPreviewResponse,
|
LogPreviewResponse,
|
||||||
MapColorThresholdsResponse,
|
MapColorThresholdsResponse,
|
||||||
@@ -44,13 +42,8 @@ from app.models.config import (
|
|||||||
RegexTestResponse,
|
RegexTestResponse,
|
||||||
ServiceStatusResponse,
|
ServiceStatusResponse,
|
||||||
)
|
)
|
||||||
from app.exceptions import ConfigOperationError, ConfigValidationError, JailNotFoundError
|
from app.services import setup_service
|
||||||
from app.utils.fail2ban_client import Fail2BanClient
|
from app.utils.fail2ban_client import Fail2BanClient
|
||||||
from app.utils.log_utils import preview_log as util_preview_log, test_regex as util_test_regex
|
|
||||||
from app.utils.setup_utils import (
|
|
||||||
get_map_color_thresholds as util_get_map_color_thresholds,
|
|
||||||
set_map_color_thresholds as util_set_map_color_thresholds,
|
|
||||||
)
|
|
||||||
|
|
||||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||||
|
|
||||||
@@ -60,7 +53,26 @@ _SOCKET_TIMEOUT: float = 10.0
|
|||||||
# Custom exceptions
|
# Custom exceptions
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
# (exceptions are now defined in app.exceptions and imported above)
|
|
||||||
|
class JailNotFoundError(Exception):
|
||||||
|
"""Raised when a requested jail name does not exist in fail2ban."""
|
||||||
|
|
||||||
|
def __init__(self, name: str) -> None:
|
||||||
|
"""Initialise with the jail name that was not found.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The jail name that could not be located.
|
||||||
|
"""
|
||||||
|
self.name: str = name
|
||||||
|
super().__init__(f"Jail not found: {name!r}")
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigValidationError(Exception):
|
||||||
|
"""Raised when a configuration value fails validation before writing."""
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigOperationError(Exception):
|
||||||
|
"""Raised when a configuration write command fails."""
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -68,7 +80,7 @@ _SOCKET_TIMEOUT: float = 10.0
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def _ok(response: object) -> object:
|
def _ok(response: Any) -> Any:
|
||||||
"""Extract payload from a fail2ban ``(return_code, data)`` response.
|
"""Extract payload from a fail2ban ``(return_code, data)`` response.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -81,7 +93,7 @@ def _ok(response: object) -> object:
|
|||||||
ValueError: If the return code indicates an error.
|
ValueError: If the return code indicates an error.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
code, data = cast("Fail2BanResponse", response)
|
code, data = response
|
||||||
except (TypeError, ValueError) as exc:
|
except (TypeError, ValueError) as exc:
|
||||||
raise ValueError(f"Unexpected fail2ban response shape: {response!r}") from exc
|
raise ValueError(f"Unexpected fail2ban response shape: {response!r}") from exc
|
||||||
if code != 0:
|
if code != 0:
|
||||||
@@ -89,11 +101,11 @@ def _ok(response: object) -> object:
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
def _to_dict(pairs: object) -> dict[str, object]:
|
def _to_dict(pairs: Any) -> dict[str, Any]:
|
||||||
"""Convert a list of ``(key, value)`` pairs to a plain dict."""
|
"""Convert a list of ``(key, value)`` pairs to a plain dict."""
|
||||||
if not isinstance(pairs, (list, tuple)):
|
if not isinstance(pairs, (list, tuple)):
|
||||||
return {}
|
return {}
|
||||||
result: dict[str, object] = {}
|
result: dict[str, Any] = {}
|
||||||
for item in pairs:
|
for item in pairs:
|
||||||
try:
|
try:
|
||||||
k, v = item
|
k, v = item
|
||||||
@@ -103,7 +115,7 @@ def _to_dict(pairs: object) -> dict[str, object]:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def _ensure_list(value: object | None) -> list[str]:
|
def _ensure_list(value: Any) -> list[str]:
|
||||||
"""Coerce a fail2ban ``get`` result to a list of strings."""
|
"""Coerce a fail2ban ``get`` result to a list of strings."""
|
||||||
if value is None:
|
if value is None:
|
||||||
return []
|
return []
|
||||||
@@ -114,14 +126,11 @@ def _ensure_list(value: object | None) -> list[str]:
|
|||||||
return [str(value)]
|
return [str(value)]
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
|
|
||||||
async def _safe_get(
|
async def _safe_get(
|
||||||
client: Fail2BanClient,
|
client: Fail2BanClient,
|
||||||
command: Fail2BanCommand,
|
command: list[Any],
|
||||||
default: object | None = None,
|
default: Any = None,
|
||||||
) -> object | None:
|
) -> Any:
|
||||||
"""Send a command and return *default* if it fails."""
|
"""Send a command and return *default* if it fails."""
|
||||||
try:
|
try:
|
||||||
return _ok(await client.send(command))
|
return _ok(await client.send(command))
|
||||||
@@ -129,15 +138,6 @@ async def _safe_get(
|
|||||||
return default
|
return default
|
||||||
|
|
||||||
|
|
||||||
async def _safe_get_typed[T](
|
|
||||||
client: Fail2BanClient,
|
|
||||||
command: Fail2BanCommand,
|
|
||||||
default: T,
|
|
||||||
) -> T:
|
|
||||||
"""Send a command and return the result typed as ``default``'s type."""
|
|
||||||
return cast("T", await _safe_get(client, command, default))
|
|
||||||
|
|
||||||
|
|
||||||
def _is_not_found_error(exc: Exception) -> bool:
|
def _is_not_found_error(exc: Exception) -> bool:
|
||||||
"""Return ``True`` if *exc* signals an unknown jail."""
|
"""Return ``True`` if *exc* signals an unknown jail."""
|
||||||
msg = str(exc).lower()
|
msg = str(exc).lower()
|
||||||
@@ -192,25 +192,47 @@ async def get_jail_config(socket_path: str, name: str) -> JailConfigResponse:
|
|||||||
raise JailNotFoundError(name) from exc
|
raise JailNotFoundError(name) from exc
|
||||||
raise
|
raise
|
||||||
|
|
||||||
bantime_raw: int = await _safe_get_typed(client, ["get", name, "bantime"], 600)
|
(
|
||||||
findtime_raw: int = await _safe_get_typed(client, ["get", name, "findtime"], 600)
|
bantime_raw,
|
||||||
maxretry_raw: int = await _safe_get_typed(client, ["get", name, "maxretry"], 5)
|
findtime_raw,
|
||||||
failregex_raw: list[object] = await _safe_get_typed(client, ["get", name, "failregex"], [])
|
maxretry_raw,
|
||||||
ignoreregex_raw: list[object] = await _safe_get_typed(client, ["get", name, "ignoreregex"], [])
|
failregex_raw,
|
||||||
logpath_raw: list[object] = await _safe_get_typed(client, ["get", name, "logpath"], [])
|
ignoreregex_raw,
|
||||||
datepattern_raw: str | None = await _safe_get_typed(client, ["get", name, "datepattern"], None)
|
logpath_raw,
|
||||||
logencoding_raw: str = await _safe_get_typed(client, ["get", name, "logencoding"], "UTF-8")
|
datepattern_raw,
|
||||||
backend_raw: str = await _safe_get_typed(client, ["get", name, "backend"], "polling")
|
logencoding_raw,
|
||||||
usedns_raw: str = await _safe_get_typed(client, ["get", name, "usedns"], "warn")
|
backend_raw,
|
||||||
prefregex_raw: str = await _safe_get_typed(client, ["get", name, "prefregex"], "")
|
usedns_raw,
|
||||||
actions_raw: list[object] = await _safe_get_typed(client, ["get", name, "actions"], [])
|
prefregex_raw,
|
||||||
bt_increment_raw: bool = await _safe_get_typed(client, ["get", name, "bantime.increment"], False)
|
actions_raw,
|
||||||
bt_factor_raw: str | float | None = await _safe_get_typed(client, ["get", name, "bantime.factor"], None)
|
bt_increment_raw,
|
||||||
bt_formula_raw: str | None = await _safe_get_typed(client, ["get", name, "bantime.formula"], None)
|
bt_factor_raw,
|
||||||
bt_multipliers_raw: str | None = await _safe_get_typed(client, ["get", name, "bantime.multipliers"], None)
|
bt_formula_raw,
|
||||||
bt_maxtime_raw: str | int | None = await _safe_get_typed(client, ["get", name, "bantime.maxtime"], None)
|
bt_multipliers_raw,
|
||||||
bt_rndtime_raw: str | int | None = await _safe_get_typed(client, ["get", name, "bantime.rndtime"], None)
|
bt_maxtime_raw,
|
||||||
bt_overalljails_raw: bool = await _safe_get_typed(client, ["get", name, "bantime.overalljails"], False)
|
bt_rndtime_raw,
|
||||||
|
bt_overalljails_raw,
|
||||||
|
) = await asyncio.gather(
|
||||||
|
_safe_get(client, ["get", name, "bantime"], 600),
|
||||||
|
_safe_get(client, ["get", name, "findtime"], 600),
|
||||||
|
_safe_get(client, ["get", name, "maxretry"], 5),
|
||||||
|
_safe_get(client, ["get", name, "failregex"], []),
|
||||||
|
_safe_get(client, ["get", name, "ignoreregex"], []),
|
||||||
|
_safe_get(client, ["get", name, "logpath"], []),
|
||||||
|
_safe_get(client, ["get", name, "datepattern"], None),
|
||||||
|
_safe_get(client, ["get", name, "logencoding"], "UTF-8"),
|
||||||
|
_safe_get(client, ["get", name, "backend"], "polling"),
|
||||||
|
_safe_get(client, ["get", name, "usedns"], "warn"),
|
||||||
|
_safe_get(client, ["get", name, "prefregex"], ""),
|
||||||
|
_safe_get(client, ["get", name, "actions"], []),
|
||||||
|
_safe_get(client, ["get", name, "bantime.increment"], False),
|
||||||
|
_safe_get(client, ["get", name, "bantime.factor"], None),
|
||||||
|
_safe_get(client, ["get", name, "bantime.formula"], None),
|
||||||
|
_safe_get(client, ["get", name, "bantime.multipliers"], None),
|
||||||
|
_safe_get(client, ["get", name, "bantime.maxtime"], None),
|
||||||
|
_safe_get(client, ["get", name, "bantime.rndtime"], None),
|
||||||
|
_safe_get(client, ["get", name, "bantime.overalljails"], False),
|
||||||
|
)
|
||||||
|
|
||||||
bantime_escalation = BantimeEscalation(
|
bantime_escalation = BantimeEscalation(
|
||||||
increment=bool(bt_increment_raw),
|
increment=bool(bt_increment_raw),
|
||||||
@@ -330,7 +352,7 @@ async def update_jail_config(
|
|||||||
raise JailNotFoundError(name) from exc
|
raise JailNotFoundError(name) from exc
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def _set(key: str, value: Fail2BanToken) -> None:
|
async def _set(key: str, value: Any) -> None:
|
||||||
try:
|
try:
|
||||||
_ok(await client.send(["set", name, key, value]))
|
_ok(await client.send(["set", name, key, value]))
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
@@ -346,8 +368,9 @@ async def update_jail_config(
|
|||||||
await _set("datepattern", update.date_pattern)
|
await _set("datepattern", update.date_pattern)
|
||||||
if update.dns_mode is not None:
|
if update.dns_mode is not None:
|
||||||
await _set("usedns", update.dns_mode)
|
await _set("usedns", update.dns_mode)
|
||||||
if update.backend is not None:
|
# Fail2ban does not support changing the log monitoring backend at runtime.
|
||||||
await _set("backend", update.backend)
|
# The configuration value is retained for read/display purposes but must not
|
||||||
|
# be applied via the socket API.
|
||||||
if update.log_encoding is not None:
|
if update.log_encoding is not None:
|
||||||
await _set("logencoding", update.log_encoding)
|
await _set("logencoding", update.log_encoding)
|
||||||
if update.prefregex is not None:
|
if update.prefregex is not None:
|
||||||
@@ -400,7 +423,7 @@ async def _replace_regex_list(
|
|||||||
new_patterns: Replacement list (may be empty to clear).
|
new_patterns: Replacement list (may be empty to clear).
|
||||||
"""
|
"""
|
||||||
# Determine current count.
|
# Determine current count.
|
||||||
current_raw: list[object] = await _safe_get_typed(client, ["get", jail, field], [])
|
current_raw = await _safe_get(client, ["get", jail, field], [])
|
||||||
current: list[str] = _ensure_list(current_raw)
|
current: list[str] = _ensure_list(current_raw)
|
||||||
|
|
||||||
del_cmd = f"del{field}"
|
del_cmd = f"del{field}"
|
||||||
@@ -447,10 +470,10 @@ async def get_global_config(socket_path: str) -> GlobalConfigResponse:
|
|||||||
db_purge_age_raw,
|
db_purge_age_raw,
|
||||||
db_max_matches_raw,
|
db_max_matches_raw,
|
||||||
) = await asyncio.gather(
|
) = await asyncio.gather(
|
||||||
_safe_get_typed(client, ["get", "loglevel"], "INFO"),
|
_safe_get(client, ["get", "loglevel"], "INFO"),
|
||||||
_safe_get_typed(client, ["get", "logtarget"], "STDOUT"),
|
_safe_get(client, ["get", "logtarget"], "STDOUT"),
|
||||||
_safe_get_typed(client, ["get", "dbpurgeage"], 86400),
|
_safe_get(client, ["get", "dbpurgeage"], 86400),
|
||||||
_safe_get_typed(client, ["get", "dbmaxmatches"], 10),
|
_safe_get(client, ["get", "dbmaxmatches"], 10),
|
||||||
)
|
)
|
||||||
|
|
||||||
return GlobalConfigResponse(
|
return GlobalConfigResponse(
|
||||||
@@ -474,7 +497,7 @@ async def update_global_config(socket_path: str, update: GlobalConfigUpdate) ->
|
|||||||
"""
|
"""
|
||||||
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
||||||
|
|
||||||
async def _set_global(key: str, value: Fail2BanToken) -> None:
|
async def _set_global(key: str, value: Any) -> None:
|
||||||
try:
|
try:
|
||||||
_ok(await client.send(["set", key, value]))
|
_ok(await client.send(["set", key, value]))
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
@@ -498,8 +521,27 @@ async def update_global_config(socket_path: str, update: GlobalConfigUpdate) ->
|
|||||||
|
|
||||||
|
|
||||||
def test_regex(request: RegexTestRequest) -> RegexTestResponse:
|
def test_regex(request: RegexTestRequest) -> RegexTestResponse:
|
||||||
"""Proxy to log utilities for regex test without service imports."""
|
"""Test a regex pattern against a sample log line.
|
||||||
return util_test_regex(request)
|
|
||||||
|
This is a pure in-process operation — no socket communication occurs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: The :class:`~app.models.config.RegexTestRequest` payload.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:class:`~app.models.config.RegexTestResponse` with match result.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
compiled = re.compile(request.fail_regex)
|
||||||
|
except re.error as exc:
|
||||||
|
return RegexTestResponse(matched=False, groups=[], error=str(exc))
|
||||||
|
|
||||||
|
match = compiled.search(request.log_line)
|
||||||
|
if match is None:
|
||||||
|
return RegexTestResponse(matched=False)
|
||||||
|
|
||||||
|
groups: list[str] = list(match.groups() or [])
|
||||||
|
return RegexTestResponse(matched=True, groups=[str(g) for g in groups if g is not None])
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -577,14 +619,101 @@ async def delete_log_path(
|
|||||||
raise ConfigOperationError(f"Failed to delete log path {log_path!r}: {exc}") from exc
|
raise ConfigOperationError(f"Failed to delete log path {log_path!r}: {exc}") from exc
|
||||||
|
|
||||||
|
|
||||||
async def preview_log(
|
async def preview_log(req: LogPreviewRequest) -> LogPreviewResponse:
|
||||||
req: LogPreviewRequest,
|
"""Read the last *num_lines* of a log file and test *fail_regex* against each.
|
||||||
preview_fn: Callable[[LogPreviewRequest], Awaitable[LogPreviewResponse]] | None = None,
|
|
||||||
) -> LogPreviewResponse:
|
This operation reads from the local filesystem — no socket is used.
|
||||||
"""Proxy to an injectable log preview function."""
|
|
||||||
if preview_fn is None:
|
Args:
|
||||||
preview_fn = util_preview_log
|
req: :class:`~app.models.config.LogPreviewRequest`.
|
||||||
return await preview_fn(req)
|
|
||||||
|
Returns:
|
||||||
|
:class:`~app.models.config.LogPreviewResponse` with line-by-line results.
|
||||||
|
"""
|
||||||
|
# Validate the regex first.
|
||||||
|
try:
|
||||||
|
compiled = re.compile(req.fail_regex)
|
||||||
|
except re.error as exc:
|
||||||
|
return LogPreviewResponse(
|
||||||
|
lines=[],
|
||||||
|
total_lines=0,
|
||||||
|
matched_count=0,
|
||||||
|
regex_error=str(exc),
|
||||||
|
)
|
||||||
|
|
||||||
|
path = Path(req.log_path)
|
||||||
|
if not path.is_file():
|
||||||
|
return LogPreviewResponse(
|
||||||
|
lines=[],
|
||||||
|
total_lines=0,
|
||||||
|
matched_count=0,
|
||||||
|
regex_error=f"File not found: {req.log_path!r}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Read the last num_lines lines efficiently.
|
||||||
|
try:
|
||||||
|
raw_lines = await asyncio.get_event_loop().run_in_executor(
|
||||||
|
None,
|
||||||
|
_read_tail_lines,
|
||||||
|
str(path),
|
||||||
|
req.num_lines,
|
||||||
|
)
|
||||||
|
except OSError as exc:
|
||||||
|
return LogPreviewResponse(
|
||||||
|
lines=[],
|
||||||
|
total_lines=0,
|
||||||
|
matched_count=0,
|
||||||
|
regex_error=f"Cannot read file: {exc}",
|
||||||
|
)
|
||||||
|
|
||||||
|
result_lines: list[LogPreviewLine] = []
|
||||||
|
matched_count = 0
|
||||||
|
for line in raw_lines:
|
||||||
|
m = compiled.search(line)
|
||||||
|
groups = [str(g) for g in (m.groups() or []) if g is not None] if m else []
|
||||||
|
result_lines.append(LogPreviewLine(line=line, matched=(m is not None), groups=groups))
|
||||||
|
if m:
|
||||||
|
matched_count += 1
|
||||||
|
|
||||||
|
return LogPreviewResponse(
|
||||||
|
lines=result_lines,
|
||||||
|
total_lines=len(result_lines),
|
||||||
|
matched_count=matched_count,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _read_tail_lines(file_path: str, num_lines: int) -> list[str]:
|
||||||
|
"""Read the last *num_lines* from *file_path* synchronously.
|
||||||
|
|
||||||
|
Uses a memory-efficient approach that seeks from the end of the file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Absolute path to the log file.
|
||||||
|
num_lines: Number of lines to return.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of stripped line strings.
|
||||||
|
"""
|
||||||
|
chunk_size = 8192
|
||||||
|
raw_lines: list[bytes] = []
|
||||||
|
with open(file_path, "rb") as fh:
|
||||||
|
fh.seek(0, 2) # seek to end
|
||||||
|
end_pos = fh.tell()
|
||||||
|
if end_pos == 0:
|
||||||
|
return []
|
||||||
|
buf = b""
|
||||||
|
pos = end_pos
|
||||||
|
while len(raw_lines) <= num_lines and pos > 0:
|
||||||
|
read_size = min(chunk_size, pos)
|
||||||
|
pos -= read_size
|
||||||
|
fh.seek(pos)
|
||||||
|
chunk = fh.read(read_size)
|
||||||
|
buf = chunk + buf
|
||||||
|
raw_lines = buf.split(b"\n")
|
||||||
|
# Strip incomplete leading line unless we've read the whole file.
|
||||||
|
if pos > 0 and len(raw_lines) > 1:
|
||||||
|
raw_lines = raw_lines[1:]
|
||||||
|
return [ln.decode("utf-8", errors="replace").rstrip() for ln in raw_lines[-num_lines:] if ln.strip()]
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -601,7 +730,7 @@ async def get_map_color_thresholds(db: aiosqlite.Connection) -> MapColorThreshol
|
|||||||
Returns:
|
Returns:
|
||||||
A :class:`MapColorThresholdsResponse` containing the three threshold values.
|
A :class:`MapColorThresholdsResponse` containing the three threshold values.
|
||||||
"""
|
"""
|
||||||
high, medium, low = await util_get_map_color_thresholds(db)
|
high, medium, low = await setup_service.get_map_color_thresholds(db)
|
||||||
return MapColorThresholdsResponse(
|
return MapColorThresholdsResponse(
|
||||||
threshold_high=high,
|
threshold_high=high,
|
||||||
threshold_medium=medium,
|
threshold_medium=medium,
|
||||||
@@ -622,7 +751,7 @@ async def update_map_color_thresholds(
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If validation fails (thresholds must satisfy high > medium > low).
|
ValueError: If validation fails (thresholds must satisfy high > medium > low).
|
||||||
"""
|
"""
|
||||||
await util_set_map_color_thresholds(
|
await setup_service.set_map_color_thresholds(
|
||||||
db,
|
db,
|
||||||
threshold_high=update.threshold_high,
|
threshold_high=update.threshold_high,
|
||||||
threshold_medium=update.threshold_medium,
|
threshold_medium=update.threshold_medium,
|
||||||
@@ -644,7 +773,16 @@ _SAFE_LOG_PREFIXES: tuple[str, ...] = ("/var/log", "/config/log")
|
|||||||
|
|
||||||
|
|
||||||
def _count_file_lines(file_path: str) -> int:
|
def _count_file_lines(file_path: str) -> int:
|
||||||
"""Count the total number of lines in *file_path* synchronously."""
|
"""Count the total number of lines in *file_path* synchronously.
|
||||||
|
|
||||||
|
Uses a memory-efficient buffered read to avoid loading the whole file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Absolute path to the file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Total number of lines in the file.
|
||||||
|
"""
|
||||||
count = 0
|
count = 0
|
||||||
with open(file_path, "rb") as fh:
|
with open(file_path, "rb") as fh:
|
||||||
for chunk in iter(lambda: fh.read(65536), b""):
|
for chunk in iter(lambda: fh.read(65536), b""):
|
||||||
@@ -652,32 +790,6 @@ def _count_file_lines(file_path: str) -> int:
|
|||||||
return count
|
return count
|
||||||
|
|
||||||
|
|
||||||
def _read_tail_lines(file_path: str, num_lines: int) -> list[str]:
|
|
||||||
"""Read the last *num_lines* from *file_path* in a memory-efficient way."""
|
|
||||||
chunk_size = 8192
|
|
||||||
raw_lines: list[bytes] = []
|
|
||||||
with open(file_path, "rb") as fh:
|
|
||||||
fh.seek(0, 2)
|
|
||||||
end_pos = fh.tell()
|
|
||||||
if end_pos == 0:
|
|
||||||
return []
|
|
||||||
|
|
||||||
buf = b""
|
|
||||||
pos = end_pos
|
|
||||||
while len(raw_lines) <= num_lines and pos > 0:
|
|
||||||
read_size = min(chunk_size, pos)
|
|
||||||
pos -= read_size
|
|
||||||
fh.seek(pos)
|
|
||||||
chunk = fh.read(read_size)
|
|
||||||
buf = chunk + buf
|
|
||||||
raw_lines = buf.split(b"\n")
|
|
||||||
|
|
||||||
if pos > 0 and len(raw_lines) > 1:
|
|
||||||
raw_lines = raw_lines[1:]
|
|
||||||
|
|
||||||
return [ln.decode("utf-8", errors="replace").rstrip() for ln in raw_lines[-num_lines:] if ln.strip()]
|
|
||||||
|
|
||||||
|
|
||||||
async def read_fail2ban_log(
|
async def read_fail2ban_log(
|
||||||
socket_path: str,
|
socket_path: str,
|
||||||
lines: int,
|
lines: int,
|
||||||
@@ -710,8 +822,8 @@ async def read_fail2ban_log(
|
|||||||
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
||||||
|
|
||||||
log_level_raw, log_target_raw = await asyncio.gather(
|
log_level_raw, log_target_raw = await asyncio.gather(
|
||||||
_safe_get_typed(client, ["get", "loglevel"], "INFO"),
|
_safe_get(client, ["get", "loglevel"], "INFO"),
|
||||||
_safe_get_typed(client, ["get", "logtarget"], "STDOUT"),
|
_safe_get(client, ["get", "logtarget"], "STDOUT"),
|
||||||
)
|
)
|
||||||
|
|
||||||
log_level = str(log_level_raw or "INFO").upper()
|
log_level = str(log_level_raw or "INFO").upper()
|
||||||
@@ -772,33 +884,28 @@ async def read_fail2ban_log(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def get_service_status(
|
async def get_service_status(socket_path: str) -> ServiceStatusResponse:
|
||||||
socket_path: str,
|
|
||||||
probe_fn: Callable[[str], Awaitable[ServiceStatusResponse]] | None = None,
|
|
||||||
) -> ServiceStatusResponse:
|
|
||||||
"""Return fail2ban service health status with log configuration.
|
"""Return fail2ban service health status with log configuration.
|
||||||
|
|
||||||
Delegates to an injectable *probe_fn* (defaults to
|
Delegates to :func:`~app.services.health_service.probe` for the core
|
||||||
:func:`~app.services.health_service.probe`). This avoids direct service-to-
|
health snapshot and augments it with the current log-level and log-target
|
||||||
service imports inside this module.
|
values from the socket.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
socket_path: Path to the fail2ban Unix domain socket.
|
socket_path: Path to the fail2ban Unix domain socket.
|
||||||
probe_fn: Optional probe function.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
:class:`~app.models.config.ServiceStatusResponse`.
|
:class:`~app.models.config.ServiceStatusResponse`.
|
||||||
"""
|
"""
|
||||||
if probe_fn is None:
|
from app.services.health_service import probe # lazy import avoids circular dep
|
||||||
raise ValueError("probe_fn is required to avoid service-to-service coupling")
|
|
||||||
|
|
||||||
server_status = await probe_fn(socket_path)
|
server_status = await probe(socket_path)
|
||||||
|
|
||||||
if server_status.online:
|
if server_status.online:
|
||||||
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
||||||
log_level_raw, log_target_raw = await asyncio.gather(
|
log_level_raw, log_target_raw = await asyncio.gather(
|
||||||
_safe_get_typed(client, ["get", "loglevel"], "INFO"),
|
_safe_get(client, ["get", "loglevel"], "INFO"),
|
||||||
_safe_get_typed(client, ["get", "logtarget"], "STDOUT"),
|
_safe_get(client, ["get", "logtarget"], "STDOUT"),
|
||||||
)
|
)
|
||||||
log_level = str(log_level_raw or "INFO").upper()
|
log_level = str(log_level_raw or "INFO").upper()
|
||||||
log_target = str(log_target_raw or "STDOUT")
|
log_target = str(log_target_raw or "STDOUT")
|
||||||
|
|||||||
@@ -817,7 +817,7 @@ async def get_parsed_filter_file(config_dir: str, name: str) -> FilterConfig:
|
|||||||
"""Parse a filter definition file and return its structured representation.
|
"""Parse a filter definition file and return its structured representation.
|
||||||
|
|
||||||
Reads the raw ``.conf``/``.local`` file from ``filter.d/``, parses it with
|
Reads the raw ``.conf``/``.local`` file from ``filter.d/``, parses it with
|
||||||
:func:`~app.utils.conffile_parser.parse_filter_file`, and returns the
|
:func:`~app.services.conffile_parser.parse_filter_file`, and returns the
|
||||||
result.
|
result.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -831,7 +831,7 @@ async def get_parsed_filter_file(config_dir: str, name: str) -> FilterConfig:
|
|||||||
ConfigFileNotFoundError: If no matching file is found.
|
ConfigFileNotFoundError: If no matching file is found.
|
||||||
ConfigDirError: If *config_dir* does not exist.
|
ConfigDirError: If *config_dir* does not exist.
|
||||||
"""
|
"""
|
||||||
from app.utils.conffile_parser import parse_filter_file # avoid circular imports
|
from app.services.conffile_parser import parse_filter_file # avoid circular imports
|
||||||
|
|
||||||
def _do() -> FilterConfig:
|
def _do() -> FilterConfig:
|
||||||
filter_d = _resolve_subdir(config_dir, "filter.d")
|
filter_d = _resolve_subdir(config_dir, "filter.d")
|
||||||
@@ -863,7 +863,7 @@ async def update_parsed_filter_file(
|
|||||||
ConfigFileWriteError: If the file cannot be written.
|
ConfigFileWriteError: If the file cannot be written.
|
||||||
ConfigDirError: If *config_dir* does not exist.
|
ConfigDirError: If *config_dir* does not exist.
|
||||||
"""
|
"""
|
||||||
from app.utils.conffile_parser import ( # avoid circular imports
|
from app.services.conffile_parser import ( # avoid circular imports
|
||||||
merge_filter_update,
|
merge_filter_update,
|
||||||
parse_filter_file,
|
parse_filter_file,
|
||||||
serialize_filter_config,
|
serialize_filter_config,
|
||||||
@@ -901,7 +901,7 @@ async def get_parsed_action_file(config_dir: str, name: str) -> ActionConfig:
|
|||||||
ConfigFileNotFoundError: If no matching file is found.
|
ConfigFileNotFoundError: If no matching file is found.
|
||||||
ConfigDirError: If *config_dir* does not exist.
|
ConfigDirError: If *config_dir* does not exist.
|
||||||
"""
|
"""
|
||||||
from app.utils.conffile_parser import parse_action_file # avoid circular imports
|
from app.services.conffile_parser import parse_action_file # avoid circular imports
|
||||||
|
|
||||||
def _do() -> ActionConfig:
|
def _do() -> ActionConfig:
|
||||||
action_d = _resolve_subdir(config_dir, "action.d")
|
action_d = _resolve_subdir(config_dir, "action.d")
|
||||||
@@ -930,7 +930,7 @@ async def update_parsed_action_file(
|
|||||||
ConfigFileWriteError: If the file cannot be written.
|
ConfigFileWriteError: If the file cannot be written.
|
||||||
ConfigDirError: If *config_dir* does not exist.
|
ConfigDirError: If *config_dir* does not exist.
|
||||||
"""
|
"""
|
||||||
from app.utils.conffile_parser import ( # avoid circular imports
|
from app.services.conffile_parser import ( # avoid circular imports
|
||||||
merge_action_update,
|
merge_action_update,
|
||||||
parse_action_file,
|
parse_action_file,
|
||||||
serialize_action_config,
|
serialize_action_config,
|
||||||
@@ -963,7 +963,7 @@ async def get_parsed_jail_file(config_dir: str, filename: str) -> JailFileConfig
|
|||||||
ConfigFileNotFoundError: If no matching file is found.
|
ConfigFileNotFoundError: If no matching file is found.
|
||||||
ConfigDirError: If *config_dir* does not exist.
|
ConfigDirError: If *config_dir* does not exist.
|
||||||
"""
|
"""
|
||||||
from app.utils.conffile_parser import parse_jail_file # avoid circular imports
|
from app.services.conffile_parser import parse_jail_file # avoid circular imports
|
||||||
|
|
||||||
def _do() -> JailFileConfig:
|
def _do() -> JailFileConfig:
|
||||||
jail_d = _resolve_subdir(config_dir, "jail.d")
|
jail_d = _resolve_subdir(config_dir, "jail.d")
|
||||||
@@ -992,7 +992,7 @@ async def update_parsed_jail_file(
|
|||||||
ConfigFileWriteError: If the file cannot be written.
|
ConfigFileWriteError: If the file cannot be written.
|
||||||
ConfigDirError: If *config_dir* does not exist.
|
ConfigDirError: If *config_dir* does not exist.
|
||||||
"""
|
"""
|
||||||
from app.utils.conffile_parser import ( # avoid circular imports
|
from app.services.conffile_parser import ( # avoid circular imports
|
||||||
merge_jail_file_update,
|
merge_jail_file_update,
|
||||||
parse_jail_file,
|
parse_jail_file,
|
||||||
serialize_jail_file_config,
|
serialize_jail_file_config,
|
||||||
@@ -1,920 +0,0 @@
|
|||||||
"""Filter configuration management for BanGUI.
|
|
||||||
|
|
||||||
Handles parsing, validation, and lifecycle operations (create/update/delete)
|
|
||||||
for fail2ban filter configurations.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import configparser
|
|
||||||
import contextlib
|
|
||||||
import io
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import tempfile
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import structlog
|
|
||||||
|
|
||||||
from app.models.config import (
|
|
||||||
FilterConfig,
|
|
||||||
FilterConfigUpdate,
|
|
||||||
FilterCreateRequest,
|
|
||||||
FilterListResponse,
|
|
||||||
FilterUpdateRequest,
|
|
||||||
AssignFilterRequest,
|
|
||||||
)
|
|
||||||
from app.exceptions import FilterInvalidRegexError, JailNotFoundError
|
|
||||||
from app.utils import conffile_parser
|
|
||||||
from app.utils.jail_utils import reload_jails
|
|
||||||
|
|
||||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Custom exceptions
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class FilterNotFoundError(Exception):
|
|
||||||
"""Raised when the requested filter name is not found in ``filter.d/``."""
|
|
||||||
|
|
||||||
def __init__(self, name: str) -> None:
|
|
||||||
"""Initialise with the filter name that was not found.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: The filter name that could not be located.
|
|
||||||
"""
|
|
||||||
self.name: str = name
|
|
||||||
super().__init__(f"Filter not found: {name!r}")
|
|
||||||
|
|
||||||
|
|
||||||
class FilterAlreadyExistsError(Exception):
|
|
||||||
"""Raised when trying to create a filter whose ``.conf`` or ``.local`` already exists."""
|
|
||||||
|
|
||||||
def __init__(self, name: str) -> None:
|
|
||||||
"""Initialise with the filter name that already exists.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: The filter name that already exists.
|
|
||||||
"""
|
|
||||||
self.name: str = name
|
|
||||||
super().__init__(f"Filter already exists: {name!r}")
|
|
||||||
|
|
||||||
|
|
||||||
class FilterReadonlyError(Exception):
|
|
||||||
"""Raised when trying to delete a shipped ``.conf`` filter with no ``.local`` override."""
|
|
||||||
|
|
||||||
def __init__(self, name: str) -> None:
|
|
||||||
"""Initialise with the filter name that cannot be deleted.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: The filter name that is read-only (shipped ``.conf`` only).
|
|
||||||
"""
|
|
||||||
self.name: str = name
|
|
||||||
super().__init__(
|
|
||||||
f"Filter {name!r} is a shipped default (.conf only); only user-created .local files can be deleted."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FilterNameError(Exception):
|
|
||||||
"""Raised when a filter name contains invalid characters."""
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Additional helper functions for this service
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class JailNameError(Exception):
|
|
||||||
"""Raised when a jail name contains invalid characters."""
|
|
||||||
|
|
||||||
|
|
||||||
_SAFE_JAIL_NAME_RE: re.Pattern[str] = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$")
|
|
||||||
|
|
||||||
|
|
||||||
def _safe_filter_name(name: str) -> str:
|
|
||||||
"""Validate *name* and return it unchanged or raise :class:`FilterNameError`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: Proposed filter name (without extension).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The name unchanged if valid.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
FilterNameError: If *name* contains unsafe characters.
|
|
||||||
"""
|
|
||||||
if not _SAFE_FILTER_NAME_RE.match(name):
|
|
||||||
raise FilterNameError(
|
|
||||||
f"Filter name {name!r} contains invalid characters. "
|
|
||||||
"Only alphanumeric characters, hyphens, underscores, and dots are "
|
|
||||||
"allowed; must start with an alphanumeric character."
|
|
||||||
)
|
|
||||||
return name
|
|
||||||
|
|
||||||
|
|
||||||
def _safe_jail_name(name: str) -> str:
|
|
||||||
"""Validate *name* and return it unchanged or raise :class:`JailNameError`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: Proposed jail name.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The name unchanged if valid.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
JailNameError: If *name* contains unsafe characters.
|
|
||||||
"""
|
|
||||||
if not _SAFE_JAIL_NAME_RE.match(name):
|
|
||||||
raise JailNameError(
|
|
||||||
f"Jail name {name!r} contains invalid characters. "
|
|
||||||
"Only alphanumeric characters, hyphens, underscores, and dots are "
|
|
||||||
"allowed; must start with an alphanumeric character."
|
|
||||||
)
|
|
||||||
return name
|
|
||||||
|
|
||||||
|
|
||||||
def _build_parser() -> configparser.RawConfigParser:
|
|
||||||
"""Create a :class:`configparser.RawConfigParser` for fail2ban configs.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Parser with interpolation disabled and case-sensitive option names.
|
|
||||||
"""
|
|
||||||
parser = configparser.RawConfigParser(interpolation=None, strict=False)
|
|
||||||
# fail2ban keys are lowercase but preserve case to be safe.
|
|
||||||
parser.optionxform = str # type: ignore[assignment]
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
def _is_truthy(value: str) -> bool:
|
|
||||||
"""Return ``True`` if *value* is a fail2ban boolean true string.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
value: Raw string from config (e.g. ``"true"``, ``"yes"``, ``"1"``).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
``True`` when the value represents enabled.
|
|
||||||
"""
|
|
||||||
return value.strip().lower() in _TRUE_VALUES
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_multiline(raw: str) -> list[str]:
|
|
||||||
"""Split a multi-line INI value into individual non-blank lines.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
raw: Raw multi-line string from configparser.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of stripped, non-empty, non-comment strings.
|
|
||||||
"""
|
|
||||||
result: list[str] = []
|
|
||||||
for line in raw.splitlines():
|
|
||||||
stripped = line.strip()
|
|
||||||
if stripped and not stripped.startswith("#"):
|
|
||||||
result.append(stripped)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_filter(raw_filter: str, jail_name: str, mode: str) -> str:
|
|
||||||
"""Resolve fail2ban variable placeholders in a filter string.
|
|
||||||
|
|
||||||
Handles the common default ``%(__name__)s[mode=%(mode)s]`` pattern that
|
|
||||||
fail2ban uses so the filter name displayed to the user is readable.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
raw_filter: Raw ``filter`` value from config (may contain ``%()s``).
|
|
||||||
jail_name: The jail's section name, used to substitute ``%(__name__)s``.
|
|
||||||
mode: The jail's ``mode`` value, used to substitute ``%(mode)s``.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Human-readable filter string.
|
|
||||||
"""
|
|
||||||
result = raw_filter.replace("%(__name__)s", jail_name)
|
|
||||||
result = result.replace("%(mode)s", mode)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Internal helpers - from config_file_service for local use
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _set_jail_local_key_sync(
|
|
||||||
config_dir: Path,
|
|
||||||
jail_name: str,
|
|
||||||
key: str,
|
|
||||||
value: str,
|
|
||||||
) -> None:
|
|
||||||
"""Update ``jail.d/{jail_name}.local`` to set a single key in the jail section.
|
|
||||||
|
|
||||||
If the ``.local`` file already exists it is read, the key is updated (or
|
|
||||||
added), and the file is written back atomically without disturbing other
|
|
||||||
settings. If the file does not exist a new one is created containing
|
|
||||||
only the BanGUI header comment, the jail section, and the requested key.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config_dir: The fail2ban configuration root directory.
|
|
||||||
jail_name: Validated jail name (used as section name and filename stem).
|
|
||||||
key: Config key to set inside the jail section.
|
|
||||||
value: Config value to assign.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ConfigWriteError: If writing fails.
|
|
||||||
"""
|
|
||||||
jail_d = config_dir / "jail.d"
|
|
||||||
try:
|
|
||||||
jail_d.mkdir(parents=True, exist_ok=True)
|
|
||||||
except OSError as exc:
|
|
||||||
raise ConfigWriteError(f"Cannot create jail.d directory: {exc}") from exc
|
|
||||||
|
|
||||||
local_path = jail_d / f"{jail_name}.local"
|
|
||||||
|
|
||||||
parser = _build_parser()
|
|
||||||
if local_path.is_file():
|
|
||||||
try:
|
|
||||||
parser.read(str(local_path), encoding="utf-8")
|
|
||||||
except (configparser.Error, OSError) as exc:
|
|
||||||
log.warning(
|
|
||||||
"jail_local_read_for_update_error",
|
|
||||||
jail=jail_name,
|
|
||||||
error=str(exc),
|
|
||||||
)
|
|
||||||
|
|
||||||
if not parser.has_section(jail_name):
|
|
||||||
parser.add_section(jail_name)
|
|
||||||
parser.set(jail_name, key, value)
|
|
||||||
|
|
||||||
# Serialize: write a BanGUI header then the parser output.
|
|
||||||
buf = io.StringIO()
|
|
||||||
buf.write("# Managed by BanGUI — do not edit manually\n\n")
|
|
||||||
parser.write(buf)
|
|
||||||
content = buf.getvalue()
|
|
||||||
|
|
||||||
try:
|
|
||||||
with tempfile.NamedTemporaryFile(
|
|
||||||
mode="w",
|
|
||||||
encoding="utf-8",
|
|
||||||
dir=jail_d,
|
|
||||||
delete=False,
|
|
||||||
suffix=".tmp",
|
|
||||||
) as tmp:
|
|
||||||
tmp.write(content)
|
|
||||||
tmp_name = tmp.name
|
|
||||||
os.replace(tmp_name, local_path)
|
|
||||||
except OSError as exc:
|
|
||||||
with contextlib.suppress(OSError):
|
|
||||||
os.unlink(tmp_name) # noqa: F821
|
|
||||||
raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc
|
|
||||||
|
|
||||||
log.info(
|
|
||||||
"jail_local_key_set",
|
|
||||||
jail=jail_name,
|
|
||||||
key=key,
|
|
||||||
path=str(local_path),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_filter_base_name(filter_raw: str) -> str:
|
|
||||||
"""Extract the base filter name from a raw fail2ban filter string.
|
|
||||||
|
|
||||||
fail2ban jail configs may specify a filter with an optional mode suffix,
|
|
||||||
e.g. ``sshd``, ``sshd[mode=aggressive]``, or
|
|
||||||
``%(__name__)s[mode=%(mode)s]``. This function strips the ``[…]`` mode
|
|
||||||
block and any leading/trailing whitespace to return just the file-system
|
|
||||||
base name used to look up ``filter.d/{name}.conf``.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
filter_raw: Raw ``filter`` value from a jail config (already
|
|
||||||
with ``%(__name__)s`` substituted by the caller).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Base filter name, e.g. ``"sshd"``.
|
|
||||||
"""
|
|
||||||
bracket = filter_raw.find("[")
|
|
||||||
if bracket != -1:
|
|
||||||
return filter_raw[:bracket].strip()
|
|
||||||
return filter_raw.strip()
|
|
||||||
|
|
||||||
|
|
||||||
def _build_filter_to_jails_map(
|
|
||||||
all_jails: dict[str, dict[str, str]],
|
|
||||||
active_names: set[str],
|
|
||||||
) -> dict[str, list[str]]:
|
|
||||||
"""Return a mapping of filter base name → list of active jail names.
|
|
||||||
|
|
||||||
Iterates over every jail whose name is in *active_names*, resolves its
|
|
||||||
``filter`` config key, and records the jail against the base filter name.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
all_jails: Merged jail config dict — ``{jail_name: {key: value}}``.
|
|
||||||
active_names: Set of jail names currently running in fail2ban.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
``{filter_base_name: [jail_name, …]}``.
|
|
||||||
"""
|
|
||||||
mapping: dict[str, list[str]] = {}
|
|
||||||
for jail_name, settings in all_jails.items():
|
|
||||||
if jail_name not in active_names:
|
|
||||||
continue
|
|
||||||
raw_filter = settings.get("filter", "")
|
|
||||||
mode = settings.get("mode", "normal")
|
|
||||||
resolved = _resolve_filter(raw_filter, jail_name, mode) if raw_filter else jail_name
|
|
||||||
base = _extract_filter_base_name(resolved)
|
|
||||||
if base:
|
|
||||||
mapping.setdefault(base, []).append(jail_name)
|
|
||||||
return mapping
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_filters_sync(
|
|
||||||
filter_d: Path,
|
|
||||||
) -> list[tuple[str, str, str, bool, str]]:
|
|
||||||
"""Synchronously scan ``filter.d/`` and return per-filter tuples.
|
|
||||||
|
|
||||||
Each tuple contains:
|
|
||||||
|
|
||||||
- ``name`` — filter base name (``"sshd"``).
|
|
||||||
- ``filename`` — actual filename (``"sshd.conf"`` or ``"sshd.local"``).
|
|
||||||
- ``content`` — merged file content (``conf`` overridden by ``local``).
|
|
||||||
- ``has_local`` — whether a ``.local`` override exists alongside a ``.conf``.
|
|
||||||
- ``source_path`` — absolute path to the primary (``conf``) source file, or
|
|
||||||
to the ``.local`` file for user-created (local-only) filters.
|
|
||||||
|
|
||||||
Also discovers ``.local``-only files (user-created filters with no
|
|
||||||
corresponding ``.conf``). These are returned with ``has_local = False``
|
|
||||||
and ``source_path`` pointing to the ``.local`` file itself.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
filter_d: Path to the ``filter.d`` directory.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of ``(name, filename, content, has_local, source_path)`` tuples,
|
|
||||||
sorted by name.
|
|
||||||
"""
|
|
||||||
if not filter_d.is_dir():
|
|
||||||
log.warning("filter_d_not_found", path=str(filter_d))
|
|
||||||
return []
|
|
||||||
|
|
||||||
conf_names: set[str] = set()
|
|
||||||
results: list[tuple[str, str, str, bool, str]] = []
|
|
||||||
|
|
||||||
# ---- .conf-based filters (with optional .local override) ----------------
|
|
||||||
for conf_path in sorted(filter_d.glob("*.conf")):
|
|
||||||
if not conf_path.is_file():
|
|
||||||
continue
|
|
||||||
name = conf_path.stem
|
|
||||||
filename = conf_path.name
|
|
||||||
conf_names.add(name)
|
|
||||||
local_path = conf_path.with_suffix(".local")
|
|
||||||
has_local = local_path.is_file()
|
|
||||||
|
|
||||||
try:
|
|
||||||
content = conf_path.read_text(encoding="utf-8")
|
|
||||||
except OSError as exc:
|
|
||||||
log.warning("filter_read_error", name=name, path=str(conf_path), error=str(exc))
|
|
||||||
continue
|
|
||||||
|
|
||||||
if has_local:
|
|
||||||
try:
|
|
||||||
local_content = local_path.read_text(encoding="utf-8")
|
|
||||||
# Append local content after conf so configparser reads local
|
|
||||||
# values last (higher priority).
|
|
||||||
content = content + "\n" + local_content
|
|
||||||
except OSError as exc:
|
|
||||||
log.warning(
|
|
||||||
"filter_local_read_error",
|
|
||||||
name=name,
|
|
||||||
path=str(local_path),
|
|
||||||
error=str(exc),
|
|
||||||
)
|
|
||||||
|
|
||||||
results.append((name, filename, content, has_local, str(conf_path)))
|
|
||||||
|
|
||||||
# ---- .local-only filters (user-created, no corresponding .conf) ----------
|
|
||||||
for local_path in sorted(filter_d.glob("*.local")):
|
|
||||||
if not local_path.is_file():
|
|
||||||
continue
|
|
||||||
name = local_path.stem
|
|
||||||
if name in conf_names:
|
|
||||||
# Already covered above as a .conf filter with a .local override.
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
content = local_path.read_text(encoding="utf-8")
|
|
||||||
except OSError as exc:
|
|
||||||
log.warning(
|
|
||||||
"filter_local_read_error",
|
|
||||||
name=name,
|
|
||||||
path=str(local_path),
|
|
||||||
error=str(exc),
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
results.append((name, local_path.name, content, False, str(local_path)))
|
|
||||||
|
|
||||||
results.sort(key=lambda t: t[0])
|
|
||||||
log.debug("filters_scanned", count=len(results), filter_d=str(filter_d))
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_regex_patterns(patterns: list[str]) -> None:
|
|
||||||
"""Validate each pattern in *patterns* using Python's ``re`` module.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
patterns: List of regex strings to validate.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
FilterInvalidRegexError: If any pattern fails to compile.
|
|
||||||
"""
|
|
||||||
for pattern in patterns:
|
|
||||||
try:
|
|
||||||
re.compile(pattern)
|
|
||||||
except re.error as exc:
|
|
||||||
raise FilterInvalidRegexError(pattern, str(exc)) from exc
|
|
||||||
|
|
||||||
|
|
||||||
def _write_filter_local_sync(filter_d: Path, name: str, content: str) -> None:
|
|
||||||
"""Write *content* to ``filter.d/{name}.local`` atomically.
|
|
||||||
|
|
||||||
The write is atomic: content is written to a temp file first, then
|
|
||||||
renamed into place. The ``filter.d/`` directory is created if absent.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
filter_d: Path to the ``filter.d`` directory.
|
|
||||||
name: Validated filter base name (used as filename stem).
|
|
||||||
content: Full serialized filter content to write.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ConfigWriteError: If writing fails.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
filter_d.mkdir(parents=True, exist_ok=True)
|
|
||||||
except OSError as exc:
|
|
||||||
raise ConfigWriteError(f"Cannot create filter.d directory: {exc}") from exc
|
|
||||||
|
|
||||||
local_path = filter_d / f"{name}.local"
|
|
||||||
try:
|
|
||||||
with tempfile.NamedTemporaryFile(
|
|
||||||
mode="w",
|
|
||||||
encoding="utf-8",
|
|
||||||
dir=filter_d,
|
|
||||||
delete=False,
|
|
||||||
suffix=".tmp",
|
|
||||||
) as tmp:
|
|
||||||
tmp.write(content)
|
|
||||||
tmp_name = tmp.name
|
|
||||||
os.replace(tmp_name, local_path)
|
|
||||||
except OSError as exc:
|
|
||||||
with contextlib.suppress(OSError):
|
|
||||||
os.unlink(tmp_name) # noqa: F821
|
|
||||||
raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc
|
|
||||||
|
|
||||||
log.info("filter_local_written", filter=name, path=str(local_path))
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Public API — filter discovery
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
async def list_filters(
|
|
||||||
config_dir: str,
|
|
||||||
socket_path: str,
|
|
||||||
) -> FilterListResponse:
|
|
||||||
"""Return all available filters from ``filter.d/`` with active/inactive status.
|
|
||||||
|
|
||||||
Scans ``{config_dir}/filter.d/`` for ``.conf`` files, merges any
|
|
||||||
corresponding ``.local`` overrides, parses each file into a
|
|
||||||
:class:`~app.models.config.FilterConfig`, and cross-references with the
|
|
||||||
currently running jails to determine which filters are active.
|
|
||||||
|
|
||||||
A filter is considered *active* when its base name matches the ``filter``
|
|
||||||
field of at least one currently running jail.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config_dir: Absolute path to the fail2ban configuration directory.
|
|
||||||
socket_path: Path to the fail2ban Unix domain socket.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
:class:`~app.models.config.FilterListResponse` with all filters
|
|
||||||
sorted alphabetically, active ones carrying non-empty
|
|
||||||
``used_by_jails`` lists.
|
|
||||||
"""
|
|
||||||
filter_d = Path(config_dir) / "filter.d"
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
|
|
||||||
# Run the synchronous scan in a thread-pool executor.
|
|
||||||
raw_filters: list[tuple[str, str, str, bool, str]] = await loop.run_in_executor(None, _parse_filters_sync, filter_d)
|
|
||||||
|
|
||||||
# Fetch active jail names and their configs concurrently.
|
|
||||||
all_jails_result, active_names = await asyncio.gather(
|
|
||||||
loop.run_in_executor(None, _parse_jails_sync, Path(config_dir)),
|
|
||||||
_get_active_jail_names(socket_path),
|
|
||||||
)
|
|
||||||
all_jails, _source_files = all_jails_result
|
|
||||||
|
|
||||||
filter_to_jails = _build_filter_to_jails_map(all_jails, active_names)
|
|
||||||
|
|
||||||
filters: list[FilterConfig] = []
|
|
||||||
for name, filename, content, has_local, source_path in raw_filters:
|
|
||||||
cfg = conffile_parser.parse_filter_file(content, name=name, filename=filename)
|
|
||||||
used_by = sorted(filter_to_jails.get(name, []))
|
|
||||||
filters.append(
|
|
||||||
FilterConfig(
|
|
||||||
name=cfg.name,
|
|
||||||
filename=cfg.filename,
|
|
||||||
before=cfg.before,
|
|
||||||
after=cfg.after,
|
|
||||||
variables=cfg.variables,
|
|
||||||
prefregex=cfg.prefregex,
|
|
||||||
failregex=cfg.failregex,
|
|
||||||
ignoreregex=cfg.ignoreregex,
|
|
||||||
maxlines=cfg.maxlines,
|
|
||||||
datepattern=cfg.datepattern,
|
|
||||||
journalmatch=cfg.journalmatch,
|
|
||||||
active=len(used_by) > 0,
|
|
||||||
used_by_jails=used_by,
|
|
||||||
source_file=source_path,
|
|
||||||
has_local_override=has_local,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
log.info("filters_listed", total=len(filters), active=sum(1 for f in filters if f.active))
|
|
||||||
return FilterListResponse(filters=filters, total=len(filters))
|
|
||||||
|
|
||||||
|
|
||||||
async def get_filter(
|
|
||||||
config_dir: str,
|
|
||||||
socket_path: str,
|
|
||||||
name: str,
|
|
||||||
) -> FilterConfig:
|
|
||||||
"""Return a single filter from ``filter.d/`` with active/inactive status.
|
|
||||||
|
|
||||||
Reads ``{config_dir}/filter.d/{name}.conf``, merges any ``.local``
|
|
||||||
override, and enriches the parsed :class:`~app.models.config.FilterConfig`
|
|
||||||
with ``active``, ``used_by_jails``, ``source_file``, and
|
|
||||||
``has_local_override``.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config_dir: Absolute path to the fail2ban configuration directory.
|
|
||||||
socket_path: Path to the fail2ban Unix domain socket.
|
|
||||||
name: Filter base name (e.g. ``"sshd"`` or ``"sshd.conf"``).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
:class:`~app.models.config.FilterConfig` with status fields populated.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
FilterNotFoundError: If no ``{name}.conf`` or ``{name}.local`` file
|
|
||||||
exists in ``filter.d/``.
|
|
||||||
"""
|
|
||||||
# Normalise — strip extension if provided (.conf=5 chars, .local=6 chars).
|
|
||||||
if name.endswith(".conf"):
|
|
||||||
base_name = name[:-5]
|
|
||||||
elif name.endswith(".local"):
|
|
||||||
base_name = name[:-6]
|
|
||||||
else:
|
|
||||||
base_name = name
|
|
||||||
|
|
||||||
filter_d = Path(config_dir) / "filter.d"
|
|
||||||
conf_path = filter_d / f"{base_name}.conf"
|
|
||||||
local_path = filter_d / f"{base_name}.local"
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
|
|
||||||
def _read() -> tuple[str, bool, str]:
|
|
||||||
"""Read filter content and return (content, has_local_override, source_path)."""
|
|
||||||
has_local = local_path.is_file()
|
|
||||||
if conf_path.is_file():
|
|
||||||
content = conf_path.read_text(encoding="utf-8")
|
|
||||||
if has_local:
|
|
||||||
try:
|
|
||||||
content += "\n" + local_path.read_text(encoding="utf-8")
|
|
||||||
except OSError as exc:
|
|
||||||
log.warning(
|
|
||||||
"filter_local_read_error",
|
|
||||||
name=base_name,
|
|
||||||
path=str(local_path),
|
|
||||||
error=str(exc),
|
|
||||||
)
|
|
||||||
return content, has_local, str(conf_path)
|
|
||||||
elif has_local:
|
|
||||||
# Local-only filter: created by the user, no shipped .conf base.
|
|
||||||
content = local_path.read_text(encoding="utf-8")
|
|
||||||
return content, False, str(local_path)
|
|
||||||
else:
|
|
||||||
raise FilterNotFoundError(base_name)
|
|
||||||
|
|
||||||
content, has_local, source_path = await loop.run_in_executor(None, _read)
|
|
||||||
|
|
||||||
cfg = conffile_parser.parse_filter_file(content, name=base_name, filename=f"{base_name}.conf")
|
|
||||||
|
|
||||||
all_jails_result, active_names = await asyncio.gather(
|
|
||||||
loop.run_in_executor(None, _parse_jails_sync, Path(config_dir)),
|
|
||||||
_get_active_jail_names(socket_path),
|
|
||||||
)
|
|
||||||
all_jails, _source_files = all_jails_result
|
|
||||||
filter_to_jails = _build_filter_to_jails_map(all_jails, active_names)
|
|
||||||
|
|
||||||
used_by = sorted(filter_to_jails.get(base_name, []))
|
|
||||||
log.info("filter_fetched", name=base_name, active=len(used_by) > 0)
|
|
||||||
return FilterConfig(
|
|
||||||
name=cfg.name,
|
|
||||||
filename=cfg.filename,
|
|
||||||
before=cfg.before,
|
|
||||||
after=cfg.after,
|
|
||||||
variables=cfg.variables,
|
|
||||||
prefregex=cfg.prefregex,
|
|
||||||
failregex=cfg.failregex,
|
|
||||||
ignoreregex=cfg.ignoreregex,
|
|
||||||
maxlines=cfg.maxlines,
|
|
||||||
datepattern=cfg.datepattern,
|
|
||||||
journalmatch=cfg.journalmatch,
|
|
||||||
active=len(used_by) > 0,
|
|
||||||
used_by_jails=used_by,
|
|
||||||
source_file=source_path,
|
|
||||||
has_local_override=has_local,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Public API — filter write operations
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
async def update_filter(
|
|
||||||
config_dir: str,
|
|
||||||
socket_path: str,
|
|
||||||
name: str,
|
|
||||||
req: FilterUpdateRequest,
|
|
||||||
do_reload: bool = False,
|
|
||||||
) -> FilterConfig:
|
|
||||||
"""Update a filter's ``.local`` override with new regex/pattern values.
|
|
||||||
|
|
||||||
Reads the current merged configuration for *name* (``conf`` + any existing
|
|
||||||
``local``), applies the non-``None`` fields in *req* on top of it, and
|
|
||||||
writes the resulting definition to ``filter.d/{name}.local``. The
|
|
||||||
original ``.conf`` file is never modified.
|
|
||||||
|
|
||||||
All regex patterns in *req* are validated with Python's ``re`` module
|
|
||||||
before any write occurs.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config_dir: Absolute path to the fail2ban configuration directory.
|
|
||||||
socket_path: Path to the fail2ban Unix domain socket.
|
|
||||||
name: Filter base name (e.g. ``"sshd"`` or ``"sshd.conf"``).
|
|
||||||
req: Partial update — only non-``None`` fields are applied.
|
|
||||||
do_reload: When ``True``, trigger a full fail2ban reload after writing.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
:class:`~app.models.config.FilterConfig` reflecting the updated state.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
FilterNameError: If *name* contains invalid characters.
|
|
||||||
FilterNotFoundError: If no ``{name}.conf`` or ``{name}.local`` exists.
|
|
||||||
FilterInvalidRegexError: If any supplied regex pattern is invalid.
|
|
||||||
ConfigWriteError: If writing the ``.local`` file fails.
|
|
||||||
"""
|
|
||||||
base_name = name[:-5] if name.endswith(".conf") or name.endswith(".local") else name
|
|
||||||
_safe_filter_name(base_name)
|
|
||||||
|
|
||||||
# Validate regex patterns before touching the filesystem.
|
|
||||||
patterns: list[str] = []
|
|
||||||
if req.failregex is not None:
|
|
||||||
patterns.extend(req.failregex)
|
|
||||||
if req.ignoreregex is not None:
|
|
||||||
patterns.extend(req.ignoreregex)
|
|
||||||
_validate_regex_patterns(patterns)
|
|
||||||
|
|
||||||
# Fetch the current merged config (raises FilterNotFoundError if absent).
|
|
||||||
current = await get_filter(config_dir, socket_path, base_name)
|
|
||||||
|
|
||||||
# Build a FilterConfigUpdate from the request fields.
|
|
||||||
update = FilterConfigUpdate(
|
|
||||||
failregex=req.failregex,
|
|
||||||
ignoreregex=req.ignoreregex,
|
|
||||||
datepattern=req.datepattern,
|
|
||||||
journalmatch=req.journalmatch,
|
|
||||||
)
|
|
||||||
|
|
||||||
merged = conffile_parser.merge_filter_update(current, update)
|
|
||||||
content = conffile_parser.serialize_filter_config(merged)
|
|
||||||
|
|
||||||
filter_d = Path(config_dir) / "filter.d"
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
await loop.run_in_executor(None, _write_filter_local_sync, filter_d, base_name, content)
|
|
||||||
|
|
||||||
if do_reload:
|
|
||||||
try:
|
|
||||||
await reload_jails(socket_path)
|
|
||||||
except Exception as exc: # noqa: BLE001
|
|
||||||
log.warning(
|
|
||||||
"reload_after_filter_update_failed",
|
|
||||||
filter=base_name,
|
|
||||||
error=str(exc),
|
|
||||||
)
|
|
||||||
|
|
||||||
log.info("filter_updated", filter=base_name, reload=do_reload)
|
|
||||||
return await get_filter(config_dir, socket_path, base_name)
|
|
||||||
|
|
||||||
|
|
||||||
async def create_filter(
|
|
||||||
config_dir: str,
|
|
||||||
socket_path: str,
|
|
||||||
req: FilterCreateRequest,
|
|
||||||
do_reload: bool = False,
|
|
||||||
) -> FilterConfig:
|
|
||||||
"""Create a brand-new user-defined filter in ``filter.d/{name}.local``.
|
|
||||||
|
|
||||||
No ``.conf`` is written; fail2ban loads ``.local`` files directly. If a
|
|
||||||
``.conf`` or ``.local`` file already exists for the requested name, a
|
|
||||||
:class:`FilterAlreadyExistsError` is raised.
|
|
||||||
|
|
||||||
All regex patterns are validated with Python's ``re`` module before
|
|
||||||
writing.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config_dir: Absolute path to the fail2ban configuration directory.
|
|
||||||
socket_path: Path to the fail2ban Unix domain socket.
|
|
||||||
req: Filter name and definition fields.
|
|
||||||
do_reload: When ``True``, trigger a full fail2ban reload after writing.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
:class:`~app.models.config.FilterConfig` for the newly created filter.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
FilterNameError: If ``req.name`` contains invalid characters.
|
|
||||||
FilterAlreadyExistsError: If a ``.conf`` or ``.local`` already exists.
|
|
||||||
FilterInvalidRegexError: If any regex pattern is invalid.
|
|
||||||
ConfigWriteError: If writing fails.
|
|
||||||
"""
|
|
||||||
_safe_filter_name(req.name)
|
|
||||||
|
|
||||||
filter_d = Path(config_dir) / "filter.d"
|
|
||||||
conf_path = filter_d / f"{req.name}.conf"
|
|
||||||
local_path = filter_d / f"{req.name}.local"
|
|
||||||
|
|
||||||
def _check_not_exists() -> None:
|
|
||||||
if conf_path.is_file() or local_path.is_file():
|
|
||||||
raise FilterAlreadyExistsError(req.name)
|
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
await loop.run_in_executor(None, _check_not_exists)
|
|
||||||
|
|
||||||
# Validate regex patterns.
|
|
||||||
patterns: list[str] = list(req.failregex) + list(req.ignoreregex)
|
|
||||||
_validate_regex_patterns(patterns)
|
|
||||||
|
|
||||||
# Build a FilterConfig and serialise it.
|
|
||||||
cfg = FilterConfig(
|
|
||||||
name=req.name,
|
|
||||||
filename=f"{req.name}.local",
|
|
||||||
failregex=req.failregex,
|
|
||||||
ignoreregex=req.ignoreregex,
|
|
||||||
prefregex=req.prefregex,
|
|
||||||
datepattern=req.datepattern,
|
|
||||||
journalmatch=req.journalmatch,
|
|
||||||
)
|
|
||||||
content = conffile_parser.serialize_filter_config(cfg)
|
|
||||||
|
|
||||||
await loop.run_in_executor(None, _write_filter_local_sync, filter_d, req.name, content)
|
|
||||||
|
|
||||||
if do_reload:
|
|
||||||
try:
|
|
||||||
await reload_jails(socket_path)
|
|
||||||
except Exception as exc: # noqa: BLE001
|
|
||||||
log.warning(
|
|
||||||
"reload_after_filter_create_failed",
|
|
||||||
filter=req.name,
|
|
||||||
error=str(exc),
|
|
||||||
)
|
|
||||||
|
|
||||||
log.info("filter_created", filter=req.name, reload=do_reload)
|
|
||||||
# Re-fetch to get the canonical FilterConfig (source_file, active, etc.).
|
|
||||||
return await get_filter(config_dir, socket_path, req.name)
|
|
||||||
|
|
||||||
|
|
||||||
async def delete_filter(
|
|
||||||
config_dir: str,
|
|
||||||
name: str,
|
|
||||||
) -> None:
|
|
||||||
"""Delete a user-created filter's ``.local`` file.
|
|
||||||
|
|
||||||
Deletion rules:
|
|
||||||
- If only a ``.conf`` file exists (shipped default, no user override) →
|
|
||||||
:class:`FilterReadonlyError`.
|
|
||||||
- If a ``.local`` file exists (whether or not a ``.conf`` also exists) →
|
|
||||||
the ``.local`` file is deleted. The shipped ``.conf`` is never touched.
|
|
||||||
- If neither file exists → :class:`FilterNotFoundError`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config_dir: Absolute path to the fail2ban configuration directory.
|
|
||||||
name: Filter base name (e.g. ``"sshd"``).
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
FilterNameError: If *name* contains invalid characters.
|
|
||||||
FilterNotFoundError: If no filter file is found for *name*.
|
|
||||||
FilterReadonlyError: If only a shipped ``.conf`` exists (no ``.local``).
|
|
||||||
ConfigWriteError: If deletion of the ``.local`` file fails.
|
|
||||||
"""
|
|
||||||
base_name = name[:-5] if name.endswith(".conf") or name.endswith(".local") else name
|
|
||||||
_safe_filter_name(base_name)
|
|
||||||
|
|
||||||
filter_d = Path(config_dir) / "filter.d"
|
|
||||||
conf_path = filter_d / f"{base_name}.conf"
|
|
||||||
local_path = filter_d / f"{base_name}.local"
|
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
|
|
||||||
def _delete() -> None:
|
|
||||||
has_conf = conf_path.is_file()
|
|
||||||
has_local = local_path.is_file()
|
|
||||||
|
|
||||||
if not has_conf and not has_local:
|
|
||||||
raise FilterNotFoundError(base_name)
|
|
||||||
|
|
||||||
if has_conf and not has_local:
|
|
||||||
# Shipped default — nothing user-writable to remove.
|
|
||||||
raise FilterReadonlyError(base_name)
|
|
||||||
|
|
||||||
try:
|
|
||||||
local_path.unlink()
|
|
||||||
except OSError as exc:
|
|
||||||
raise ConfigWriteError(f"Failed to delete {local_path}: {exc}") from exc
|
|
||||||
|
|
||||||
log.info("filter_local_deleted", filter=base_name, path=str(local_path))
|
|
||||||
|
|
||||||
await loop.run_in_executor(None, _delete)
|
|
||||||
|
|
||||||
|
|
||||||
async def assign_filter_to_jail(
|
|
||||||
config_dir: str,
|
|
||||||
socket_path: str,
|
|
||||||
jail_name: str,
|
|
||||||
req: AssignFilterRequest,
|
|
||||||
do_reload: bool = False,
|
|
||||||
) -> None:
|
|
||||||
"""Assign a filter to a jail by updating the jail's ``.local`` file.
|
|
||||||
|
|
||||||
Writes ``filter = {req.filter_name}`` into the ``[{jail_name}]`` section
|
|
||||||
of ``jail.d/{jail_name}.local``. If the ``.local`` file already contains
|
|
||||||
other settings for this jail they are preserved.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config_dir: Absolute path to the fail2ban configuration directory.
|
|
||||||
socket_path: Path to the fail2ban Unix domain socket.
|
|
||||||
jail_name: Name of the jail to update.
|
|
||||||
req: Request containing the filter name to assign.
|
|
||||||
do_reload: When ``True``, trigger a full fail2ban reload after writing.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
JailNameError: If *jail_name* contains invalid characters.
|
|
||||||
FilterNameError: If ``req.filter_name`` contains invalid characters.
|
|
||||||
JailNotFoundError: If *jail_name* is not defined in any config file.
|
|
||||||
FilterNotFoundError: If ``req.filter_name`` does not exist in
|
|
||||||
``filter.d/``.
|
|
||||||
ConfigWriteError: If writing fails.
|
|
||||||
"""
|
|
||||||
_safe_jail_name(jail_name)
|
|
||||||
_safe_filter_name(req.filter_name)
|
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
|
|
||||||
# Verify the jail exists in config.
|
|
||||||
all_jails, _src = await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir))
|
|
||||||
if jail_name not in all_jails:
|
|
||||||
raise JailNotFoundInConfigError(jail_name)
|
|
||||||
|
|
||||||
# Verify the filter exists (conf or local).
|
|
||||||
filter_d = Path(config_dir) / "filter.d"
|
|
||||||
|
|
||||||
def _check_filter() -> None:
|
|
||||||
conf_exists = (filter_d / f"{req.filter_name}.conf").is_file()
|
|
||||||
local_exists = (filter_d / f"{req.filter_name}.local").is_file()
|
|
||||||
if not conf_exists and not local_exists:
|
|
||||||
raise FilterNotFoundError(req.filter_name)
|
|
||||||
|
|
||||||
await loop.run_in_executor(None, _check_filter)
|
|
||||||
|
|
||||||
await loop.run_in_executor(
|
|
||||||
None,
|
|
||||||
_set_jail_local_key_sync,
|
|
||||||
Path(config_dir),
|
|
||||||
jail_name,
|
|
||||||
"filter",
|
|
||||||
req.filter_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
if do_reload:
|
|
||||||
try:
|
|
||||||
await reload_jails(socket_path)
|
|
||||||
except Exception as exc: # noqa: BLE001
|
|
||||||
log.warning(
|
|
||||||
"reload_after_assign_filter_failed",
|
|
||||||
jail=jail_name,
|
|
||||||
filter=req.filter_name,
|
|
||||||
error=str(exc),
|
|
||||||
)
|
|
||||||
|
|
||||||
log.info(
|
|
||||||
"filter_assigned_to_jail",
|
|
||||||
jail=jail_name,
|
|
||||||
filter=req.filter_name,
|
|
||||||
reload=do_reload,
|
|
||||||
)
|
|
||||||
@@ -20,7 +20,9 @@ Usage::
|
|||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import aiosqlite
|
import aiosqlite
|
||||||
# Use the geo_service directly in application startup
|
from app.services import geo_service
|
||||||
|
|
||||||
|
# warm the cache from the persistent store at startup
|
||||||
async with aiosqlite.connect("bangui.db") as db:
|
async with aiosqlite.connect("bangui.db") as db:
|
||||||
await geo_service.load_cache_from_db(db)
|
await geo_service.load_cache_from_db(db)
|
||||||
|
|
||||||
@@ -28,8 +30,7 @@ Usage::
|
|||||||
# single lookup
|
# single lookup
|
||||||
info = await geo_service.lookup("1.2.3.4", session)
|
info = await geo_service.lookup("1.2.3.4", session)
|
||||||
if info:
|
if info:
|
||||||
# info.country_code == "DE"
|
print(info.country_code) # "DE"
|
||||||
... # use the GeoInfo object in your application
|
|
||||||
|
|
||||||
# bulk lookup (more efficient for large sets)
|
# bulk lookup (more efficient for large sets)
|
||||||
geo_map = await geo_service.lookup_batch(["1.2.3.4", "5.6.7.8"], session)
|
geo_map = await geo_service.lookup_batch(["1.2.3.4", "5.6.7.8"], session)
|
||||||
@@ -39,14 +40,12 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import structlog
|
import structlog
|
||||||
|
|
||||||
from app.models.geo import GeoInfo
|
|
||||||
from app.repositories import geo_cache_repo
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import aiosqlite
|
import aiosqlite
|
||||||
import geoip2.database
|
import geoip2.database
|
||||||
@@ -91,6 +90,32 @@ _BATCH_DELAY: float = 1.5
|
|||||||
#: transient error (e.g. connection reset due to rate limiting).
|
#: transient error (e.g. connection reset due to rate limiting).
|
||||||
_BATCH_MAX_RETRIES: int = 2
|
_BATCH_MAX_RETRIES: int = 2
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Domain model
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GeoInfo:
|
||||||
|
"""Geographical and network metadata for a single IP address.
|
||||||
|
|
||||||
|
All fields default to ``None`` when the information is unavailable or
|
||||||
|
the lookup fails gracefully.
|
||||||
|
"""
|
||||||
|
|
||||||
|
country_code: str | None
|
||||||
|
"""ISO 3166-1 alpha-2 country code, e.g. ``"DE"``."""
|
||||||
|
|
||||||
|
country_name: str | None
|
||||||
|
"""Human-readable country name, e.g. ``"Germany"``."""
|
||||||
|
|
||||||
|
asn: str | None
|
||||||
|
"""Autonomous System Number string, e.g. ``"AS3320"``."""
|
||||||
|
|
||||||
|
org: str | None
|
||||||
|
"""Organisation name associated with the IP, e.g. ``"Deutsche Telekom"``."""
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Internal cache
|
# Internal cache
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -159,7 +184,11 @@ async def cache_stats(db: aiosqlite.Connection) -> dict[str, int]:
|
|||||||
Dict with keys ``cache_size``, ``unresolved``, ``neg_cache_size``,
|
Dict with keys ``cache_size``, ``unresolved``, ``neg_cache_size``,
|
||||||
and ``dirty_size``.
|
and ``dirty_size``.
|
||||||
"""
|
"""
|
||||||
unresolved = await geo_cache_repo.count_unresolved(db)
|
async with db.execute(
|
||||||
|
"SELECT COUNT(*) FROM geo_cache WHERE country_code IS NULL"
|
||||||
|
) as cur:
|
||||||
|
row = await cur.fetchone()
|
||||||
|
unresolved: int = int(row[0]) if row else 0
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"cache_size": len(_cache),
|
"cache_size": len(_cache),
|
||||||
@@ -169,24 +198,6 @@ async def cache_stats(db: aiosqlite.Connection) -> dict[str, int]:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
async def count_unresolved(db: aiosqlite.Connection) -> int:
|
|
||||||
"""Return the number of unresolved entries in the persistent geo cache."""
|
|
||||||
|
|
||||||
return await geo_cache_repo.count_unresolved(db)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_unresolved_ips(db: aiosqlite.Connection) -> list[str]:
|
|
||||||
"""Return geo cache IPs where the country code has not yet been resolved.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Open BanGUI application database connection.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of IP addresses that are candidates for re-resolution.
|
|
||||||
"""
|
|
||||||
return await geo_cache_repo.get_unresolved_ips(db)
|
|
||||||
|
|
||||||
|
|
||||||
def init_geoip(mmdb_path: str | None) -> None:
|
def init_geoip(mmdb_path: str | None) -> None:
|
||||||
"""Initialise the MaxMind GeoLite2-Country database reader.
|
"""Initialise the MaxMind GeoLite2-Country database reader.
|
||||||
|
|
||||||
@@ -257,18 +268,21 @@ async def load_cache_from_db(db: aiosqlite.Connection) -> None:
|
|||||||
database (not the fail2ban database).
|
database (not the fail2ban database).
|
||||||
"""
|
"""
|
||||||
count = 0
|
count = 0
|
||||||
for row in await geo_cache_repo.load_all(db):
|
async with db.execute(
|
||||||
country_code: str | None = row["country_code"]
|
"SELECT ip, country_code, country_name, asn, org FROM geo_cache"
|
||||||
if country_code is None:
|
) as cur:
|
||||||
continue
|
async for row in cur:
|
||||||
ip: str = row["ip"]
|
ip: str = str(row[0])
|
||||||
_cache[ip] = GeoInfo(
|
country_code: str | None = row[1]
|
||||||
country_code=country_code,
|
if country_code is None:
|
||||||
country_name=row["country_name"],
|
continue
|
||||||
asn=row["asn"],
|
_cache[ip] = GeoInfo(
|
||||||
org=row["org"],
|
country_code=country_code,
|
||||||
)
|
country_name=row[2],
|
||||||
count += 1
|
asn=row[3],
|
||||||
|
org=row[4],
|
||||||
|
)
|
||||||
|
count += 1
|
||||||
log.info("geo_cache_loaded_from_db", entries=count)
|
log.info("geo_cache_loaded_from_db", entries=count)
|
||||||
|
|
||||||
|
|
||||||
@@ -287,13 +301,18 @@ async def _persist_entry(
|
|||||||
ip: IP address string.
|
ip: IP address string.
|
||||||
info: Resolved geo data to persist.
|
info: Resolved geo data to persist.
|
||||||
"""
|
"""
|
||||||
await geo_cache_repo.upsert_entry(
|
await db.execute(
|
||||||
db=db,
|
"""
|
||||||
ip=ip,
|
INSERT INTO geo_cache (ip, country_code, country_name, asn, org)
|
||||||
country_code=info.country_code,
|
VALUES (?, ?, ?, ?, ?)
|
||||||
country_name=info.country_name,
|
ON CONFLICT(ip) DO UPDATE SET
|
||||||
asn=info.asn,
|
country_code = excluded.country_code,
|
||||||
org=info.org,
|
country_name = excluded.country_name,
|
||||||
|
asn = excluded.asn,
|
||||||
|
org = excluded.org,
|
||||||
|
cached_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now')
|
||||||
|
""",
|
||||||
|
(ip, info.country_code, info.country_name, info.asn, info.org),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -307,7 +326,10 @@ async def _persist_neg_entry(db: aiosqlite.Connection, ip: str) -> None:
|
|||||||
db: BanGUI application database connection.
|
db: BanGUI application database connection.
|
||||||
ip: IP address string whose resolution failed.
|
ip: IP address string whose resolution failed.
|
||||||
"""
|
"""
|
||||||
await geo_cache_repo.upsert_neg_entry(db=db, ip=ip)
|
await db.execute(
|
||||||
|
"INSERT OR IGNORE INTO geo_cache (ip) VALUES (?)",
|
||||||
|
(ip,),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -563,7 +585,19 @@ async def lookup_batch(
|
|||||||
if db is not None:
|
if db is not None:
|
||||||
if pos_rows:
|
if pos_rows:
|
||||||
try:
|
try:
|
||||||
await geo_cache_repo.bulk_upsert_entries(db, pos_rows)
|
await db.executemany(
|
||||||
|
"""
|
||||||
|
INSERT INTO geo_cache (ip, country_code, country_name, asn, org)
|
||||||
|
VALUES (?, ?, ?, ?, ?)
|
||||||
|
ON CONFLICT(ip) DO UPDATE SET
|
||||||
|
country_code = excluded.country_code,
|
||||||
|
country_name = excluded.country_name,
|
||||||
|
asn = excluded.asn,
|
||||||
|
org = excluded.org,
|
||||||
|
cached_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now')
|
||||||
|
""",
|
||||||
|
pos_rows,
|
||||||
|
)
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
log.warning(
|
log.warning(
|
||||||
"geo_batch_persist_failed",
|
"geo_batch_persist_failed",
|
||||||
@@ -572,7 +606,10 @@ async def lookup_batch(
|
|||||||
)
|
)
|
||||||
if neg_ips:
|
if neg_ips:
|
||||||
try:
|
try:
|
||||||
await geo_cache_repo.bulk_upsert_neg_entries(db, neg_ips)
|
await db.executemany(
|
||||||
|
"INSERT OR IGNORE INTO geo_cache (ip) VALUES (?)",
|
||||||
|
[(ip,) for ip in neg_ips],
|
||||||
|
)
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
log.warning(
|
log.warning(
|
||||||
"geo_batch_persist_neg_failed",
|
"geo_batch_persist_neg_failed",
|
||||||
@@ -755,7 +792,19 @@ async def flush_dirty(db: aiosqlite.Connection) -> int:
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await geo_cache_repo.bulk_upsert_entries(db, rows)
|
await db.executemany(
|
||||||
|
"""
|
||||||
|
INSERT INTO geo_cache (ip, country_code, country_name, asn, org)
|
||||||
|
VALUES (?, ?, ?, ?, ?)
|
||||||
|
ON CONFLICT(ip) DO UPDATE SET
|
||||||
|
country_code = excluded.country_code,
|
||||||
|
country_name = excluded.country_name,
|
||||||
|
asn = excluded.asn,
|
||||||
|
org = excluded.org,
|
||||||
|
cached_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now')
|
||||||
|
""",
|
||||||
|
rows,
|
||||||
|
)
|
||||||
await db.commit()
|
await db.commit()
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
log.warning("geo_flush_dirty_failed", error=str(exc))
|
log.warning("geo_flush_dirty_failed", error=str(exc))
|
||||||
|
|||||||
@@ -9,17 +9,12 @@ seconds by the background health-check task, not on every HTTP request.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import cast
|
from typing import Any
|
||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
|
|
||||||
from app.models.server import ServerStatus
|
from app.models.server import ServerStatus
|
||||||
from app.utils.fail2ban_client import (
|
from app.utils.fail2ban_client import Fail2BanClient, Fail2BanConnectionError, Fail2BanProtocolError
|
||||||
Fail2BanClient,
|
|
||||||
Fail2BanConnectionError,
|
|
||||||
Fail2BanProtocolError,
|
|
||||||
Fail2BanResponse,
|
|
||||||
)
|
|
||||||
|
|
||||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||||
|
|
||||||
@@ -30,7 +25,7 @@ log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
|||||||
_SOCKET_TIMEOUT: float = 5.0
|
_SOCKET_TIMEOUT: float = 5.0
|
||||||
|
|
||||||
|
|
||||||
def _ok(response: object) -> object:
|
def _ok(response: Any) -> Any:
|
||||||
"""Extract the payload from a fail2ban ``(return_code, data)`` response.
|
"""Extract the payload from a fail2ban ``(return_code, data)`` response.
|
||||||
|
|
||||||
fail2ban wraps every response in a ``(0, data)`` success tuple or
|
fail2ban wraps every response in a ``(0, data)`` success tuple or
|
||||||
@@ -47,7 +42,7 @@ def _ok(response: object) -> object:
|
|||||||
ValueError: If the response indicates an error (return code ≠ 0).
|
ValueError: If the response indicates an error (return code ≠ 0).
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
code, data = cast("Fail2BanResponse", response)
|
code, data = response
|
||||||
except (TypeError, ValueError) as exc:
|
except (TypeError, ValueError) as exc:
|
||||||
raise ValueError(f"Unexpected fail2ban response shape: {response!r}") from exc
|
raise ValueError(f"Unexpected fail2ban response shape: {response!r}") from exc
|
||||||
|
|
||||||
@@ -57,7 +52,7 @@ def _ok(response: object) -> object:
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
def _to_dict(pairs: object) -> dict[str, object]:
|
def _to_dict(pairs: Any) -> dict[str, Any]:
|
||||||
"""Convert a list of ``(key, value)`` pairs to a plain dict.
|
"""Convert a list of ``(key, value)`` pairs to a plain dict.
|
||||||
|
|
||||||
fail2ban returns structured data as lists of 2-tuples rather than dicts.
|
fail2ban returns structured data as lists of 2-tuples rather than dicts.
|
||||||
@@ -71,7 +66,7 @@ def _to_dict(pairs: object) -> dict[str, object]:
|
|||||||
"""
|
"""
|
||||||
if not isinstance(pairs, (list, tuple)):
|
if not isinstance(pairs, (list, tuple)):
|
||||||
return {}
|
return {}
|
||||||
result: dict[str, object] = {}
|
result: dict[str, Any] = {}
|
||||||
for item in pairs:
|
for item in pairs:
|
||||||
try:
|
try:
|
||||||
k, v = item
|
k, v = item
|
||||||
@@ -124,7 +119,7 @@ async def probe(socket_path: str, timeout: float = _SOCKET_TIMEOUT) -> ServerSta
|
|||||||
# 3. Global status — jail count and names #
|
# 3. Global status — jail count and names #
|
||||||
# ------------------------------------------------------------------ #
|
# ------------------------------------------------------------------ #
|
||||||
status_data = _to_dict(_ok(await client.send(["status"])))
|
status_data = _to_dict(_ok(await client.send(["status"])))
|
||||||
active_jails: int = int(str(status_data.get("Number of jail", 0) or 0))
|
active_jails: int = int(status_data.get("Number of jail", 0) or 0)
|
||||||
jail_list_raw: str = str(status_data.get("Jail list", "") or "").strip()
|
jail_list_raw: str = str(status_data.get("Jail list", "") or "").strip()
|
||||||
jail_names: list[str] = (
|
jail_names: list[str] = (
|
||||||
[j.strip() for j in jail_list_raw.split(",") if j.strip()]
|
[j.strip() for j in jail_list_raw.split(",") if j.strip()]
|
||||||
@@ -143,8 +138,8 @@ async def probe(socket_path: str, timeout: float = _SOCKET_TIMEOUT) -> ServerSta
|
|||||||
jail_resp = _to_dict(_ok(await client.send(["status", jail_name])))
|
jail_resp = _to_dict(_ok(await client.send(["status", jail_name])))
|
||||||
filter_stats = _to_dict(jail_resp.get("Filter") or [])
|
filter_stats = _to_dict(jail_resp.get("Filter") or [])
|
||||||
action_stats = _to_dict(jail_resp.get("Actions") or [])
|
action_stats = _to_dict(jail_resp.get("Actions") or [])
|
||||||
total_failures += int(str(filter_stats.get("Currently failed", 0) or 0))
|
total_failures += int(filter_stats.get("Currently failed", 0) or 0)
|
||||||
total_bans += int(str(action_stats.get("Currently banned", 0) or 0))
|
total_bans += int(action_stats.get("Currently banned", 0) or 0)
|
||||||
except (ValueError, TypeError, KeyError) as exc:
|
except (ValueError, TypeError, KeyError) as exc:
|
||||||
log.warning(
|
log.warning(
|
||||||
"fail2ban_jail_status_parse_error",
|
"fail2ban_jail_status_parse_error",
|
||||||
|
|||||||
@@ -11,22 +11,19 @@ modifies or locks the fail2ban database.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from typing import TYPE_CHECKING
|
from typing import Any
|
||||||
|
|
||||||
|
import aiosqlite
|
||||||
import structlog
|
import structlog
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
from app.models.ban import BLOCKLIST_JAIL, BanOrigin, TIME_RANGE_SECONDS, TimeRange
|
||||||
from app.models.geo import GeoEnricher
|
|
||||||
|
|
||||||
from app.models.ban import TIME_RANGE_SECONDS, TimeRange
|
|
||||||
from app.models.history import (
|
from app.models.history import (
|
||||||
HistoryBanItem,
|
HistoryBanItem,
|
||||||
HistoryListResponse,
|
HistoryListResponse,
|
||||||
IpDetailResponse,
|
IpDetailResponse,
|
||||||
IpTimelineEvent,
|
IpTimelineEvent,
|
||||||
)
|
)
|
||||||
from app.repositories import fail2ban_db_repo
|
from app.services.ban_service import _get_fail2ban_db_path, _parse_data_json, _ts_to_iso
|
||||||
from app.utils.fail2ban_db_utils import get_fail2ban_db_path, parse_data_json, ts_to_iso
|
|
||||||
|
|
||||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||||
|
|
||||||
@@ -61,10 +58,11 @@ async def list_history(
|
|||||||
*,
|
*,
|
||||||
range_: TimeRange | None = None,
|
range_: TimeRange | None = None,
|
||||||
jail: str | None = None,
|
jail: str | None = None,
|
||||||
|
origin: BanOrigin | None = None,
|
||||||
ip_filter: str | None = None,
|
ip_filter: str | None = None,
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
page_size: int = _DEFAULT_PAGE_SIZE,
|
page_size: int = _DEFAULT_PAGE_SIZE,
|
||||||
geo_enricher: GeoEnricher | None = None,
|
geo_enricher: Any | None = None,
|
||||||
) -> HistoryListResponse:
|
) -> HistoryListResponse:
|
||||||
"""Return a paginated list of historical ban records with optional filters.
|
"""Return a paginated list of historical ban records with optional filters.
|
||||||
|
|
||||||
@@ -76,6 +74,8 @@ async def list_history(
|
|||||||
socket_path: Path to the fail2ban Unix domain socket.
|
socket_path: Path to the fail2ban Unix domain socket.
|
||||||
range_: Time-range preset. ``None`` means all-time (no time filter).
|
range_: Time-range preset. ``None`` means all-time (no time filter).
|
||||||
jail: If given, restrict results to bans from this jail.
|
jail: If given, restrict results to bans from this jail.
|
||||||
|
origin: Optional origin filter — ``"blocklist"`` restricts results to
|
||||||
|
the ``blocklist-import`` jail, ``"selfblock"`` excludes it.
|
||||||
ip_filter: If given, restrict results to bans for this exact IP
|
ip_filter: If given, restrict results to bans for this exact IP
|
||||||
(or a prefix — the query uses ``LIKE ip_filter%``).
|
(or a prefix — the query uses ``LIKE ip_filter%``).
|
||||||
page: 1-based page number (default: ``1``).
|
page: 1-based page number (default: ``1``).
|
||||||
@@ -87,13 +87,36 @@ async def list_history(
|
|||||||
and the total matching count.
|
and the total matching count.
|
||||||
"""
|
"""
|
||||||
effective_page_size: int = min(page_size, _MAX_PAGE_SIZE)
|
effective_page_size: int = min(page_size, _MAX_PAGE_SIZE)
|
||||||
|
offset: int = (page - 1) * effective_page_size
|
||||||
|
|
||||||
# Build WHERE clauses dynamically.
|
# Build WHERE clauses dynamically.
|
||||||
since: int | None = None
|
wheres: list[str] = []
|
||||||
if range_ is not None:
|
params: list[Any] = []
|
||||||
since = _since_unix(range_)
|
|
||||||
|
|
||||||
db_path: str = await get_fail2ban_db_path(socket_path)
|
if range_ is not None:
|
||||||
|
since: int = _since_unix(range_)
|
||||||
|
wheres.append("timeofban >= ?")
|
||||||
|
params.append(since)
|
||||||
|
|
||||||
|
if jail is not None:
|
||||||
|
wheres.append("jail = ?")
|
||||||
|
params.append(jail)
|
||||||
|
|
||||||
|
if origin is not None:
|
||||||
|
if origin == "blocklist":
|
||||||
|
wheres.append("jail = ?")
|
||||||
|
params.append(BLOCKLIST_JAIL)
|
||||||
|
elif origin == "selfblock":
|
||||||
|
wheres.append("jail != ?")
|
||||||
|
params.append(BLOCKLIST_JAIL)
|
||||||
|
|
||||||
|
if ip_filter is not None:
|
||||||
|
wheres.append("ip LIKE ?")
|
||||||
|
params.append(f"{ip_filter}%")
|
||||||
|
|
||||||
|
where_sql: str = ("WHERE " + " AND ".join(wheres)) if wheres else ""
|
||||||
|
|
||||||
|
db_path: str = await _get_fail2ban_db_path(socket_path)
|
||||||
log.info(
|
log.info(
|
||||||
"history_service_list",
|
"history_service_list",
|
||||||
db_path=db_path,
|
db_path=db_path,
|
||||||
@@ -103,22 +126,32 @@ async def list_history(
|
|||||||
page=page,
|
page=page,
|
||||||
)
|
)
|
||||||
|
|
||||||
rows, total = await fail2ban_db_repo.get_history_page(
|
async with aiosqlite.connect(f"file:{db_path}?mode=ro", uri=True) as f2b_db:
|
||||||
db_path=db_path,
|
f2b_db.row_factory = aiosqlite.Row
|
||||||
since=since,
|
|
||||||
jail=jail,
|
async with f2b_db.execute(
|
||||||
ip_filter=ip_filter,
|
f"SELECT COUNT(*) FROM bans {where_sql}", # noqa: S608
|
||||||
page=page,
|
params,
|
||||||
page_size=effective_page_size,
|
) as cur:
|
||||||
)
|
count_row = await cur.fetchone()
|
||||||
|
total: int = int(count_row[0]) if count_row else 0
|
||||||
|
|
||||||
|
async with f2b_db.execute(
|
||||||
|
f"SELECT jail, ip, timeofban, bancount, data " # noqa: S608
|
||||||
|
f"FROM bans {where_sql} "
|
||||||
|
"ORDER BY timeofban DESC "
|
||||||
|
"LIMIT ? OFFSET ?",
|
||||||
|
[*params, effective_page_size, offset],
|
||||||
|
) as cur:
|
||||||
|
rows = await cur.fetchall()
|
||||||
|
|
||||||
items: list[HistoryBanItem] = []
|
items: list[HistoryBanItem] = []
|
||||||
for row in rows:
|
for row in rows:
|
||||||
jail_name: str = row.jail
|
jail_name: str = str(row["jail"])
|
||||||
ip: str = row.ip
|
ip: str = str(row["ip"])
|
||||||
banned_at: str = ts_to_iso(row.timeofban)
|
banned_at: str = _ts_to_iso(int(row["timeofban"]))
|
||||||
ban_count: int = row.bancount
|
ban_count: int = int(row["bancount"])
|
||||||
matches, failures = parse_data_json(row.data)
|
matches, failures = _parse_data_json(row["data"])
|
||||||
|
|
||||||
country_code: str | None = None
|
country_code: str | None = None
|
||||||
country_name: str | None = None
|
country_name: str | None = None
|
||||||
@@ -163,7 +196,7 @@ async def get_ip_detail(
|
|||||||
socket_path: str,
|
socket_path: str,
|
||||||
ip: str,
|
ip: str,
|
||||||
*,
|
*,
|
||||||
geo_enricher: GeoEnricher | None = None,
|
geo_enricher: Any | None = None,
|
||||||
) -> IpDetailResponse | None:
|
) -> IpDetailResponse | None:
|
||||||
"""Return the full historical record for a single IP address.
|
"""Return the full historical record for a single IP address.
|
||||||
|
|
||||||
@@ -180,10 +213,19 @@ async def get_ip_detail(
|
|||||||
:class:`~app.models.history.IpDetailResponse` if any records exist
|
:class:`~app.models.history.IpDetailResponse` if any records exist
|
||||||
for *ip*, or ``None`` if the IP has no history in the database.
|
for *ip*, or ``None`` if the IP has no history in the database.
|
||||||
"""
|
"""
|
||||||
db_path: str = await get_fail2ban_db_path(socket_path)
|
db_path: str = await _get_fail2ban_db_path(socket_path)
|
||||||
log.info("history_service_ip_detail", db_path=db_path, ip=ip)
|
log.info("history_service_ip_detail", db_path=db_path, ip=ip)
|
||||||
|
|
||||||
rows = await fail2ban_db_repo.get_history_for_ip(db_path=db_path, ip=ip)
|
async with aiosqlite.connect(f"file:{db_path}?mode=ro", uri=True) as f2b_db:
|
||||||
|
f2b_db.row_factory = aiosqlite.Row
|
||||||
|
async with f2b_db.execute(
|
||||||
|
"SELECT jail, ip, timeofban, bancount, data "
|
||||||
|
"FROM bans "
|
||||||
|
"WHERE ip = ? "
|
||||||
|
"ORDER BY timeofban DESC",
|
||||||
|
(ip,),
|
||||||
|
) as cur:
|
||||||
|
rows = await cur.fetchall()
|
||||||
|
|
||||||
if not rows:
|
if not rows:
|
||||||
return None
|
return None
|
||||||
@@ -192,10 +234,10 @@ async def get_ip_detail(
|
|||||||
total_failures: int = 0
|
total_failures: int = 0
|
||||||
|
|
||||||
for row in rows:
|
for row in rows:
|
||||||
jail_name: str = row.jail
|
jail_name: str = str(row["jail"])
|
||||||
banned_at: str = ts_to_iso(row.timeofban)
|
banned_at: str = _ts_to_iso(int(row["timeofban"]))
|
||||||
ban_count: int = row.bancount
|
ban_count: int = int(row["bancount"])
|
||||||
matches, failures = parse_data_json(row.data)
|
matches, failures = _parse_data_json(row["data"])
|
||||||
total_failures += failures
|
total_failures += failures
|
||||||
timeline.append(
|
timeline.append(
|
||||||
IpTimelineEvent(
|
IpTimelineEvent(
|
||||||
|
|||||||
@@ -1,998 +0,0 @@
|
|||||||
"""Jail configuration management for BanGUI.
|
|
||||||
|
|
||||||
Handles parsing, validation, and lifecycle operations (activate/deactivate)
|
|
||||||
for fail2ban jail configurations. Provides functions to discover inactive
|
|
||||||
jails, validate their configurations before activation, and manage jail
|
|
||||||
overrides in jail.d/*.local files.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import configparser
|
|
||||||
import contextlib
|
|
||||||
import io
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import tempfile
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import cast
|
|
||||||
|
|
||||||
import structlog
|
|
||||||
|
|
||||||
from app.exceptions import JailNotFoundError
|
|
||||||
from app.models.config import (
|
|
||||||
ActivateJailRequest,
|
|
||||||
InactiveJail,
|
|
||||||
InactiveJailListResponse,
|
|
||||||
JailActivationResponse,
|
|
||||||
JailValidationIssue,
|
|
||||||
JailValidationResult,
|
|
||||||
RollbackResponse,
|
|
||||||
)
|
|
||||||
from app.utils.config_file_utils import (
|
|
||||||
_build_inactive_jail,
|
|
||||||
_ordered_config_files,
|
|
||||||
_parse_jails_sync,
|
|
||||||
_validate_jail_config_sync,
|
|
||||||
)
|
|
||||||
from app.utils.jail_utils import reload_jails
|
|
||||||
from app.utils.fail2ban_client import (
|
|
||||||
Fail2BanClient,
|
|
||||||
Fail2BanConnectionError,
|
|
||||||
Fail2BanResponse,
|
|
||||||
)
|
|
||||||
|
|
||||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Constants
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
_SOCKET_TIMEOUT: float = 10.0
|
|
||||||
|
|
||||||
# Allowlist pattern for jail names used in path construction.
|
|
||||||
_SAFE_JAIL_NAME_RE: re.Pattern[str] = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$")
|
|
||||||
|
|
||||||
# Sections that are not jail definitions.
|
|
||||||
_META_SECTIONS: frozenset[str] = frozenset({"INCLUDES", "DEFAULT"})
|
|
||||||
|
|
||||||
# True-ish values for the ``enabled`` key.
|
|
||||||
_TRUE_VALUES: frozenset[str] = frozenset({"true", "yes", "1"})
|
|
||||||
|
|
||||||
# False-ish values for the ``enabled`` key.
|
|
||||||
_FALSE_VALUES: frozenset[str] = frozenset({"false", "no", "0"})
|
|
||||||
|
|
||||||
# Seconds to wait between fail2ban liveness probes after a reload.
|
|
||||||
_POST_RELOAD_PROBE_INTERVAL: float = 2.0
|
|
||||||
|
|
||||||
# Maximum number of post-reload probe attempts (initial attempt + retries).
|
|
||||||
_POST_RELOAD_MAX_ATTEMPTS: int = 4
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Custom exceptions
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class JailNotFoundInConfigError(Exception):
|
|
||||||
"""Raised when the requested jail name is not defined in any config file."""
|
|
||||||
|
|
||||||
def __init__(self, name: str) -> None:
|
|
||||||
"""Initialise with the jail name that was not found.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: The jail name that could not be located.
|
|
||||||
"""
|
|
||||||
self.name: str = name
|
|
||||||
super().__init__(f"Jail not found in config files: {name!r}")
|
|
||||||
|
|
||||||
|
|
||||||
class JailAlreadyActiveError(Exception):
|
|
||||||
"""Raised when trying to activate a jail that is already active."""
|
|
||||||
|
|
||||||
def __init__(self, name: str) -> None:
|
|
||||||
"""Initialise with the jail name.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: The jail that is already active.
|
|
||||||
"""
|
|
||||||
self.name: str = name
|
|
||||||
super().__init__(f"Jail is already active: {name!r}")
|
|
||||||
|
|
||||||
|
|
||||||
class JailAlreadyInactiveError(Exception):
|
|
||||||
"""Raised when trying to deactivate a jail that is already inactive."""
|
|
||||||
|
|
||||||
def __init__(self, name: str) -> None:
|
|
||||||
"""Initialise with the jail name.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: The jail that is already inactive.
|
|
||||||
"""
|
|
||||||
self.name: str = name
|
|
||||||
super().__init__(f"Jail is already inactive: {name!r}")
|
|
||||||
|
|
||||||
|
|
||||||
class JailNameError(Exception):
|
|
||||||
"""Raised when a jail name contains invalid characters."""
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigWriteError(Exception):
|
|
||||||
"""Raised when writing a ``.local`` override file fails."""
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Internal helpers
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _safe_jail_name(name: str) -> str:
|
|
||||||
"""Validate *name* and return it unchanged or raise :class:`JailNameError`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: Proposed jail name.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The name unchanged if valid.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
JailNameError: If *name* contains unsafe characters.
|
|
||||||
"""
|
|
||||||
if not _SAFE_JAIL_NAME_RE.match(name):
|
|
||||||
raise JailNameError(
|
|
||||||
f"Jail name {name!r} contains invalid characters. "
|
|
||||||
"Only alphanumeric characters, hyphens, underscores, and dots are "
|
|
||||||
"allowed; must start with an alphanumeric character."
|
|
||||||
)
|
|
||||||
return name
|
|
||||||
|
|
||||||
|
|
||||||
def _build_parser() -> configparser.RawConfigParser:
|
|
||||||
"""Create a :class:`configparser.RawConfigParser` for fail2ban configs.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Parser with interpolation disabled and case-sensitive option names.
|
|
||||||
"""
|
|
||||||
parser = configparser.RawConfigParser(interpolation=None, strict=False)
|
|
||||||
# fail2ban keys are lowercase but preserve case to be safe.
|
|
||||||
parser.optionxform = str # type: ignore[assignment]
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
def _is_truthy(value: str) -> bool:
|
|
||||||
"""Return ``True`` if *value* is a fail2ban boolean true string.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
value: Raw string from config (e.g. ``"true"``, ``"yes"``, ``"1"``).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
``True`` when the value represents enabled.
|
|
||||||
"""
|
|
||||||
return value.strip().lower() in _TRUE_VALUES
|
|
||||||
|
|
||||||
|
|
||||||
def _write_local_override_sync(
|
|
||||||
config_dir: Path,
|
|
||||||
jail_name: str,
|
|
||||||
enabled: bool,
|
|
||||||
overrides: dict[str, object],
|
|
||||||
) -> None:
|
|
||||||
"""Write a ``jail.d/{name}.local`` file atomically.
|
|
||||||
|
|
||||||
Always writes to ``jail.d/{jail_name}.local``. If the file already
|
|
||||||
exists it is replaced entirely. The write is atomic: content is
|
|
||||||
written to a temp file first, then renamed into place.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config_dir: The fail2ban configuration root directory.
|
|
||||||
jail_name: Validated jail name (used as filename stem).
|
|
||||||
enabled: Value to write for ``enabled =``.
|
|
||||||
overrides: Optional setting overrides (bantime, findtime, maxretry,
|
|
||||||
port, logpath).
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ConfigWriteError: If writing fails.
|
|
||||||
"""
|
|
||||||
jail_d = config_dir / "jail.d"
|
|
||||||
try:
|
|
||||||
jail_d.mkdir(parents=True, exist_ok=True)
|
|
||||||
except OSError as exc:
|
|
||||||
raise ConfigWriteError(f"Cannot create jail.d directory: {exc}") from exc
|
|
||||||
|
|
||||||
local_path = jail_d / f"{jail_name}.local"
|
|
||||||
|
|
||||||
lines: list[str] = [
|
|
||||||
"# Managed by BanGUI — do not edit manually",
|
|
||||||
"",
|
|
||||||
f"[{jail_name}]",
|
|
||||||
"",
|
|
||||||
f"enabled = {'true' if enabled else 'false'}",
|
|
||||||
# Provide explicit banaction defaults so fail2ban can resolve the
|
|
||||||
# %(banaction)s interpolation used in the built-in action_ chain.
|
|
||||||
"banaction = iptables-multiport",
|
|
||||||
"banaction_allports = iptables-allports",
|
|
||||||
]
|
|
||||||
|
|
||||||
if overrides.get("bantime") is not None:
|
|
||||||
lines.append(f"bantime = {overrides['bantime']}")
|
|
||||||
if overrides.get("findtime") is not None:
|
|
||||||
lines.append(f"findtime = {overrides['findtime']}")
|
|
||||||
if overrides.get("maxretry") is not None:
|
|
||||||
lines.append(f"maxretry = {overrides['maxretry']}")
|
|
||||||
if overrides.get("port") is not None:
|
|
||||||
lines.append(f"port = {overrides['port']}")
|
|
||||||
if overrides.get("logpath"):
|
|
||||||
paths: list[str] = cast("list[str]", overrides["logpath"])
|
|
||||||
if paths:
|
|
||||||
lines.append(f"logpath = {paths[0]}")
|
|
||||||
for p in paths[1:]:
|
|
||||||
lines.append(f" {p}")
|
|
||||||
|
|
||||||
content = "\n".join(lines) + "\n"
|
|
||||||
|
|
||||||
try:
|
|
||||||
with tempfile.NamedTemporaryFile(
|
|
||||||
mode="w",
|
|
||||||
encoding="utf-8",
|
|
||||||
dir=jail_d,
|
|
||||||
delete=False,
|
|
||||||
suffix=".tmp",
|
|
||||||
) as tmp:
|
|
||||||
tmp.write(content)
|
|
||||||
tmp_name = tmp.name
|
|
||||||
os.replace(tmp_name, local_path)
|
|
||||||
except OSError as exc:
|
|
||||||
# Clean up temp file if rename failed.
|
|
||||||
with contextlib.suppress(OSError):
|
|
||||||
os.unlink(tmp_name) # noqa: F821 — only reachable when tmp_name is set
|
|
||||||
raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc
|
|
||||||
|
|
||||||
log.info(
|
|
||||||
"jail_local_written",
|
|
||||||
jail=jail_name,
|
|
||||||
path=str(local_path),
|
|
||||||
enabled=enabled,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _restore_local_file_sync(local_path: Path, original_content: bytes | None) -> None:
|
|
||||||
"""Restore a ``.local`` file to its pre-activation state.
|
|
||||||
|
|
||||||
If *original_content* is ``None``, the file is deleted (it did not exist
|
|
||||||
before the activation). Otherwise the original bytes are written back
|
|
||||||
atomically via a temp-file rename.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
local_path: Absolute path to the ``.local`` file to restore.
|
|
||||||
original_content: Original raw bytes to write back, or ``None`` to
|
|
||||||
delete the file.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ConfigWriteError: If the write or delete operation fails.
|
|
||||||
"""
|
|
||||||
if original_content is None:
|
|
||||||
try:
|
|
||||||
local_path.unlink(missing_ok=True)
|
|
||||||
except OSError as exc:
|
|
||||||
raise ConfigWriteError(f"Failed to delete {local_path} during rollback: {exc}") from exc
|
|
||||||
return
|
|
||||||
|
|
||||||
tmp_name: str | None = None
|
|
||||||
try:
|
|
||||||
with tempfile.NamedTemporaryFile(
|
|
||||||
mode="wb",
|
|
||||||
dir=local_path.parent,
|
|
||||||
delete=False,
|
|
||||||
suffix=".tmp",
|
|
||||||
) as tmp:
|
|
||||||
tmp.write(original_content)
|
|
||||||
tmp_name = tmp.name
|
|
||||||
os.replace(tmp_name, local_path)
|
|
||||||
except OSError as exc:
|
|
||||||
with contextlib.suppress(OSError):
|
|
||||||
if tmp_name is not None:
|
|
||||||
os.unlink(tmp_name)
|
|
||||||
raise ConfigWriteError(f"Failed to restore {local_path} during rollback: {exc}") from exc
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_regex_patterns(patterns: list[str]) -> None:
|
|
||||||
"""Validate each pattern in *patterns* using Python's ``re`` module.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
patterns: List of regex strings to validate.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
FilterInvalidRegexError: If any pattern fails to compile.
|
|
||||||
"""
|
|
||||||
for pattern in patterns:
|
|
||||||
try:
|
|
||||||
re.compile(pattern)
|
|
||||||
except re.error as exc:
|
|
||||||
# Import here to avoid circular dependency
|
|
||||||
from app.exceptions import FilterInvalidRegexError
|
|
||||||
raise FilterInvalidRegexError(pattern, str(exc)) from exc
|
|
||||||
|
|
||||||
|
|
||||||
def _set_jail_local_key_sync(
|
|
||||||
config_dir: Path,
|
|
||||||
jail_name: str,
|
|
||||||
key: str,
|
|
||||||
value: str,
|
|
||||||
) -> None:
|
|
||||||
"""Update ``jail.d/{jail_name}.local`` to set a single key in the jail section.
|
|
||||||
|
|
||||||
If the ``.local`` file already exists it is read, the key is updated (or
|
|
||||||
added), and the file is written back atomically without disturbing other
|
|
||||||
settings. If the file does not exist a new one is created containing
|
|
||||||
only the BanGUI header comment, the jail section, and the requested key.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config_dir: The fail2ban configuration root directory.
|
|
||||||
jail_name: Validated jail name (used as section name and filename stem).
|
|
||||||
key: Config key to set inside the jail section.
|
|
||||||
value: Config value to assign.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ConfigWriteError: If writing fails.
|
|
||||||
"""
|
|
||||||
jail_d = config_dir / "jail.d"
|
|
||||||
try:
|
|
||||||
jail_d.mkdir(parents=True, exist_ok=True)
|
|
||||||
except OSError as exc:
|
|
||||||
raise ConfigWriteError(f"Cannot create jail.d directory: {exc}") from exc
|
|
||||||
|
|
||||||
local_path = jail_d / f"{jail_name}.local"
|
|
||||||
|
|
||||||
parser = _build_parser()
|
|
||||||
if local_path.is_file():
|
|
||||||
try:
|
|
||||||
parser.read(str(local_path), encoding="utf-8")
|
|
||||||
except (configparser.Error, OSError) as exc:
|
|
||||||
log.warning(
|
|
||||||
"jail_local_read_for_update_error",
|
|
||||||
jail=jail_name,
|
|
||||||
error=str(exc),
|
|
||||||
)
|
|
||||||
|
|
||||||
if not parser.has_section(jail_name):
|
|
||||||
parser.add_section(jail_name)
|
|
||||||
parser.set(jail_name, key, value)
|
|
||||||
|
|
||||||
# Serialize: write a BanGUI header then the parser output.
|
|
||||||
buf = io.StringIO()
|
|
||||||
buf.write("# Managed by BanGUI — do not edit manually\n\n")
|
|
||||||
parser.write(buf)
|
|
||||||
content = buf.getvalue()
|
|
||||||
|
|
||||||
try:
|
|
||||||
with tempfile.NamedTemporaryFile(
|
|
||||||
mode="w",
|
|
||||||
encoding="utf-8",
|
|
||||||
dir=jail_d,
|
|
||||||
delete=False,
|
|
||||||
suffix=".tmp",
|
|
||||||
) as tmp:
|
|
||||||
tmp.write(content)
|
|
||||||
tmp_name = tmp.name
|
|
||||||
os.replace(tmp_name, local_path)
|
|
||||||
except OSError as exc:
|
|
||||||
with contextlib.suppress(OSError):
|
|
||||||
os.unlink(tmp_name) # noqa: F821
|
|
||||||
raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc
|
|
||||||
|
|
||||||
log.info(
|
|
||||||
"jail_local_key_set",
|
|
||||||
jail=jail_name,
|
|
||||||
key=key,
|
|
||||||
path=str(local_path),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def _probe_fail2ban_running(socket_path: str) -> bool:
|
|
||||||
"""Return ``True`` if the fail2ban socket responds to a ping.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
socket_path: Path to the fail2ban Unix domain socket.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
``True`` when fail2ban is reachable, ``False`` otherwise.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
client = Fail2BanClient(socket_path=socket_path, timeout=5.0)
|
|
||||||
resp = await client.send(["ping"])
|
|
||||||
return isinstance(resp, (list, tuple)) and resp[0] == 0
|
|
||||||
except Exception: # noqa: BLE001
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
async def wait_for_fail2ban(
|
|
||||||
socket_path: str,
|
|
||||||
max_wait_seconds: float = 10.0,
|
|
||||||
poll_interval: float = 2.0,
|
|
||||||
) -> bool:
|
|
||||||
"""Poll the fail2ban socket until it responds or the timeout expires.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
socket_path: Path to the fail2ban Unix domain socket.
|
|
||||||
max_wait_seconds: Total time budget in seconds.
|
|
||||||
poll_interval: Delay between probe attempts in seconds.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
``True`` if fail2ban came online within the budget.
|
|
||||||
"""
|
|
||||||
elapsed = 0.0
|
|
||||||
while elapsed < max_wait_seconds:
|
|
||||||
if await _probe_fail2ban_running(socket_path):
|
|
||||||
return True
|
|
||||||
await asyncio.sleep(poll_interval)
|
|
||||||
elapsed += poll_interval
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
async def start_daemon(start_cmd_parts: list[str]) -> bool:
|
|
||||||
"""Start the fail2ban daemon using *start_cmd_parts*.
|
|
||||||
|
|
||||||
Uses :func:`asyncio.create_subprocess_exec` (no shell interpretation)
|
|
||||||
to avoid command injection.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
start_cmd_parts: Command and arguments, e.g.
|
|
||||||
``["fail2ban-client", "start"]``.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
``True`` when the process exited with code 0.
|
|
||||||
"""
|
|
||||||
if not start_cmd_parts:
|
|
||||||
log.warning("fail2ban_start_cmd_empty")
|
|
||||||
return False
|
|
||||||
try:
|
|
||||||
proc = await asyncio.create_subprocess_exec(
|
|
||||||
*start_cmd_parts,
|
|
||||||
stdout=asyncio.subprocess.PIPE,
|
|
||||||
stderr=asyncio.subprocess.PIPE,
|
|
||||||
)
|
|
||||||
await asyncio.wait_for(proc.wait(), timeout=30.0)
|
|
||||||
success = proc.returncode == 0
|
|
||||||
if not success:
|
|
||||||
log.warning(
|
|
||||||
"fail2ban_start_cmd_nonzero",
|
|
||||||
cmd=start_cmd_parts,
|
|
||||||
returncode=proc.returncode,
|
|
||||||
)
|
|
||||||
return success
|
|
||||||
except (TimeoutError, OSError) as exc:
|
|
||||||
log.warning("fail2ban_start_cmd_error", cmd=start_cmd_parts, error=str(exc))
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
# Shared functions from config_file_service are imported from app.utils.config_file_utils
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Public API
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
async def list_inactive_jails(
|
|
||||||
config_dir: str,
|
|
||||||
socket_path: str,
|
|
||||||
) -> InactiveJailListResponse:
|
|
||||||
"""Return all jails defined in config files that are not currently active.
|
|
||||||
|
|
||||||
Parses ``jail.conf``, ``jail.local``, and ``jail.d/`` following the
|
|
||||||
fail2ban merge order. A jail is considered inactive when:
|
|
||||||
|
|
||||||
- Its merged ``enabled`` value is ``false`` (or absent, which defaults to
|
|
||||||
``false`` in fail2ban), **or**
|
|
||||||
- Its ``enabled`` value is ``true`` in config but fail2ban does not report
|
|
||||||
it as running.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config_dir: Absolute path to the fail2ban configuration directory.
|
|
||||||
socket_path: Path to the fail2ban Unix domain socket.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
:class:`~app.models.config.InactiveJailListResponse` with all
|
|
||||||
inactive jails.
|
|
||||||
"""
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
parsed_result: tuple[dict[str, dict[str, str]], dict[str, str]] = await loop.run_in_executor(
|
|
||||||
None, _parse_jails_sync, Path(config_dir)
|
|
||||||
)
|
|
||||||
all_jails, source_files = parsed_result
|
|
||||||
active_names: set[str] = await _get_active_jail_names(socket_path)
|
|
||||||
|
|
||||||
inactive: list[InactiveJail] = []
|
|
||||||
for jail_name, settings in sorted(all_jails.items()):
|
|
||||||
if jail_name in active_names:
|
|
||||||
# fail2ban reports this jail as running — skip it.
|
|
||||||
continue
|
|
||||||
|
|
||||||
source = source_files.get(jail_name, config_dir)
|
|
||||||
inactive.append(_build_inactive_jail(jail_name, settings, source, Path(config_dir)))
|
|
||||||
|
|
||||||
log.info(
|
|
||||||
"inactive_jails_listed",
|
|
||||||
total_defined=len(all_jails),
|
|
||||||
active=len(active_names),
|
|
||||||
inactive=len(inactive),
|
|
||||||
)
|
|
||||||
return InactiveJailListResponse(jails=inactive, total=len(inactive))
|
|
||||||
|
|
||||||
|
|
||||||
async def activate_jail(
|
|
||||||
config_dir: str,
|
|
||||||
socket_path: str,
|
|
||||||
name: str,
|
|
||||||
req: ActivateJailRequest,
|
|
||||||
) -> JailActivationResponse:
|
|
||||||
"""Enable an inactive jail and reload fail2ban.
|
|
||||||
|
|
||||||
Performs pre-activation validation, writes ``enabled = true`` (plus any
|
|
||||||
override values from *req*) to ``jail.d/{name}.local``, and triggers a
|
|
||||||
full fail2ban reload. After the reload a multi-attempt health probe
|
|
||||||
determines whether fail2ban (and the specific jail) are still running.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config_dir: Absolute path to the fail2ban configuration directory.
|
|
||||||
socket_path: Path to the fail2ban Unix domain socket.
|
|
||||||
name: Name of the jail to activate. Must exist in the parsed config.
|
|
||||||
req: Optional override values to write alongside ``enabled = true``.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
:class:`~app.models.config.JailActivationResponse` including
|
|
||||||
``fail2ban_running`` and ``validation_warnings`` fields.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
JailNameError: If *name* contains invalid characters.
|
|
||||||
JailNotFoundInConfigError: If *name* is not defined in any config file.
|
|
||||||
JailAlreadyActiveError: If fail2ban already reports *name* as running.
|
|
||||||
ConfigWriteError: If writing the ``.local`` file fails.
|
|
||||||
~app.utils.fail2ban_client.Fail2BanConnectionError: If the fail2ban
|
|
||||||
socket is unreachable during reload.
|
|
||||||
"""
|
|
||||||
_safe_jail_name(name)
|
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
all_jails, _source_files = await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir))
|
|
||||||
|
|
||||||
if name not in all_jails:
|
|
||||||
raise JailNotFoundInConfigError(name)
|
|
||||||
|
|
||||||
active_names = await _get_active_jail_names(socket_path)
|
|
||||||
if name in active_names:
|
|
||||||
raise JailAlreadyActiveError(name)
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------- #
|
|
||||||
# Pre-activation validation — collect warnings but do not block #
|
|
||||||
# ---------------------------------------------------------------------- #
|
|
||||||
validation_result: JailValidationResult = await loop.run_in_executor(
|
|
||||||
None, _validate_jail_config_sync, Path(config_dir), name
|
|
||||||
)
|
|
||||||
warnings: list[str] = [f"{i.field}: {i.message}" for i in validation_result.issues]
|
|
||||||
if warnings:
|
|
||||||
log.warning(
|
|
||||||
"jail_activation_validation_warnings",
|
|
||||||
jail=name,
|
|
||||||
warnings=warnings,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Block activation on critical validation failures (missing filter or logpath).
|
|
||||||
blocking = [i for i in validation_result.issues if i.field in ("filter", "logpath")]
|
|
||||||
if blocking:
|
|
||||||
log.warning(
|
|
||||||
"jail_activation_blocked",
|
|
||||||
jail=name,
|
|
||||||
issues=[f"{i.field}: {i.message}" for i in blocking],
|
|
||||||
)
|
|
||||||
return JailActivationResponse(
|
|
||||||
name=name,
|
|
||||||
active=False,
|
|
||||||
fail2ban_running=True,
|
|
||||||
validation_warnings=warnings,
|
|
||||||
message=(f"Jail {name!r} cannot be activated: " + "; ".join(i.message for i in blocking)),
|
|
||||||
)
|
|
||||||
|
|
||||||
overrides: dict[str, object] = {
|
|
||||||
"bantime": req.bantime,
|
|
||||||
"findtime": req.findtime,
|
|
||||||
"maxretry": req.maxretry,
|
|
||||||
"port": req.port,
|
|
||||||
"logpath": req.logpath,
|
|
||||||
}
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------- #
|
|
||||||
# Backup the existing .local file (if any) before overwriting it so that #
|
|
||||||
# we can restore it if activation fails. #
|
|
||||||
# ---------------------------------------------------------------------- #
|
|
||||||
local_path = Path(config_dir) / "jail.d" / f"{name}.local"
|
|
||||||
original_content: bytes | None = await loop.run_in_executor(
|
|
||||||
None,
|
|
||||||
lambda: local_path.read_bytes() if local_path.exists() else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
await loop.run_in_executor(
|
|
||||||
None,
|
|
||||||
_write_local_override_sync,
|
|
||||||
Path(config_dir),
|
|
||||||
name,
|
|
||||||
True,
|
|
||||||
overrides,
|
|
||||||
)
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------- #
|
|
||||||
# Activation reload — if it fails, roll back immediately #
|
|
||||||
# ---------------------------------------------------------------------- #
|
|
||||||
try:
|
|
||||||
await reload_jails(socket_path, include_jails=[name])
|
|
||||||
except JailNotFoundError as exc:
|
|
||||||
# Jail configuration is invalid (e.g. missing logpath that prevents
|
|
||||||
# fail2ban from loading the jail). Roll back and provide a specific error.
|
|
||||||
log.warning(
|
|
||||||
"reload_after_activate_failed_jail_not_found",
|
|
||||||
jail=name,
|
|
||||||
error=str(exc),
|
|
||||||
)
|
|
||||||
recovered = await _rollback_activation_async(config_dir, name, socket_path, original_content)
|
|
||||||
return JailActivationResponse(
|
|
||||||
name=name,
|
|
||||||
active=False,
|
|
||||||
fail2ban_running=False,
|
|
||||||
recovered=recovered,
|
|
||||||
validation_warnings=warnings,
|
|
||||||
message=(
|
|
||||||
f"Jail {name!r} activation failed: {str(exc)}. "
|
|
||||||
"Check that all logpath files exist and are readable. "
|
|
||||||
"The configuration was "
|
|
||||||
+ ("automatically recovered." if recovered else "not recovered — manual intervention is required.")
|
|
||||||
),
|
|
||||||
)
|
|
||||||
except Exception as exc: # noqa: BLE001
|
|
||||||
log.warning("reload_after_activate_failed", jail=name, error=str(exc))
|
|
||||||
recovered = await _rollback_activation_async(config_dir, name, socket_path, original_content)
|
|
||||||
return JailActivationResponse(
|
|
||||||
name=name,
|
|
||||||
active=False,
|
|
||||||
fail2ban_running=False,
|
|
||||||
recovered=recovered,
|
|
||||||
validation_warnings=warnings,
|
|
||||||
message=(
|
|
||||||
f"Jail {name!r} activation failed during reload and the "
|
|
||||||
"configuration was "
|
|
||||||
+ ("automatically recovered." if recovered else "not recovered — manual intervention is required.")
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------- #
|
|
||||||
# Post-reload health probe with retries #
|
|
||||||
# ---------------------------------------------------------------------- #
|
|
||||||
fail2ban_running = False
|
|
||||||
for attempt in range(_POST_RELOAD_MAX_ATTEMPTS):
|
|
||||||
if attempt > 0:
|
|
||||||
await asyncio.sleep(_POST_RELOAD_PROBE_INTERVAL)
|
|
||||||
if await _probe_fail2ban_running(socket_path):
|
|
||||||
fail2ban_running = True
|
|
||||||
break
|
|
||||||
|
|
||||||
if not fail2ban_running:
|
|
||||||
log.warning(
|
|
||||||
"fail2ban_down_after_activate",
|
|
||||||
jail=name,
|
|
||||||
message="fail2ban socket unreachable after reload — initiating rollback.",
|
|
||||||
)
|
|
||||||
recovered = await _rollback_activation_async(config_dir, name, socket_path, original_content)
|
|
||||||
return JailActivationResponse(
|
|
||||||
name=name,
|
|
||||||
active=False,
|
|
||||||
fail2ban_running=False,
|
|
||||||
recovered=recovered,
|
|
||||||
validation_warnings=warnings,
|
|
||||||
message=(
|
|
||||||
f"Jail {name!r} activation failed: fail2ban stopped responding "
|
|
||||||
"after reload. The configuration was "
|
|
||||||
+ ("automatically recovered." if recovered else "not recovered — manual intervention is required.")
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify the jail actually started (config error may prevent it silently).
|
|
||||||
post_reload_names = await _get_active_jail_names(socket_path)
|
|
||||||
actually_running = name in post_reload_names
|
|
||||||
if not actually_running:
|
|
||||||
log.warning(
|
|
||||||
"jail_activation_unverified",
|
|
||||||
jail=name,
|
|
||||||
message="Jail did not appear in running jails — initiating rollback.",
|
|
||||||
)
|
|
||||||
recovered = await _rollback_activation_async(config_dir, name, socket_path, original_content)
|
|
||||||
return JailActivationResponse(
|
|
||||||
name=name,
|
|
||||||
active=False,
|
|
||||||
fail2ban_running=True,
|
|
||||||
recovered=recovered,
|
|
||||||
validation_warnings=warnings,
|
|
||||||
message=(
|
|
||||||
f"Jail {name!r} was written to config but did not start after "
|
|
||||||
"reload. The configuration was "
|
|
||||||
+ ("automatically recovered." if recovered else "not recovered — manual intervention is required.")
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
log.info("jail_activated", jail=name)
|
|
||||||
return JailActivationResponse(
|
|
||||||
name=name,
|
|
||||||
active=True,
|
|
||||||
fail2ban_running=True,
|
|
||||||
validation_warnings=warnings,
|
|
||||||
message=f"Jail {name!r} activated successfully.",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def _rollback_activation_async(
|
|
||||||
config_dir: str,
|
|
||||||
name: str,
|
|
||||||
socket_path: str,
|
|
||||||
original_content: bytes | None,
|
|
||||||
) -> bool:
|
|
||||||
"""Restore the pre-activation ``.local`` file and reload fail2ban.
|
|
||||||
|
|
||||||
Called internally by :func:`activate_jail` when the activation fails after
|
|
||||||
the config file was already written. Tries to:
|
|
||||||
|
|
||||||
1. Restore the original file content (or delete the file if it was newly
|
|
||||||
created by the activation attempt).
|
|
||||||
2. Reload fail2ban so the daemon runs with the restored configuration.
|
|
||||||
3. Probe fail2ban to confirm it came back up.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config_dir: Absolute path to the fail2ban configuration directory.
|
|
||||||
name: Name of the jail whose ``.local`` file should be restored.
|
|
||||||
socket_path: Path to the fail2ban Unix domain socket.
|
|
||||||
original_content: Raw bytes of the original ``.local`` file, or
|
|
||||||
``None`` if the file did not exist before the activation.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
``True`` if fail2ban is responsive again after the rollback, ``False``
|
|
||||||
if recovery also failed.
|
|
||||||
"""
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
local_path = Path(config_dir) / "jail.d" / f"{name}.local"
|
|
||||||
|
|
||||||
# Step 1 — restore original file (or delete it).
|
|
||||||
try:
|
|
||||||
await loop.run_in_executor(None, _restore_local_file_sync, local_path, original_content)
|
|
||||||
log.info("jail_activation_rollback_file_restored", jail=name)
|
|
||||||
except ConfigWriteError as exc:
|
|
||||||
log.error("jail_activation_rollback_restore_failed", jail=name, error=str(exc))
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Step 2 — reload fail2ban with the restored config.
|
|
||||||
try:
|
|
||||||
await reload_jails(socket_path)
|
|
||||||
log.info("jail_activation_rollback_reload_ok", jail=name)
|
|
||||||
except Exception as exc: # noqa: BLE001
|
|
||||||
log.warning("jail_activation_rollback_reload_failed", jail=name, error=str(exc))
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Step 3 — wait for fail2ban to come back.
|
|
||||||
for attempt in range(_POST_RELOAD_MAX_ATTEMPTS):
|
|
||||||
if attempt > 0:
|
|
||||||
await asyncio.sleep(_POST_RELOAD_PROBE_INTERVAL)
|
|
||||||
if await _probe_fail2ban_running(socket_path):
|
|
||||||
log.info("jail_activation_rollback_recovered", jail=name)
|
|
||||||
return True
|
|
||||||
|
|
||||||
log.warning("jail_activation_rollback_still_down", jail=name)
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
async def deactivate_jail(
|
|
||||||
config_dir: str,
|
|
||||||
socket_path: str,
|
|
||||||
name: str,
|
|
||||||
) -> JailActivationResponse:
|
|
||||||
"""Disable an active jail and reload fail2ban.
|
|
||||||
|
|
||||||
Writes ``enabled = false`` to ``jail.d/{name}.local`` and triggers a
|
|
||||||
full fail2ban reload so the jail stops immediately.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config_dir: Absolute path to the fail2ban configuration directory.
|
|
||||||
socket_path: Path to the fail2ban Unix domain socket.
|
|
||||||
name: Name of the jail to deactivate. Must exist in the parsed config.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
:class:`~app.models.config.JailActivationResponse`.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
JailNameError: If *name* contains invalid characters.
|
|
||||||
JailNotFoundInConfigError: If *name* is not defined in any config file.
|
|
||||||
JailAlreadyInactiveError: If fail2ban already reports *name* as not
|
|
||||||
running.
|
|
||||||
ConfigWriteError: If writing the ``.local`` file fails.
|
|
||||||
~app.utils.fail2ban_client.Fail2BanConnectionError: If the fail2ban
|
|
||||||
socket is unreachable during reload.
|
|
||||||
"""
|
|
||||||
_safe_jail_name(name)
|
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
all_jails, _source_files = await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir))
|
|
||||||
|
|
||||||
if name not in all_jails:
|
|
||||||
raise JailNotFoundInConfigError(name)
|
|
||||||
|
|
||||||
active_names = await _get_active_jail_names(socket_path)
|
|
||||||
if name not in active_names:
|
|
||||||
raise JailAlreadyInactiveError(name)
|
|
||||||
|
|
||||||
await loop.run_in_executor(
|
|
||||||
None,
|
|
||||||
_write_local_override_sync,
|
|
||||||
Path(config_dir),
|
|
||||||
name,
|
|
||||||
False,
|
|
||||||
{},
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
await reload_jails(socket_path, exclude_jails=[name])
|
|
||||||
except Exception as exc: # noqa: BLE001
|
|
||||||
log.warning("reload_after_deactivate_failed", jail=name, error=str(exc))
|
|
||||||
|
|
||||||
log.info("jail_deactivated", jail=name)
|
|
||||||
return JailActivationResponse(
|
|
||||||
name=name,
|
|
||||||
active=False,
|
|
||||||
message=f"Jail {name!r} deactivated successfully.",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def delete_jail_local_override(
|
|
||||||
config_dir: str,
|
|
||||||
socket_path: str,
|
|
||||||
name: str,
|
|
||||||
) -> None:
|
|
||||||
"""Delete the ``jail.d/{name}.local`` override file for an inactive jail.
|
|
||||||
|
|
||||||
This is the clean-up action shown in the config UI when an inactive jail
|
|
||||||
still has a ``.local`` override file (e.g. ``enabled = false``). The
|
|
||||||
file is deleted outright; no fail2ban reload is required because the jail
|
|
||||||
is already inactive.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config_dir: Absolute path to the fail2ban configuration directory.
|
|
||||||
socket_path: Path to the fail2ban Unix domain socket.
|
|
||||||
name: Name of the jail whose ``.local`` file should be removed.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
JailNameError: If *name* contains invalid characters.
|
|
||||||
JailNotFoundInConfigError: If *name* is not defined in any config file.
|
|
||||||
JailAlreadyActiveError: If the jail is currently active (refusing to
|
|
||||||
delete the live config file).
|
|
||||||
ConfigWriteError: If the file cannot be deleted.
|
|
||||||
"""
|
|
||||||
_safe_jail_name(name)
|
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
all_jails, _source_files = await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir))
|
|
||||||
|
|
||||||
if name not in all_jails:
|
|
||||||
raise JailNotFoundInConfigError(name)
|
|
||||||
|
|
||||||
active_names = await _get_active_jail_names(socket_path)
|
|
||||||
if name in active_names:
|
|
||||||
raise JailAlreadyActiveError(name)
|
|
||||||
|
|
||||||
local_path = Path(config_dir) / "jail.d" / f"{name}.local"
|
|
||||||
try:
|
|
||||||
await loop.run_in_executor(None, lambda: local_path.unlink(missing_ok=True))
|
|
||||||
except OSError as exc:
|
|
||||||
raise ConfigWriteError(f"Failed to delete {local_path}: {exc}") from exc
|
|
||||||
|
|
||||||
log.info("jail_local_override_deleted", jail=name, path=str(local_path))
|
|
||||||
|
|
||||||
|
|
||||||
async def validate_jail_config(
|
|
||||||
config_dir: str,
|
|
||||||
name: str,
|
|
||||||
) -> JailValidationResult:
|
|
||||||
"""Run pre-activation validation checks on a jail configuration.
|
|
||||||
|
|
||||||
Validates that referenced filter and action files exist in ``filter.d/``
|
|
||||||
and ``action.d/``, that all regex patterns compile, and that declared log
|
|
||||||
paths exist on disk.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config_dir: Absolute path to the fail2ban configuration directory.
|
|
||||||
name: Name of the jail to validate.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
:class:`~app.models.config.JailValidationResult` with any issues found.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
JailNameError: If *name* contains invalid characters.
|
|
||||||
"""
|
|
||||||
_safe_jail_name(name)
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
return await loop.run_in_executor(
|
|
||||||
None,
|
|
||||||
_validate_jail_config_sync,
|
|
||||||
Path(config_dir),
|
|
||||||
name,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def rollback_jail(
|
|
||||||
config_dir: str,
|
|
||||||
socket_path: str,
|
|
||||||
name: str,
|
|
||||||
start_cmd_parts: list[str],
|
|
||||||
) -> RollbackResponse:
|
|
||||||
"""Disable a bad jail config and restart the fail2ban daemon.
|
|
||||||
|
|
||||||
Writes ``enabled = false`` to ``jail.d/{name}.local`` (works even when
|
|
||||||
fail2ban is down — only a file write), then attempts to start the daemon
|
|
||||||
with *start_cmd_parts*. Waits up to 10 seconds for the socket to respond.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config_dir: Absolute path to the fail2ban configuration directory.
|
|
||||||
socket_path: Path to the fail2ban Unix domain socket.
|
|
||||||
name: Name of the jail to disable.
|
|
||||||
start_cmd_parts: Argument list for the daemon start command, e.g.
|
|
||||||
``["fail2ban-client", "start"]``.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
:class:`~app.models.config.RollbackResponse`.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
JailNameError: If *name* contains invalid characters.
|
|
||||||
ConfigWriteError: If writing the ``.local`` file fails.
|
|
||||||
"""
|
|
||||||
_safe_jail_name(name)
|
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
|
|
||||||
# Write enabled=false — this must succeed even when fail2ban is down.
|
|
||||||
await loop.run_in_executor(
|
|
||||||
None,
|
|
||||||
_write_local_override_sync,
|
|
||||||
Path(config_dir),
|
|
||||||
name,
|
|
||||||
False,
|
|
||||||
{},
|
|
||||||
)
|
|
||||||
log.info("jail_rolled_back_disabled", jail=name)
|
|
||||||
|
|
||||||
# Attempt to start the daemon.
|
|
||||||
started = await start_daemon(start_cmd_parts)
|
|
||||||
log.info("jail_rollback_start_attempted", jail=name, start_ok=started)
|
|
||||||
|
|
||||||
# Wait for the socket to come back.
|
|
||||||
fail2ban_running = await wait_for_fail2ban(socket_path, max_wait_seconds=10.0, poll_interval=2.0)
|
|
||||||
|
|
||||||
active_jails = 0
|
|
||||||
if fail2ban_running:
|
|
||||||
names = await _get_active_jail_names(socket_path)
|
|
||||||
active_jails = len(names)
|
|
||||||
|
|
||||||
if fail2ban_running:
|
|
||||||
log.info("jail_rollback_success", jail=name, active_jails=active_jails)
|
|
||||||
return RollbackResponse(
|
|
||||||
jail_name=name,
|
|
||||||
disabled=True,
|
|
||||||
fail2ban_running=True,
|
|
||||||
active_jails=active_jails,
|
|
||||||
message=(f"Jail {name!r} disabled and fail2ban restarted successfully with {active_jails} active jail(s)."),
|
|
||||||
)
|
|
||||||
|
|
||||||
log.warning("jail_rollback_fail2ban_still_down", jail=name)
|
|
||||||
return RollbackResponse(
|
|
||||||
jail_name=name,
|
|
||||||
disabled=True,
|
|
||||||
fail2ban_running=False,
|
|
||||||
active_jails=0,
|
|
||||||
message=(
|
|
||||||
f"Jail {name!r} was disabled but fail2ban did not come back online. "
|
|
||||||
"Check the fail2ban log for additional errors."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
@@ -14,11 +14,10 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import contextlib
|
import contextlib
|
||||||
import ipaddress
|
import ipaddress
|
||||||
from typing import TYPE_CHECKING, TypedDict, cast
|
from typing import Any
|
||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
|
|
||||||
from app.exceptions import JailNotFoundError, JailOperationError
|
|
||||||
from app.models.ban import ActiveBan, ActiveBanListResponse, JailBannedIpsResponse
|
from app.models.ban import ActiveBan, ActiveBanListResponse, JailBannedIpsResponse
|
||||||
from app.models.config import BantimeEscalation
|
from app.models.config import BantimeEscalation
|
||||||
from app.models.jail import (
|
from app.models.jail import (
|
||||||
@@ -28,36 +27,10 @@ from app.models.jail import (
|
|||||||
JailStatus,
|
JailStatus,
|
||||||
JailSummary,
|
JailSummary,
|
||||||
)
|
)
|
||||||
from app.utils.fail2ban_client import (
|
from app.utils.fail2ban_client import Fail2BanClient, Fail2BanConnectionError
|
||||||
Fail2BanClient,
|
|
||||||
Fail2BanCommand,
|
|
||||||
Fail2BanConnectionError,
|
|
||||||
Fail2BanResponse,
|
|
||||||
Fail2BanToken,
|
|
||||||
)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from collections.abc import Awaitable
|
|
||||||
|
|
||||||
import aiohttp
|
|
||||||
import aiosqlite
|
|
||||||
|
|
||||||
from app.models.geo import GeoBatchLookup, GeoEnricher, GeoInfo
|
|
||||||
|
|
||||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||||
|
|
||||||
class IpLookupResult(TypedDict):
|
|
||||||
"""Result returned by :func:`lookup_ip`.
|
|
||||||
|
|
||||||
This is intentionally a :class:`TypedDict` to provide precise typing for
|
|
||||||
callers (e.g. routers) while keeping the implementation flexible.
|
|
||||||
"""
|
|
||||||
|
|
||||||
ip: str
|
|
||||||
currently_banned_in: list[str]
|
|
||||||
geo: GeoInfo | None
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Constants
|
# Constants
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -82,12 +55,29 @@ _backend_cmd_lock: asyncio.Lock = asyncio.Lock()
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class JailNotFoundError(Exception):
|
||||||
|
"""Raised when a requested jail name does not exist in fail2ban."""
|
||||||
|
|
||||||
|
def __init__(self, name: str) -> None:
|
||||||
|
"""Initialise with the jail name that was not found.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The jail name that could not be located.
|
||||||
|
"""
|
||||||
|
self.name: str = name
|
||||||
|
super().__init__(f"Jail not found: {name!r}")
|
||||||
|
|
||||||
|
|
||||||
|
class JailOperationError(Exception):
|
||||||
|
"""Raised when a jail control command fails for a non-auth reason."""
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Internal helpers
|
# Internal helpers
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def _ok(response: object) -> object:
|
def _ok(response: Any) -> Any:
|
||||||
"""Extract the payload from a fail2ban ``(return_code, data)`` response.
|
"""Extract the payload from a fail2ban ``(return_code, data)`` response.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -100,7 +90,7 @@ def _ok(response: object) -> object:
|
|||||||
ValueError: If the response indicates an error (return code ≠ 0).
|
ValueError: If the response indicates an error (return code ≠ 0).
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
code, data = cast("Fail2BanResponse", response)
|
code, data = response
|
||||||
except (TypeError, ValueError) as exc:
|
except (TypeError, ValueError) as exc:
|
||||||
raise ValueError(f"Unexpected fail2ban response shape: {response!r}") from exc
|
raise ValueError(f"Unexpected fail2ban response shape: {response!r}") from exc
|
||||||
|
|
||||||
@@ -110,7 +100,7 @@ def _ok(response: object) -> object:
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
def _to_dict(pairs: object) -> dict[str, object]:
|
def _to_dict(pairs: Any) -> dict[str, Any]:
|
||||||
"""Convert a list of ``(key, value)`` pairs to a plain dict.
|
"""Convert a list of ``(key, value)`` pairs to a plain dict.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -121,7 +111,7 @@ def _to_dict(pairs: object) -> dict[str, object]:
|
|||||||
"""
|
"""
|
||||||
if not isinstance(pairs, (list, tuple)):
|
if not isinstance(pairs, (list, tuple)):
|
||||||
return {}
|
return {}
|
||||||
result: dict[str, object] = {}
|
result: dict[str, Any] = {}
|
||||||
for item in pairs:
|
for item in pairs:
|
||||||
try:
|
try:
|
||||||
k, v = item
|
k, v = item
|
||||||
@@ -131,7 +121,7 @@ def _to_dict(pairs: object) -> dict[str, object]:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def _ensure_list(value: object | None) -> list[str]:
|
def _ensure_list(value: Any) -> list[str]:
|
||||||
"""Coerce a fail2ban response value to a list of strings.
|
"""Coerce a fail2ban response value to a list of strings.
|
||||||
|
|
||||||
Some fail2ban ``get`` responses return ``None`` or a single string
|
Some fail2ban ``get`` responses return ``None`` or a single string
|
||||||
@@ -180,9 +170,9 @@ def _is_not_found_error(exc: Exception) -> bool:
|
|||||||
|
|
||||||
async def _safe_get(
|
async def _safe_get(
|
||||||
client: Fail2BanClient,
|
client: Fail2BanClient,
|
||||||
command: Fail2BanCommand,
|
command: list[Any],
|
||||||
default: object | None = None,
|
default: Any = None,
|
||||||
) -> object | None:
|
) -> Any:
|
||||||
"""Send a ``get`` command and return ``default`` on error.
|
"""Send a ``get`` command and return ``default`` on error.
|
||||||
|
|
||||||
Errors during optional detail queries (logpath, regex, etc.) should
|
Errors during optional detail queries (logpath, regex, etc.) should
|
||||||
@@ -197,8 +187,7 @@ async def _safe_get(
|
|||||||
The response payload, or *default* on any error.
|
The response payload, or *default* on any error.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
response = await client.send(command)
|
return _ok(await client.send(command))
|
||||||
return _ok(cast("Fail2BanResponse", response))
|
|
||||||
except (ValueError, TypeError, Exception):
|
except (ValueError, TypeError, Exception):
|
||||||
return default
|
return default
|
||||||
|
|
||||||
@@ -320,7 +309,7 @@ async def _fetch_jail_summary(
|
|||||||
backend_cmd_is_supported = await _check_backend_cmd_supported(client, name)
|
backend_cmd_is_supported = await _check_backend_cmd_supported(client, name)
|
||||||
|
|
||||||
# Build the gather list based on command support.
|
# Build the gather list based on command support.
|
||||||
gather_list: list[Awaitable[object]] = [
|
gather_list: list[Any] = [
|
||||||
client.send(["status", name, "short"]),
|
client.send(["status", name, "short"]),
|
||||||
client.send(["get", name, "bantime"]),
|
client.send(["get", name, "bantime"]),
|
||||||
client.send(["get", name, "findtime"]),
|
client.send(["get", name, "findtime"]),
|
||||||
@@ -333,23 +322,25 @@ async def _fetch_jail_summary(
|
|||||||
client.send(["get", name, "backend"]),
|
client.send(["get", name, "backend"]),
|
||||||
client.send(["get", name, "idle"]),
|
client.send(["get", name, "idle"]),
|
||||||
])
|
])
|
||||||
|
uses_backend_backend_commands = True
|
||||||
else:
|
else:
|
||||||
# Commands not supported; return default values without sending.
|
# Commands not supported; return default values without sending.
|
||||||
async def _return_default(value: object | None) -> Fail2BanResponse:
|
async def _return_default(value: Any) -> tuple[int, Any]:
|
||||||
return (0, value)
|
return (0, value)
|
||||||
|
|
||||||
gather_list.extend([
|
gather_list.extend([
|
||||||
_return_default("polling"), # backend default
|
_return_default("polling"), # backend default
|
||||||
_return_default(False), # idle default
|
_return_default(False), # idle default
|
||||||
])
|
])
|
||||||
|
uses_backend_backend_commands = False
|
||||||
|
|
||||||
_r = await asyncio.gather(*gather_list, return_exceptions=True)
|
_r = await asyncio.gather(*gather_list, return_exceptions=True)
|
||||||
status_raw: object | Exception = _r[0]
|
status_raw: Any = _r[0]
|
||||||
bantime_raw: object | Exception = _r[1]
|
bantime_raw: Any = _r[1]
|
||||||
findtime_raw: object | Exception = _r[2]
|
findtime_raw: Any = _r[2]
|
||||||
maxretry_raw: object | Exception = _r[3]
|
maxretry_raw: Any = _r[3]
|
||||||
backend_raw: object | Exception = _r[4]
|
backend_raw: Any = _r[4]
|
||||||
idle_raw: object | Exception = _r[5]
|
idle_raw: Any = _r[5]
|
||||||
|
|
||||||
# Parse jail status (filter + actions).
|
# Parse jail status (filter + actions).
|
||||||
jail_status: JailStatus | None = None
|
jail_status: JailStatus | None = None
|
||||||
@@ -359,35 +350,35 @@ async def _fetch_jail_summary(
|
|||||||
filter_stats = _to_dict(raw.get("Filter") or [])
|
filter_stats = _to_dict(raw.get("Filter") or [])
|
||||||
action_stats = _to_dict(raw.get("Actions") or [])
|
action_stats = _to_dict(raw.get("Actions") or [])
|
||||||
jail_status = JailStatus(
|
jail_status = JailStatus(
|
||||||
currently_banned=int(str(action_stats.get("Currently banned", 0) or 0)),
|
currently_banned=int(action_stats.get("Currently banned", 0) or 0),
|
||||||
total_banned=int(str(action_stats.get("Total banned", 0) or 0)),
|
total_banned=int(action_stats.get("Total banned", 0) or 0),
|
||||||
currently_failed=int(str(filter_stats.get("Currently failed", 0) or 0)),
|
currently_failed=int(filter_stats.get("Currently failed", 0) or 0),
|
||||||
total_failed=int(str(filter_stats.get("Total failed", 0) or 0)),
|
total_failed=int(filter_stats.get("Total failed", 0) or 0),
|
||||||
)
|
)
|
||||||
except (ValueError, TypeError) as exc:
|
except (ValueError, TypeError) as exc:
|
||||||
log.warning("jail_status_parse_error", jail=name, error=str(exc))
|
log.warning("jail_status_parse_error", jail=name, error=str(exc))
|
||||||
|
|
||||||
def _safe_int(raw: object | Exception, fallback: int) -> int:
|
def _safe_int(raw: Any, fallback: int) -> int:
|
||||||
if isinstance(raw, Exception):
|
if isinstance(raw, Exception):
|
||||||
return fallback
|
return fallback
|
||||||
try:
|
try:
|
||||||
return int(str(_ok(cast("Fail2BanResponse", raw))))
|
return int(_ok(raw))
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
return fallback
|
return fallback
|
||||||
|
|
||||||
def _safe_str(raw: object | Exception, fallback: str) -> str:
|
def _safe_str(raw: Any, fallback: str) -> str:
|
||||||
if isinstance(raw, Exception):
|
if isinstance(raw, Exception):
|
||||||
return fallback
|
return fallback
|
||||||
try:
|
try:
|
||||||
return str(_ok(cast("Fail2BanResponse", raw)))
|
return str(_ok(raw))
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
return fallback
|
return fallback
|
||||||
|
|
||||||
def _safe_bool(raw: object | Exception, fallback: bool = False) -> bool:
|
def _safe_bool(raw: Any, fallback: bool = False) -> bool:
|
||||||
if isinstance(raw, Exception):
|
if isinstance(raw, Exception):
|
||||||
return fallback
|
return fallback
|
||||||
try:
|
try:
|
||||||
return bool(_ok(cast("Fail2BanResponse", raw)))
|
return bool(_ok(raw))
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
return fallback
|
return fallback
|
||||||
|
|
||||||
@@ -437,10 +428,10 @@ async def get_jail(socket_path: str, name: str) -> JailDetailResponse:
|
|||||||
action_stats = _to_dict(raw.get("Actions") or [])
|
action_stats = _to_dict(raw.get("Actions") or [])
|
||||||
|
|
||||||
jail_status = JailStatus(
|
jail_status = JailStatus(
|
||||||
currently_banned=int(str(action_stats.get("Currently banned", 0) or 0)),
|
currently_banned=int(action_stats.get("Currently banned", 0) or 0),
|
||||||
total_banned=int(str(action_stats.get("Total banned", 0) or 0)),
|
total_banned=int(action_stats.get("Total banned", 0) or 0),
|
||||||
currently_failed=int(str(filter_stats.get("Currently failed", 0) or 0)),
|
currently_failed=int(filter_stats.get("Currently failed", 0) or 0),
|
||||||
total_failed=int(str(filter_stats.get("Total failed", 0) or 0)),
|
total_failed=int(filter_stats.get("Total failed", 0) or 0),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fetch all detail fields in parallel.
|
# Fetch all detail fields in parallel.
|
||||||
@@ -489,11 +480,11 @@ async def get_jail(socket_path: str, name: str) -> JailDetailResponse:
|
|||||||
bt_increment: bool = bool(bt_increment_raw)
|
bt_increment: bool = bool(bt_increment_raw)
|
||||||
bantime_escalation = BantimeEscalation(
|
bantime_escalation = BantimeEscalation(
|
||||||
increment=bt_increment,
|
increment=bt_increment,
|
||||||
factor=float(str(bt_factor_raw)) if bt_factor_raw is not None else None,
|
factor=float(bt_factor_raw) if bt_factor_raw is not None else None,
|
||||||
formula=str(bt_formula_raw) if bt_formula_raw else None,
|
formula=str(bt_formula_raw) if bt_formula_raw else None,
|
||||||
multipliers=str(bt_multipliers_raw) if bt_multipliers_raw else None,
|
multipliers=str(bt_multipliers_raw) if bt_multipliers_raw else None,
|
||||||
max_time=int(str(bt_maxtime_raw)) if bt_maxtime_raw is not None else None,
|
max_time=int(bt_maxtime_raw) if bt_maxtime_raw is not None else None,
|
||||||
rnd_time=int(str(bt_rndtime_raw)) if bt_rndtime_raw is not None else None,
|
rnd_time=int(bt_rndtime_raw) if bt_rndtime_raw is not None else None,
|
||||||
overall_jails=bool(bt_overalljails_raw),
|
overall_jails=bool(bt_overalljails_raw),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -509,9 +500,9 @@ async def get_jail(socket_path: str, name: str) -> JailDetailResponse:
|
|||||||
ignore_ips=_ensure_list(ignoreip_raw),
|
ignore_ips=_ensure_list(ignoreip_raw),
|
||||||
date_pattern=str(datepattern_raw) if datepattern_raw else None,
|
date_pattern=str(datepattern_raw) if datepattern_raw else None,
|
||||||
log_encoding=str(logencoding_raw or "UTF-8"),
|
log_encoding=str(logencoding_raw or "UTF-8"),
|
||||||
find_time=int(str(findtime_raw or 600)),
|
find_time=int(findtime_raw or 600),
|
||||||
ban_time=int(str(bantime_raw or 600)),
|
ban_time=int(bantime_raw or 600),
|
||||||
max_retry=int(str(maxretry_raw or 5)),
|
max_retry=int(maxretry_raw or 5),
|
||||||
bantime_escalation=bantime_escalation,
|
bantime_escalation=bantime_escalation,
|
||||||
status=jail_status,
|
status=jail_status,
|
||||||
actions=_ensure_list(actions_raw),
|
actions=_ensure_list(actions_raw),
|
||||||
@@ -680,8 +671,8 @@ async def reload_all(
|
|||||||
if exclude_jails:
|
if exclude_jails:
|
||||||
names_set -= set(exclude_jails)
|
names_set -= set(exclude_jails)
|
||||||
|
|
||||||
stream: list[list[object]] = [["start", n] for n in sorted(names_set)]
|
stream: list[list[str]] = [["start", n] for n in sorted(names_set)]
|
||||||
_ok(await client.send(["reload", "--all", [], cast("Fail2BanToken", stream)]))
|
_ok(await client.send(["reload", "--all", [], stream]))
|
||||||
log.info("all_jails_reloaded")
|
log.info("all_jails_reloaded")
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
# Detect UnknownJailException (missing or invalid jail configuration)
|
# Detect UnknownJailException (missing or invalid jail configuration)
|
||||||
@@ -804,10 +795,9 @@ async def unban_ip(
|
|||||||
|
|
||||||
async def get_active_bans(
|
async def get_active_bans(
|
||||||
socket_path: str,
|
socket_path: str,
|
||||||
geo_batch_lookup: GeoBatchLookup | None = None,
|
geo_enricher: Any | None = None,
|
||||||
geo_enricher: GeoEnricher | None = None,
|
http_session: Any | None = None,
|
||||||
http_session: aiohttp.ClientSession | None = None,
|
app_db: Any | None = None,
|
||||||
app_db: aiosqlite.Connection | None = None,
|
|
||||||
) -> ActiveBanListResponse:
|
) -> ActiveBanListResponse:
|
||||||
"""Return all currently banned IPs across every jail.
|
"""Return all currently banned IPs across every jail.
|
||||||
|
|
||||||
@@ -842,6 +832,7 @@ async def get_active_bans(
|
|||||||
~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket
|
~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket
|
||||||
cannot be reached.
|
cannot be reached.
|
||||||
"""
|
"""
|
||||||
|
from app.services import geo_service # noqa: PLC0415
|
||||||
|
|
||||||
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
||||||
|
|
||||||
@@ -858,7 +849,7 @@ async def get_active_bans(
|
|||||||
return ActiveBanListResponse(bans=[], total=0)
|
return ActiveBanListResponse(bans=[], total=0)
|
||||||
|
|
||||||
# For each jail, fetch the ban list with time info in parallel.
|
# For each jail, fetch the ban list with time info in parallel.
|
||||||
results: list[object | Exception] = await asyncio.gather(
|
results: list[Any] = await asyncio.gather(
|
||||||
*[client.send(["get", jn, "banip", "--with-time"]) for jn in jail_names],
|
*[client.send(["get", jn, "banip", "--with-time"]) for jn in jail_names],
|
||||||
return_exceptions=True,
|
return_exceptions=True,
|
||||||
)
|
)
|
||||||
@@ -874,7 +865,7 @@ async def get_active_bans(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ban_list: list[str] = cast("list[str]", _ok(raw_result)) or []
|
ban_list: list[str] = _ok(raw_result) or []
|
||||||
except (TypeError, ValueError) as exc:
|
except (TypeError, ValueError) as exc:
|
||||||
log.warning(
|
log.warning(
|
||||||
"active_bans_parse_error",
|
"active_bans_parse_error",
|
||||||
@@ -889,10 +880,10 @@ async def get_active_bans(
|
|||||||
bans.append(ban)
|
bans.append(ban)
|
||||||
|
|
||||||
# Enrich with geo data — prefer batch lookup over per-IP enricher.
|
# Enrich with geo data — prefer batch lookup over per-IP enricher.
|
||||||
if http_session is not None and bans and geo_batch_lookup is not None:
|
if http_session is not None and bans:
|
||||||
all_ips: list[str] = [ban.ip for ban in bans]
|
all_ips: list[str] = [ban.ip for ban in bans]
|
||||||
try:
|
try:
|
||||||
geo_map = await geo_batch_lookup(all_ips, http_session, db=app_db)
|
geo_map = await geo_service.lookup_batch(all_ips, http_session, db=app_db)
|
||||||
except Exception: # noqa: BLE001
|
except Exception: # noqa: BLE001
|
||||||
log.warning("active_bans_batch_geo_failed")
|
log.warning("active_bans_batch_geo_failed")
|
||||||
geo_map = {}
|
geo_map = {}
|
||||||
@@ -1001,9 +992,8 @@ async def get_jail_banned_ips(
|
|||||||
page: int = 1,
|
page: int = 1,
|
||||||
page_size: int = 25,
|
page_size: int = 25,
|
||||||
search: str | None = None,
|
search: str | None = None,
|
||||||
geo_batch_lookup: GeoBatchLookup | None = None,
|
http_session: Any | None = None,
|
||||||
http_session: aiohttp.ClientSession | None = None,
|
app_db: Any | None = None,
|
||||||
app_db: aiosqlite.Connection | None = None,
|
|
||||||
) -> JailBannedIpsResponse:
|
) -> JailBannedIpsResponse:
|
||||||
"""Return a paginated list of currently banned IPs for a single jail.
|
"""Return a paginated list of currently banned IPs for a single jail.
|
||||||
|
|
||||||
@@ -1029,6 +1019,8 @@ async def get_jail_banned_ips(
|
|||||||
~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket is
|
~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket is
|
||||||
unreachable.
|
unreachable.
|
||||||
"""
|
"""
|
||||||
|
from app.services import geo_service # noqa: PLC0415
|
||||||
|
|
||||||
# Clamp page_size to the allowed maximum.
|
# Clamp page_size to the allowed maximum.
|
||||||
page_size = min(page_size, _MAX_PAGE_SIZE)
|
page_size = min(page_size, _MAX_PAGE_SIZE)
|
||||||
|
|
||||||
@@ -1048,7 +1040,7 @@ async def get_jail_banned_ips(
|
|||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
raw_result = []
|
raw_result = []
|
||||||
|
|
||||||
ban_list: list[str] = cast("list[str]", raw_result) or []
|
ban_list: list[str] = raw_result or []
|
||||||
|
|
||||||
# Parse all entries.
|
# Parse all entries.
|
||||||
all_bans: list[ActiveBan] = []
|
all_bans: list[ActiveBan] = []
|
||||||
@@ -1069,10 +1061,10 @@ async def get_jail_banned_ips(
|
|||||||
page_bans = all_bans[start : start + page_size]
|
page_bans = all_bans[start : start + page_size]
|
||||||
|
|
||||||
# Geo-enrich only the page slice.
|
# Geo-enrich only the page slice.
|
||||||
if http_session is not None and page_bans and geo_batch_lookup is not None:
|
if http_session is not None and page_bans:
|
||||||
page_ips = [b.ip for b in page_bans]
|
page_ips = [b.ip for b in page_bans]
|
||||||
try:
|
try:
|
||||||
geo_map = await geo_batch_lookup(page_ips, http_session, db=app_db)
|
geo_map = await geo_service.lookup_batch(page_ips, http_session, db=app_db)
|
||||||
except Exception: # noqa: BLE001
|
except Exception: # noqa: BLE001
|
||||||
log.warning("jail_banned_ips_geo_failed", jail=jail_name)
|
log.warning("jail_banned_ips_geo_failed", jail=jail_name)
|
||||||
geo_map = {}
|
geo_map = {}
|
||||||
@@ -1102,7 +1094,7 @@ async def get_jail_banned_ips(
|
|||||||
|
|
||||||
async def _enrich_bans(
|
async def _enrich_bans(
|
||||||
bans: list[ActiveBan],
|
bans: list[ActiveBan],
|
||||||
geo_enricher: GeoEnricher,
|
geo_enricher: Any,
|
||||||
) -> list[ActiveBan]:
|
) -> list[ActiveBan]:
|
||||||
"""Enrich ban records with geo data asynchronously.
|
"""Enrich ban records with geo data asynchronously.
|
||||||
|
|
||||||
@@ -1113,15 +1105,14 @@ async def _enrich_bans(
|
|||||||
Returns:
|
Returns:
|
||||||
The same list with ``country`` fields populated where lookup succeeded.
|
The same list with ``country`` fields populated where lookup succeeded.
|
||||||
"""
|
"""
|
||||||
geo_results: list[object | Exception] = await asyncio.gather(
|
geo_results: list[Any] = await asyncio.gather(
|
||||||
*[cast("Awaitable[object]", geo_enricher(ban.ip)) for ban in bans],
|
*[geo_enricher(ban.ip) for ban in bans],
|
||||||
return_exceptions=True,
|
return_exceptions=True,
|
||||||
)
|
)
|
||||||
enriched: list[ActiveBan] = []
|
enriched: list[ActiveBan] = []
|
||||||
for ban, geo in zip(bans, geo_results, strict=False):
|
for ban, geo in zip(bans, geo_results, strict=False):
|
||||||
if geo is not None and not isinstance(geo, Exception):
|
if geo is not None and not isinstance(geo, Exception):
|
||||||
geo_info = cast("GeoInfo", geo)
|
enriched.append(ban.model_copy(update={"country": geo.country_code}))
|
||||||
enriched.append(ban.model_copy(update={"country": geo_info.country_code}))
|
|
||||||
else:
|
else:
|
||||||
enriched.append(ban)
|
enriched.append(ban)
|
||||||
return enriched
|
return enriched
|
||||||
@@ -1269,8 +1260,8 @@ async def set_ignore_self(socket_path: str, name: str, *, on: bool) -> None:
|
|||||||
async def lookup_ip(
|
async def lookup_ip(
|
||||||
socket_path: str,
|
socket_path: str,
|
||||||
ip: str,
|
ip: str,
|
||||||
geo_enricher: GeoEnricher | None = None,
|
geo_enricher: Any | None = None,
|
||||||
) -> IpLookupResult:
|
) -> dict[str, Any]:
|
||||||
"""Return ban status and history for a single IP address.
|
"""Return ban status and history for a single IP address.
|
||||||
|
|
||||||
Checks every running jail for whether the IP is currently banned.
|
Checks every running jail for whether the IP is currently banned.
|
||||||
@@ -1313,7 +1304,7 @@ async def lookup_ip(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Check ban status per jail in parallel.
|
# Check ban status per jail in parallel.
|
||||||
ban_results: list[object | Exception] = await asyncio.gather(
|
ban_results: list[Any] = await asyncio.gather(
|
||||||
*[client.send(["get", jn, "banip"]) for jn in jail_names],
|
*[client.send(["get", jn, "banip"]) for jn in jail_names],
|
||||||
return_exceptions=True,
|
return_exceptions=True,
|
||||||
)
|
)
|
||||||
@@ -1323,7 +1314,7 @@ async def lookup_ip(
|
|||||||
if isinstance(result, Exception):
|
if isinstance(result, Exception):
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
ban_list: list[str] = cast("list[str]", _ok(result)) or []
|
ban_list: list[str] = _ok(result) or []
|
||||||
if ip in ban_list:
|
if ip in ban_list:
|
||||||
currently_banned_in.append(jail_name)
|
currently_banned_in.append(jail_name)
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
@@ -1360,6 +1351,6 @@ async def unban_all_ips(socket_path: str) -> int:
|
|||||||
cannot be reached.
|
cannot be reached.
|
||||||
"""
|
"""
|
||||||
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
||||||
count: int = int(str(_ok(await client.send(["unban", "--all"])) or 0))
|
count: int = int(_ok(await client.send(["unban", "--all"])))
|
||||||
log.info("all_ips_unbanned", count=count)
|
log.info("all_ips_unbanned", count=count)
|
||||||
return count
|
return count
|
||||||
|
|||||||
@@ -1,128 +0,0 @@
|
|||||||
"""Log helper service.
|
|
||||||
|
|
||||||
Contains regex test and log preview helpers that are independent of
|
|
||||||
fail2ban socket operations.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import re
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from app.models.config import (
|
|
||||||
LogPreviewLine,
|
|
||||||
LogPreviewRequest,
|
|
||||||
LogPreviewResponse,
|
|
||||||
RegexTestRequest,
|
|
||||||
RegexTestResponse,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_regex(request: RegexTestRequest) -> RegexTestResponse:
|
|
||||||
"""Test a regex pattern against a sample log line.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: The regex test payload.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
RegexTestResponse with match result, groups and optional error.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
compiled = re.compile(request.fail_regex)
|
|
||||||
except re.error as exc:
|
|
||||||
return RegexTestResponse(matched=False, groups=[], error=str(exc))
|
|
||||||
|
|
||||||
match = compiled.search(request.log_line)
|
|
||||||
if match is None:
|
|
||||||
return RegexTestResponse(matched=False)
|
|
||||||
|
|
||||||
groups: list[str] = list(match.groups() or [])
|
|
||||||
return RegexTestResponse(matched=True, groups=[str(g) for g in groups if g is not None])
|
|
||||||
|
|
||||||
|
|
||||||
async def preview_log(req: LogPreviewRequest) -> LogPreviewResponse:
|
|
||||||
"""Inspect the last lines of a log file and evaluate regex matches.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
req: Log preview request.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
LogPreviewResponse with lines, total_lines and matched_count, or error.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
compiled = re.compile(req.fail_regex)
|
|
||||||
except re.error as exc:
|
|
||||||
return LogPreviewResponse(
|
|
||||||
lines=[],
|
|
||||||
total_lines=0,
|
|
||||||
matched_count=0,
|
|
||||||
regex_error=str(exc),
|
|
||||||
)
|
|
||||||
|
|
||||||
path = Path(req.log_path)
|
|
||||||
if not path.is_file():
|
|
||||||
return LogPreviewResponse(
|
|
||||||
lines=[],
|
|
||||||
total_lines=0,
|
|
||||||
matched_count=0,
|
|
||||||
regex_error=f"File not found: {req.log_path!r}",
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
raw_lines = await asyncio.get_event_loop().run_in_executor(
|
|
||||||
None,
|
|
||||||
_read_tail_lines,
|
|
||||||
str(path),
|
|
||||||
req.num_lines,
|
|
||||||
)
|
|
||||||
except OSError as exc:
|
|
||||||
return LogPreviewResponse(
|
|
||||||
lines=[],
|
|
||||||
total_lines=0,
|
|
||||||
matched_count=0,
|
|
||||||
regex_error=f"Cannot read file: {exc}",
|
|
||||||
)
|
|
||||||
|
|
||||||
result_lines: list[LogPreviewLine] = []
|
|
||||||
matched_count = 0
|
|
||||||
for line in raw_lines:
|
|
||||||
m = compiled.search(line)
|
|
||||||
groups = [str(g) for g in (m.groups() or []) if g is not None] if m else []
|
|
||||||
result_lines.append(
|
|
||||||
LogPreviewLine(line=line, matched=(m is not None), groups=groups),
|
|
||||||
)
|
|
||||||
if m:
|
|
||||||
matched_count += 1
|
|
||||||
|
|
||||||
return LogPreviewResponse(
|
|
||||||
lines=result_lines,
|
|
||||||
total_lines=len(result_lines),
|
|
||||||
matched_count=matched_count,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _read_tail_lines(file_path: str, num_lines: int) -> list[str]:
|
|
||||||
"""Read the last *num_lines* from *file_path* in a memory-efficient way."""
|
|
||||||
chunk_size = 8192
|
|
||||||
raw_lines: list[bytes] = []
|
|
||||||
with open(file_path, "rb") as fh:
|
|
||||||
fh.seek(0, 2)
|
|
||||||
end_pos = fh.tell()
|
|
||||||
if end_pos == 0:
|
|
||||||
return []
|
|
||||||
|
|
||||||
buf = b""
|
|
||||||
pos = end_pos
|
|
||||||
while len(raw_lines) <= num_lines and pos > 0:
|
|
||||||
read_size = min(chunk_size, pos)
|
|
||||||
pos -= read_size
|
|
||||||
fh.seek(pos)
|
|
||||||
chunk = fh.read(read_size)
|
|
||||||
buf = chunk + buf
|
|
||||||
raw_lines = buf.split(b"\n")
|
|
||||||
|
|
||||||
if pos > 0 and len(raw_lines) > 1:
|
|
||||||
raw_lines = raw_lines[1:]
|
|
||||||
|
|
||||||
return [ln.decode("utf-8", errors="replace").rstrip() for ln in raw_lines[-num_lines:] if ln.strip()]
|
|
||||||
@@ -10,50 +10,25 @@ HTTP/FastAPI concerns.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import cast
|
from typing import Any
|
||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
|
|
||||||
from app.exceptions import ServerOperationError
|
|
||||||
from app.exceptions import ServerOperationError
|
|
||||||
from app.models.server import ServerSettings, ServerSettingsResponse, ServerSettingsUpdate
|
from app.models.server import ServerSettings, ServerSettingsResponse, ServerSettingsUpdate
|
||||||
from app.utils.fail2ban_client import Fail2BanClient, Fail2BanCommand, Fail2BanResponse
|
from app.utils.fail2ban_client import Fail2BanClient
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Types
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
type Fail2BanSettingValue = str | int | bool
|
|
||||||
"""Allowed values for server settings commands."""
|
|
||||||
|
|
||||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||||
|
|
||||||
_SOCKET_TIMEOUT: float = 10.0
|
_SOCKET_TIMEOUT: float = 10.0
|
||||||
|
|
||||||
|
|
||||||
def _to_int(value: object | None, default: int) -> int:
|
# ---------------------------------------------------------------------------
|
||||||
"""Convert a raw value to an int, falling back to a default.
|
# Custom exceptions
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
The fail2ban control socket can return either int or str values for some
|
|
||||||
settings, so we normalise them here in a type-safe way.
|
|
||||||
"""
|
|
||||||
if isinstance(value, int):
|
|
||||||
return value
|
|
||||||
if isinstance(value, float):
|
|
||||||
return int(value)
|
|
||||||
if isinstance(value, str):
|
|
||||||
try:
|
|
||||||
return int(value)
|
|
||||||
except ValueError:
|
|
||||||
return default
|
|
||||||
return default
|
|
||||||
|
|
||||||
|
|
||||||
def _to_str(value: object | None, default: str) -> str:
|
class ServerOperationError(Exception):
|
||||||
"""Convert a raw value to a string, falling back to a default."""
|
"""Raised when a server-level set command fails."""
|
||||||
if value is None:
|
|
||||||
return default
|
|
||||||
return str(value)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -61,7 +36,7 @@ def _to_str(value: object | None, default: str) -> str:
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def _ok(response: Fail2BanResponse) -> object:
|
def _ok(response: Any) -> Any:
|
||||||
"""Extract payload from a fail2ban ``(code, data)`` response.
|
"""Extract payload from a fail2ban ``(code, data)`` response.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -84,9 +59,9 @@ def _ok(response: Fail2BanResponse) -> object:
|
|||||||
|
|
||||||
async def _safe_get(
|
async def _safe_get(
|
||||||
client: Fail2BanClient,
|
client: Fail2BanClient,
|
||||||
command: Fail2BanCommand,
|
command: list[Any],
|
||||||
default: object | None = None,
|
default: Any = None,
|
||||||
) -> object | None:
|
) -> Any:
|
||||||
"""Send a command and silently return *default* on any error.
|
"""Send a command and silently return *default* on any error.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -98,8 +73,7 @@ async def _safe_get(
|
|||||||
The successful response, or *default*.
|
The successful response, or *default*.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
response = await client.send(command)
|
return _ok(await client.send(command))
|
||||||
return _ok(cast("Fail2BanResponse", response))
|
|
||||||
except Exception:
|
except Exception:
|
||||||
return default
|
return default
|
||||||
|
|
||||||
@@ -144,20 +118,13 @@ async def get_settings(socket_path: str) -> ServerSettingsResponse:
|
|||||||
_safe_get(client, ["get", "dbmaxmatches"], 10),
|
_safe_get(client, ["get", "dbmaxmatches"], 10),
|
||||||
)
|
)
|
||||||
|
|
||||||
log_level = _to_str(log_level_raw, "INFO").upper()
|
|
||||||
log_target = _to_str(log_target_raw, "STDOUT")
|
|
||||||
syslog_socket = _to_str(syslog_socket_raw, "") or None
|
|
||||||
db_path = _to_str(db_path_raw, "/var/lib/fail2ban/fail2ban.sqlite3")
|
|
||||||
db_purge_age = _to_int(db_purge_age_raw, 86400)
|
|
||||||
db_max_matches = _to_int(db_max_matches_raw, 10)
|
|
||||||
|
|
||||||
settings = ServerSettings(
|
settings = ServerSettings(
|
||||||
log_level=log_level,
|
log_level=str(log_level_raw or "INFO").upper(),
|
||||||
log_target=log_target,
|
log_target=str(log_target_raw or "STDOUT"),
|
||||||
syslog_socket=syslog_socket,
|
syslog_socket=str(syslog_socket_raw) if syslog_socket_raw else None,
|
||||||
db_path=db_path,
|
db_path=str(db_path_raw or "/var/lib/fail2ban/fail2ban.sqlite3"),
|
||||||
db_purge_age=db_purge_age,
|
db_purge_age=int(db_purge_age_raw or 86400),
|
||||||
db_max_matches=db_max_matches,
|
db_max_matches=int(db_max_matches_raw or 10),
|
||||||
)
|
)
|
||||||
|
|
||||||
log.info("server_settings_fetched")
|
log.info("server_settings_fetched")
|
||||||
@@ -179,10 +146,9 @@ async def update_settings(socket_path: str, update: ServerSettingsUpdate) -> Non
|
|||||||
"""
|
"""
|
||||||
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
||||||
|
|
||||||
async def _set(key: str, value: Fail2BanSettingValue) -> None:
|
async def _set(key: str, value: Any) -> None:
|
||||||
try:
|
try:
|
||||||
response = await client.send(["set", key, value])
|
_ok(await client.send(["set", key, value]))
|
||||||
_ok(cast("Fail2BanResponse", response))
|
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
raise ServerOperationError(f"Failed to set {key!r} = {value!r}: {exc}") from exc
|
raise ServerOperationError(f"Failed to set {key!r} = {value!r}: {exc}") from exc
|
||||||
|
|
||||||
@@ -216,8 +182,7 @@ async def flush_logs(socket_path: str) -> str:
|
|||||||
"""
|
"""
|
||||||
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
||||||
try:
|
try:
|
||||||
response = await client.send(["flushlogs"])
|
result = _ok(await client.send(["flushlogs"]))
|
||||||
result = _ok(cast("Fail2BanResponse", response))
|
|
||||||
log.info("logs_flushed", result=result)
|
log.info("logs_flushed", result=result)
|
||||||
return str(result)
|
return str(result)
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
|
|||||||
@@ -102,20 +102,30 @@ async def run_setup(
|
|||||||
log.info("bangui_setup_completed")
|
log.info("bangui_setup_completed")
|
||||||
|
|
||||||
|
|
||||||
from app.utils.setup_utils import (
|
|
||||||
get_map_color_thresholds as util_get_map_color_thresholds,
|
|
||||||
get_password_hash as util_get_password_hash,
|
|
||||||
set_map_color_thresholds as util_set_map_color_thresholds,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_password_hash(db: aiosqlite.Connection) -> str | None:
|
async def get_password_hash(db: aiosqlite.Connection) -> str | None:
|
||||||
"""Return the stored bcrypt password hash, or ``None`` if not set."""
|
"""Return the stored bcrypt password hash, or ``None`` if not set.
|
||||||
return await util_get_password_hash(db)
|
|
||||||
|
Args:
|
||||||
|
db: Active aiosqlite connection.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The bcrypt hash string, or ``None``.
|
||||||
|
"""
|
||||||
|
return await settings_repo.get_setting(db, _KEY_PASSWORD_HASH)
|
||||||
|
|
||||||
|
|
||||||
async def get_timezone(db: aiosqlite.Connection) -> str:
|
async def get_timezone(db: aiosqlite.Connection) -> str:
|
||||||
"""Return the configured IANA timezone string."""
|
"""Return the configured IANA timezone string.
|
||||||
|
|
||||||
|
Falls back to ``"UTC"`` when no timezone has been stored (e.g. before
|
||||||
|
setup completes or for legacy databases).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Active aiosqlite connection.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An IANA timezone identifier such as ``"Europe/Berlin"`` or ``"UTC"``.
|
||||||
|
"""
|
||||||
tz = await settings_repo.get_setting(db, _KEY_TIMEZONE)
|
tz = await settings_repo.get_setting(db, _KEY_TIMEZONE)
|
||||||
return tz if tz else "UTC"
|
return tz if tz else "UTC"
|
||||||
|
|
||||||
@@ -123,8 +133,31 @@ async def get_timezone(db: aiosqlite.Connection) -> str:
|
|||||||
async def get_map_color_thresholds(
|
async def get_map_color_thresholds(
|
||||||
db: aiosqlite.Connection,
|
db: aiosqlite.Connection,
|
||||||
) -> tuple[int, int, int]:
|
) -> tuple[int, int, int]:
|
||||||
"""Return the configured map color thresholds (high, medium, low)."""
|
"""Return the configured map color thresholds (high, medium, low).
|
||||||
return await util_get_map_color_thresholds(db)
|
|
||||||
|
Falls back to default values (100, 50, 20) if not set.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Active aiosqlite connection.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of (threshold_high, threshold_medium, threshold_low).
|
||||||
|
"""
|
||||||
|
high = await settings_repo.get_setting(
|
||||||
|
db, _KEY_MAP_COLOR_THRESHOLD_HIGH
|
||||||
|
)
|
||||||
|
medium = await settings_repo.get_setting(
|
||||||
|
db, _KEY_MAP_COLOR_THRESHOLD_MEDIUM
|
||||||
|
)
|
||||||
|
low = await settings_repo.get_setting(
|
||||||
|
db, _KEY_MAP_COLOR_THRESHOLD_LOW
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
int(high) if high else 100,
|
||||||
|
int(medium) if medium else 50,
|
||||||
|
int(low) if low else 20,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def set_map_color_thresholds(
|
async def set_map_color_thresholds(
|
||||||
@@ -134,12 +167,31 @@ async def set_map_color_thresholds(
|
|||||||
threshold_medium: int,
|
threshold_medium: int,
|
||||||
threshold_low: int,
|
threshold_low: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Update the map color threshold configuration."""
|
"""Update the map color threshold configuration.
|
||||||
await util_set_map_color_thresholds(
|
|
||||||
db,
|
Args:
|
||||||
threshold_high=threshold_high,
|
db: Active aiosqlite connection.
|
||||||
threshold_medium=threshold_medium,
|
threshold_high: Ban count for red coloring.
|
||||||
threshold_low=threshold_low,
|
threshold_medium: Ban count for yellow coloring.
|
||||||
|
threshold_low: Ban count for green coloring.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If thresholds are not positive integers or if
|
||||||
|
high <= medium <= low.
|
||||||
|
"""
|
||||||
|
if threshold_high <= 0 or threshold_medium <= 0 or threshold_low <= 0:
|
||||||
|
raise ValueError("All thresholds must be positive integers.")
|
||||||
|
if not (threshold_high > threshold_medium > threshold_low):
|
||||||
|
raise ValueError("Thresholds must satisfy: high > medium > low.")
|
||||||
|
|
||||||
|
await settings_repo.set_setting(
|
||||||
|
db, _KEY_MAP_COLOR_THRESHOLD_HIGH, str(threshold_high)
|
||||||
|
)
|
||||||
|
await settings_repo.set_setting(
|
||||||
|
db, _KEY_MAP_COLOR_THRESHOLD_MEDIUM, str(threshold_medium)
|
||||||
|
)
|
||||||
|
await settings_repo.set_setting(
|
||||||
|
db, _KEY_MAP_COLOR_THRESHOLD_LOW, str(threshold_low)
|
||||||
)
|
)
|
||||||
log.info(
|
log.info(
|
||||||
"map_color_thresholds_updated",
|
"map_color_thresholds_updated",
|
||||||
|
|||||||
@@ -43,15 +43,9 @@ async def _run_import(app: Any) -> None:
|
|||||||
http_session = app.state.http_session
|
http_session = app.state.http_session
|
||||||
socket_path: str = app.state.settings.fail2ban_socket
|
socket_path: str = app.state.settings.fail2ban_socket
|
||||||
|
|
||||||
from app.services import jail_service
|
|
||||||
|
|
||||||
log.info("blocklist_import_starting")
|
log.info("blocklist_import_starting")
|
||||||
try:
|
try:
|
||||||
result = await blocklist_service.import_all(
|
result = await blocklist_service.import_all(db, http_session, socket_path)
|
||||||
db,
|
|
||||||
http_session,
|
|
||||||
socket_path,
|
|
||||||
)
|
|
||||||
log.info(
|
log.info(
|
||||||
"blocklist_import_finished",
|
"blocklist_import_finished",
|
||||||
total_imported=result.total_imported,
|
total_imported=result.total_imported,
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ The task runs every 10 minutes. On each invocation it:
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
|
|
||||||
@@ -35,7 +35,7 @@ GEO_RE_RESOLVE_INTERVAL: int = 600
|
|||||||
JOB_ID: str = "geo_re_resolve"
|
JOB_ID: str = "geo_re_resolve"
|
||||||
|
|
||||||
|
|
||||||
async def _run_re_resolve(app: FastAPI) -> None:
|
async def _run_re_resolve(app: Any) -> None:
|
||||||
"""Query NULL-country IPs from the database and re-resolve them.
|
"""Query NULL-country IPs from the database and re-resolve them.
|
||||||
|
|
||||||
Reads shared resources from ``app.state`` and delegates to
|
Reads shared resources from ``app.state`` and delegates to
|
||||||
@@ -49,7 +49,12 @@ async def _run_re_resolve(app: FastAPI) -> None:
|
|||||||
http_session = app.state.http_session
|
http_session = app.state.http_session
|
||||||
|
|
||||||
# Fetch all IPs with NULL country_code from the persistent cache.
|
# Fetch all IPs with NULL country_code from the persistent cache.
|
||||||
unresolved_ips = await geo_service.get_unresolved_ips(db)
|
unresolved_ips: list[str] = []
|
||||||
|
async with db.execute(
|
||||||
|
"SELECT ip FROM geo_cache WHERE country_code IS NULL"
|
||||||
|
) as cursor:
|
||||||
|
async for row in cursor:
|
||||||
|
unresolved_ips.append(str(row[0]))
|
||||||
|
|
||||||
if not unresolved_ips:
|
if not unresolved_ips:
|
||||||
log.debug("geo_re_resolve_skip", reason="no_unresolved_ips")
|
log.debug("geo_re_resolve_skip", reason="no_unresolved_ips")
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ within 60 seconds of that activation, a
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
from typing import TYPE_CHECKING, TypedDict
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
|
|
||||||
@@ -31,14 +31,6 @@ if TYPE_CHECKING: # pragma: no cover
|
|||||||
|
|
||||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
class ActivationRecord(TypedDict):
|
|
||||||
"""Stored timestamp data for a jail activation event."""
|
|
||||||
|
|
||||||
jail_name: str
|
|
||||||
at: datetime.datetime
|
|
||||||
|
|
||||||
|
|
||||||
#: How often the probe fires (seconds).
|
#: How often the probe fires (seconds).
|
||||||
HEALTH_CHECK_INTERVAL: int = 30
|
HEALTH_CHECK_INTERVAL: int = 30
|
||||||
|
|
||||||
@@ -47,7 +39,7 @@ HEALTH_CHECK_INTERVAL: int = 30
|
|||||||
_ACTIVATION_CRASH_WINDOW: int = 60
|
_ACTIVATION_CRASH_WINDOW: int = 60
|
||||||
|
|
||||||
|
|
||||||
async def _run_probe(app: FastAPI) -> None:
|
async def _run_probe(app: Any) -> None:
|
||||||
"""Probe fail2ban and cache the result on *app.state*.
|
"""Probe fail2ban and cache the result on *app.state*.
|
||||||
|
|
||||||
Detects online/offline state transitions. When fail2ban goes offline
|
Detects online/offline state transitions. When fail2ban goes offline
|
||||||
@@ -94,7 +86,7 @@ async def _run_probe(app: FastAPI) -> None:
|
|||||||
elif not status.online and prev_status.online:
|
elif not status.online and prev_status.online:
|
||||||
log.warning("fail2ban_went_offline")
|
log.warning("fail2ban_went_offline")
|
||||||
# Check whether this crash happened shortly after a jail activation.
|
# Check whether this crash happened shortly after a jail activation.
|
||||||
last_activation: ActivationRecord | None = getattr(
|
last_activation: dict[str, Any] | None = getattr(
|
||||||
app.state, "last_activation", None
|
app.state, "last_activation", None
|
||||||
)
|
)
|
||||||
if last_activation is not None:
|
if last_activation is not None:
|
||||||
|
|||||||
@@ -1,21 +0,0 @@
|
|||||||
"""Utilities re-exported from config_file_service for cross-module usage."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from app.services.config_file_service import (
|
|
||||||
_build_inactive_jail,
|
|
||||||
_get_active_jail_names,
|
|
||||||
_ordered_config_files,
|
|
||||||
_parse_jails_sync,
|
|
||||||
_validate_jail_config_sync,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"_ordered_config_files",
|
|
||||||
"_parse_jails_sync",
|
|
||||||
"_build_inactive_jail",
|
|
||||||
"_get_active_jail_names",
|
|
||||||
"_validate_jail_config_sync",
|
|
||||||
]
|
|
||||||
@@ -21,52 +21,14 @@ import contextlib
|
|||||||
import errno
|
import errno
|
||||||
import socket
|
import socket
|
||||||
import time
|
import time
|
||||||
from collections.abc import Mapping, Sequence, Set
|
|
||||||
from pickle import HIGHEST_PROTOCOL, dumps, loads
|
from pickle import HIGHEST_PROTOCOL, dumps, loads
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
import structlog
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Types
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
# Use covariant container types so callers can pass ``list[int]`` / ``dict[str, str]``
|
|
||||||
# without needing to cast. At runtime we only accept the basic built-in
|
|
||||||
# containers supported by fail2ban's protocol (list/dict/set) and stringify
|
|
||||||
# anything else.
|
|
||||||
#
|
|
||||||
# NOTE: ``Sequence`` will also accept tuples, but tuples are stringified at
|
|
||||||
# runtime because fail2ban only understands lists.
|
|
||||||
|
|
||||||
type Fail2BanToken = (
|
|
||||||
str
|
|
||||||
| int
|
|
||||||
| float
|
|
||||||
| bool
|
|
||||||
| None
|
|
||||||
| Mapping[str, object]
|
|
||||||
| Sequence[object]
|
|
||||||
| Set[object]
|
|
||||||
)
|
|
||||||
"""A single token in a fail2ban command.
|
|
||||||
|
|
||||||
Fail2ban accepts simple types (str/int/float/bool) plus compound types
|
|
||||||
(list/dict/set). Complex objects are stringified before being sent.
|
|
||||||
"""
|
|
||||||
|
|
||||||
type Fail2BanCommand = Sequence[Fail2BanToken]
|
|
||||||
"""A command sent to fail2ban over the socket.
|
|
||||||
|
|
||||||
Commands are pickle serialised sequences of tokens.
|
|
||||||
"""
|
|
||||||
|
|
||||||
type Fail2BanResponse = tuple[int, object]
|
|
||||||
"""A typical fail2ban response containing a status code and payload."""
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from types import TracebackType
|
from types import TracebackType
|
||||||
|
|
||||||
|
import structlog
|
||||||
|
|
||||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||||
|
|
||||||
# fail2ban protocol constants — inline to avoid a hard import dependency
|
# fail2ban protocol constants — inline to avoid a hard import dependency
|
||||||
@@ -119,9 +81,9 @@ class Fail2BanProtocolError(Exception):
|
|||||||
|
|
||||||
def _send_command_sync(
|
def _send_command_sync(
|
||||||
socket_path: str,
|
socket_path: str,
|
||||||
command: Fail2BanCommand,
|
command: list[Any],
|
||||||
timeout: float,
|
timeout: float,
|
||||||
) -> object:
|
) -> Any:
|
||||||
"""Send a command to fail2ban and return the parsed response.
|
"""Send a command to fail2ban and return the parsed response.
|
||||||
|
|
||||||
This is a **synchronous** function intended to be called from within
|
This is a **synchronous** function intended to be called from within
|
||||||
@@ -218,7 +180,7 @@ def _send_command_sync(
|
|||||||
) from last_oserror
|
) from last_oserror
|
||||||
|
|
||||||
|
|
||||||
def _coerce_command_token(token: object) -> Fail2BanToken:
|
def _coerce_command_token(token: Any) -> Any:
|
||||||
"""Coerce a command token to a type that fail2ban understands.
|
"""Coerce a command token to a type that fail2ban understands.
|
||||||
|
|
||||||
fail2ban's ``CSocket.convert`` accepts ``str``, ``bool``, ``int``,
|
fail2ban's ``CSocket.convert`` accepts ``str``, ``bool``, ``int``,
|
||||||
@@ -267,7 +229,7 @@ class Fail2BanClient:
|
|||||||
self.socket_path: str = socket_path
|
self.socket_path: str = socket_path
|
||||||
self.timeout: float = timeout
|
self.timeout: float = timeout
|
||||||
|
|
||||||
async def send(self, command: Fail2BanCommand) -> object:
|
async def send(self, command: list[Any]) -> Any:
|
||||||
"""Send a command to fail2ban and return the response.
|
"""Send a command to fail2ban and return the response.
|
||||||
|
|
||||||
Acquires the module-level concurrency semaphore before dispatching
|
Acquires the module-level concurrency semaphore before dispatching
|
||||||
@@ -305,13 +267,13 @@ class Fail2BanClient:
|
|||||||
log.debug("fail2ban_sending_command", command=command)
|
log.debug("fail2ban_sending_command", command=command)
|
||||||
loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
|
loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
|
||||||
try:
|
try:
|
||||||
response: object = await loop.run_in_executor(
|
response: Any = await loop.run_in_executor(
|
||||||
None,
|
None,
|
||||||
_send_command_sync,
|
_send_command_sync,
|
||||||
self.socket_path,
|
self.socket_path,
|
||||||
command,
|
command,
|
||||||
self.timeout,
|
self.timeout,
|
||||||
)
|
)
|
||||||
except Fail2BanConnectionError:
|
except Fail2BanConnectionError:
|
||||||
log.warning(
|
log.warning(
|
||||||
"fail2ban_connection_error",
|
"fail2ban_connection_error",
|
||||||
@@ -338,7 +300,7 @@ class Fail2BanClient:
|
|||||||
``True`` when the daemon responds correctly, ``False`` otherwise.
|
``True`` when the daemon responds correctly, ``False`` otherwise.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
response: object = await self.send(["ping"])
|
response: Any = await self.send(["ping"])
|
||||||
return bool(response == 1) # fail2ban returns 1 on successful ping
|
return bool(response == 1) # fail2ban returns 1 on successful ping
|
||||||
except (Fail2BanConnectionError, Fail2BanProtocolError):
|
except (Fail2BanConnectionError, Fail2BanProtocolError):
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -1,63 +0,0 @@
|
|||||||
"""Utilities shared by fail2ban-related services."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
from datetime import UTC, datetime
|
|
||||||
|
|
||||||
|
|
||||||
def ts_to_iso(unix_ts: int) -> str:
|
|
||||||
"""Convert a Unix timestamp to an ISO 8601 UTC string."""
|
|
||||||
return datetime.fromtimestamp(unix_ts, tz=UTC).isoformat()
|
|
||||||
|
|
||||||
|
|
||||||
async def get_fail2ban_db_path(socket_path: str) -> str:
|
|
||||||
"""Query fail2ban for the path to its SQLite database file."""
|
|
||||||
from app.utils.fail2ban_client import Fail2BanClient # pragma: no cover
|
|
||||||
|
|
||||||
socket_timeout: float = 5.0
|
|
||||||
|
|
||||||
async with Fail2BanClient(socket_path, timeout=socket_timeout) as client:
|
|
||||||
response = await client.send(["get", "dbfile"])
|
|
||||||
|
|
||||||
if not isinstance(response, tuple) or len(response) != 2:
|
|
||||||
raise RuntimeError(f"Unexpected response from fail2ban: {response!r}")
|
|
||||||
|
|
||||||
code, data = response
|
|
||||||
if code != 0:
|
|
||||||
raise RuntimeError(f"fail2ban error code {code}: {data!r}")
|
|
||||||
|
|
||||||
if data is None:
|
|
||||||
raise RuntimeError("fail2ban has no database configured (dbfile is None)")
|
|
||||||
|
|
||||||
return str(data)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_data_json(raw: object) -> tuple[list[str], int]:
|
|
||||||
"""Extract matches and failure count from the fail2ban bans.data value."""
|
|
||||||
if raw is None:
|
|
||||||
return [], 0
|
|
||||||
|
|
||||||
obj: dict[str, object] = {}
|
|
||||||
if isinstance(raw, str):
|
|
||||||
try:
|
|
||||||
parsed = json.loads(raw)
|
|
||||||
if isinstance(parsed, dict):
|
|
||||||
obj = parsed
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
return [], 0
|
|
||||||
elif isinstance(raw, dict):
|
|
||||||
obj = raw
|
|
||||||
|
|
||||||
raw_matches = obj.get("matches")
|
|
||||||
matches = [str(m) for m in raw_matches] if isinstance(raw_matches, list) else []
|
|
||||||
|
|
||||||
raw_failures = obj.get("failures")
|
|
||||||
failures = 0
|
|
||||||
if isinstance(raw_failures, (int, float, str)):
|
|
||||||
try:
|
|
||||||
failures = int(raw_failures)
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
failures = 0
|
|
||||||
|
|
||||||
return matches, failures
|
|
||||||
@@ -49,7 +49,7 @@ logpath = /dev/null
|
|||||||
backend = auto
|
backend = auto
|
||||||
maxretry = 1
|
maxretry = 1
|
||||||
findtime = 1d
|
findtime = 1d
|
||||||
bantime = 1w
|
bantime = 86400
|
||||||
ignoreip = 127.0.0.0/8 ::1 172.16.0.0/12
|
ignoreip = 127.0.0.0/8 ::1 172.16.0.0/12
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|||||||
@@ -1,20 +0,0 @@
|
|||||||
"""Jail helpers to decouple service layer dependencies."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from collections.abc import Sequence
|
|
||||||
|
|
||||||
from app.services.jail_service import reload_all
|
|
||||||
|
|
||||||
|
|
||||||
async def reload_jails(
|
|
||||||
socket_path: str,
|
|
||||||
include_jails: Sequence[str] | None = None,
|
|
||||||
exclude_jails: Sequence[str] | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Reload fail2ban jails using shared jail service helper."""
|
|
||||||
await reload_all(
|
|
||||||
socket_path,
|
|
||||||
include_jails=list(include_jails) if include_jails is not None else None,
|
|
||||||
exclude_jails=list(exclude_jails) if exclude_jails is not None else None,
|
|
||||||
)
|
|
||||||
@@ -1,14 +0,0 @@
|
|||||||
"""Log-related helpers to avoid direct service-to-service imports."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from app.models.config import LogPreviewRequest, LogPreviewResponse, RegexTestRequest, RegexTestResponse
|
|
||||||
from app.services.log_service import preview_log as _preview_log, test_regex as _test_regex
|
|
||||||
|
|
||||||
|
|
||||||
async def preview_log(req: LogPreviewRequest) -> LogPreviewResponse:
|
|
||||||
return await _preview_log(req)
|
|
||||||
|
|
||||||
|
|
||||||
def test_regex(req: RegexTestRequest) -> RegexTestResponse:
|
|
||||||
return _test_regex(req)
|
|
||||||
@@ -1,47 +0,0 @@
|
|||||||
"""Setup-related utilities shared by multiple services."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from app.repositories import settings_repo
|
|
||||||
|
|
||||||
_KEY_PASSWORD_HASH = "master_password_hash"
|
|
||||||
_KEY_SETUP_DONE = "setup_completed"
|
|
||||||
_KEY_MAP_COLOR_THRESHOLD_HIGH = "map_color_threshold_high"
|
|
||||||
_KEY_MAP_COLOR_THRESHOLD_MEDIUM = "map_color_threshold_medium"
|
|
||||||
_KEY_MAP_COLOR_THRESHOLD_LOW = "map_color_threshold_low"
|
|
||||||
|
|
||||||
|
|
||||||
async def get_password_hash(db):
|
|
||||||
"""Return the stored master password hash or None."""
|
|
||||||
return await settings_repo.get_setting(db, _KEY_PASSWORD_HASH)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_map_color_thresholds(db):
|
|
||||||
"""Return map color thresholds as tuple (high, medium, low)."""
|
|
||||||
high = await settings_repo.get_setting(db, _KEY_MAP_COLOR_THRESHOLD_HIGH)
|
|
||||||
medium = await settings_repo.get_setting(db, _KEY_MAP_COLOR_THRESHOLD_MEDIUM)
|
|
||||||
low = await settings_repo.get_setting(db, _KEY_MAP_COLOR_THRESHOLD_LOW)
|
|
||||||
|
|
||||||
return (
|
|
||||||
int(high) if high else 100,
|
|
||||||
int(medium) if medium else 50,
|
|
||||||
int(low) if low else 20,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def set_map_color_thresholds(
|
|
||||||
db,
|
|
||||||
*,
|
|
||||||
threshold_high: int,
|
|
||||||
threshold_medium: int,
|
|
||||||
threshold_low: int,
|
|
||||||
) -> None:
|
|
||||||
"""Persist map color thresholds after validating values."""
|
|
||||||
if threshold_high <= 0 or threshold_medium <= 0 or threshold_low <= 0:
|
|
||||||
raise ValueError("All thresholds must be positive integers.")
|
|
||||||
if not (threshold_high > threshold_medium > threshold_low):
|
|
||||||
raise ValueError("Thresholds must satisfy: high > medium > low.")
|
|
||||||
|
|
||||||
await settings_repo.set_setting(db, _KEY_MAP_COLOR_THRESHOLD_HIGH, str(threshold_high))
|
|
||||||
await settings_repo.set_setting(db, _KEY_MAP_COLOR_THRESHOLD_MEDIUM, str(threshold_medium))
|
|
||||||
await settings_repo.set_setting(db, _KEY_MAP_COLOR_THRESHOLD_LOW, str(threshold_low))
|
|
||||||
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "bangui-backend"
|
name = "bangui-backend"
|
||||||
version = "0.9.0"
|
version = "0.9.4"
|
||||||
description = "BanGUI backend — fail2ban web management interface"
|
description = "BanGUI backend — fail2ban web management interface"
|
||||||
requires-python = ">=3.12"
|
requires-python = ">=3.12"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
@@ -60,5 +60,4 @@ plugins = ["pydantic.mypy"]
|
|||||||
asyncio_mode = "auto"
|
asyncio_mode = "auto"
|
||||||
pythonpath = [".", "../fail2ban-master"]
|
pythonpath = [".", "../fail2ban-master"]
|
||||||
testpaths = ["tests"]
|
testpaths = ["tests"]
|
||||||
addopts = "--asyncio-mode=auto --cov=app --cov-report=term-missing"
|
addopts = "--cov=app --cov-report=term-missing"
|
||||||
filterwarnings = ["ignore::pytest.PytestRemovedIn9Warning"]
|
|
||||||
|
|||||||
@@ -37,15 +37,9 @@ def test_settings(tmp_path: Path) -> Settings:
|
|||||||
Returns:
|
Returns:
|
||||||
A :class:`~app.config.Settings` instance with overridden paths.
|
A :class:`~app.config.Settings` instance with overridden paths.
|
||||||
"""
|
"""
|
||||||
config_dir = tmp_path / "fail2ban"
|
|
||||||
(config_dir / "jail.d").mkdir(parents=True)
|
|
||||||
(config_dir / "filter.d").mkdir(parents=True)
|
|
||||||
(config_dir / "action.d").mkdir(parents=True)
|
|
||||||
|
|
||||||
return Settings(
|
return Settings(
|
||||||
database_path=str(tmp_path / "test_bangui.db"),
|
database_path=str(tmp_path / "test_bangui.db"),
|
||||||
fail2ban_socket="/tmp/fake_fail2ban.sock",
|
fail2ban_socket="/tmp/fake_fail2ban.sock",
|
||||||
fail2ban_config_dir=str(config_dir),
|
|
||||||
session_secret="test-secret-key-do-not-use-in-production",
|
session_secret="test-secret-key-do-not-use-in-production",
|
||||||
session_duration_minutes=60,
|
session_duration_minutes=60,
|
||||||
timezone="UTC",
|
timezone="UTC",
|
||||||
|
|||||||
@@ -1,138 +0,0 @@
|
|||||||
"""Tests for the fail2ban_db repository.
|
|
||||||
|
|
||||||
These tests use an in-memory sqlite file created under pytest's tmp_path and
|
|
||||||
exercise the core query functions used by the services.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import aiosqlite
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from app.repositories import fail2ban_db_repo
|
|
||||||
|
|
||||||
|
|
||||||
async def _create_bans_table(db: aiosqlite.Connection) -> None:
|
|
||||||
await db.execute(
|
|
||||||
"""
|
|
||||||
CREATE TABLE bans (
|
|
||||||
jail TEXT,
|
|
||||||
ip TEXT,
|
|
||||||
timeofban INTEGER,
|
|
||||||
bancount INTEGER,
|
|
||||||
data TEXT
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_check_db_nonempty_returns_false_when_table_is_empty(tmp_path: Path) -> None:
|
|
||||||
db_path = str(tmp_path / "fail2ban.db")
|
|
||||||
async with aiosqlite.connect(db_path) as db:
|
|
||||||
await _create_bans_table(db)
|
|
||||||
|
|
||||||
assert await fail2ban_db_repo.check_db_nonempty(db_path) is False
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_check_db_nonempty_returns_true_when_row_exists(tmp_path: Path) -> None:
|
|
||||||
db_path = str(tmp_path / "fail2ban.db")
|
|
||||||
async with aiosqlite.connect(db_path) as db:
|
|
||||||
await _create_bans_table(db)
|
|
||||||
await db.execute(
|
|
||||||
"INSERT INTO bans (jail, ip, timeofban, bancount, data) VALUES (?, ?, ?, ?, ?)",
|
|
||||||
("jail1", "1.2.3.4", 123, 1, "{}"),
|
|
||||||
)
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
assert await fail2ban_db_repo.check_db_nonempty(db_path) is True
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_currently_banned_filters_and_pagination(tmp_path: Path) -> None:
|
|
||||||
db_path = str(tmp_path / "fail2ban.db")
|
|
||||||
async with aiosqlite.connect(db_path) as db:
|
|
||||||
await _create_bans_table(db)
|
|
||||||
# Three bans; one is from the blocklist-import jail.
|
|
||||||
await db.executemany(
|
|
||||||
"INSERT INTO bans (jail, ip, timeofban, bancount, data) VALUES (?, ?, ?, ?, ?)",
|
|
||||||
[
|
|
||||||
("jail1", "1.1.1.1", 10, 1, "{}"),
|
|
||||||
("blocklist-import", "2.2.2.2", 20, 2, "{}"),
|
|
||||||
("jail1", "3.3.3.3", 30, 3, "{}"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
records, total = await fail2ban_db_repo.get_currently_banned(
|
|
||||||
db_path=db_path,
|
|
||||||
since=15,
|
|
||||||
origin="selfblock",
|
|
||||||
limit=10,
|
|
||||||
offset=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Only the non-blocklist row with timeofban >= 15 should remain.
|
|
||||||
assert total == 1
|
|
||||||
assert len(records) == 1
|
|
||||||
assert records[0].ip == "3.3.3.3"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_ban_counts_by_bucket_ignores_out_of_range_buckets(tmp_path: Path) -> None:
|
|
||||||
db_path = str(tmp_path / "fail2ban.db")
|
|
||||||
async with aiosqlite.connect(db_path) as db:
|
|
||||||
await _create_bans_table(db)
|
|
||||||
await db.executemany(
|
|
||||||
"INSERT INTO bans (jail, ip, timeofban, bancount, data) VALUES (?, ?, ?, ?, ?)",
|
|
||||||
[
|
|
||||||
("jail1", "1.1.1.1", 5, 1, "{}"),
|
|
||||||
("jail1", "2.2.2.2", 15, 1, "{}"),
|
|
||||||
("jail1", "3.3.3.3", 35, 1, "{}"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
counts = await fail2ban_db_repo.get_ban_counts_by_bucket(
|
|
||||||
db_path=db_path,
|
|
||||||
since=0,
|
|
||||||
bucket_secs=10,
|
|
||||||
num_buckets=3,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert counts == [1, 1, 0]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_history_page_and_for_ip(tmp_path: Path) -> None:
|
|
||||||
db_path = str(tmp_path / "fail2ban.db")
|
|
||||||
async with aiosqlite.connect(db_path) as db:
|
|
||||||
await _create_bans_table(db)
|
|
||||||
await db.executemany(
|
|
||||||
"INSERT INTO bans (jail, ip, timeofban, bancount, data) VALUES (?, ?, ?, ?, ?)",
|
|
||||||
[
|
|
||||||
("jail1", "1.1.1.1", 100, 1, "{}"),
|
|
||||||
("jail1", "1.1.1.1", 200, 2, "{}"),
|
|
||||||
("jail1", "2.2.2.2", 300, 3, "{}"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
page, total = await fail2ban_db_repo.get_history_page(
|
|
||||||
db_path=db_path,
|
|
||||||
since=None,
|
|
||||||
jail="jail1",
|
|
||||||
ip_filter="1.1.1",
|
|
||||||
page=1,
|
|
||||||
page_size=10,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert total == 2
|
|
||||||
assert len(page) == 2
|
|
||||||
assert page[0].ip == "1.1.1.1"
|
|
||||||
|
|
||||||
history_for_ip = await fail2ban_db_repo.get_history_for_ip(db_path=db_path, ip="2.2.2.2")
|
|
||||||
assert len(history_for_ip) == 1
|
|
||||||
assert history_for_ip[0].ip == "2.2.2.2"
|
|
||||||
@@ -1,140 +0,0 @@
|
|||||||
"""Tests for the geo cache repository."""
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import aiosqlite
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from app.repositories import geo_cache_repo
|
|
||||||
|
|
||||||
|
|
||||||
async def _create_geo_cache_table(db: aiosqlite.Connection) -> None:
|
|
||||||
await db.execute(
|
|
||||||
"""
|
|
||||||
CREATE TABLE IF NOT EXISTS geo_cache (
|
|
||||||
ip TEXT PRIMARY KEY,
|
|
||||||
country_code TEXT,
|
|
||||||
country_name TEXT,
|
|
||||||
asn TEXT,
|
|
||||||
org TEXT,
|
|
||||||
cached_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now'))
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_unresolved_ips_returns_empty_when_none_exist(tmp_path: Path) -> None:
|
|
||||||
db_path = str(tmp_path / "geo_cache.db")
|
|
||||||
async with aiosqlite.connect(db_path) as db:
|
|
||||||
await _create_geo_cache_table(db)
|
|
||||||
await db.execute(
|
|
||||||
"INSERT INTO geo_cache (ip, country_code, country_name, asn, org) VALUES (?, ?, ?, ?, ?)",
|
|
||||||
("1.1.1.1", "DE", "Germany", "AS123", "Test"),
|
|
||||||
)
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
async with aiosqlite.connect(db_path) as db:
|
|
||||||
ips = await geo_cache_repo.get_unresolved_ips(db)
|
|
||||||
|
|
||||||
assert ips == []
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_unresolved_ips_returns_pending_ips(tmp_path: Path) -> None:
|
|
||||||
db_path = str(tmp_path / "geo_cache.db")
|
|
||||||
async with aiosqlite.connect(db_path) as db:
|
|
||||||
await _create_geo_cache_table(db)
|
|
||||||
await db.executemany(
|
|
||||||
"INSERT INTO geo_cache (ip, country_code) VALUES (?, ?)",
|
|
||||||
[
|
|
||||||
("2.2.2.2", None),
|
|
||||||
("3.3.3.3", None),
|
|
||||||
("4.4.4.4", "US"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
async with aiosqlite.connect(db_path) as db:
|
|
||||||
ips = await geo_cache_repo.get_unresolved_ips(db)
|
|
||||||
|
|
||||||
assert sorted(ips) == ["2.2.2.2", "3.3.3.3"]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_load_all_and_count_unresolved(tmp_path: Path) -> None:
|
|
||||||
db_path = str(tmp_path / "geo_cache.db")
|
|
||||||
async with aiosqlite.connect(db_path) as db:
|
|
||||||
await _create_geo_cache_table(db)
|
|
||||||
await db.executemany(
|
|
||||||
"INSERT INTO geo_cache (ip, country_code, country_name, asn, org) VALUES (?, ?, ?, ?, ?)",
|
|
||||||
[
|
|
||||||
("5.5.5.5", None, None, None, None),
|
|
||||||
("6.6.6.6", "FR", "France", "AS456", "TestOrg"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
async with aiosqlite.connect(db_path) as db:
|
|
||||||
rows = await geo_cache_repo.load_all(db)
|
|
||||||
unresolved = await geo_cache_repo.count_unresolved(db)
|
|
||||||
|
|
||||||
assert unresolved == 1
|
|
||||||
assert any(row["ip"] == "6.6.6.6" and row["country_code"] == "FR" for row in rows)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_upsert_entry_and_neg_entry(tmp_path: Path) -> None:
|
|
||||||
db_path = str(tmp_path / "geo_cache.db")
|
|
||||||
async with aiosqlite.connect(db_path) as db:
|
|
||||||
await _create_geo_cache_table(db)
|
|
||||||
|
|
||||||
await geo_cache_repo.upsert_entry(
|
|
||||||
db,
|
|
||||||
"7.7.7.7",
|
|
||||||
"GB",
|
|
||||||
"United Kingdom",
|
|
||||||
"AS789",
|
|
||||||
"TestOrg",
|
|
||||||
)
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
await geo_cache_repo.upsert_neg_entry(db, "8.8.8.8")
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
# Ensure positive entry is present.
|
|
||||||
async with db.execute("SELECT country_code FROM geo_cache WHERE ip = ?", ("7.7.7.7",)) as cur:
|
|
||||||
row = await cur.fetchone()
|
|
||||||
assert row is not None
|
|
||||||
assert row[0] == "GB"
|
|
||||||
|
|
||||||
# Ensure negative entry exists with NULL country_code.
|
|
||||||
async with db.execute("SELECT country_code FROM geo_cache WHERE ip = ?", ("8.8.8.8",)) as cur:
|
|
||||||
row = await cur.fetchone()
|
|
||||||
assert row is not None
|
|
||||||
assert row[0] is None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_bulk_upsert_entries_and_neg_entries(tmp_path: Path) -> None:
|
|
||||||
db_path = str(tmp_path / "geo_cache.db")
|
|
||||||
async with aiosqlite.connect(db_path) as db:
|
|
||||||
await _create_geo_cache_table(db)
|
|
||||||
|
|
||||||
rows = [
|
|
||||||
("9.9.9.9", "NL", "Netherlands", "AS101", "Test"),
|
|
||||||
("10.10.10.10", "JP", "Japan", "AS102", "Test"),
|
|
||||||
]
|
|
||||||
count = await geo_cache_repo.bulk_upsert_entries(db, rows)
|
|
||||||
assert count == 2
|
|
||||||
|
|
||||||
neg_count = await geo_cache_repo.bulk_upsert_neg_entries(db, ["11.11.11.11", "12.12.12.12"])
|
|
||||||
assert neg_count == 2
|
|
||||||
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
async with db.execute("SELECT COUNT(*) FROM geo_cache") as cur:
|
|
||||||
row = await cur.fetchone()
|
|
||||||
assert row is not None
|
|
||||||
assert int(row[0]) == 4
|
|
||||||
@@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Generator
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -158,12 +157,12 @@ class TestRequireAuthSessionCache:
|
|||||||
"""In-memory session token cache inside ``require_auth``."""
|
"""In-memory session token cache inside ``require_auth``."""
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def reset_cache(self) -> Generator[None, None, None]:
|
def reset_cache(self) -> None: # type: ignore[misc]
|
||||||
"""Flush the session cache before and after every test in this class."""
|
"""Flush the session cache before and after every test in this class."""
|
||||||
from app import dependencies
|
from app import dependencies
|
||||||
|
|
||||||
dependencies.clear_session_cache()
|
dependencies.clear_session_cache()
|
||||||
yield
|
yield # type: ignore[misc]
|
||||||
dependencies.clear_session_cache()
|
dependencies.clear_session_cache()
|
||||||
|
|
||||||
async def test_second_request_skips_db(self, client: AsyncClient) -> None:
|
async def test_second_request_skips_db(self, client: AsyncClient) -> None:
|
||||||
|
|||||||
@@ -501,7 +501,7 @@ class TestRegexTest:
|
|||||||
"""POST /api/config/regex-test returns matched=true for a valid match."""
|
"""POST /api/config/regex-test returns matched=true for a valid match."""
|
||||||
mock_response = RegexTestResponse(matched=True, groups=["1.2.3.4"], error=None)
|
mock_response = RegexTestResponse(matched=True, groups=["1.2.3.4"], error=None)
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.log_service.test_regex",
|
"app.routers.config.config_service.test_regex",
|
||||||
return_value=mock_response,
|
return_value=mock_response,
|
||||||
):
|
):
|
||||||
resp = await config_client.post(
|
resp = await config_client.post(
|
||||||
@@ -519,7 +519,7 @@ class TestRegexTest:
|
|||||||
"""POST /api/config/regex-test returns matched=false for no match."""
|
"""POST /api/config/regex-test returns matched=false for no match."""
|
||||||
mock_response = RegexTestResponse(matched=False, groups=[], error=None)
|
mock_response = RegexTestResponse(matched=False, groups=[], error=None)
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.log_service.test_regex",
|
"app.routers.config.config_service.test_regex",
|
||||||
return_value=mock_response,
|
return_value=mock_response,
|
||||||
):
|
):
|
||||||
resp = await config_client.post(
|
resp = await config_client.post(
|
||||||
@@ -597,7 +597,7 @@ class TestPreviewLog:
|
|||||||
matched_count=1,
|
matched_count=1,
|
||||||
)
|
)
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.log_service.preview_log",
|
"app.routers.config.config_service.preview_log",
|
||||||
AsyncMock(return_value=mock_response),
|
AsyncMock(return_value=mock_response),
|
||||||
):
|
):
|
||||||
resp = await config_client.post(
|
resp = await config_client.post(
|
||||||
@@ -725,7 +725,7 @@ class TestGetInactiveJails:
|
|||||||
mock_response = InactiveJailListResponse(jails=[mock_jail], total=1)
|
mock_response = InactiveJailListResponse(jails=[mock_jail], total=1)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.jail_config_service.list_inactive_jails",
|
"app.routers.config.config_file_service.list_inactive_jails",
|
||||||
AsyncMock(return_value=mock_response),
|
AsyncMock(return_value=mock_response),
|
||||||
):
|
):
|
||||||
resp = await config_client.get("/api/config/jails/inactive")
|
resp = await config_client.get("/api/config/jails/inactive")
|
||||||
@@ -740,7 +740,7 @@ class TestGetInactiveJails:
|
|||||||
from app.models.config import InactiveJailListResponse
|
from app.models.config import InactiveJailListResponse
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.jail_config_service.list_inactive_jails",
|
"app.routers.config.config_file_service.list_inactive_jails",
|
||||||
AsyncMock(return_value=InactiveJailListResponse(jails=[], total=0)),
|
AsyncMock(return_value=InactiveJailListResponse(jails=[], total=0)),
|
||||||
):
|
):
|
||||||
resp = await config_client.get("/api/config/jails/inactive")
|
resp = await config_client.get("/api/config/jails/inactive")
|
||||||
@@ -776,7 +776,7 @@ class TestActivateJail:
|
|||||||
message="Jail 'apache-auth' activated successfully.",
|
message="Jail 'apache-auth' activated successfully.",
|
||||||
)
|
)
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.jail_config_service.activate_jail",
|
"app.routers.config.config_file_service.activate_jail",
|
||||||
AsyncMock(return_value=mock_response),
|
AsyncMock(return_value=mock_response),
|
||||||
):
|
):
|
||||||
resp = await config_client.post(
|
resp = await config_client.post(
|
||||||
@@ -796,7 +796,7 @@ class TestActivateJail:
|
|||||||
name="apache-auth", active=True, message="Activated."
|
name="apache-auth", active=True, message="Activated."
|
||||||
)
|
)
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.jail_config_service.activate_jail",
|
"app.routers.config.config_file_service.activate_jail",
|
||||||
AsyncMock(return_value=mock_response),
|
AsyncMock(return_value=mock_response),
|
||||||
) as mock_activate:
|
) as mock_activate:
|
||||||
resp = await config_client.post(
|
resp = await config_client.post(
|
||||||
@@ -812,10 +812,10 @@ class TestActivateJail:
|
|||||||
|
|
||||||
async def test_404_for_unknown_jail(self, config_client: AsyncClient) -> None:
|
async def test_404_for_unknown_jail(self, config_client: AsyncClient) -> None:
|
||||||
"""POST /api/config/jails/missing/activate returns 404."""
|
"""POST /api/config/jails/missing/activate returns 404."""
|
||||||
from app.services.jail_config_service import JailNotFoundInConfigError
|
from app.services.config_file_service import JailNotFoundInConfigError
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.jail_config_service.activate_jail",
|
"app.routers.config.config_file_service.activate_jail",
|
||||||
AsyncMock(side_effect=JailNotFoundInConfigError("missing")),
|
AsyncMock(side_effect=JailNotFoundInConfigError("missing")),
|
||||||
):
|
):
|
||||||
resp = await config_client.post(
|
resp = await config_client.post(
|
||||||
@@ -826,10 +826,10 @@ class TestActivateJail:
|
|||||||
|
|
||||||
async def test_409_when_already_active(self, config_client: AsyncClient) -> None:
|
async def test_409_when_already_active(self, config_client: AsyncClient) -> None:
|
||||||
"""POST /api/config/jails/sshd/activate returns 409 if already active."""
|
"""POST /api/config/jails/sshd/activate returns 409 if already active."""
|
||||||
from app.services.jail_config_service import JailAlreadyActiveError
|
from app.services.config_file_service import JailAlreadyActiveError
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.jail_config_service.activate_jail",
|
"app.routers.config.config_file_service.activate_jail",
|
||||||
AsyncMock(side_effect=JailAlreadyActiveError("sshd")),
|
AsyncMock(side_effect=JailAlreadyActiveError("sshd")),
|
||||||
):
|
):
|
||||||
resp = await config_client.post(
|
resp = await config_client.post(
|
||||||
@@ -840,10 +840,10 @@ class TestActivateJail:
|
|||||||
|
|
||||||
async def test_400_for_invalid_jail_name(self, config_client: AsyncClient) -> None:
|
async def test_400_for_invalid_jail_name(self, config_client: AsyncClient) -> None:
|
||||||
"""POST /api/config/jails/ with bad name returns 400."""
|
"""POST /api/config/jails/ with bad name returns 400."""
|
||||||
from app.services.jail_config_service import JailNameError
|
from app.services.config_file_service import JailNameError
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.jail_config_service.activate_jail",
|
"app.routers.config.config_file_service.activate_jail",
|
||||||
AsyncMock(side_effect=JailNameError("bad name")),
|
AsyncMock(side_effect=JailNameError("bad name")),
|
||||||
):
|
):
|
||||||
resp = await config_client.post(
|
resp = await config_client.post(
|
||||||
@@ -872,7 +872,7 @@ class TestActivateJail:
|
|||||||
message="Jail 'airsonic-auth' cannot be activated: log file '/var/log/airsonic/airsonic.log' not found",
|
message="Jail 'airsonic-auth' cannot be activated: log file '/var/log/airsonic/airsonic.log' not found",
|
||||||
)
|
)
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.jail_config_service.activate_jail",
|
"app.routers.config.config_file_service.activate_jail",
|
||||||
AsyncMock(return_value=blocked_response),
|
AsyncMock(return_value=blocked_response),
|
||||||
):
|
):
|
||||||
resp = await config_client.post(
|
resp = await config_client.post(
|
||||||
@@ -905,7 +905,7 @@ class TestDeactivateJail:
|
|||||||
message="Jail 'sshd' deactivated successfully.",
|
message="Jail 'sshd' deactivated successfully.",
|
||||||
)
|
)
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.jail_config_service.deactivate_jail",
|
"app.routers.config.config_file_service.deactivate_jail",
|
||||||
AsyncMock(return_value=mock_response),
|
AsyncMock(return_value=mock_response),
|
||||||
):
|
):
|
||||||
resp = await config_client.post("/api/config/jails/sshd/deactivate")
|
resp = await config_client.post("/api/config/jails/sshd/deactivate")
|
||||||
@@ -917,10 +917,10 @@ class TestDeactivateJail:
|
|||||||
|
|
||||||
async def test_404_for_unknown_jail(self, config_client: AsyncClient) -> None:
|
async def test_404_for_unknown_jail(self, config_client: AsyncClient) -> None:
|
||||||
"""POST /api/config/jails/missing/deactivate returns 404."""
|
"""POST /api/config/jails/missing/deactivate returns 404."""
|
||||||
from app.services.jail_config_service import JailNotFoundInConfigError
|
from app.services.config_file_service import JailNotFoundInConfigError
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.jail_config_service.deactivate_jail",
|
"app.routers.config.config_file_service.deactivate_jail",
|
||||||
AsyncMock(side_effect=JailNotFoundInConfigError("missing")),
|
AsyncMock(side_effect=JailNotFoundInConfigError("missing")),
|
||||||
):
|
):
|
||||||
resp = await config_client.post(
|
resp = await config_client.post(
|
||||||
@@ -931,10 +931,10 @@ class TestDeactivateJail:
|
|||||||
|
|
||||||
async def test_409_when_already_inactive(self, config_client: AsyncClient) -> None:
|
async def test_409_when_already_inactive(self, config_client: AsyncClient) -> None:
|
||||||
"""POST /api/config/jails/apache-auth/deactivate returns 409 if already inactive."""
|
"""POST /api/config/jails/apache-auth/deactivate returns 409 if already inactive."""
|
||||||
from app.services.jail_config_service import JailAlreadyInactiveError
|
from app.services.config_file_service import JailAlreadyInactiveError
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.jail_config_service.deactivate_jail",
|
"app.routers.config.config_file_service.deactivate_jail",
|
||||||
AsyncMock(side_effect=JailAlreadyInactiveError("apache-auth")),
|
AsyncMock(side_effect=JailAlreadyInactiveError("apache-auth")),
|
||||||
):
|
):
|
||||||
resp = await config_client.post(
|
resp = await config_client.post(
|
||||||
@@ -945,10 +945,10 @@ class TestDeactivateJail:
|
|||||||
|
|
||||||
async def test_400_for_invalid_jail_name(self, config_client: AsyncClient) -> None:
|
async def test_400_for_invalid_jail_name(self, config_client: AsyncClient) -> None:
|
||||||
"""POST /api/config/jails/.../deactivate with bad name returns 400."""
|
"""POST /api/config/jails/.../deactivate with bad name returns 400."""
|
||||||
from app.services.jail_config_service import JailNameError
|
from app.services.config_file_service import JailNameError
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.jail_config_service.deactivate_jail",
|
"app.routers.config.config_file_service.deactivate_jail",
|
||||||
AsyncMock(side_effect=JailNameError("bad")),
|
AsyncMock(side_effect=JailNameError("bad")),
|
||||||
):
|
):
|
||||||
resp = await config_client.post(
|
resp = await config_client.post(
|
||||||
@@ -976,7 +976,7 @@ class TestDeactivateJail:
|
|||||||
)
|
)
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
"app.routers.config.jail_config_service.deactivate_jail",
|
"app.routers.config.config_file_service.deactivate_jail",
|
||||||
AsyncMock(return_value=mock_response),
|
AsyncMock(return_value=mock_response),
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
@@ -1027,7 +1027,7 @@ class TestListFilters:
|
|||||||
total=1,
|
total=1,
|
||||||
)
|
)
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.filter_config_service.list_filters",
|
"app.routers.config.config_file_service.list_filters",
|
||||||
AsyncMock(return_value=mock_response),
|
AsyncMock(return_value=mock_response),
|
||||||
):
|
):
|
||||||
resp = await config_client.get("/api/config/filters")
|
resp = await config_client.get("/api/config/filters")
|
||||||
@@ -1043,7 +1043,7 @@ class TestListFilters:
|
|||||||
from app.models.config import FilterListResponse
|
from app.models.config import FilterListResponse
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.filter_config_service.list_filters",
|
"app.routers.config.config_file_service.list_filters",
|
||||||
AsyncMock(return_value=FilterListResponse(filters=[], total=0)),
|
AsyncMock(return_value=FilterListResponse(filters=[], total=0)),
|
||||||
):
|
):
|
||||||
resp = await config_client.get("/api/config/filters")
|
resp = await config_client.get("/api/config/filters")
|
||||||
@@ -1066,7 +1066,7 @@ class TestListFilters:
|
|||||||
total=2,
|
total=2,
|
||||||
)
|
)
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.filter_config_service.list_filters",
|
"app.routers.config.config_file_service.list_filters",
|
||||||
AsyncMock(return_value=mock_response),
|
AsyncMock(return_value=mock_response),
|
||||||
):
|
):
|
||||||
resp = await config_client.get("/api/config/filters")
|
resp = await config_client.get("/api/config/filters")
|
||||||
@@ -1095,7 +1095,7 @@ class TestGetFilter:
|
|||||||
async def test_200_returns_filter(self, config_client: AsyncClient) -> None:
|
async def test_200_returns_filter(self, config_client: AsyncClient) -> None:
|
||||||
"""GET /api/config/filters/sshd returns 200 with FilterConfig."""
|
"""GET /api/config/filters/sshd returns 200 with FilterConfig."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.filter_config_service.get_filter",
|
"app.routers.config.config_file_service.get_filter",
|
||||||
AsyncMock(return_value=_make_filter_config("sshd")),
|
AsyncMock(return_value=_make_filter_config("sshd")),
|
||||||
):
|
):
|
||||||
resp = await config_client.get("/api/config/filters/sshd")
|
resp = await config_client.get("/api/config/filters/sshd")
|
||||||
@@ -1108,10 +1108,10 @@ class TestGetFilter:
|
|||||||
|
|
||||||
async def test_404_for_unknown_filter(self, config_client: AsyncClient) -> None:
|
async def test_404_for_unknown_filter(self, config_client: AsyncClient) -> None:
|
||||||
"""GET /api/config/filters/missing returns 404."""
|
"""GET /api/config/filters/missing returns 404."""
|
||||||
from app.services.filter_config_service import FilterNotFoundError
|
from app.services.config_file_service import FilterNotFoundError
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.filter_config_service.get_filter",
|
"app.routers.config.config_file_service.get_filter",
|
||||||
AsyncMock(side_effect=FilterNotFoundError("missing")),
|
AsyncMock(side_effect=FilterNotFoundError("missing")),
|
||||||
):
|
):
|
||||||
resp = await config_client.get("/api/config/filters/missing")
|
resp = await config_client.get("/api/config/filters/missing")
|
||||||
@@ -1138,7 +1138,7 @@ class TestUpdateFilter:
|
|||||||
async def test_200_returns_updated_filter(self, config_client: AsyncClient) -> None:
|
async def test_200_returns_updated_filter(self, config_client: AsyncClient) -> None:
|
||||||
"""PUT /api/config/filters/sshd returns 200 with updated FilterConfig."""
|
"""PUT /api/config/filters/sshd returns 200 with updated FilterConfig."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.filter_config_service.update_filter",
|
"app.routers.config.config_file_service.update_filter",
|
||||||
AsyncMock(return_value=_make_filter_config("sshd")),
|
AsyncMock(return_value=_make_filter_config("sshd")),
|
||||||
):
|
):
|
||||||
resp = await config_client.put(
|
resp = await config_client.put(
|
||||||
@@ -1151,10 +1151,10 @@ class TestUpdateFilter:
|
|||||||
|
|
||||||
async def test_404_for_unknown_filter(self, config_client: AsyncClient) -> None:
|
async def test_404_for_unknown_filter(self, config_client: AsyncClient) -> None:
|
||||||
"""PUT /api/config/filters/missing returns 404."""
|
"""PUT /api/config/filters/missing returns 404."""
|
||||||
from app.services.filter_config_service import FilterNotFoundError
|
from app.services.config_file_service import FilterNotFoundError
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.filter_config_service.update_filter",
|
"app.routers.config.config_file_service.update_filter",
|
||||||
AsyncMock(side_effect=FilterNotFoundError("missing")),
|
AsyncMock(side_effect=FilterNotFoundError("missing")),
|
||||||
):
|
):
|
||||||
resp = await config_client.put(
|
resp = await config_client.put(
|
||||||
@@ -1166,10 +1166,10 @@ class TestUpdateFilter:
|
|||||||
|
|
||||||
async def test_422_for_invalid_regex(self, config_client: AsyncClient) -> None:
|
async def test_422_for_invalid_regex(self, config_client: AsyncClient) -> None:
|
||||||
"""PUT /api/config/filters/sshd returns 422 for bad regex."""
|
"""PUT /api/config/filters/sshd returns 422 for bad regex."""
|
||||||
from app.services.filter_config_service import FilterInvalidRegexError
|
from app.services.config_file_service import FilterInvalidRegexError
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.filter_config_service.update_filter",
|
"app.routers.config.config_file_service.update_filter",
|
||||||
AsyncMock(side_effect=FilterInvalidRegexError("[bad", "unterminated")),
|
AsyncMock(side_effect=FilterInvalidRegexError("[bad", "unterminated")),
|
||||||
):
|
):
|
||||||
resp = await config_client.put(
|
resp = await config_client.put(
|
||||||
@@ -1181,10 +1181,10 @@ class TestUpdateFilter:
|
|||||||
|
|
||||||
async def test_400_for_invalid_name(self, config_client: AsyncClient) -> None:
|
async def test_400_for_invalid_name(self, config_client: AsyncClient) -> None:
|
||||||
"""PUT /api/config/filters/... with bad name returns 400."""
|
"""PUT /api/config/filters/... with bad name returns 400."""
|
||||||
from app.services.filter_config_service import FilterNameError
|
from app.services.config_file_service import FilterNameError
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.filter_config_service.update_filter",
|
"app.routers.config.config_file_service.update_filter",
|
||||||
AsyncMock(side_effect=FilterNameError("bad")),
|
AsyncMock(side_effect=FilterNameError("bad")),
|
||||||
):
|
):
|
||||||
resp = await config_client.put(
|
resp = await config_client.put(
|
||||||
@@ -1197,7 +1197,7 @@ class TestUpdateFilter:
|
|||||||
async def test_reload_query_param_passed(self, config_client: AsyncClient) -> None:
|
async def test_reload_query_param_passed(self, config_client: AsyncClient) -> None:
|
||||||
"""PUT /api/config/filters/sshd?reload=true passes do_reload=True."""
|
"""PUT /api/config/filters/sshd?reload=true passes do_reload=True."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.filter_config_service.update_filter",
|
"app.routers.config.config_file_service.update_filter",
|
||||||
AsyncMock(return_value=_make_filter_config("sshd")),
|
AsyncMock(return_value=_make_filter_config("sshd")),
|
||||||
) as mock_update:
|
) as mock_update:
|
||||||
resp = await config_client.put(
|
resp = await config_client.put(
|
||||||
@@ -1228,7 +1228,7 @@ class TestCreateFilter:
|
|||||||
async def test_201_creates_filter(self, config_client: AsyncClient) -> None:
|
async def test_201_creates_filter(self, config_client: AsyncClient) -> None:
|
||||||
"""POST /api/config/filters returns 201 with FilterConfig."""
|
"""POST /api/config/filters returns 201 with FilterConfig."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.filter_config_service.create_filter",
|
"app.routers.config.config_file_service.create_filter",
|
||||||
AsyncMock(return_value=_make_filter_config("my-custom")),
|
AsyncMock(return_value=_make_filter_config("my-custom")),
|
||||||
):
|
):
|
||||||
resp = await config_client.post(
|
resp = await config_client.post(
|
||||||
@@ -1241,10 +1241,10 @@ class TestCreateFilter:
|
|||||||
|
|
||||||
async def test_409_when_already_exists(self, config_client: AsyncClient) -> None:
|
async def test_409_when_already_exists(self, config_client: AsyncClient) -> None:
|
||||||
"""POST /api/config/filters returns 409 if filter exists."""
|
"""POST /api/config/filters returns 409 if filter exists."""
|
||||||
from app.services.filter_config_service import FilterAlreadyExistsError
|
from app.services.config_file_service import FilterAlreadyExistsError
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.filter_config_service.create_filter",
|
"app.routers.config.config_file_service.create_filter",
|
||||||
AsyncMock(side_effect=FilterAlreadyExistsError("sshd")),
|
AsyncMock(side_effect=FilterAlreadyExistsError("sshd")),
|
||||||
):
|
):
|
||||||
resp = await config_client.post(
|
resp = await config_client.post(
|
||||||
@@ -1256,10 +1256,10 @@ class TestCreateFilter:
|
|||||||
|
|
||||||
async def test_422_for_invalid_regex(self, config_client: AsyncClient) -> None:
|
async def test_422_for_invalid_regex(self, config_client: AsyncClient) -> None:
|
||||||
"""POST /api/config/filters returns 422 for bad regex."""
|
"""POST /api/config/filters returns 422 for bad regex."""
|
||||||
from app.services.filter_config_service import FilterInvalidRegexError
|
from app.services.config_file_service import FilterInvalidRegexError
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.filter_config_service.create_filter",
|
"app.routers.config.config_file_service.create_filter",
|
||||||
AsyncMock(side_effect=FilterInvalidRegexError("[bad", "unterminated")),
|
AsyncMock(side_effect=FilterInvalidRegexError("[bad", "unterminated")),
|
||||||
):
|
):
|
||||||
resp = await config_client.post(
|
resp = await config_client.post(
|
||||||
@@ -1271,10 +1271,10 @@ class TestCreateFilter:
|
|||||||
|
|
||||||
async def test_400_for_invalid_name(self, config_client: AsyncClient) -> None:
|
async def test_400_for_invalid_name(self, config_client: AsyncClient) -> None:
|
||||||
"""POST /api/config/filters returns 400 for invalid filter name."""
|
"""POST /api/config/filters returns 400 for invalid filter name."""
|
||||||
from app.services.filter_config_service import FilterNameError
|
from app.services.config_file_service import FilterNameError
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.filter_config_service.create_filter",
|
"app.routers.config.config_file_service.create_filter",
|
||||||
AsyncMock(side_effect=FilterNameError("bad")),
|
AsyncMock(side_effect=FilterNameError("bad")),
|
||||||
):
|
):
|
||||||
resp = await config_client.post(
|
resp = await config_client.post(
|
||||||
@@ -1304,7 +1304,7 @@ class TestDeleteFilter:
|
|||||||
async def test_204_deletes_filter(self, config_client: AsyncClient) -> None:
|
async def test_204_deletes_filter(self, config_client: AsyncClient) -> None:
|
||||||
"""DELETE /api/config/filters/my-custom returns 204."""
|
"""DELETE /api/config/filters/my-custom returns 204."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.filter_config_service.delete_filter",
|
"app.routers.config.config_file_service.delete_filter",
|
||||||
AsyncMock(return_value=None),
|
AsyncMock(return_value=None),
|
||||||
):
|
):
|
||||||
resp = await config_client.delete("/api/config/filters/my-custom")
|
resp = await config_client.delete("/api/config/filters/my-custom")
|
||||||
@@ -1313,10 +1313,10 @@ class TestDeleteFilter:
|
|||||||
|
|
||||||
async def test_404_for_unknown_filter(self, config_client: AsyncClient) -> None:
|
async def test_404_for_unknown_filter(self, config_client: AsyncClient) -> None:
|
||||||
"""DELETE /api/config/filters/missing returns 404."""
|
"""DELETE /api/config/filters/missing returns 404."""
|
||||||
from app.services.filter_config_service import FilterNotFoundError
|
from app.services.config_file_service import FilterNotFoundError
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.filter_config_service.delete_filter",
|
"app.routers.config.config_file_service.delete_filter",
|
||||||
AsyncMock(side_effect=FilterNotFoundError("missing")),
|
AsyncMock(side_effect=FilterNotFoundError("missing")),
|
||||||
):
|
):
|
||||||
resp = await config_client.delete("/api/config/filters/missing")
|
resp = await config_client.delete("/api/config/filters/missing")
|
||||||
@@ -1325,10 +1325,10 @@ class TestDeleteFilter:
|
|||||||
|
|
||||||
async def test_409_for_readonly_filter(self, config_client: AsyncClient) -> None:
|
async def test_409_for_readonly_filter(self, config_client: AsyncClient) -> None:
|
||||||
"""DELETE /api/config/filters/sshd returns 409 for shipped conf-only filter."""
|
"""DELETE /api/config/filters/sshd returns 409 for shipped conf-only filter."""
|
||||||
from app.services.filter_config_service import FilterReadonlyError
|
from app.services.config_file_service import FilterReadonlyError
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.filter_config_service.delete_filter",
|
"app.routers.config.config_file_service.delete_filter",
|
||||||
AsyncMock(side_effect=FilterReadonlyError("sshd")),
|
AsyncMock(side_effect=FilterReadonlyError("sshd")),
|
||||||
):
|
):
|
||||||
resp = await config_client.delete("/api/config/filters/sshd")
|
resp = await config_client.delete("/api/config/filters/sshd")
|
||||||
@@ -1337,10 +1337,10 @@ class TestDeleteFilter:
|
|||||||
|
|
||||||
async def test_400_for_invalid_name(self, config_client: AsyncClient) -> None:
|
async def test_400_for_invalid_name(self, config_client: AsyncClient) -> None:
|
||||||
"""DELETE /api/config/filters/... with bad name returns 400."""
|
"""DELETE /api/config/filters/... with bad name returns 400."""
|
||||||
from app.services.filter_config_service import FilterNameError
|
from app.services.config_file_service import FilterNameError
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.filter_config_service.delete_filter",
|
"app.routers.config.config_file_service.delete_filter",
|
||||||
AsyncMock(side_effect=FilterNameError("bad")),
|
AsyncMock(side_effect=FilterNameError("bad")),
|
||||||
):
|
):
|
||||||
resp = await config_client.delete("/api/config/filters/bad")
|
resp = await config_client.delete("/api/config/filters/bad")
|
||||||
@@ -1367,7 +1367,7 @@ class TestAssignFilterToJail:
|
|||||||
async def test_204_assigns_filter(self, config_client: AsyncClient) -> None:
|
async def test_204_assigns_filter(self, config_client: AsyncClient) -> None:
|
||||||
"""POST /api/config/jails/sshd/filter returns 204 on success."""
|
"""POST /api/config/jails/sshd/filter returns 204 on success."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.filter_config_service.assign_filter_to_jail",
|
"app.routers.config.config_file_service.assign_filter_to_jail",
|
||||||
AsyncMock(return_value=None),
|
AsyncMock(return_value=None),
|
||||||
):
|
):
|
||||||
resp = await config_client.post(
|
resp = await config_client.post(
|
||||||
@@ -1379,10 +1379,10 @@ class TestAssignFilterToJail:
|
|||||||
|
|
||||||
async def test_404_for_unknown_jail(self, config_client: AsyncClient) -> None:
|
async def test_404_for_unknown_jail(self, config_client: AsyncClient) -> None:
|
||||||
"""POST /api/config/jails/missing/filter returns 404."""
|
"""POST /api/config/jails/missing/filter returns 404."""
|
||||||
from app.services.jail_config_service import JailNotFoundInConfigError
|
from app.services.config_file_service import JailNotFoundInConfigError
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.filter_config_service.assign_filter_to_jail",
|
"app.routers.config.config_file_service.assign_filter_to_jail",
|
||||||
AsyncMock(side_effect=JailNotFoundInConfigError("missing")),
|
AsyncMock(side_effect=JailNotFoundInConfigError("missing")),
|
||||||
):
|
):
|
||||||
resp = await config_client.post(
|
resp = await config_client.post(
|
||||||
@@ -1394,10 +1394,10 @@ class TestAssignFilterToJail:
|
|||||||
|
|
||||||
async def test_404_for_unknown_filter(self, config_client: AsyncClient) -> None:
|
async def test_404_for_unknown_filter(self, config_client: AsyncClient) -> None:
|
||||||
"""POST /api/config/jails/sshd/filter returns 404 when filter not found."""
|
"""POST /api/config/jails/sshd/filter returns 404 when filter not found."""
|
||||||
from app.services.filter_config_service import FilterNotFoundError
|
from app.services.config_file_service import FilterNotFoundError
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.filter_config_service.assign_filter_to_jail",
|
"app.routers.config.config_file_service.assign_filter_to_jail",
|
||||||
AsyncMock(side_effect=FilterNotFoundError("missing-filter")),
|
AsyncMock(side_effect=FilterNotFoundError("missing-filter")),
|
||||||
):
|
):
|
||||||
resp = await config_client.post(
|
resp = await config_client.post(
|
||||||
@@ -1409,10 +1409,10 @@ class TestAssignFilterToJail:
|
|||||||
|
|
||||||
async def test_400_for_invalid_jail_name(self, config_client: AsyncClient) -> None:
|
async def test_400_for_invalid_jail_name(self, config_client: AsyncClient) -> None:
|
||||||
"""POST /api/config/jails/.../filter with bad jail name returns 400."""
|
"""POST /api/config/jails/.../filter with bad jail name returns 400."""
|
||||||
from app.services.jail_config_service import JailNameError
|
from app.services.config_file_service import JailNameError
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.filter_config_service.assign_filter_to_jail",
|
"app.routers.config.config_file_service.assign_filter_to_jail",
|
||||||
AsyncMock(side_effect=JailNameError("bad")),
|
AsyncMock(side_effect=JailNameError("bad")),
|
||||||
):
|
):
|
||||||
resp = await config_client.post(
|
resp = await config_client.post(
|
||||||
@@ -1424,10 +1424,10 @@ class TestAssignFilterToJail:
|
|||||||
|
|
||||||
async def test_400_for_invalid_filter_name(self, config_client: AsyncClient) -> None:
|
async def test_400_for_invalid_filter_name(self, config_client: AsyncClient) -> None:
|
||||||
"""POST /api/config/jails/sshd/filter with bad filter name returns 400."""
|
"""POST /api/config/jails/sshd/filter with bad filter name returns 400."""
|
||||||
from app.services.filter_config_service import FilterNameError
|
from app.services.config_file_service import FilterNameError
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.filter_config_service.assign_filter_to_jail",
|
"app.routers.config.config_file_service.assign_filter_to_jail",
|
||||||
AsyncMock(side_effect=FilterNameError("bad")),
|
AsyncMock(side_effect=FilterNameError("bad")),
|
||||||
):
|
):
|
||||||
resp = await config_client.post(
|
resp = await config_client.post(
|
||||||
@@ -1440,7 +1440,7 @@ class TestAssignFilterToJail:
|
|||||||
async def test_reload_query_param_passed(self, config_client: AsyncClient) -> None:
|
async def test_reload_query_param_passed(self, config_client: AsyncClient) -> None:
|
||||||
"""POST /api/config/jails/sshd/filter?reload=true passes do_reload=True."""
|
"""POST /api/config/jails/sshd/filter?reload=true passes do_reload=True."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.filter_config_service.assign_filter_to_jail",
|
"app.routers.config.config_file_service.assign_filter_to_jail",
|
||||||
AsyncMock(return_value=None),
|
AsyncMock(return_value=None),
|
||||||
) as mock_assign:
|
) as mock_assign:
|
||||||
resp = await config_client.post(
|
resp = await config_client.post(
|
||||||
@@ -1478,7 +1478,7 @@ class TestListActionsRouter:
|
|||||||
mock_response = ActionListResponse(actions=[mock_action], total=1)
|
mock_response = ActionListResponse(actions=[mock_action], total=1)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.action_config_service.list_actions",
|
"app.routers.config.config_file_service.list_actions",
|
||||||
AsyncMock(return_value=mock_response),
|
AsyncMock(return_value=mock_response),
|
||||||
):
|
):
|
||||||
resp = await config_client.get("/api/config/actions")
|
resp = await config_client.get("/api/config/actions")
|
||||||
@@ -1496,7 +1496,7 @@ class TestListActionsRouter:
|
|||||||
mock_response = ActionListResponse(actions=[inactive, active], total=2)
|
mock_response = ActionListResponse(actions=[inactive, active], total=2)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.action_config_service.list_actions",
|
"app.routers.config.config_file_service.list_actions",
|
||||||
AsyncMock(return_value=mock_response),
|
AsyncMock(return_value=mock_response),
|
||||||
):
|
):
|
||||||
resp = await config_client.get("/api/config/actions")
|
resp = await config_client.get("/api/config/actions")
|
||||||
@@ -1524,7 +1524,7 @@ class TestGetActionRouter:
|
|||||||
)
|
)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.action_config_service.get_action",
|
"app.routers.config.config_file_service.get_action",
|
||||||
AsyncMock(return_value=mock_action),
|
AsyncMock(return_value=mock_action),
|
||||||
):
|
):
|
||||||
resp = await config_client.get("/api/config/actions/iptables")
|
resp = await config_client.get("/api/config/actions/iptables")
|
||||||
@@ -1533,10 +1533,10 @@ class TestGetActionRouter:
|
|||||||
assert resp.json()["name"] == "iptables"
|
assert resp.json()["name"] == "iptables"
|
||||||
|
|
||||||
async def test_404_when_not_found(self, config_client: AsyncClient) -> None:
|
async def test_404_when_not_found(self, config_client: AsyncClient) -> None:
|
||||||
from app.services.action_config_service import ActionNotFoundError
|
from app.services.config_file_service import ActionNotFoundError
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.action_config_service.get_action",
|
"app.routers.config.config_file_service.get_action",
|
||||||
AsyncMock(side_effect=ActionNotFoundError("missing")),
|
AsyncMock(side_effect=ActionNotFoundError("missing")),
|
||||||
):
|
):
|
||||||
resp = await config_client.get("/api/config/actions/missing")
|
resp = await config_client.get("/api/config/actions/missing")
|
||||||
@@ -1563,7 +1563,7 @@ class TestUpdateActionRouter:
|
|||||||
)
|
)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.action_config_service.update_action",
|
"app.routers.config.config_file_service.update_action",
|
||||||
AsyncMock(return_value=updated),
|
AsyncMock(return_value=updated),
|
||||||
):
|
):
|
||||||
resp = await config_client.put(
|
resp = await config_client.put(
|
||||||
@@ -1575,10 +1575,10 @@ class TestUpdateActionRouter:
|
|||||||
assert resp.json()["actionban"] == "echo ban"
|
assert resp.json()["actionban"] == "echo ban"
|
||||||
|
|
||||||
async def test_404_when_not_found(self, config_client: AsyncClient) -> None:
|
async def test_404_when_not_found(self, config_client: AsyncClient) -> None:
|
||||||
from app.services.action_config_service import ActionNotFoundError
|
from app.services.config_file_service import ActionNotFoundError
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.action_config_service.update_action",
|
"app.routers.config.config_file_service.update_action",
|
||||||
AsyncMock(side_effect=ActionNotFoundError("missing")),
|
AsyncMock(side_effect=ActionNotFoundError("missing")),
|
||||||
):
|
):
|
||||||
resp = await config_client.put(
|
resp = await config_client.put(
|
||||||
@@ -1588,10 +1588,10 @@ class TestUpdateActionRouter:
|
|||||||
assert resp.status_code == 404
|
assert resp.status_code == 404
|
||||||
|
|
||||||
async def test_400_for_bad_name(self, config_client: AsyncClient) -> None:
|
async def test_400_for_bad_name(self, config_client: AsyncClient) -> None:
|
||||||
from app.services.action_config_service import ActionNameError
|
from app.services.config_file_service import ActionNameError
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.action_config_service.update_action",
|
"app.routers.config.config_file_service.update_action",
|
||||||
AsyncMock(side_effect=ActionNameError()),
|
AsyncMock(side_effect=ActionNameError()),
|
||||||
):
|
):
|
||||||
resp = await config_client.put(
|
resp = await config_client.put(
|
||||||
@@ -1620,7 +1620,7 @@ class TestCreateActionRouter:
|
|||||||
)
|
)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.action_config_service.create_action",
|
"app.routers.config.config_file_service.create_action",
|
||||||
AsyncMock(return_value=created),
|
AsyncMock(return_value=created),
|
||||||
):
|
):
|
||||||
resp = await config_client.post(
|
resp = await config_client.post(
|
||||||
@@ -1632,10 +1632,10 @@ class TestCreateActionRouter:
|
|||||||
assert resp.json()["name"] == "custom"
|
assert resp.json()["name"] == "custom"
|
||||||
|
|
||||||
async def test_409_when_already_exists(self, config_client: AsyncClient) -> None:
|
async def test_409_when_already_exists(self, config_client: AsyncClient) -> None:
|
||||||
from app.services.action_config_service import ActionAlreadyExistsError
|
from app.services.config_file_service import ActionAlreadyExistsError
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.action_config_service.create_action",
|
"app.routers.config.config_file_service.create_action",
|
||||||
AsyncMock(side_effect=ActionAlreadyExistsError("iptables")),
|
AsyncMock(side_effect=ActionAlreadyExistsError("iptables")),
|
||||||
):
|
):
|
||||||
resp = await config_client.post(
|
resp = await config_client.post(
|
||||||
@@ -1646,10 +1646,10 @@ class TestCreateActionRouter:
|
|||||||
assert resp.status_code == 409
|
assert resp.status_code == 409
|
||||||
|
|
||||||
async def test_400_for_bad_name(self, config_client: AsyncClient) -> None:
|
async def test_400_for_bad_name(self, config_client: AsyncClient) -> None:
|
||||||
from app.services.action_config_service import ActionNameError
|
from app.services.config_file_service import ActionNameError
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.action_config_service.create_action",
|
"app.routers.config.config_file_service.create_action",
|
||||||
AsyncMock(side_effect=ActionNameError()),
|
AsyncMock(side_effect=ActionNameError()),
|
||||||
):
|
):
|
||||||
resp = await config_client.post(
|
resp = await config_client.post(
|
||||||
@@ -1671,7 +1671,7 @@ class TestCreateActionRouter:
|
|||||||
class TestDeleteActionRouter:
|
class TestDeleteActionRouter:
|
||||||
async def test_204_on_delete(self, config_client: AsyncClient) -> None:
|
async def test_204_on_delete(self, config_client: AsyncClient) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.action_config_service.delete_action",
|
"app.routers.config.config_file_service.delete_action",
|
||||||
AsyncMock(return_value=None),
|
AsyncMock(return_value=None),
|
||||||
):
|
):
|
||||||
resp = await config_client.delete("/api/config/actions/custom")
|
resp = await config_client.delete("/api/config/actions/custom")
|
||||||
@@ -1679,10 +1679,10 @@ class TestDeleteActionRouter:
|
|||||||
assert resp.status_code == 204
|
assert resp.status_code == 204
|
||||||
|
|
||||||
async def test_404_when_not_found(self, config_client: AsyncClient) -> None:
|
async def test_404_when_not_found(self, config_client: AsyncClient) -> None:
|
||||||
from app.services.action_config_service import ActionNotFoundError
|
from app.services.config_file_service import ActionNotFoundError
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.action_config_service.delete_action",
|
"app.routers.config.config_file_service.delete_action",
|
||||||
AsyncMock(side_effect=ActionNotFoundError("missing")),
|
AsyncMock(side_effect=ActionNotFoundError("missing")),
|
||||||
):
|
):
|
||||||
resp = await config_client.delete("/api/config/actions/missing")
|
resp = await config_client.delete("/api/config/actions/missing")
|
||||||
@@ -1690,10 +1690,10 @@ class TestDeleteActionRouter:
|
|||||||
assert resp.status_code == 404
|
assert resp.status_code == 404
|
||||||
|
|
||||||
async def test_409_when_readonly(self, config_client: AsyncClient) -> None:
|
async def test_409_when_readonly(self, config_client: AsyncClient) -> None:
|
||||||
from app.services.action_config_service import ActionReadonlyError
|
from app.services.config_file_service import ActionReadonlyError
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.action_config_service.delete_action",
|
"app.routers.config.config_file_service.delete_action",
|
||||||
AsyncMock(side_effect=ActionReadonlyError("iptables")),
|
AsyncMock(side_effect=ActionReadonlyError("iptables")),
|
||||||
):
|
):
|
||||||
resp = await config_client.delete("/api/config/actions/iptables")
|
resp = await config_client.delete("/api/config/actions/iptables")
|
||||||
@@ -1701,10 +1701,10 @@ class TestDeleteActionRouter:
|
|||||||
assert resp.status_code == 409
|
assert resp.status_code == 409
|
||||||
|
|
||||||
async def test_400_for_bad_name(self, config_client: AsyncClient) -> None:
|
async def test_400_for_bad_name(self, config_client: AsyncClient) -> None:
|
||||||
from app.services.action_config_service import ActionNameError
|
from app.services.config_file_service import ActionNameError
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.action_config_service.delete_action",
|
"app.routers.config.config_file_service.delete_action",
|
||||||
AsyncMock(side_effect=ActionNameError()),
|
AsyncMock(side_effect=ActionNameError()),
|
||||||
):
|
):
|
||||||
resp = await config_client.delete("/api/config/actions/badname")
|
resp = await config_client.delete("/api/config/actions/badname")
|
||||||
@@ -1723,7 +1723,7 @@ class TestDeleteActionRouter:
|
|||||||
class TestAssignActionToJailRouter:
|
class TestAssignActionToJailRouter:
|
||||||
async def test_204_on_success(self, config_client: AsyncClient) -> None:
|
async def test_204_on_success(self, config_client: AsyncClient) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.action_config_service.assign_action_to_jail",
|
"app.routers.config.config_file_service.assign_action_to_jail",
|
||||||
AsyncMock(return_value=None),
|
AsyncMock(return_value=None),
|
||||||
):
|
):
|
||||||
resp = await config_client.post(
|
resp = await config_client.post(
|
||||||
@@ -1734,10 +1734,10 @@ class TestAssignActionToJailRouter:
|
|||||||
assert resp.status_code == 204
|
assert resp.status_code == 204
|
||||||
|
|
||||||
async def test_404_when_jail_not_found(self, config_client: AsyncClient) -> None:
|
async def test_404_when_jail_not_found(self, config_client: AsyncClient) -> None:
|
||||||
from app.services.jail_config_service import JailNotFoundInConfigError
|
from app.services.config_file_service import JailNotFoundInConfigError
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.action_config_service.assign_action_to_jail",
|
"app.routers.config.config_file_service.assign_action_to_jail",
|
||||||
AsyncMock(side_effect=JailNotFoundInConfigError("missing")),
|
AsyncMock(side_effect=JailNotFoundInConfigError("missing")),
|
||||||
):
|
):
|
||||||
resp = await config_client.post(
|
resp = await config_client.post(
|
||||||
@@ -1748,10 +1748,10 @@ class TestAssignActionToJailRouter:
|
|||||||
assert resp.status_code == 404
|
assert resp.status_code == 404
|
||||||
|
|
||||||
async def test_404_when_action_not_found(self, config_client: AsyncClient) -> None:
|
async def test_404_when_action_not_found(self, config_client: AsyncClient) -> None:
|
||||||
from app.services.action_config_service import ActionNotFoundError
|
from app.services.config_file_service import ActionNotFoundError
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.action_config_service.assign_action_to_jail",
|
"app.routers.config.config_file_service.assign_action_to_jail",
|
||||||
AsyncMock(side_effect=ActionNotFoundError("missing")),
|
AsyncMock(side_effect=ActionNotFoundError("missing")),
|
||||||
):
|
):
|
||||||
resp = await config_client.post(
|
resp = await config_client.post(
|
||||||
@@ -1762,10 +1762,10 @@ class TestAssignActionToJailRouter:
|
|||||||
assert resp.status_code == 404
|
assert resp.status_code == 404
|
||||||
|
|
||||||
async def test_400_for_bad_jail_name(self, config_client: AsyncClient) -> None:
|
async def test_400_for_bad_jail_name(self, config_client: AsyncClient) -> None:
|
||||||
from app.services.jail_config_service import JailNameError
|
from app.services.config_file_service import JailNameError
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.action_config_service.assign_action_to_jail",
|
"app.routers.config.config_file_service.assign_action_to_jail",
|
||||||
AsyncMock(side_effect=JailNameError()),
|
AsyncMock(side_effect=JailNameError()),
|
||||||
):
|
):
|
||||||
resp = await config_client.post(
|
resp = await config_client.post(
|
||||||
@@ -1776,10 +1776,10 @@ class TestAssignActionToJailRouter:
|
|||||||
assert resp.status_code == 400
|
assert resp.status_code == 400
|
||||||
|
|
||||||
async def test_400_for_bad_action_name(self, config_client: AsyncClient) -> None:
|
async def test_400_for_bad_action_name(self, config_client: AsyncClient) -> None:
|
||||||
from app.services.action_config_service import ActionNameError
|
from app.services.config_file_service import ActionNameError
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.action_config_service.assign_action_to_jail",
|
"app.routers.config.config_file_service.assign_action_to_jail",
|
||||||
AsyncMock(side_effect=ActionNameError()),
|
AsyncMock(side_effect=ActionNameError()),
|
||||||
):
|
):
|
||||||
resp = await config_client.post(
|
resp = await config_client.post(
|
||||||
@@ -1791,7 +1791,7 @@ class TestAssignActionToJailRouter:
|
|||||||
|
|
||||||
async def test_reload_param_passed(self, config_client: AsyncClient) -> None:
|
async def test_reload_param_passed(self, config_client: AsyncClient) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.action_config_service.assign_action_to_jail",
|
"app.routers.config.config_file_service.assign_action_to_jail",
|
||||||
AsyncMock(return_value=None),
|
AsyncMock(return_value=None),
|
||||||
) as mock_assign:
|
) as mock_assign:
|
||||||
resp = await config_client.post(
|
resp = await config_client.post(
|
||||||
@@ -1814,7 +1814,7 @@ class TestAssignActionToJailRouter:
|
|||||||
class TestRemoveActionFromJailRouter:
|
class TestRemoveActionFromJailRouter:
|
||||||
async def test_204_on_success(self, config_client: AsyncClient) -> None:
|
async def test_204_on_success(self, config_client: AsyncClient) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.action_config_service.remove_action_from_jail",
|
"app.routers.config.config_file_service.remove_action_from_jail",
|
||||||
AsyncMock(return_value=None),
|
AsyncMock(return_value=None),
|
||||||
):
|
):
|
||||||
resp = await config_client.delete(
|
resp = await config_client.delete(
|
||||||
@@ -1824,10 +1824,10 @@ class TestRemoveActionFromJailRouter:
|
|||||||
assert resp.status_code == 204
|
assert resp.status_code == 204
|
||||||
|
|
||||||
async def test_404_when_jail_not_found(self, config_client: AsyncClient) -> None:
|
async def test_404_when_jail_not_found(self, config_client: AsyncClient) -> None:
|
||||||
from app.services.jail_config_service import JailNotFoundInConfigError
|
from app.services.config_file_service import JailNotFoundInConfigError
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.action_config_service.remove_action_from_jail",
|
"app.routers.config.config_file_service.remove_action_from_jail",
|
||||||
AsyncMock(side_effect=JailNotFoundInConfigError("missing")),
|
AsyncMock(side_effect=JailNotFoundInConfigError("missing")),
|
||||||
):
|
):
|
||||||
resp = await config_client.delete(
|
resp = await config_client.delete(
|
||||||
@@ -1837,10 +1837,10 @@ class TestRemoveActionFromJailRouter:
|
|||||||
assert resp.status_code == 404
|
assert resp.status_code == 404
|
||||||
|
|
||||||
async def test_400_for_bad_jail_name(self, config_client: AsyncClient) -> None:
|
async def test_400_for_bad_jail_name(self, config_client: AsyncClient) -> None:
|
||||||
from app.services.jail_config_service import JailNameError
|
from app.services.config_file_service import JailNameError
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.action_config_service.remove_action_from_jail",
|
"app.routers.config.config_file_service.remove_action_from_jail",
|
||||||
AsyncMock(side_effect=JailNameError()),
|
AsyncMock(side_effect=JailNameError()),
|
||||||
):
|
):
|
||||||
resp = await config_client.delete(
|
resp = await config_client.delete(
|
||||||
@@ -1850,10 +1850,10 @@ class TestRemoveActionFromJailRouter:
|
|||||||
assert resp.status_code == 400
|
assert resp.status_code == 400
|
||||||
|
|
||||||
async def test_400_for_bad_action_name(self, config_client: AsyncClient) -> None:
|
async def test_400_for_bad_action_name(self, config_client: AsyncClient) -> None:
|
||||||
from app.services.action_config_service import ActionNameError
|
from app.services.config_file_service import ActionNameError
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.action_config_service.remove_action_from_jail",
|
"app.routers.config.config_file_service.remove_action_from_jail",
|
||||||
AsyncMock(side_effect=ActionNameError()),
|
AsyncMock(side_effect=ActionNameError()),
|
||||||
):
|
):
|
||||||
resp = await config_client.delete(
|
resp = await config_client.delete(
|
||||||
@@ -1864,7 +1864,7 @@ class TestRemoveActionFromJailRouter:
|
|||||||
|
|
||||||
async def test_reload_param_passed(self, config_client: AsyncClient) -> None:
|
async def test_reload_param_passed(self, config_client: AsyncClient) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.action_config_service.remove_action_from_jail",
|
"app.routers.config.config_file_service.remove_action_from_jail",
|
||||||
AsyncMock(return_value=None),
|
AsyncMock(return_value=None),
|
||||||
) as mock_rm:
|
) as mock_rm:
|
||||||
resp = await config_client.delete(
|
resp = await config_client.delete(
|
||||||
@@ -2060,7 +2060,7 @@ class TestValidateJailEndpoint:
|
|||||||
jail_name="sshd", valid=True, issues=[]
|
jail_name="sshd", valid=True, issues=[]
|
||||||
)
|
)
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.jail_config_service.validate_jail_config",
|
"app.routers.config.config_file_service.validate_jail_config",
|
||||||
AsyncMock(return_value=mock_result),
|
AsyncMock(return_value=mock_result),
|
||||||
):
|
):
|
||||||
resp = await config_client.post("/api/config/jails/sshd/validate")
|
resp = await config_client.post("/api/config/jails/sshd/validate")
|
||||||
@@ -2080,7 +2080,7 @@ class TestValidateJailEndpoint:
|
|||||||
jail_name="sshd", valid=False, issues=[issue]
|
jail_name="sshd", valid=False, issues=[issue]
|
||||||
)
|
)
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.jail_config_service.validate_jail_config",
|
"app.routers.config.config_file_service.validate_jail_config",
|
||||||
AsyncMock(return_value=mock_result),
|
AsyncMock(return_value=mock_result),
|
||||||
):
|
):
|
||||||
resp = await config_client.post("/api/config/jails/sshd/validate")
|
resp = await config_client.post("/api/config/jails/sshd/validate")
|
||||||
@@ -2093,10 +2093,10 @@ class TestValidateJailEndpoint:
|
|||||||
|
|
||||||
async def test_400_for_invalid_jail_name(self, config_client: AsyncClient) -> None:
|
async def test_400_for_invalid_jail_name(self, config_client: AsyncClient) -> None:
|
||||||
"""POST /api/config/jails/bad-name/validate returns 400 on JailNameError."""
|
"""POST /api/config/jails/bad-name/validate returns 400 on JailNameError."""
|
||||||
from app.services.jail_config_service import JailNameError
|
from app.services.config_file_service import JailNameError
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.jail_config_service.validate_jail_config",
|
"app.routers.config.config_file_service.validate_jail_config",
|
||||||
AsyncMock(side_effect=JailNameError("bad name")),
|
AsyncMock(side_effect=JailNameError("bad name")),
|
||||||
):
|
):
|
||||||
resp = await config_client.post("/api/config/jails/bad-name/validate")
|
resp = await config_client.post("/api/config/jails/bad-name/validate")
|
||||||
@@ -2188,7 +2188,7 @@ class TestRollbackEndpoint:
|
|||||||
message="Jail 'sshd' disabled and fail2ban restarted.",
|
message="Jail 'sshd' disabled and fail2ban restarted.",
|
||||||
)
|
)
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.jail_config_service.rollback_jail",
|
"app.routers.config.config_file_service.rollback_jail",
|
||||||
AsyncMock(return_value=mock_result),
|
AsyncMock(return_value=mock_result),
|
||||||
):
|
):
|
||||||
resp = await config_client.post("/api/config/jails/sshd/rollback")
|
resp = await config_client.post("/api/config/jails/sshd/rollback")
|
||||||
@@ -2225,7 +2225,7 @@ class TestRollbackEndpoint:
|
|||||||
message="fail2ban did not come back online.",
|
message="fail2ban did not come back online.",
|
||||||
)
|
)
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.jail_config_service.rollback_jail",
|
"app.routers.config.config_file_service.rollback_jail",
|
||||||
AsyncMock(return_value=mock_result),
|
AsyncMock(return_value=mock_result),
|
||||||
):
|
):
|
||||||
resp = await config_client.post("/api/config/jails/sshd/rollback")
|
resp = await config_client.post("/api/config/jails/sshd/rollback")
|
||||||
@@ -2238,10 +2238,10 @@ class TestRollbackEndpoint:
|
|||||||
|
|
||||||
async def test_400_for_invalid_jail_name(self, config_client: AsyncClient) -> None:
|
async def test_400_for_invalid_jail_name(self, config_client: AsyncClient) -> None:
|
||||||
"""POST /api/config/jails/bad/rollback returns 400 on JailNameError."""
|
"""POST /api/config/jails/bad/rollback returns 400 on JailNameError."""
|
||||||
from app.services.jail_config_service import JailNameError
|
from app.services.config_file_service import JailNameError
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.jail_config_service.rollback_jail",
|
"app.routers.config.config_file_service.rollback_jail",
|
||||||
AsyncMock(side_effect=JailNameError("bad")),
|
AsyncMock(side_effect=JailNameError("bad")),
|
||||||
):
|
):
|
||||||
resp = await config_client.post("/api/config/jails/bad/rollback")
|
resp = await config_client.post("/api/config/jails/bad/rollback")
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ from app.models.file_config import (
|
|||||||
JailConfigFileContent,
|
JailConfigFileContent,
|
||||||
JailConfigFilesResponse,
|
JailConfigFilesResponse,
|
||||||
)
|
)
|
||||||
from app.services.raw_config_io_service import (
|
from app.services.file_config_service import (
|
||||||
ConfigDirError,
|
ConfigDirError,
|
||||||
ConfigFileExistsError,
|
ConfigFileExistsError,
|
||||||
ConfigFileNameError,
|
ConfigFileNameError,
|
||||||
@@ -112,7 +112,7 @@ class TestListJailConfigFiles:
|
|||||||
self, file_config_client: AsyncClient
|
self, file_config_client: AsyncClient
|
||||||
) -> None:
|
) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.file_config.raw_config_io_service.list_jail_config_files",
|
"app.routers.file_config.file_config_service.list_jail_config_files",
|
||||||
AsyncMock(return_value=_jail_files_resp()),
|
AsyncMock(return_value=_jail_files_resp()),
|
||||||
):
|
):
|
||||||
resp = await file_config_client.get("/api/config/jail-files")
|
resp = await file_config_client.get("/api/config/jail-files")
|
||||||
@@ -126,7 +126,7 @@ class TestListJailConfigFiles:
|
|||||||
self, file_config_client: AsyncClient
|
self, file_config_client: AsyncClient
|
||||||
) -> None:
|
) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.file_config.raw_config_io_service.list_jail_config_files",
|
"app.routers.file_config.file_config_service.list_jail_config_files",
|
||||||
AsyncMock(side_effect=ConfigDirError("not found")),
|
AsyncMock(side_effect=ConfigDirError("not found")),
|
||||||
):
|
):
|
||||||
resp = await file_config_client.get("/api/config/jail-files")
|
resp = await file_config_client.get("/api/config/jail-files")
|
||||||
@@ -157,7 +157,7 @@ class TestGetJailConfigFile:
|
|||||||
content="[sshd]\nenabled = true\n",
|
content="[sshd]\nenabled = true\n",
|
||||||
)
|
)
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.file_config.raw_config_io_service.get_jail_config_file",
|
"app.routers.file_config.file_config_service.get_jail_config_file",
|
||||||
AsyncMock(return_value=content),
|
AsyncMock(return_value=content),
|
||||||
):
|
):
|
||||||
resp = await file_config_client.get("/api/config/jail-files/sshd.conf")
|
resp = await file_config_client.get("/api/config/jail-files/sshd.conf")
|
||||||
@@ -167,7 +167,7 @@ class TestGetJailConfigFile:
|
|||||||
|
|
||||||
async def test_404_not_found(self, file_config_client: AsyncClient) -> None:
|
async def test_404_not_found(self, file_config_client: AsyncClient) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.file_config.raw_config_io_service.get_jail_config_file",
|
"app.routers.file_config.file_config_service.get_jail_config_file",
|
||||||
AsyncMock(side_effect=ConfigFileNotFoundError("missing.conf")),
|
AsyncMock(side_effect=ConfigFileNotFoundError("missing.conf")),
|
||||||
):
|
):
|
||||||
resp = await file_config_client.get("/api/config/jail-files/missing.conf")
|
resp = await file_config_client.get("/api/config/jail-files/missing.conf")
|
||||||
@@ -178,7 +178,7 @@ class TestGetJailConfigFile:
|
|||||||
self, file_config_client: AsyncClient
|
self, file_config_client: AsyncClient
|
||||||
) -> None:
|
) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.file_config.raw_config_io_service.get_jail_config_file",
|
"app.routers.file_config.file_config_service.get_jail_config_file",
|
||||||
AsyncMock(side_effect=ConfigFileNameError("bad name")),
|
AsyncMock(side_effect=ConfigFileNameError("bad name")),
|
||||||
):
|
):
|
||||||
resp = await file_config_client.get("/api/config/jail-files/bad.txt")
|
resp = await file_config_client.get("/api/config/jail-files/bad.txt")
|
||||||
@@ -194,7 +194,7 @@ class TestGetJailConfigFile:
|
|||||||
class TestSetJailConfigEnabled:
|
class TestSetJailConfigEnabled:
|
||||||
async def test_204_on_success(self, file_config_client: AsyncClient) -> None:
|
async def test_204_on_success(self, file_config_client: AsyncClient) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.file_config.raw_config_io_service.set_jail_config_enabled",
|
"app.routers.file_config.file_config_service.set_jail_config_enabled",
|
||||||
AsyncMock(return_value=None),
|
AsyncMock(return_value=None),
|
||||||
):
|
):
|
||||||
resp = await file_config_client.put(
|
resp = await file_config_client.put(
|
||||||
@@ -206,7 +206,7 @@ class TestSetJailConfigEnabled:
|
|||||||
|
|
||||||
async def test_404_file_not_found(self, file_config_client: AsyncClient) -> None:
|
async def test_404_file_not_found(self, file_config_client: AsyncClient) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.file_config.raw_config_io_service.set_jail_config_enabled",
|
"app.routers.file_config.file_config_service.set_jail_config_enabled",
|
||||||
AsyncMock(side_effect=ConfigFileNotFoundError("missing.conf")),
|
AsyncMock(side_effect=ConfigFileNotFoundError("missing.conf")),
|
||||||
):
|
):
|
||||||
resp = await file_config_client.put(
|
resp = await file_config_client.put(
|
||||||
@@ -232,7 +232,7 @@ class TestGetFilterFileRaw:
|
|||||||
|
|
||||||
async def test_200_returns_content(self, file_config_client: AsyncClient) -> None:
|
async def test_200_returns_content(self, file_config_client: AsyncClient) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.file_config.raw_config_io_service.get_filter_file",
|
"app.routers.file_config.file_config_service.get_filter_file",
|
||||||
AsyncMock(return_value=_conf_file_content("nginx")),
|
AsyncMock(return_value=_conf_file_content("nginx")),
|
||||||
):
|
):
|
||||||
resp = await file_config_client.get("/api/config/filters/nginx/raw")
|
resp = await file_config_client.get("/api/config/filters/nginx/raw")
|
||||||
@@ -242,7 +242,7 @@ class TestGetFilterFileRaw:
|
|||||||
|
|
||||||
async def test_404_not_found(self, file_config_client: AsyncClient) -> None:
|
async def test_404_not_found(self, file_config_client: AsyncClient) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.file_config.raw_config_io_service.get_filter_file",
|
"app.routers.file_config.file_config_service.get_filter_file",
|
||||||
AsyncMock(side_effect=ConfigFileNotFoundError("missing")),
|
AsyncMock(side_effect=ConfigFileNotFoundError("missing")),
|
||||||
):
|
):
|
||||||
resp = await file_config_client.get("/api/config/filters/missing/raw")
|
resp = await file_config_client.get("/api/config/filters/missing/raw")
|
||||||
@@ -258,7 +258,7 @@ class TestGetFilterFileRaw:
|
|||||||
class TestUpdateFilterFile:
|
class TestUpdateFilterFile:
|
||||||
async def test_204_on_success(self, file_config_client: AsyncClient) -> None:
|
async def test_204_on_success(self, file_config_client: AsyncClient) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.file_config.raw_config_io_service.write_filter_file",
|
"app.routers.file_config.file_config_service.write_filter_file",
|
||||||
AsyncMock(return_value=None),
|
AsyncMock(return_value=None),
|
||||||
):
|
):
|
||||||
resp = await file_config_client.put(
|
resp = await file_config_client.put(
|
||||||
@@ -270,7 +270,7 @@ class TestUpdateFilterFile:
|
|||||||
|
|
||||||
async def test_400_write_error(self, file_config_client: AsyncClient) -> None:
|
async def test_400_write_error(self, file_config_client: AsyncClient) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.file_config.raw_config_io_service.write_filter_file",
|
"app.routers.file_config.file_config_service.write_filter_file",
|
||||||
AsyncMock(side_effect=ConfigFileWriteError("disk full")),
|
AsyncMock(side_effect=ConfigFileWriteError("disk full")),
|
||||||
):
|
):
|
||||||
resp = await file_config_client.put(
|
resp = await file_config_client.put(
|
||||||
@@ -289,7 +289,7 @@ class TestUpdateFilterFile:
|
|||||||
class TestCreateFilterFile:
|
class TestCreateFilterFile:
|
||||||
async def test_201_creates_file(self, file_config_client: AsyncClient) -> None:
|
async def test_201_creates_file(self, file_config_client: AsyncClient) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.file_config.raw_config_io_service.create_filter_file",
|
"app.routers.file_config.file_config_service.create_filter_file",
|
||||||
AsyncMock(return_value="myfilter.conf"),
|
AsyncMock(return_value="myfilter.conf"),
|
||||||
):
|
):
|
||||||
resp = await file_config_client.post(
|
resp = await file_config_client.post(
|
||||||
@@ -302,7 +302,7 @@ class TestCreateFilterFile:
|
|||||||
|
|
||||||
async def test_409_conflict(self, file_config_client: AsyncClient) -> None:
|
async def test_409_conflict(self, file_config_client: AsyncClient) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.file_config.raw_config_io_service.create_filter_file",
|
"app.routers.file_config.file_config_service.create_filter_file",
|
||||||
AsyncMock(side_effect=ConfigFileExistsError("myfilter.conf")),
|
AsyncMock(side_effect=ConfigFileExistsError("myfilter.conf")),
|
||||||
):
|
):
|
||||||
resp = await file_config_client.post(
|
resp = await file_config_client.post(
|
||||||
@@ -314,7 +314,7 @@ class TestCreateFilterFile:
|
|||||||
|
|
||||||
async def test_400_invalid_name(self, file_config_client: AsyncClient) -> None:
|
async def test_400_invalid_name(self, file_config_client: AsyncClient) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.file_config.raw_config_io_service.create_filter_file",
|
"app.routers.file_config.file_config_service.create_filter_file",
|
||||||
AsyncMock(side_effect=ConfigFileNameError("bad/../name")),
|
AsyncMock(side_effect=ConfigFileNameError("bad/../name")),
|
||||||
):
|
):
|
||||||
resp = await file_config_client.post(
|
resp = await file_config_client.post(
|
||||||
@@ -342,7 +342,7 @@ class TestListActionFiles:
|
|||||||
)
|
)
|
||||||
resp_data = ActionListResponse(actions=[mock_action], total=1)
|
resp_data = ActionListResponse(actions=[mock_action], total=1)
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.action_config_service.list_actions",
|
"app.routers.config.config_file_service.list_actions",
|
||||||
AsyncMock(return_value=resp_data),
|
AsyncMock(return_value=resp_data),
|
||||||
):
|
):
|
||||||
resp = await file_config_client.get("/api/config/actions")
|
resp = await file_config_client.get("/api/config/actions")
|
||||||
@@ -365,7 +365,7 @@ class TestCreateActionFile:
|
|||||||
actionban="echo ban <ip>",
|
actionban="echo ban <ip>",
|
||||||
)
|
)
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.config.action_config_service.create_action",
|
"app.routers.config.config_file_service.create_action",
|
||||||
AsyncMock(return_value=created),
|
AsyncMock(return_value=created),
|
||||||
):
|
):
|
||||||
resp = await file_config_client.post(
|
resp = await file_config_client.post(
|
||||||
@@ -387,7 +387,7 @@ class TestGetActionFileRaw:
|
|||||||
|
|
||||||
async def test_200_returns_content(self, file_config_client: AsyncClient) -> None:
|
async def test_200_returns_content(self, file_config_client: AsyncClient) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.file_config.raw_config_io_service.get_action_file",
|
"app.routers.file_config.file_config_service.get_action_file",
|
||||||
AsyncMock(return_value=_conf_file_content("iptables")),
|
AsyncMock(return_value=_conf_file_content("iptables")),
|
||||||
):
|
):
|
||||||
resp = await file_config_client.get("/api/config/actions/iptables/raw")
|
resp = await file_config_client.get("/api/config/actions/iptables/raw")
|
||||||
@@ -397,7 +397,7 @@ class TestGetActionFileRaw:
|
|||||||
|
|
||||||
async def test_404_not_found(self, file_config_client: AsyncClient) -> None:
|
async def test_404_not_found(self, file_config_client: AsyncClient) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.file_config.raw_config_io_service.get_action_file",
|
"app.routers.file_config.file_config_service.get_action_file",
|
||||||
AsyncMock(side_effect=ConfigFileNotFoundError("missing")),
|
AsyncMock(side_effect=ConfigFileNotFoundError("missing")),
|
||||||
):
|
):
|
||||||
resp = await file_config_client.get("/api/config/actions/missing/raw")
|
resp = await file_config_client.get("/api/config/actions/missing/raw")
|
||||||
@@ -408,7 +408,7 @@ class TestGetActionFileRaw:
|
|||||||
self, file_config_client: AsyncClient
|
self, file_config_client: AsyncClient
|
||||||
) -> None:
|
) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.file_config.raw_config_io_service.get_action_file",
|
"app.routers.file_config.file_config_service.get_action_file",
|
||||||
AsyncMock(side_effect=ConfigDirError("no dir")),
|
AsyncMock(side_effect=ConfigDirError("no dir")),
|
||||||
):
|
):
|
||||||
resp = await file_config_client.get("/api/config/actions/iptables/raw")
|
resp = await file_config_client.get("/api/config/actions/iptables/raw")
|
||||||
@@ -426,7 +426,7 @@ class TestUpdateActionFileRaw:
|
|||||||
|
|
||||||
async def test_204_on_success(self, file_config_client: AsyncClient) -> None:
|
async def test_204_on_success(self, file_config_client: AsyncClient) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.file_config.raw_config_io_service.write_action_file",
|
"app.routers.file_config.file_config_service.write_action_file",
|
||||||
AsyncMock(return_value=None),
|
AsyncMock(return_value=None),
|
||||||
):
|
):
|
||||||
resp = await file_config_client.put(
|
resp = await file_config_client.put(
|
||||||
@@ -438,7 +438,7 @@ class TestUpdateActionFileRaw:
|
|||||||
|
|
||||||
async def test_400_write_error(self, file_config_client: AsyncClient) -> None:
|
async def test_400_write_error(self, file_config_client: AsyncClient) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.file_config.raw_config_io_service.write_action_file",
|
"app.routers.file_config.file_config_service.write_action_file",
|
||||||
AsyncMock(side_effect=ConfigFileWriteError("disk full")),
|
AsyncMock(side_effect=ConfigFileWriteError("disk full")),
|
||||||
):
|
):
|
||||||
resp = await file_config_client.put(
|
resp = await file_config_client.put(
|
||||||
@@ -450,7 +450,7 @@ class TestUpdateActionFileRaw:
|
|||||||
|
|
||||||
async def test_404_not_found(self, file_config_client: AsyncClient) -> None:
|
async def test_404_not_found(self, file_config_client: AsyncClient) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.file_config.raw_config_io_service.write_action_file",
|
"app.routers.file_config.file_config_service.write_action_file",
|
||||||
AsyncMock(side_effect=ConfigFileNotFoundError("missing")),
|
AsyncMock(side_effect=ConfigFileNotFoundError("missing")),
|
||||||
):
|
):
|
||||||
resp = await file_config_client.put(
|
resp = await file_config_client.put(
|
||||||
@@ -462,7 +462,7 @@ class TestUpdateActionFileRaw:
|
|||||||
|
|
||||||
async def test_400_invalid_name(self, file_config_client: AsyncClient) -> None:
|
async def test_400_invalid_name(self, file_config_client: AsyncClient) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.file_config.raw_config_io_service.write_action_file",
|
"app.routers.file_config.file_config_service.write_action_file",
|
||||||
AsyncMock(side_effect=ConfigFileNameError("bad/../name")),
|
AsyncMock(side_effect=ConfigFileNameError("bad/../name")),
|
||||||
):
|
):
|
||||||
resp = await file_config_client.put(
|
resp = await file_config_client.put(
|
||||||
@@ -481,7 +481,7 @@ class TestUpdateActionFileRaw:
|
|||||||
class TestCreateJailConfigFile:
|
class TestCreateJailConfigFile:
|
||||||
async def test_201_creates_file(self, file_config_client: AsyncClient) -> None:
|
async def test_201_creates_file(self, file_config_client: AsyncClient) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.file_config.raw_config_io_service.create_jail_config_file",
|
"app.routers.file_config.file_config_service.create_jail_config_file",
|
||||||
AsyncMock(return_value="myjail.conf"),
|
AsyncMock(return_value="myjail.conf"),
|
||||||
):
|
):
|
||||||
resp = await file_config_client.post(
|
resp = await file_config_client.post(
|
||||||
@@ -494,7 +494,7 @@ class TestCreateJailConfigFile:
|
|||||||
|
|
||||||
async def test_409_conflict(self, file_config_client: AsyncClient) -> None:
|
async def test_409_conflict(self, file_config_client: AsyncClient) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.file_config.raw_config_io_service.create_jail_config_file",
|
"app.routers.file_config.file_config_service.create_jail_config_file",
|
||||||
AsyncMock(side_effect=ConfigFileExistsError("myjail.conf")),
|
AsyncMock(side_effect=ConfigFileExistsError("myjail.conf")),
|
||||||
):
|
):
|
||||||
resp = await file_config_client.post(
|
resp = await file_config_client.post(
|
||||||
@@ -506,7 +506,7 @@ class TestCreateJailConfigFile:
|
|||||||
|
|
||||||
async def test_400_invalid_name(self, file_config_client: AsyncClient) -> None:
|
async def test_400_invalid_name(self, file_config_client: AsyncClient) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.file_config.raw_config_io_service.create_jail_config_file",
|
"app.routers.file_config.file_config_service.create_jail_config_file",
|
||||||
AsyncMock(side_effect=ConfigFileNameError("bad/../name")),
|
AsyncMock(side_effect=ConfigFileNameError("bad/../name")),
|
||||||
):
|
):
|
||||||
resp = await file_config_client.post(
|
resp = await file_config_client.post(
|
||||||
@@ -520,7 +520,7 @@ class TestCreateJailConfigFile:
|
|||||||
self, file_config_client: AsyncClient
|
self, file_config_client: AsyncClient
|
||||||
) -> None:
|
) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.file_config.raw_config_io_service.create_jail_config_file",
|
"app.routers.file_config.file_config_service.create_jail_config_file",
|
||||||
AsyncMock(side_effect=ConfigDirError("no dir")),
|
AsyncMock(side_effect=ConfigDirError("no dir")),
|
||||||
):
|
):
|
||||||
resp = await file_config_client.post(
|
resp = await file_config_client.post(
|
||||||
@@ -542,7 +542,7 @@ class TestGetParsedFilter:
|
|||||||
) -> None:
|
) -> None:
|
||||||
cfg = FilterConfig(name="nginx", filename="nginx.conf")
|
cfg = FilterConfig(name="nginx", filename="nginx.conf")
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.file_config.raw_config_io_service.get_parsed_filter_file",
|
"app.routers.file_config.file_config_service.get_parsed_filter_file",
|
||||||
AsyncMock(return_value=cfg),
|
AsyncMock(return_value=cfg),
|
||||||
):
|
):
|
||||||
resp = await file_config_client.get("/api/config/filters/nginx/parsed")
|
resp = await file_config_client.get("/api/config/filters/nginx/parsed")
|
||||||
@@ -554,7 +554,7 @@ class TestGetParsedFilter:
|
|||||||
|
|
||||||
async def test_404_not_found(self, file_config_client: AsyncClient) -> None:
|
async def test_404_not_found(self, file_config_client: AsyncClient) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.file_config.raw_config_io_service.get_parsed_filter_file",
|
"app.routers.file_config.file_config_service.get_parsed_filter_file",
|
||||||
AsyncMock(side_effect=ConfigFileNotFoundError("missing")),
|
AsyncMock(side_effect=ConfigFileNotFoundError("missing")),
|
||||||
):
|
):
|
||||||
resp = await file_config_client.get(
|
resp = await file_config_client.get(
|
||||||
@@ -567,7 +567,7 @@ class TestGetParsedFilter:
|
|||||||
self, file_config_client: AsyncClient
|
self, file_config_client: AsyncClient
|
||||||
) -> None:
|
) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.file_config.raw_config_io_service.get_parsed_filter_file",
|
"app.routers.file_config.file_config_service.get_parsed_filter_file",
|
||||||
AsyncMock(side_effect=ConfigDirError("no dir")),
|
AsyncMock(side_effect=ConfigDirError("no dir")),
|
||||||
):
|
):
|
||||||
resp = await file_config_client.get("/api/config/filters/nginx/parsed")
|
resp = await file_config_client.get("/api/config/filters/nginx/parsed")
|
||||||
@@ -583,7 +583,7 @@ class TestGetParsedFilter:
|
|||||||
class TestUpdateParsedFilter:
|
class TestUpdateParsedFilter:
|
||||||
async def test_204_on_success(self, file_config_client: AsyncClient) -> None:
|
async def test_204_on_success(self, file_config_client: AsyncClient) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.file_config.raw_config_io_service.update_parsed_filter_file",
|
"app.routers.file_config.file_config_service.update_parsed_filter_file",
|
||||||
AsyncMock(return_value=None),
|
AsyncMock(return_value=None),
|
||||||
):
|
):
|
||||||
resp = await file_config_client.put(
|
resp = await file_config_client.put(
|
||||||
@@ -595,7 +595,7 @@ class TestUpdateParsedFilter:
|
|||||||
|
|
||||||
async def test_404_not_found(self, file_config_client: AsyncClient) -> None:
|
async def test_404_not_found(self, file_config_client: AsyncClient) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.file_config.raw_config_io_service.update_parsed_filter_file",
|
"app.routers.file_config.file_config_service.update_parsed_filter_file",
|
||||||
AsyncMock(side_effect=ConfigFileNotFoundError("missing")),
|
AsyncMock(side_effect=ConfigFileNotFoundError("missing")),
|
||||||
):
|
):
|
||||||
resp = await file_config_client.put(
|
resp = await file_config_client.put(
|
||||||
@@ -607,7 +607,7 @@ class TestUpdateParsedFilter:
|
|||||||
|
|
||||||
async def test_400_write_error(self, file_config_client: AsyncClient) -> None:
|
async def test_400_write_error(self, file_config_client: AsyncClient) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.file_config.raw_config_io_service.update_parsed_filter_file",
|
"app.routers.file_config.file_config_service.update_parsed_filter_file",
|
||||||
AsyncMock(side_effect=ConfigFileWriteError("disk full")),
|
AsyncMock(side_effect=ConfigFileWriteError("disk full")),
|
||||||
):
|
):
|
||||||
resp = await file_config_client.put(
|
resp = await file_config_client.put(
|
||||||
@@ -629,7 +629,7 @@ class TestGetParsedAction:
|
|||||||
) -> None:
|
) -> None:
|
||||||
cfg = ActionConfig(name="iptables", filename="iptables.conf")
|
cfg = ActionConfig(name="iptables", filename="iptables.conf")
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.file_config.raw_config_io_service.get_parsed_action_file",
|
"app.routers.file_config.file_config_service.get_parsed_action_file",
|
||||||
AsyncMock(return_value=cfg),
|
AsyncMock(return_value=cfg),
|
||||||
):
|
):
|
||||||
resp = await file_config_client.get(
|
resp = await file_config_client.get(
|
||||||
@@ -643,7 +643,7 @@ class TestGetParsedAction:
|
|||||||
|
|
||||||
async def test_404_not_found(self, file_config_client: AsyncClient) -> None:
|
async def test_404_not_found(self, file_config_client: AsyncClient) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.file_config.raw_config_io_service.get_parsed_action_file",
|
"app.routers.file_config.file_config_service.get_parsed_action_file",
|
||||||
AsyncMock(side_effect=ConfigFileNotFoundError("missing")),
|
AsyncMock(side_effect=ConfigFileNotFoundError("missing")),
|
||||||
):
|
):
|
||||||
resp = await file_config_client.get(
|
resp = await file_config_client.get(
|
||||||
@@ -656,7 +656,7 @@ class TestGetParsedAction:
|
|||||||
self, file_config_client: AsyncClient
|
self, file_config_client: AsyncClient
|
||||||
) -> None:
|
) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.file_config.raw_config_io_service.get_parsed_action_file",
|
"app.routers.file_config.file_config_service.get_parsed_action_file",
|
||||||
AsyncMock(side_effect=ConfigDirError("no dir")),
|
AsyncMock(side_effect=ConfigDirError("no dir")),
|
||||||
):
|
):
|
||||||
resp = await file_config_client.get(
|
resp = await file_config_client.get(
|
||||||
@@ -674,7 +674,7 @@ class TestGetParsedAction:
|
|||||||
class TestUpdateParsedAction:
|
class TestUpdateParsedAction:
|
||||||
async def test_204_on_success(self, file_config_client: AsyncClient) -> None:
|
async def test_204_on_success(self, file_config_client: AsyncClient) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.file_config.raw_config_io_service.update_parsed_action_file",
|
"app.routers.file_config.file_config_service.update_parsed_action_file",
|
||||||
AsyncMock(return_value=None),
|
AsyncMock(return_value=None),
|
||||||
):
|
):
|
||||||
resp = await file_config_client.put(
|
resp = await file_config_client.put(
|
||||||
@@ -686,7 +686,7 @@ class TestUpdateParsedAction:
|
|||||||
|
|
||||||
async def test_404_not_found(self, file_config_client: AsyncClient) -> None:
|
async def test_404_not_found(self, file_config_client: AsyncClient) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.file_config.raw_config_io_service.update_parsed_action_file",
|
"app.routers.file_config.file_config_service.update_parsed_action_file",
|
||||||
AsyncMock(side_effect=ConfigFileNotFoundError("missing")),
|
AsyncMock(side_effect=ConfigFileNotFoundError("missing")),
|
||||||
):
|
):
|
||||||
resp = await file_config_client.put(
|
resp = await file_config_client.put(
|
||||||
@@ -698,7 +698,7 @@ class TestUpdateParsedAction:
|
|||||||
|
|
||||||
async def test_400_write_error(self, file_config_client: AsyncClient) -> None:
|
async def test_400_write_error(self, file_config_client: AsyncClient) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.file_config.raw_config_io_service.update_parsed_action_file",
|
"app.routers.file_config.file_config_service.update_parsed_action_file",
|
||||||
AsyncMock(side_effect=ConfigFileWriteError("disk full")),
|
AsyncMock(side_effect=ConfigFileWriteError("disk full")),
|
||||||
):
|
):
|
||||||
resp = await file_config_client.put(
|
resp = await file_config_client.put(
|
||||||
@@ -721,7 +721,7 @@ class TestGetParsedJailFile:
|
|||||||
section = JailSectionConfig(enabled=True, port="ssh")
|
section = JailSectionConfig(enabled=True, port="ssh")
|
||||||
cfg = JailFileConfig(filename="sshd.conf", jails={"sshd": section})
|
cfg = JailFileConfig(filename="sshd.conf", jails={"sshd": section})
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.file_config.raw_config_io_service.get_parsed_jail_file",
|
"app.routers.file_config.file_config_service.get_parsed_jail_file",
|
||||||
AsyncMock(return_value=cfg),
|
AsyncMock(return_value=cfg),
|
||||||
):
|
):
|
||||||
resp = await file_config_client.get(
|
resp = await file_config_client.get(
|
||||||
@@ -735,7 +735,7 @@ class TestGetParsedJailFile:
|
|||||||
|
|
||||||
async def test_404_not_found(self, file_config_client: AsyncClient) -> None:
|
async def test_404_not_found(self, file_config_client: AsyncClient) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.file_config.raw_config_io_service.get_parsed_jail_file",
|
"app.routers.file_config.file_config_service.get_parsed_jail_file",
|
||||||
AsyncMock(side_effect=ConfigFileNotFoundError("missing.conf")),
|
AsyncMock(side_effect=ConfigFileNotFoundError("missing.conf")),
|
||||||
):
|
):
|
||||||
resp = await file_config_client.get(
|
resp = await file_config_client.get(
|
||||||
@@ -748,7 +748,7 @@ class TestGetParsedJailFile:
|
|||||||
self, file_config_client: AsyncClient
|
self, file_config_client: AsyncClient
|
||||||
) -> None:
|
) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.file_config.raw_config_io_service.get_parsed_jail_file",
|
"app.routers.file_config.file_config_service.get_parsed_jail_file",
|
||||||
AsyncMock(side_effect=ConfigDirError("no dir")),
|
AsyncMock(side_effect=ConfigDirError("no dir")),
|
||||||
):
|
):
|
||||||
resp = await file_config_client.get(
|
resp = await file_config_client.get(
|
||||||
@@ -766,7 +766,7 @@ class TestGetParsedJailFile:
|
|||||||
class TestUpdateParsedJailFile:
|
class TestUpdateParsedJailFile:
|
||||||
async def test_204_on_success(self, file_config_client: AsyncClient) -> None:
|
async def test_204_on_success(self, file_config_client: AsyncClient) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.file_config.raw_config_io_service.update_parsed_jail_file",
|
"app.routers.file_config.file_config_service.update_parsed_jail_file",
|
||||||
AsyncMock(return_value=None),
|
AsyncMock(return_value=None),
|
||||||
):
|
):
|
||||||
resp = await file_config_client.put(
|
resp = await file_config_client.put(
|
||||||
@@ -778,7 +778,7 @@ class TestUpdateParsedJailFile:
|
|||||||
|
|
||||||
async def test_404_not_found(self, file_config_client: AsyncClient) -> None:
|
async def test_404_not_found(self, file_config_client: AsyncClient) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.file_config.raw_config_io_service.update_parsed_jail_file",
|
"app.routers.file_config.file_config_service.update_parsed_jail_file",
|
||||||
AsyncMock(side_effect=ConfigFileNotFoundError("missing.conf")),
|
AsyncMock(side_effect=ConfigFileNotFoundError("missing.conf")),
|
||||||
):
|
):
|
||||||
resp = await file_config_client.put(
|
resp = await file_config_client.put(
|
||||||
@@ -790,7 +790,7 @@ class TestUpdateParsedJailFile:
|
|||||||
|
|
||||||
async def test_400_write_error(self, file_config_client: AsyncClient) -> None:
|
async def test_400_write_error(self, file_config_client: AsyncClient) -> None:
|
||||||
with patch(
|
with patch(
|
||||||
"app.routers.file_config.raw_config_io_service.update_parsed_jail_file",
|
"app.routers.file_config.file_config_service.update_parsed_jail_file",
|
||||||
AsyncMock(side_effect=ConfigFileWriteError("disk full")),
|
AsyncMock(side_effect=ConfigFileWriteError("disk full")),
|
||||||
):
|
):
|
||||||
resp = await file_config_client.put(
|
resp = await file_config_client.put(
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from httpx import ASGITransport, AsyncClient
|
|||||||
from app.config import Settings
|
from app.config import Settings
|
||||||
from app.db import init_db
|
from app.db import init_db
|
||||||
from app.main import create_app
|
from app.main import create_app
|
||||||
from app.models.geo import GeoInfo
|
from app.services.geo_service import GeoInfo
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Fixtures
|
# Fixtures
|
||||||
@@ -70,7 +70,7 @@ class TestGeoLookup:
|
|||||||
async def test_200_with_geo_info(self, geo_client: AsyncClient) -> None:
|
async def test_200_with_geo_info(self, geo_client: AsyncClient) -> None:
|
||||||
"""GET /api/geo/lookup/{ip} returns 200 with enriched result."""
|
"""GET /api/geo/lookup/{ip} returns 200 with enriched result."""
|
||||||
geo = GeoInfo(country_code="DE", country_name="Germany", asn="12345", org="Acme")
|
geo = GeoInfo(country_code="DE", country_name="Germany", asn="12345", org="Acme")
|
||||||
result: dict[str, object] = {
|
result = {
|
||||||
"ip": "1.2.3.4",
|
"ip": "1.2.3.4",
|
||||||
"currently_banned_in": ["sshd"],
|
"currently_banned_in": ["sshd"],
|
||||||
"geo": geo,
|
"geo": geo,
|
||||||
@@ -92,7 +92,7 @@ class TestGeoLookup:
|
|||||||
|
|
||||||
async def test_200_when_not_banned(self, geo_client: AsyncClient) -> None:
|
async def test_200_when_not_banned(self, geo_client: AsyncClient) -> None:
|
||||||
"""GET /api/geo/lookup/{ip} returns empty list when IP is not banned anywhere."""
|
"""GET /api/geo/lookup/{ip} returns empty list when IP is not banned anywhere."""
|
||||||
result: dict[str, object] = {
|
result = {
|
||||||
"ip": "8.8.8.8",
|
"ip": "8.8.8.8",
|
||||||
"currently_banned_in": [],
|
"currently_banned_in": [],
|
||||||
"geo": GeoInfo(country_code="US", country_name="United States", asn=None, org=None),
|
"geo": GeoInfo(country_code="US", country_name="United States", asn=None, org=None),
|
||||||
@@ -108,7 +108,7 @@ class TestGeoLookup:
|
|||||||
|
|
||||||
async def test_200_with_no_geo(self, geo_client: AsyncClient) -> None:
|
async def test_200_with_no_geo(self, geo_client: AsyncClient) -> None:
|
||||||
"""GET /api/geo/lookup/{ip} returns null geo when enricher fails."""
|
"""GET /api/geo/lookup/{ip} returns null geo when enricher fails."""
|
||||||
result: dict[str, object] = {
|
result = {
|
||||||
"ip": "1.2.3.4",
|
"ip": "1.2.3.4",
|
||||||
"currently_banned_in": [],
|
"currently_banned_in": [],
|
||||||
"geo": None,
|
"geo": None,
|
||||||
@@ -144,7 +144,7 @@ class TestGeoLookup:
|
|||||||
|
|
||||||
async def test_ipv6_address(self, geo_client: AsyncClient) -> None:
|
async def test_ipv6_address(self, geo_client: AsyncClient) -> None:
|
||||||
"""GET /api/geo/lookup/{ip} handles IPv6 addresses."""
|
"""GET /api/geo/lookup/{ip} handles IPv6 addresses."""
|
||||||
result: dict[str, object] = {
|
result = {
|
||||||
"ip": "2001:db8::1",
|
"ip": "2001:db8::1",
|
||||||
"currently_banned_in": [],
|
"currently_banned_in": [],
|
||||||
"geo": None,
|
"geo": None,
|
||||||
|
|||||||
@@ -213,6 +213,18 @@ class TestHistoryList:
|
|||||||
_args, kwargs = mock_fn.call_args
|
_args, kwargs = mock_fn.call_args
|
||||||
assert kwargs.get("range_") == "7d"
|
assert kwargs.get("range_") == "7d"
|
||||||
|
|
||||||
|
async def test_forwards_origin_filter(self, history_client: AsyncClient) -> None:
|
||||||
|
"""The ``origin`` query parameter is forwarded to the service."""
|
||||||
|
mock_fn = AsyncMock(return_value=_make_history_list(n=0))
|
||||||
|
with patch(
|
||||||
|
"app.routers.history.history_service.list_history",
|
||||||
|
new=mock_fn,
|
||||||
|
):
|
||||||
|
await history_client.get("/api/history?origin=blocklist")
|
||||||
|
|
||||||
|
_args, kwargs = mock_fn.call_args
|
||||||
|
assert kwargs.get("origin") == "blocklist"
|
||||||
|
|
||||||
async def test_empty_result(self, history_client: AsyncClient) -> None:
|
async def test_empty_result(self, history_client: AsyncClient) -> None:
|
||||||
"""An empty history returns items=[] and total=0."""
|
"""An empty history returns items=[] and total=0."""
|
||||||
with patch(
|
with patch(
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ from httpx import ASGITransport, AsyncClient
|
|||||||
from app.config import Settings
|
from app.config import Settings
|
||||||
from app.db import init_db
|
from app.db import init_db
|
||||||
from app.main import create_app
|
from app.main import create_app
|
||||||
from app.models.ban import JailBannedIpsResponse
|
|
||||||
from app.models.jail import Jail, JailDetailResponse, JailListResponse, JailStatus, JailSummary
|
from app.models.jail import Jail, JailDetailResponse, JailListResponse, JailStatus, JailSummary
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -802,17 +801,17 @@ class TestGetJailBannedIps:
|
|||||||
def _mock_response(
|
def _mock_response(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
items: list[dict[str, str | None]] | None = None,
|
items: list[dict] | None = None,
|
||||||
total: int = 2,
|
total: int = 2,
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
page_size: int = 25,
|
page_size: int = 25,
|
||||||
) -> JailBannedIpsResponse:
|
) -> "JailBannedIpsResponse": # type: ignore[name-defined]
|
||||||
from app.models.ban import ActiveBan, JailBannedIpsResponse
|
from app.models.ban import ActiveBan, JailBannedIpsResponse
|
||||||
|
|
||||||
ban_items = (
|
ban_items = (
|
||||||
[
|
[
|
||||||
ActiveBan(
|
ActiveBan(
|
||||||
ip=item.get("ip") or "1.2.3.4",
|
ip=item.get("ip", "1.2.3.4"),
|
||||||
jail="sshd",
|
jail="sshd",
|
||||||
banned_at=item.get("banned_at", "2025-01-01T10:00:00+00:00"),
|
banned_at=item.get("banned_at", "2025-01-01T10:00:00+00:00"),
|
||||||
expires_at=item.get("expires_at", "2025-01-01T10:10:00+00:00"),
|
expires_at=item.get("expires_at", "2025-01-01T10:10:00+00:00"),
|
||||||
|
|||||||
@@ -247,9 +247,9 @@ class TestSetupCompleteCaching:
|
|||||||
assert not getattr(app.state, "_setup_complete_cached", False)
|
assert not getattr(app.state, "_setup_complete_cached", False)
|
||||||
|
|
||||||
# First non-exempt request — middleware queries DB and sets the flag.
|
# First non-exempt request — middleware queries DB and sets the flag.
|
||||||
await client.post("/api/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]})
|
await client.post("/api/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]}) # type: ignore[call-overload]
|
||||||
|
|
||||||
assert app.state._setup_complete_cached is True
|
assert app.state._setup_complete_cached is True # type: ignore[attr-defined]
|
||||||
|
|
||||||
async def test_cached_path_skips_is_setup_complete(
|
async def test_cached_path_skips_is_setup_complete(
|
||||||
self,
|
self,
|
||||||
@@ -267,12 +267,12 @@ class TestSetupCompleteCaching:
|
|||||||
|
|
||||||
# Do setup and warm the cache.
|
# Do setup and warm the cache.
|
||||||
await client.post("/api/setup", json=_SETUP_PAYLOAD)
|
await client.post("/api/setup", json=_SETUP_PAYLOAD)
|
||||||
await client.post("/api/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]})
|
await client.post("/api/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]}) # type: ignore[call-overload]
|
||||||
assert app.state._setup_complete_cached is True
|
assert app.state._setup_complete_cached is True # type: ignore[attr-defined]
|
||||||
|
|
||||||
call_count = 0
|
call_count = 0
|
||||||
|
|
||||||
async def _counting(db: aiosqlite.Connection) -> bool:
|
async def _counting(db): # type: ignore[no-untyped-def]
|
||||||
nonlocal call_count
|
nonlocal call_count
|
||||||
call_count += 1
|
call_count += 1
|
||||||
return True
|
return True
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ class TestCheckPasswordAsync:
|
|||||||
auth_service._check_password("secret", hashed), # noqa: SLF001
|
auth_service._check_password("secret", hashed), # noqa: SLF001
|
||||||
auth_service._check_password("wrong", hashed), # noqa: SLF001
|
auth_service._check_password("wrong", hashed), # noqa: SLF001
|
||||||
)
|
)
|
||||||
assert tuple(results) == (True, False)
|
assert results == [True, False]
|
||||||
|
|
||||||
|
|
||||||
class TestLogin:
|
class TestLogin:
|
||||||
|
|||||||
@@ -64,7 +64,7 @@ async def _create_f2b_db(path: str, rows: list[dict[str, Any]]) -> None:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def f2b_db_path(tmp_path: Path) -> str:
|
async def f2b_db_path(tmp_path: Path) -> str: # type: ignore[misc]
|
||||||
"""Return the path to a test fail2ban SQLite database with several bans."""
|
"""Return the path to a test fail2ban SQLite database with several bans."""
|
||||||
path = str(tmp_path / "fail2ban_test.sqlite3")
|
path = str(tmp_path / "fail2ban_test.sqlite3")
|
||||||
await _create_f2b_db(
|
await _create_f2b_db(
|
||||||
@@ -103,7 +103,7 @@ async def f2b_db_path(tmp_path: Path) -> str:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def mixed_origin_db_path(tmp_path: Path) -> str:
|
async def mixed_origin_db_path(tmp_path: Path) -> str: # type: ignore[misc]
|
||||||
"""Return a database with bans from both blocklist-import and organic jails."""
|
"""Return a database with bans from both blocklist-import and organic jails."""
|
||||||
path = str(tmp_path / "fail2ban_mixed_origin.sqlite3")
|
path = str(tmp_path / "fail2ban_mixed_origin.sqlite3")
|
||||||
await _create_f2b_db(
|
await _create_f2b_db(
|
||||||
@@ -136,7 +136,7 @@ async def mixed_origin_db_path(tmp_path: Path) -> str:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def empty_f2b_db_path(tmp_path: Path) -> str:
|
async def empty_f2b_db_path(tmp_path: Path) -> str: # type: ignore[misc]
|
||||||
"""Return the path to a fail2ban SQLite database with no ban records."""
|
"""Return the path to a fail2ban SQLite database with no ban records."""
|
||||||
path = str(tmp_path / "fail2ban_empty.sqlite3")
|
path = str(tmp_path / "fail2ban_empty.sqlite3")
|
||||||
await _create_f2b_db(path, [])
|
await _create_f2b_db(path, [])
|
||||||
@@ -154,7 +154,7 @@ class TestListBansHappyPath:
|
|||||||
async def test_returns_bans_in_range(self, f2b_db_path: str) -> None:
|
async def test_returns_bans_in_range(self, f2b_db_path: str) -> None:
|
||||||
"""Only bans within the selected range are returned."""
|
"""Only bans within the selected range are returned."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.ban_service.get_fail2ban_db_path",
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=f2b_db_path),
|
new=AsyncMock(return_value=f2b_db_path),
|
||||||
):
|
):
|
||||||
result = await ban_service.list_bans("/fake/sock", "24h")
|
result = await ban_service.list_bans("/fake/sock", "24h")
|
||||||
@@ -166,7 +166,7 @@ class TestListBansHappyPath:
|
|||||||
async def test_results_sorted_newest_first(self, f2b_db_path: str) -> None:
|
async def test_results_sorted_newest_first(self, f2b_db_path: str) -> None:
|
||||||
"""Items are ordered by ``banned_at`` descending (newest first)."""
|
"""Items are ordered by ``banned_at`` descending (newest first)."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.ban_service.get_fail2ban_db_path",
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=f2b_db_path),
|
new=AsyncMock(return_value=f2b_db_path),
|
||||||
):
|
):
|
||||||
result = await ban_service.list_bans("/fake/sock", "24h")
|
result = await ban_service.list_bans("/fake/sock", "24h")
|
||||||
@@ -177,7 +177,7 @@ class TestListBansHappyPath:
|
|||||||
async def test_ban_fields_present(self, f2b_db_path: str) -> None:
|
async def test_ban_fields_present(self, f2b_db_path: str) -> None:
|
||||||
"""Each item contains ip, jail, banned_at, ban_count."""
|
"""Each item contains ip, jail, banned_at, ban_count."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.ban_service.get_fail2ban_db_path",
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=f2b_db_path),
|
new=AsyncMock(return_value=f2b_db_path),
|
||||||
):
|
):
|
||||||
result = await ban_service.list_bans("/fake/sock", "24h")
|
result = await ban_service.list_bans("/fake/sock", "24h")
|
||||||
@@ -191,7 +191,7 @@ class TestListBansHappyPath:
|
|||||||
async def test_service_extracted_from_first_match(self, f2b_db_path: str) -> None:
|
async def test_service_extracted_from_first_match(self, f2b_db_path: str) -> None:
|
||||||
"""``service`` field is the first element of ``data.matches``."""
|
"""``service`` field is the first element of ``data.matches``."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.ban_service.get_fail2ban_db_path",
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=f2b_db_path),
|
new=AsyncMock(return_value=f2b_db_path),
|
||||||
):
|
):
|
||||||
result = await ban_service.list_bans("/fake/sock", "24h")
|
result = await ban_service.list_bans("/fake/sock", "24h")
|
||||||
@@ -203,7 +203,7 @@ class TestListBansHappyPath:
|
|||||||
async def test_service_is_none_when_no_matches(self, f2b_db_path: str) -> None:
|
async def test_service_is_none_when_no_matches(self, f2b_db_path: str) -> None:
|
||||||
"""``service`` is ``None`` when the ban has no stored matches."""
|
"""``service`` is ``None`` when the ban has no stored matches."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.ban_service.get_fail2ban_db_path",
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=f2b_db_path),
|
new=AsyncMock(return_value=f2b_db_path),
|
||||||
):
|
):
|
||||||
# Use 7d to include the older ban with no matches.
|
# Use 7d to include the older ban with no matches.
|
||||||
@@ -215,7 +215,7 @@ class TestListBansHappyPath:
|
|||||||
async def test_empty_db_returns_zero(self, empty_f2b_db_path: str) -> None:
|
async def test_empty_db_returns_zero(self, empty_f2b_db_path: str) -> None:
|
||||||
"""When no bans exist the result has total=0 and no items."""
|
"""When no bans exist the result has total=0 and no items."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.ban_service.get_fail2ban_db_path",
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=empty_f2b_db_path),
|
new=AsyncMock(return_value=empty_f2b_db_path),
|
||||||
):
|
):
|
||||||
result = await ban_service.list_bans("/fake/sock", "24h")
|
result = await ban_service.list_bans("/fake/sock", "24h")
|
||||||
@@ -226,7 +226,7 @@ class TestListBansHappyPath:
|
|||||||
async def test_365d_range_includes_old_bans(self, f2b_db_path: str) -> None:
|
async def test_365d_range_includes_old_bans(self, f2b_db_path: str) -> None:
|
||||||
"""The ``365d`` range includes bans that are 2 days old."""
|
"""The ``365d`` range includes bans that are 2 days old."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.ban_service.get_fail2ban_db_path",
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=f2b_db_path),
|
new=AsyncMock(return_value=f2b_db_path),
|
||||||
):
|
):
|
||||||
result = await ban_service.list_bans("/fake/sock", "365d")
|
result = await ban_service.list_bans("/fake/sock", "365d")
|
||||||
@@ -246,7 +246,7 @@ class TestListBansGeoEnrichment:
|
|||||||
self, f2b_db_path: str
|
self, f2b_db_path: str
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Geo fields are populated when an enricher returns data."""
|
"""Geo fields are populated when an enricher returns data."""
|
||||||
from app.models.geo import GeoInfo
|
from app.services.geo_service import GeoInfo
|
||||||
|
|
||||||
async def fake_enricher(ip: str) -> GeoInfo:
|
async def fake_enricher(ip: str) -> GeoInfo:
|
||||||
return GeoInfo(
|
return GeoInfo(
|
||||||
@@ -257,7 +257,7 @@ class TestListBansGeoEnrichment:
|
|||||||
)
|
)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.ban_service.get_fail2ban_db_path",
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=f2b_db_path),
|
new=AsyncMock(return_value=f2b_db_path),
|
||||||
):
|
):
|
||||||
result = await ban_service.list_bans(
|
result = await ban_service.list_bans(
|
||||||
@@ -278,7 +278,7 @@ class TestListBansGeoEnrichment:
|
|||||||
raise RuntimeError("geo service down")
|
raise RuntimeError("geo service down")
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.ban_service.get_fail2ban_db_path",
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=f2b_db_path),
|
new=AsyncMock(return_value=f2b_db_path),
|
||||||
):
|
):
|
||||||
result = await ban_service.list_bans(
|
result = await ban_service.list_bans(
|
||||||
@@ -304,27 +304,25 @@ class TestListBansBatchGeoEnrichment:
|
|||||||
"""Geo fields are populated via lookup_batch when http_session is given."""
|
"""Geo fields are populated via lookup_batch when http_session is given."""
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
from app.models.geo import GeoInfo
|
from app.services.geo_service import GeoInfo
|
||||||
|
|
||||||
fake_session = MagicMock()
|
fake_session = MagicMock()
|
||||||
fake_geo_map = {
|
fake_geo_map = {
|
||||||
"1.2.3.4": GeoInfo(country_code="DE", country_name="Germany", asn="AS3320", org="Deutsche Telekom"),
|
"1.2.3.4": GeoInfo(country_code="DE", country_name="Germany", asn="AS3320", org="Deutsche Telekom"),
|
||||||
"5.6.7.8": GeoInfo(country_code="US", country_name="United States", asn="AS15169", org="Google"),
|
"5.6.7.8": GeoInfo(country_code="US", country_name="United States", asn="AS15169", org="Google"),
|
||||||
}
|
}
|
||||||
fake_geo_batch = AsyncMock(return_value=fake_geo_map)
|
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.ban_service.get_fail2ban_db_path",
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=f2b_db_path),
|
new=AsyncMock(return_value=f2b_db_path),
|
||||||
|
), patch(
|
||||||
|
"app.services.geo_service.lookup_batch",
|
||||||
|
new=AsyncMock(return_value=fake_geo_map),
|
||||||
):
|
):
|
||||||
result = await ban_service.list_bans(
|
result = await ban_service.list_bans(
|
||||||
"/fake/sock",
|
"/fake/sock", "24h", http_session=fake_session
|
||||||
"24h",
|
|
||||||
http_session=fake_session,
|
|
||||||
geo_batch_lookup=fake_geo_batch,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
fake_geo_batch.assert_awaited_once_with(["1.2.3.4", "5.6.7.8"], fake_session, db=None)
|
|
||||||
assert result.total == 2
|
assert result.total == 2
|
||||||
de_item = next(i for i in result.items if i.ip == "1.2.3.4")
|
de_item = next(i for i in result.items if i.ip == "1.2.3.4")
|
||||||
us_item = next(i for i in result.items if i.ip == "5.6.7.8")
|
us_item = next(i for i in result.items if i.ip == "5.6.7.8")
|
||||||
@@ -341,17 +339,15 @@ class TestListBansBatchGeoEnrichment:
|
|||||||
|
|
||||||
fake_session = MagicMock()
|
fake_session = MagicMock()
|
||||||
|
|
||||||
failing_geo_batch = AsyncMock(side_effect=RuntimeError("batch geo down"))
|
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.ban_service.get_fail2ban_db_path",
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=f2b_db_path),
|
new=AsyncMock(return_value=f2b_db_path),
|
||||||
|
), patch(
|
||||||
|
"app.services.geo_service.lookup_batch",
|
||||||
|
new=AsyncMock(side_effect=RuntimeError("batch geo down")),
|
||||||
):
|
):
|
||||||
result = await ban_service.list_bans(
|
result = await ban_service.list_bans(
|
||||||
"/fake/sock",
|
"/fake/sock", "24h", http_session=fake_session
|
||||||
"24h",
|
|
||||||
http_session=fake_session,
|
|
||||||
geo_batch_lookup=failing_geo_batch,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result.total == 2
|
assert result.total == 2
|
||||||
@@ -364,27 +360,28 @@ class TestListBansBatchGeoEnrichment:
|
|||||||
"""When both http_session and geo_enricher are provided, batch wins."""
|
"""When both http_session and geo_enricher are provided, batch wins."""
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
from app.models.geo import GeoInfo
|
from app.services.geo_service import GeoInfo
|
||||||
|
|
||||||
fake_session = MagicMock()
|
fake_session = MagicMock()
|
||||||
fake_geo_map = {
|
fake_geo_map = {
|
||||||
"1.2.3.4": GeoInfo(country_code="DE", country_name="Germany", asn=None, org=None),
|
"1.2.3.4": GeoInfo(country_code="DE", country_name="Germany", asn=None, org=None),
|
||||||
"5.6.7.8": GeoInfo(country_code="DE", country_name="Germany", asn=None, org=None),
|
"5.6.7.8": GeoInfo(country_code="DE", country_name="Germany", asn=None, org=None),
|
||||||
}
|
}
|
||||||
fake_geo_batch = AsyncMock(return_value=fake_geo_map)
|
|
||||||
|
|
||||||
async def enricher_should_not_be_called(ip: str) -> GeoInfo:
|
async def enricher_should_not_be_called(ip: str) -> GeoInfo:
|
||||||
raise AssertionError(f"geo_enricher was called for {ip!r} — should not happen")
|
raise AssertionError(f"geo_enricher was called for {ip!r} — should not happen")
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.ban_service.get_fail2ban_db_path",
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=f2b_db_path),
|
new=AsyncMock(return_value=f2b_db_path),
|
||||||
|
), patch(
|
||||||
|
"app.services.geo_service.lookup_batch",
|
||||||
|
new=AsyncMock(return_value=fake_geo_map),
|
||||||
):
|
):
|
||||||
result = await ban_service.list_bans(
|
result = await ban_service.list_bans(
|
||||||
"/fake/sock",
|
"/fake/sock",
|
||||||
"24h",
|
"24h",
|
||||||
http_session=fake_session,
|
http_session=fake_session,
|
||||||
geo_batch_lookup=fake_geo_batch,
|
|
||||||
geo_enricher=enricher_should_not_be_called,
|
geo_enricher=enricher_should_not_be_called,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -404,7 +401,7 @@ class TestListBansPagination:
|
|||||||
async def test_page_size_respected(self, f2b_db_path: str) -> None:
|
async def test_page_size_respected(self, f2b_db_path: str) -> None:
|
||||||
"""``page_size=1`` returns at most one item."""
|
"""``page_size=1`` returns at most one item."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.ban_service.get_fail2ban_db_path",
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=f2b_db_path),
|
new=AsyncMock(return_value=f2b_db_path),
|
||||||
):
|
):
|
||||||
result = await ban_service.list_bans("/fake/sock", "7d", page_size=1)
|
result = await ban_service.list_bans("/fake/sock", "7d", page_size=1)
|
||||||
@@ -415,7 +412,7 @@ class TestListBansPagination:
|
|||||||
async def test_page_2_returns_remaining_items(self, f2b_db_path: str) -> None:
|
async def test_page_2_returns_remaining_items(self, f2b_db_path: str) -> None:
|
||||||
"""The second page returns items not on the first page."""
|
"""The second page returns items not on the first page."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.ban_service.get_fail2ban_db_path",
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=f2b_db_path),
|
new=AsyncMock(return_value=f2b_db_path),
|
||||||
):
|
):
|
||||||
page1 = await ban_service.list_bans("/fake/sock", "7d", page=1, page_size=1)
|
page1 = await ban_service.list_bans("/fake/sock", "7d", page=1, page_size=1)
|
||||||
@@ -429,7 +426,7 @@ class TestListBansPagination:
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""``total`` reports all matching records regardless of pagination."""
|
"""``total`` reports all matching records regardless of pagination."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.ban_service.get_fail2ban_db_path",
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=f2b_db_path),
|
new=AsyncMock(return_value=f2b_db_path),
|
||||||
):
|
):
|
||||||
result = await ban_service.list_bans("/fake/sock", "7d", page_size=1)
|
result = await ban_service.list_bans("/fake/sock", "7d", page_size=1)
|
||||||
@@ -450,7 +447,7 @@ class TestBanOriginDerivation:
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Bans from ``blocklist-import`` jail carry ``origin == "blocklist"``."""
|
"""Bans from ``blocklist-import`` jail carry ``origin == "blocklist"``."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.ban_service.get_fail2ban_db_path",
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||||
):
|
):
|
||||||
result = await ban_service.list_bans("/fake/sock", "24h")
|
result = await ban_service.list_bans("/fake/sock", "24h")
|
||||||
@@ -464,7 +461,7 @@ class TestBanOriginDerivation:
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Bans from organic jails (sshd, nginx, …) carry ``origin == "selfblock"``."""
|
"""Bans from organic jails (sshd, nginx, …) carry ``origin == "selfblock"``."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.ban_service.get_fail2ban_db_path",
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||||
):
|
):
|
||||||
result = await ban_service.list_bans("/fake/sock", "24h")
|
result = await ban_service.list_bans("/fake/sock", "24h")
|
||||||
@@ -479,7 +476,7 @@ class TestBanOriginDerivation:
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Every returned item has an ``origin`` field with a valid value."""
|
"""Every returned item has an ``origin`` field with a valid value."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.ban_service.get_fail2ban_db_path",
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||||
):
|
):
|
||||||
result = await ban_service.list_bans("/fake/sock", "24h")
|
result = await ban_service.list_bans("/fake/sock", "24h")
|
||||||
@@ -492,7 +489,7 @@ class TestBanOriginDerivation:
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""``bans_by_country`` also derives origin correctly for blocklist bans."""
|
"""``bans_by_country`` also derives origin correctly for blocklist bans."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.ban_service.get_fail2ban_db_path",
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||||
):
|
):
|
||||||
result = await ban_service.bans_by_country("/fake/sock", "24h")
|
result = await ban_service.bans_by_country("/fake/sock", "24h")
|
||||||
@@ -506,7 +503,7 @@ class TestBanOriginDerivation:
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""``bans_by_country`` derives origin correctly for organic jails."""
|
"""``bans_by_country`` derives origin correctly for organic jails."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.ban_service.get_fail2ban_db_path",
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||||
):
|
):
|
||||||
result = await ban_service.bans_by_country("/fake/sock", "24h")
|
result = await ban_service.bans_by_country("/fake/sock", "24h")
|
||||||
@@ -530,7 +527,7 @@ class TestOriginFilter:
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""``origin='blocklist'`` returns only blocklist-import jail bans."""
|
"""``origin='blocklist'`` returns only blocklist-import jail bans."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.ban_service.get_fail2ban_db_path",
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||||
):
|
):
|
||||||
result = await ban_service.list_bans(
|
result = await ban_service.list_bans(
|
||||||
@@ -547,7 +544,7 @@ class TestOriginFilter:
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""``origin='selfblock'`` excludes the blocklist-import jail."""
|
"""``origin='selfblock'`` excludes the blocklist-import jail."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.ban_service.get_fail2ban_db_path",
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||||
):
|
):
|
||||||
result = await ban_service.list_bans(
|
result = await ban_service.list_bans(
|
||||||
@@ -565,7 +562,7 @@ class TestOriginFilter:
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""``origin=None`` applies no jail restriction — all bans returned."""
|
"""``origin=None`` applies no jail restriction — all bans returned."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.ban_service.get_fail2ban_db_path",
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||||
):
|
):
|
||||||
result = await ban_service.list_bans("/fake/sock", "24h", origin=None)
|
result = await ban_service.list_bans("/fake/sock", "24h", origin=None)
|
||||||
@@ -577,7 +574,7 @@ class TestOriginFilter:
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""``bans_by_country`` with ``origin='blocklist'`` counts only blocklist bans."""
|
"""``bans_by_country`` with ``origin='blocklist'`` counts only blocklist bans."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.ban_service.get_fail2ban_db_path",
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||||
):
|
):
|
||||||
result = await ban_service.bans_by_country(
|
result = await ban_service.bans_by_country(
|
||||||
@@ -592,7 +589,7 @@ class TestOriginFilter:
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""``bans_by_country`` with ``origin='selfblock'`` excludes blocklist jails."""
|
"""``bans_by_country`` with ``origin='selfblock'`` excludes blocklist jails."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.ban_service.get_fail2ban_db_path",
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||||
):
|
):
|
||||||
result = await ban_service.bans_by_country(
|
result = await ban_service.bans_by_country(
|
||||||
@@ -607,7 +604,7 @@ class TestOriginFilter:
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""``bans_by_country`` with ``origin=None`` returns all bans."""
|
"""``bans_by_country`` with ``origin=None`` returns all bans."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.ban_service.get_fail2ban_db_path",
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||||
):
|
):
|
||||||
result = await ban_service.bans_by_country(
|
result = await ban_service.bans_by_country(
|
||||||
@@ -635,19 +632,19 @@ class TestBansbyCountryBackground:
|
|||||||
from app.services import geo_service
|
from app.services import geo_service
|
||||||
|
|
||||||
# Pre-populate the cache for all three IPs in the fixture.
|
# Pre-populate the cache for all three IPs in the fixture.
|
||||||
geo_service._cache["10.0.0.1"] = geo_service.GeoInfo(
|
geo_service._cache["10.0.0.1"] = geo_service.GeoInfo( # type: ignore[attr-defined]
|
||||||
country_code="DE", country_name="Germany", asn=None, org=None
|
country_code="DE", country_name="Germany", asn=None, org=None
|
||||||
)
|
)
|
||||||
geo_service._cache["10.0.0.2"] = geo_service.GeoInfo(
|
geo_service._cache["10.0.0.2"] = geo_service.GeoInfo( # type: ignore[attr-defined]
|
||||||
country_code="US", country_name="United States", asn=None, org=None
|
country_code="US", country_name="United States", asn=None, org=None
|
||||||
)
|
)
|
||||||
geo_service._cache["10.0.0.3"] = geo_service.GeoInfo(
|
geo_service._cache["10.0.0.3"] = geo_service.GeoInfo( # type: ignore[attr-defined]
|
||||||
country_code="JP", country_name="Japan", asn=None, org=None
|
country_code="JP", country_name="Japan", asn=None, org=None
|
||||||
)
|
)
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
"app.services.ban_service.get_fail2ban_db_path",
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
@@ -655,13 +652,8 @@ class TestBansbyCountryBackground:
|
|||||||
) as mock_create_task,
|
) as mock_create_task,
|
||||||
):
|
):
|
||||||
mock_session = AsyncMock()
|
mock_session = AsyncMock()
|
||||||
mock_batch = AsyncMock(return_value={})
|
|
||||||
result = await ban_service.bans_by_country(
|
result = await ban_service.bans_by_country(
|
||||||
"/fake/sock",
|
"/fake/sock", "24h", http_session=mock_session
|
||||||
"24h",
|
|
||||||
http_session=mock_session,
|
|
||||||
geo_cache_lookup=geo_service.lookup_cached_only,
|
|
||||||
geo_batch_lookup=mock_batch,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# All countries resolved from cache — no background task needed.
|
# All countries resolved from cache — no background task needed.
|
||||||
@@ -682,7 +674,7 @@ class TestBansbyCountryBackground:
|
|||||||
|
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
"app.services.ban_service.get_fail2ban_db_path",
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
@@ -690,13 +682,8 @@ class TestBansbyCountryBackground:
|
|||||||
) as mock_create_task,
|
) as mock_create_task,
|
||||||
):
|
):
|
||||||
mock_session = AsyncMock()
|
mock_session = AsyncMock()
|
||||||
mock_batch = AsyncMock(return_value={})
|
|
||||||
result = await ban_service.bans_by_country(
|
result = await ban_service.bans_by_country(
|
||||||
"/fake/sock",
|
"/fake/sock", "24h", http_session=mock_session
|
||||||
"24h",
|
|
||||||
http_session=mock_session,
|
|
||||||
geo_cache_lookup=geo_service.lookup_cached_only,
|
|
||||||
geo_batch_lookup=mock_batch,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Background task must have been scheduled for uncached IPs.
|
# Background task must have been scheduled for uncached IPs.
|
||||||
@@ -714,7 +701,7 @@ class TestBansbyCountryBackground:
|
|||||||
|
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
"app.services.ban_service.get_fail2ban_db_path",
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
@@ -740,7 +727,7 @@ class TestBanTrend:
|
|||||||
async def test_24h_returns_24_buckets(self, empty_f2b_db_path: str) -> None:
|
async def test_24h_returns_24_buckets(self, empty_f2b_db_path: str) -> None:
|
||||||
"""``range_='24h'`` always yields exactly 24 buckets."""
|
"""``range_='24h'`` always yields exactly 24 buckets."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.ban_service.get_fail2ban_db_path",
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=empty_f2b_db_path),
|
new=AsyncMock(return_value=empty_f2b_db_path),
|
||||||
):
|
):
|
||||||
result = await ban_service.ban_trend("/fake/sock", "24h")
|
result = await ban_service.ban_trend("/fake/sock", "24h")
|
||||||
@@ -751,7 +738,7 @@ class TestBanTrend:
|
|||||||
async def test_7d_returns_28_buckets(self, empty_f2b_db_path: str) -> None:
|
async def test_7d_returns_28_buckets(self, empty_f2b_db_path: str) -> None:
|
||||||
"""``range_='7d'`` yields 28 six-hour buckets."""
|
"""``range_='7d'`` yields 28 six-hour buckets."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.ban_service.get_fail2ban_db_path",
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=empty_f2b_db_path),
|
new=AsyncMock(return_value=empty_f2b_db_path),
|
||||||
):
|
):
|
||||||
result = await ban_service.ban_trend("/fake/sock", "7d")
|
result = await ban_service.ban_trend("/fake/sock", "7d")
|
||||||
@@ -762,7 +749,7 @@ class TestBanTrend:
|
|||||||
async def test_30d_returns_30_buckets(self, empty_f2b_db_path: str) -> None:
|
async def test_30d_returns_30_buckets(self, empty_f2b_db_path: str) -> None:
|
||||||
"""``range_='30d'`` yields 30 daily buckets."""
|
"""``range_='30d'`` yields 30 daily buckets."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.ban_service.get_fail2ban_db_path",
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=empty_f2b_db_path),
|
new=AsyncMock(return_value=empty_f2b_db_path),
|
||||||
):
|
):
|
||||||
result = await ban_service.ban_trend("/fake/sock", "30d")
|
result = await ban_service.ban_trend("/fake/sock", "30d")
|
||||||
@@ -773,7 +760,7 @@ class TestBanTrend:
|
|||||||
async def test_365d_bucket_size_label(self, empty_f2b_db_path: str) -> None:
|
async def test_365d_bucket_size_label(self, empty_f2b_db_path: str) -> None:
|
||||||
"""``range_='365d'`` uses '7d' as the bucket size label."""
|
"""``range_='365d'`` uses '7d' as the bucket size label."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.ban_service.get_fail2ban_db_path",
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=empty_f2b_db_path),
|
new=AsyncMock(return_value=empty_f2b_db_path),
|
||||||
):
|
):
|
||||||
result = await ban_service.ban_trend("/fake/sock", "365d")
|
result = await ban_service.ban_trend("/fake/sock", "365d")
|
||||||
@@ -784,7 +771,7 @@ class TestBanTrend:
|
|||||||
async def test_empty_db_all_buckets_zero(self, empty_f2b_db_path: str) -> None:
|
async def test_empty_db_all_buckets_zero(self, empty_f2b_db_path: str) -> None:
|
||||||
"""All bucket counts are zero when the database has no bans."""
|
"""All bucket counts are zero when the database has no bans."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.ban_service.get_fail2ban_db_path",
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=empty_f2b_db_path),
|
new=AsyncMock(return_value=empty_f2b_db_path),
|
||||||
):
|
):
|
||||||
result = await ban_service.ban_trend("/fake/sock", "24h")
|
result = await ban_service.ban_trend("/fake/sock", "24h")
|
||||||
@@ -794,7 +781,7 @@ class TestBanTrend:
|
|||||||
async def test_buckets_are_time_ordered(self, empty_f2b_db_path: str) -> None:
|
async def test_buckets_are_time_ordered(self, empty_f2b_db_path: str) -> None:
|
||||||
"""Buckets are ordered chronologically (ascending timestamps)."""
|
"""Buckets are ordered chronologically (ascending timestamps)."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.ban_service.get_fail2ban_db_path",
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=empty_f2b_db_path),
|
new=AsyncMock(return_value=empty_f2b_db_path),
|
||||||
):
|
):
|
||||||
result = await ban_service.ban_trend("/fake/sock", "7d")
|
result = await ban_service.ban_trend("/fake/sock", "7d")
|
||||||
@@ -817,7 +804,7 @@ class TestBanTrend:
|
|||||||
)
|
)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.ban_service.get_fail2ban_db_path",
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=path),
|
new=AsyncMock(return_value=path),
|
||||||
):
|
):
|
||||||
result = await ban_service.ban_trend("/fake/sock", "24h")
|
result = await ban_service.ban_trend("/fake/sock", "24h")
|
||||||
@@ -841,7 +828,7 @@ class TestBanTrend:
|
|||||||
)
|
)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.ban_service.get_fail2ban_db_path",
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=path),
|
new=AsyncMock(return_value=path),
|
||||||
):
|
):
|
||||||
result = await ban_service.ban_trend(
|
result = await ban_service.ban_trend(
|
||||||
@@ -867,7 +854,7 @@ class TestBanTrend:
|
|||||||
)
|
)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.ban_service.get_fail2ban_db_path",
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=path),
|
new=AsyncMock(return_value=path),
|
||||||
):
|
):
|
||||||
result = await ban_service.ban_trend(
|
result = await ban_service.ban_trend(
|
||||||
@@ -881,7 +868,7 @@ class TestBanTrend:
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.ban_service.get_fail2ban_db_path",
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=empty_f2b_db_path),
|
new=AsyncMock(return_value=empty_f2b_db_path),
|
||||||
):
|
):
|
||||||
result = await ban_service.ban_trend("/fake/sock", "24h")
|
result = await ban_service.ban_trend("/fake/sock", "24h")
|
||||||
@@ -917,7 +904,7 @@ class TestBansByJail:
|
|||||||
)
|
)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.ban_service.get_fail2ban_db_path",
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=path),
|
new=AsyncMock(return_value=path),
|
||||||
):
|
):
|
||||||
result = await ban_service.bans_by_jail("/fake/sock", "24h")
|
result = await ban_service.bans_by_jail("/fake/sock", "24h")
|
||||||
@@ -944,7 +931,7 @@ class TestBansByJail:
|
|||||||
)
|
)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.ban_service.get_fail2ban_db_path",
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=path),
|
new=AsyncMock(return_value=path),
|
||||||
):
|
):
|
||||||
result = await ban_service.bans_by_jail("/fake/sock", "24h")
|
result = await ban_service.bans_by_jail("/fake/sock", "24h")
|
||||||
@@ -955,7 +942,7 @@ class TestBansByJail:
|
|||||||
async def test_empty_db_returns_empty_list(self, empty_f2b_db_path: str) -> None:
|
async def test_empty_db_returns_empty_list(self, empty_f2b_db_path: str) -> None:
|
||||||
"""An empty database returns an empty jails list with total zero."""
|
"""An empty database returns an empty jails list with total zero."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.ban_service.get_fail2ban_db_path",
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=empty_f2b_db_path),
|
new=AsyncMock(return_value=empty_f2b_db_path),
|
||||||
):
|
):
|
||||||
result = await ban_service.bans_by_jail("/fake/sock", "24h")
|
result = await ban_service.bans_by_jail("/fake/sock", "24h")
|
||||||
@@ -967,7 +954,7 @@ class TestBansByJail:
|
|||||||
"""Bans older than the time window are not counted."""
|
"""Bans older than the time window are not counted."""
|
||||||
# f2b_db_path has one ban from _TWO_DAYS_AGO, which is outside "24h".
|
# f2b_db_path has one ban from _TWO_DAYS_AGO, which is outside "24h".
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.ban_service.get_fail2ban_db_path",
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=f2b_db_path),
|
new=AsyncMock(return_value=f2b_db_path),
|
||||||
):
|
):
|
||||||
result = await ban_service.bans_by_jail("/fake/sock", "24h")
|
result = await ban_service.bans_by_jail("/fake/sock", "24h")
|
||||||
@@ -978,7 +965,7 @@ class TestBansByJail:
|
|||||||
async def test_origin_filter_blocklist(self, mixed_origin_db_path: str) -> None:
|
async def test_origin_filter_blocklist(self, mixed_origin_db_path: str) -> None:
|
||||||
"""``origin='blocklist'`` returns only the blocklist-import jail."""
|
"""``origin='blocklist'`` returns only the blocklist-import jail."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.ban_service.get_fail2ban_db_path",
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||||
):
|
):
|
||||||
result = await ban_service.bans_by_jail(
|
result = await ban_service.bans_by_jail(
|
||||||
@@ -992,7 +979,7 @@ class TestBansByJail:
|
|||||||
async def test_origin_filter_selfblock(self, mixed_origin_db_path: str) -> None:
|
async def test_origin_filter_selfblock(self, mixed_origin_db_path: str) -> None:
|
||||||
"""``origin='selfblock'`` excludes the blocklist-import jail."""
|
"""``origin='selfblock'`` excludes the blocklist-import jail."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.ban_service.get_fail2ban_db_path",
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||||
):
|
):
|
||||||
result = await ban_service.bans_by_jail(
|
result = await ban_service.bans_by_jail(
|
||||||
@@ -1008,7 +995,7 @@ class TestBansByJail:
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""``origin=None`` returns bans from all jails."""
|
"""``origin=None`` returns bans from all jails."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.ban_service.get_fail2ban_db_path",
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||||
):
|
):
|
||||||
result = await ban_service.bans_by_jail(
|
result = await ban_service.bans_by_jail(
|
||||||
@@ -1036,7 +1023,7 @@ class TestBansByJail:
|
|||||||
|
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
"app.services.ban_service.get_fail2ban_db_path",
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=path),
|
new=AsyncMock(return_value=path),
|
||||||
),
|
),
|
||||||
patch("app.services.ban_service.log") as mock_log,
|
patch("app.services.ban_service.log") as mock_log,
|
||||||
|
|||||||
@@ -19,8 +19,8 @@ from unittest.mock import AsyncMock, patch
|
|||||||
import aiosqlite
|
import aiosqlite
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from app.models.geo import GeoInfo
|
|
||||||
from app.services import ban_service, geo_service
|
from app.services import ban_service, geo_service
|
||||||
|
from app.services.geo_service import GeoInfo
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Constants
|
# Constants
|
||||||
@@ -114,13 +114,13 @@ async def _seed_f2b_db(path: str, n: int) -> list[str]:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def event_loop_policy() -> None:
|
def event_loop_policy() -> None: # type: ignore[misc]
|
||||||
"""Use the default event loop policy for module-scoped fixtures."""
|
"""Use the default event loop policy for module-scoped fixtures."""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
async def perf_db_path(tmp_path_factory: Any) -> str:
|
async def perf_db_path(tmp_path_factory: Any) -> str: # type: ignore[misc]
|
||||||
"""Return the path to a fail2ban DB seeded with 10 000 synthetic bans.
|
"""Return the path to a fail2ban DB seeded with 10 000 synthetic bans.
|
||||||
|
|
||||||
Module-scoped so the database is created only once for all perf tests.
|
Module-scoped so the database is created only once for all perf tests.
|
||||||
@@ -161,7 +161,7 @@ class TestBanServicePerformance:
|
|||||||
return geo_service._cache.get(ip) # noqa: SLF001
|
return geo_service._cache.get(ip) # noqa: SLF001
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.ban_service.get_fail2ban_db_path",
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=perf_db_path),
|
new=AsyncMock(return_value=perf_db_path),
|
||||||
):
|
):
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
@@ -191,7 +191,7 @@ class TestBanServicePerformance:
|
|||||||
return geo_service._cache.get(ip) # noqa: SLF001
|
return geo_service._cache.get(ip) # noqa: SLF001
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.ban_service.get_fail2ban_db_path",
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=perf_db_path),
|
new=AsyncMock(return_value=perf_db_path),
|
||||||
):
|
):
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
@@ -217,7 +217,7 @@ class TestBanServicePerformance:
|
|||||||
return geo_service._cache.get(ip) # noqa: SLF001
|
return geo_service._cache.get(ip) # noqa: SLF001
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.ban_service.get_fail2ban_db_path",
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=perf_db_path),
|
new=AsyncMock(return_value=perf_db_path),
|
||||||
):
|
):
|
||||||
result = await ban_service.list_bans(
|
result = await ban_service.list_bans(
|
||||||
@@ -241,7 +241,7 @@ class TestBanServicePerformance:
|
|||||||
return geo_service._cache.get(ip) # noqa: SLF001
|
return geo_service._cache.get(ip) # noqa: SLF001
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.ban_service.get_fail2ban_db_path",
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=perf_db_path),
|
new=AsyncMock(return_value=perf_db_path),
|
||||||
):
|
):
|
||||||
result = await ban_service.bans_by_country(
|
result = await ban_service.bans_by_country(
|
||||||
|
|||||||
@@ -203,15 +203,9 @@ class TestImport:
|
|||||||
call_count += 1
|
call_count += 1
|
||||||
raise JailNotFoundError(jail)
|
raise JailNotFoundError(jail)
|
||||||
|
|
||||||
with patch("app.services.jail_service.ban_ip", side_effect=_raise_jail_not_found) as mocked_ban_ip:
|
with patch("app.services.jail_service.ban_ip", side_effect=_raise_jail_not_found):
|
||||||
from app.services import jail_service
|
|
||||||
|
|
||||||
result = await blocklist_service.import_source(
|
result = await blocklist_service.import_source(
|
||||||
source,
|
source, session, "/tmp/fake.sock", db
|
||||||
session,
|
|
||||||
"/tmp/fake.sock",
|
|
||||||
db,
|
|
||||||
ban_ip=jail_service.ban_ip,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Must abort after the first JailNotFoundError — only one ban attempt.
|
# Must abort after the first JailNotFoundError — only one ban attempt.
|
||||||
@@ -232,14 +226,7 @@ class TestImport:
|
|||||||
with patch(
|
with patch(
|
||||||
"app.services.jail_service.ban_ip", new_callable=AsyncMock
|
"app.services.jail_service.ban_ip", new_callable=AsyncMock
|
||||||
):
|
):
|
||||||
from app.services import jail_service
|
result = await blocklist_service.import_all(db, session, "/tmp/fake.sock")
|
||||||
|
|
||||||
result = await blocklist_service.import_all(
|
|
||||||
db,
|
|
||||||
session,
|
|
||||||
"/tmp/fake.sock",
|
|
||||||
ban_ip=jail_service.ban_ip,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Only S1 is enabled, S2 is disabled.
|
# Only S1 is enabled, S2 is disabled.
|
||||||
assert len(result.results) == 1
|
assert len(result.results) == 1
|
||||||
@@ -328,15 +315,20 @@ class TestGeoPrewarmCacheFilter:
|
|||||||
def _mock_is_cached(ip: str) -> bool:
|
def _mock_is_cached(ip: str) -> bool:
|
||||||
return ip == "1.2.3.4"
|
return ip == "1.2.3.4"
|
||||||
|
|
||||||
mock_batch = AsyncMock(return_value={})
|
with (
|
||||||
with patch("app.services.jail_service.ban_ip", new_callable=AsyncMock):
|
patch("app.services.jail_service.ban_ip", new_callable=AsyncMock),
|
||||||
|
patch(
|
||||||
|
"app.services.geo_service.is_cached",
|
||||||
|
side_effect=_mock_is_cached,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"app.services.geo_service.lookup_batch",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value={},
|
||||||
|
) as mock_batch,
|
||||||
|
):
|
||||||
result = await blocklist_service.import_source(
|
result = await blocklist_service.import_source(
|
||||||
source,
|
source, session, "/tmp/fake.sock", db
|
||||||
session,
|
|
||||||
"/tmp/fake.sock",
|
|
||||||
db,
|
|
||||||
geo_is_cached=_mock_is_cached,
|
|
||||||
geo_batch_lookup=mock_batch,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result.ips_imported == 3
|
assert result.ips_imported == 3
|
||||||
@@ -345,40 +337,3 @@ class TestGeoPrewarmCacheFilter:
|
|||||||
call_ips = mock_batch.call_args[0][0]
|
call_ips = mock_batch.call_args[0][0]
|
||||||
assert "1.2.3.4" not in call_ips
|
assert "1.2.3.4" not in call_ips
|
||||||
assert set(call_ips) == {"5.6.7.8", "9.10.11.12"}
|
assert set(call_ips) == {"5.6.7.8", "9.10.11.12"}
|
||||||
|
|
||||||
|
|
||||||
class TestImportLogPagination:
|
|
||||||
async def test_list_import_logs_empty(self, db: aiosqlite.Connection) -> None:
|
|
||||||
"""list_import_logs returns an empty page when no logs exist."""
|
|
||||||
resp = await blocklist_service.list_import_logs(
|
|
||||||
db, source_id=None, page=1, page_size=10
|
|
||||||
)
|
|
||||||
assert resp.items == []
|
|
||||||
assert resp.total == 0
|
|
||||||
assert resp.page == 1
|
|
||||||
assert resp.page_size == 10
|
|
||||||
assert resp.total_pages == 1
|
|
||||||
|
|
||||||
async def test_list_import_logs_paginates(self, db: aiosqlite.Connection) -> None:
|
|
||||||
"""list_import_logs computes total pages and returns the correct subset."""
|
|
||||||
from app.repositories import import_log_repo
|
|
||||||
|
|
||||||
for i in range(3):
|
|
||||||
await import_log_repo.add_log(
|
|
||||||
db,
|
|
||||||
source_id=None,
|
|
||||||
source_url=f"https://example{i}.test/ips.txt",
|
|
||||||
ips_imported=1,
|
|
||||||
ips_skipped=0,
|
|
||||||
errors=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
resp = await blocklist_service.list_import_logs(
|
|
||||||
db, source_id=None, page=2, page_size=2
|
|
||||||
)
|
|
||||||
assert resp.total == 3
|
|
||||||
assert resp.total_pages == 2
|
|
||||||
assert resp.page == 2
|
|
||||||
assert resp.page_size == 2
|
|
||||||
assert len(resp.items) == 1
|
|
||||||
assert resp.items[0].source_url == "https://example0.test/ips.txt"
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from app.utils.conffile_parser import (
|
from app.services.conffile_parser import (
|
||||||
merge_action_update,
|
merge_action_update,
|
||||||
merge_filter_update,
|
merge_filter_update,
|
||||||
parse_action_file,
|
parse_action_file,
|
||||||
@@ -451,7 +451,7 @@ class TestParseJailFile:
|
|||||||
"""Unit tests for parse_jail_file."""
|
"""Unit tests for parse_jail_file."""
|
||||||
|
|
||||||
def test_minimal_parses_correctly(self) -> None:
|
def test_minimal_parses_correctly(self) -> None:
|
||||||
from app.utils.conffile_parser import parse_jail_file
|
from app.services.conffile_parser import parse_jail_file
|
||||||
|
|
||||||
cfg = parse_jail_file(MINIMAL_JAIL, filename="sshd.conf")
|
cfg = parse_jail_file(MINIMAL_JAIL, filename="sshd.conf")
|
||||||
assert cfg.filename == "sshd.conf"
|
assert cfg.filename == "sshd.conf"
|
||||||
@@ -463,7 +463,7 @@ class TestParseJailFile:
|
|||||||
assert jail.logpath == ["/var/log/auth.log"]
|
assert jail.logpath == ["/var/log/auth.log"]
|
||||||
|
|
||||||
def test_full_parses_multiple_jails(self) -> None:
|
def test_full_parses_multiple_jails(self) -> None:
|
||||||
from app.utils.conffile_parser import parse_jail_file
|
from app.services.conffile_parser import parse_jail_file
|
||||||
|
|
||||||
cfg = parse_jail_file(FULL_JAIL)
|
cfg = parse_jail_file(FULL_JAIL)
|
||||||
assert len(cfg.jails) == 2
|
assert len(cfg.jails) == 2
|
||||||
@@ -471,7 +471,7 @@ class TestParseJailFile:
|
|||||||
assert "nginx-botsearch" in cfg.jails
|
assert "nginx-botsearch" in cfg.jails
|
||||||
|
|
||||||
def test_full_jail_numeric_fields(self) -> None:
|
def test_full_jail_numeric_fields(self) -> None:
|
||||||
from app.utils.conffile_parser import parse_jail_file
|
from app.services.conffile_parser import parse_jail_file
|
||||||
|
|
||||||
jail = parse_jail_file(FULL_JAIL).jails["sshd"]
|
jail = parse_jail_file(FULL_JAIL).jails["sshd"]
|
||||||
assert jail.maxretry == 3
|
assert jail.maxretry == 3
|
||||||
@@ -479,7 +479,7 @@ class TestParseJailFile:
|
|||||||
assert jail.bantime == 3600
|
assert jail.bantime == 3600
|
||||||
|
|
||||||
def test_full_jail_multiline_logpath(self) -> None:
|
def test_full_jail_multiline_logpath(self) -> None:
|
||||||
from app.utils.conffile_parser import parse_jail_file
|
from app.services.conffile_parser import parse_jail_file
|
||||||
|
|
||||||
jail = parse_jail_file(FULL_JAIL).jails["sshd"]
|
jail = parse_jail_file(FULL_JAIL).jails["sshd"]
|
||||||
assert len(jail.logpath) == 2
|
assert len(jail.logpath) == 2
|
||||||
@@ -487,53 +487,53 @@ class TestParseJailFile:
|
|||||||
assert "/var/log/syslog" in jail.logpath
|
assert "/var/log/syslog" in jail.logpath
|
||||||
|
|
||||||
def test_full_jail_multiline_action(self) -> None:
|
def test_full_jail_multiline_action(self) -> None:
|
||||||
from app.utils.conffile_parser import parse_jail_file
|
from app.services.conffile_parser import parse_jail_file
|
||||||
|
|
||||||
jail = parse_jail_file(FULL_JAIL).jails["nginx-botsearch"]
|
jail = parse_jail_file(FULL_JAIL).jails["nginx-botsearch"]
|
||||||
assert len(jail.action) == 2
|
assert len(jail.action) == 2
|
||||||
assert "sendmail-whois" in jail.action
|
assert "sendmail-whois" in jail.action
|
||||||
|
|
||||||
def test_enabled_true(self) -> None:
|
def test_enabled_true(self) -> None:
|
||||||
from app.utils.conffile_parser import parse_jail_file
|
from app.services.conffile_parser import parse_jail_file
|
||||||
|
|
||||||
jail = parse_jail_file(FULL_JAIL).jails["sshd"]
|
jail = parse_jail_file(FULL_JAIL).jails["sshd"]
|
||||||
assert jail.enabled is True
|
assert jail.enabled is True
|
||||||
|
|
||||||
def test_enabled_false(self) -> None:
|
def test_enabled_false(self) -> None:
|
||||||
from app.utils.conffile_parser import parse_jail_file
|
from app.services.conffile_parser import parse_jail_file
|
||||||
|
|
||||||
jail = parse_jail_file(FULL_JAIL).jails["nginx-botsearch"]
|
jail = parse_jail_file(FULL_JAIL).jails["nginx-botsearch"]
|
||||||
assert jail.enabled is False
|
assert jail.enabled is False
|
||||||
|
|
||||||
def test_extra_keys_captured(self) -> None:
|
def test_extra_keys_captured(self) -> None:
|
||||||
from app.utils.conffile_parser import parse_jail_file
|
from app.services.conffile_parser import parse_jail_file
|
||||||
|
|
||||||
jail = parse_jail_file(JAIL_WITH_EXTRA).jails["sshd"]
|
jail = parse_jail_file(JAIL_WITH_EXTRA).jails["sshd"]
|
||||||
assert jail.extra["custom_key"] == "custom_value"
|
assert jail.extra["custom_key"] == "custom_value"
|
||||||
assert jail.extra["another_key"] == "42"
|
assert jail.extra["another_key"] == "42"
|
||||||
|
|
||||||
def test_extra_keys_not_in_named_fields(self) -> None:
|
def test_extra_keys_not_in_named_fields(self) -> None:
|
||||||
from app.utils.conffile_parser import parse_jail_file
|
from app.services.conffile_parser import parse_jail_file
|
||||||
|
|
||||||
jail = parse_jail_file(JAIL_WITH_EXTRA).jails["sshd"]
|
jail = parse_jail_file(JAIL_WITH_EXTRA).jails["sshd"]
|
||||||
assert "enabled" not in jail.extra
|
assert "enabled" not in jail.extra
|
||||||
assert "logpath" not in jail.extra
|
assert "logpath" not in jail.extra
|
||||||
|
|
||||||
def test_empty_file_yields_no_jails(self) -> None:
|
def test_empty_file_yields_no_jails(self) -> None:
|
||||||
from app.utils.conffile_parser import parse_jail_file
|
from app.services.conffile_parser import parse_jail_file
|
||||||
|
|
||||||
cfg = parse_jail_file("")
|
cfg = parse_jail_file("")
|
||||||
assert cfg.jails == {}
|
assert cfg.jails == {}
|
||||||
|
|
||||||
def test_invalid_ini_does_not_raise(self) -> None:
|
def test_invalid_ini_does_not_raise(self) -> None:
|
||||||
from app.utils.conffile_parser import parse_jail_file
|
from app.services.conffile_parser import parse_jail_file
|
||||||
|
|
||||||
# Should not raise; just parse what it can.
|
# Should not raise; just parse what it can.
|
||||||
cfg = parse_jail_file("@@@ not valid ini @@@", filename="bad.conf")
|
cfg = parse_jail_file("@@@ not valid ini @@@", filename="bad.conf")
|
||||||
assert isinstance(cfg.jails, dict)
|
assert isinstance(cfg.jails, dict)
|
||||||
|
|
||||||
def test_default_section_ignored(self) -> None:
|
def test_default_section_ignored(self) -> None:
|
||||||
from app.utils.conffile_parser import parse_jail_file
|
from app.services.conffile_parser import parse_jail_file
|
||||||
|
|
||||||
content = "[DEFAULT]\nignoreip = 127.0.0.1\n\n[sshd]\nenabled = true\n"
|
content = "[DEFAULT]\nignoreip = 127.0.0.1\n\n[sshd]\nenabled = true\n"
|
||||||
cfg = parse_jail_file(content)
|
cfg = parse_jail_file(content)
|
||||||
@@ -550,7 +550,7 @@ class TestJailFileRoundTrip:
|
|||||||
"""Tests that parse → serialize → parse preserves values."""
|
"""Tests that parse → serialize → parse preserves values."""
|
||||||
|
|
||||||
def test_minimal_round_trip(self) -> None:
|
def test_minimal_round_trip(self) -> None:
|
||||||
from app.utils.conffile_parser import parse_jail_file, serialize_jail_file_config
|
from app.services.conffile_parser import parse_jail_file, serialize_jail_file_config
|
||||||
|
|
||||||
original = parse_jail_file(MINIMAL_JAIL, filename="sshd.conf")
|
original = parse_jail_file(MINIMAL_JAIL, filename="sshd.conf")
|
||||||
serialized = serialize_jail_file_config(original)
|
serialized = serialize_jail_file_config(original)
|
||||||
@@ -560,7 +560,7 @@ class TestJailFileRoundTrip:
|
|||||||
assert restored.jails["sshd"].logpath == original.jails["sshd"].logpath
|
assert restored.jails["sshd"].logpath == original.jails["sshd"].logpath
|
||||||
|
|
||||||
def test_full_round_trip(self) -> None:
|
def test_full_round_trip(self) -> None:
|
||||||
from app.utils.conffile_parser import parse_jail_file, serialize_jail_file_config
|
from app.services.conffile_parser import parse_jail_file, serialize_jail_file_config
|
||||||
|
|
||||||
original = parse_jail_file(FULL_JAIL)
|
original = parse_jail_file(FULL_JAIL)
|
||||||
serialized = serialize_jail_file_config(original)
|
serialized = serialize_jail_file_config(original)
|
||||||
@@ -573,7 +573,7 @@ class TestJailFileRoundTrip:
|
|||||||
assert sorted(restored_jail.action) == sorted(jail.action)
|
assert sorted(restored_jail.action) == sorted(jail.action)
|
||||||
|
|
||||||
def test_extra_keys_round_trip(self) -> None:
|
def test_extra_keys_round_trip(self) -> None:
|
||||||
from app.utils.conffile_parser import parse_jail_file, serialize_jail_file_config
|
from app.services.conffile_parser import parse_jail_file, serialize_jail_file_config
|
||||||
|
|
||||||
original = parse_jail_file(JAIL_WITH_EXTRA)
|
original = parse_jail_file(JAIL_WITH_EXTRA)
|
||||||
serialized = serialize_jail_file_config(original)
|
serialized = serialize_jail_file_config(original)
|
||||||
@@ -591,7 +591,7 @@ class TestMergeJailFileUpdate:
|
|||||||
|
|
||||||
def test_none_update_returns_original(self) -> None:
|
def test_none_update_returns_original(self) -> None:
|
||||||
from app.models.config import JailFileConfigUpdate
|
from app.models.config import JailFileConfigUpdate
|
||||||
from app.utils.conffile_parser import merge_jail_file_update, parse_jail_file
|
from app.services.conffile_parser import merge_jail_file_update, parse_jail_file
|
||||||
|
|
||||||
cfg = parse_jail_file(FULL_JAIL)
|
cfg = parse_jail_file(FULL_JAIL)
|
||||||
update = JailFileConfigUpdate()
|
update = JailFileConfigUpdate()
|
||||||
@@ -600,7 +600,7 @@ class TestMergeJailFileUpdate:
|
|||||||
|
|
||||||
def test_update_replaces_jail(self) -> None:
|
def test_update_replaces_jail(self) -> None:
|
||||||
from app.models.config import JailFileConfigUpdate, JailSectionConfig
|
from app.models.config import JailFileConfigUpdate, JailSectionConfig
|
||||||
from app.utils.conffile_parser import merge_jail_file_update, parse_jail_file
|
from app.services.conffile_parser import merge_jail_file_update, parse_jail_file
|
||||||
|
|
||||||
cfg = parse_jail_file(FULL_JAIL)
|
cfg = parse_jail_file(FULL_JAIL)
|
||||||
new_sshd = JailSectionConfig(enabled=False, port="2222")
|
new_sshd = JailSectionConfig(enabled=False, port="2222")
|
||||||
@@ -613,7 +613,7 @@ class TestMergeJailFileUpdate:
|
|||||||
|
|
||||||
def test_update_adds_new_jail(self) -> None:
|
def test_update_adds_new_jail(self) -> None:
|
||||||
from app.models.config import JailFileConfigUpdate, JailSectionConfig
|
from app.models.config import JailFileConfigUpdate, JailSectionConfig
|
||||||
from app.utils.conffile_parser import merge_jail_file_update, parse_jail_file
|
from app.services.conffile_parser import merge_jail_file_update, parse_jail_file
|
||||||
|
|
||||||
cfg = parse_jail_file(MINIMAL_JAIL)
|
cfg = parse_jail_file(MINIMAL_JAIL)
|
||||||
new_jail = JailSectionConfig(enabled=True, port="443")
|
new_jail = JailSectionConfig(enabled=True, port="443")
|
||||||
|
|||||||
@@ -13,19 +13,15 @@ from app.services.config_file_service import (
|
|||||||
JailNameError,
|
JailNameError,
|
||||||
JailNotFoundInConfigError,
|
JailNotFoundInConfigError,
|
||||||
_build_inactive_jail,
|
_build_inactive_jail,
|
||||||
_extract_action_base_name,
|
|
||||||
_extract_filter_base_name,
|
|
||||||
_ordered_config_files,
|
_ordered_config_files,
|
||||||
_parse_jails_sync,
|
_parse_jails_sync,
|
||||||
_resolve_filter,
|
_resolve_filter,
|
||||||
_safe_jail_name,
|
_safe_jail_name,
|
||||||
_validate_jail_config_sync,
|
|
||||||
_write_local_override_sync,
|
_write_local_override_sync,
|
||||||
activate_jail,
|
activate_jail,
|
||||||
deactivate_jail,
|
deactivate_jail,
|
||||||
list_inactive_jails,
|
list_inactive_jails,
|
||||||
rollback_jail,
|
rollback_jail,
|
||||||
validate_jail_config,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -296,7 +292,9 @@ class TestBuildInactiveJail:
|
|||||||
|
|
||||||
def test_has_local_override_absent(self, tmp_path: Path) -> None:
|
def test_has_local_override_absent(self, tmp_path: Path) -> None:
|
||||||
"""has_local_override is False when no .local file exists."""
|
"""has_local_override is False when no .local file exists."""
|
||||||
jail = _build_inactive_jail("sshd", {}, "/etc/fail2ban/jail.d/sshd.conf", config_dir=tmp_path)
|
jail = _build_inactive_jail(
|
||||||
|
"sshd", {}, "/etc/fail2ban/jail.d/sshd.conf", config_dir=tmp_path
|
||||||
|
)
|
||||||
assert jail.has_local_override is False
|
assert jail.has_local_override is False
|
||||||
|
|
||||||
def test_has_local_override_present(self, tmp_path: Path) -> None:
|
def test_has_local_override_present(self, tmp_path: Path) -> None:
|
||||||
@@ -304,7 +302,9 @@ class TestBuildInactiveJail:
|
|||||||
local = tmp_path / "jail.d" / "sshd.local"
|
local = tmp_path / "jail.d" / "sshd.local"
|
||||||
local.parent.mkdir(parents=True, exist_ok=True)
|
local.parent.mkdir(parents=True, exist_ok=True)
|
||||||
local.write_text("[sshd]\nenabled = false\n")
|
local.write_text("[sshd]\nenabled = false\n")
|
||||||
jail = _build_inactive_jail("sshd", {}, "/etc/fail2ban/jail.d/sshd.conf", config_dir=tmp_path)
|
jail = _build_inactive_jail(
|
||||||
|
"sshd", {}, "/etc/fail2ban/jail.d/sshd.conf", config_dir=tmp_path
|
||||||
|
)
|
||||||
assert jail.has_local_override is True
|
assert jail.has_local_override is True
|
||||||
|
|
||||||
def test_has_local_override_no_config_dir(self) -> None:
|
def test_has_local_override_no_config_dir(self) -> None:
|
||||||
@@ -363,7 +363,9 @@ class TestWriteLocalOverrideSync:
|
|||||||
assert "2222" in content
|
assert "2222" in content
|
||||||
|
|
||||||
def test_override_logpath_list(self, tmp_path: Path) -> None:
|
def test_override_logpath_list(self, tmp_path: Path) -> None:
|
||||||
_write_local_override_sync(tmp_path, "sshd", True, {"logpath": ["/var/log/auth.log", "/var/log/secure"]})
|
_write_local_override_sync(
|
||||||
|
tmp_path, "sshd", True, {"logpath": ["/var/log/auth.log", "/var/log/secure"]}
|
||||||
|
)
|
||||||
content = (tmp_path / "jail.d" / "sshd.local").read_text()
|
content = (tmp_path / "jail.d" / "sshd.local").read_text()
|
||||||
assert "/var/log/auth.log" in content
|
assert "/var/log/auth.log" in content
|
||||||
assert "/var/log/secure" in content
|
assert "/var/log/secure" in content
|
||||||
@@ -445,7 +447,9 @@ class TestListInactiveJails:
|
|||||||
assert "sshd" in names
|
assert "sshd" in names
|
||||||
assert "apache-auth" in names
|
assert "apache-auth" in names
|
||||||
|
|
||||||
async def test_has_local_override_true_when_local_file_exists(self, tmp_path: Path) -> None:
|
async def test_has_local_override_true_when_local_file_exists(
|
||||||
|
self, tmp_path: Path
|
||||||
|
) -> None:
|
||||||
"""has_local_override is True for a jail whose jail.d .local file exists."""
|
"""has_local_override is True for a jail whose jail.d .local file exists."""
|
||||||
_write(tmp_path / "jail.conf", JAIL_CONF)
|
_write(tmp_path / "jail.conf", JAIL_CONF)
|
||||||
local = tmp_path / "jail.d" / "apache-auth.local"
|
local = tmp_path / "jail.d" / "apache-auth.local"
|
||||||
@@ -459,7 +463,9 @@ class TestListInactiveJails:
|
|||||||
jail = next(j for j in result.jails if j.name == "apache-auth")
|
jail = next(j for j in result.jails if j.name == "apache-auth")
|
||||||
assert jail.has_local_override is True
|
assert jail.has_local_override is True
|
||||||
|
|
||||||
async def test_has_local_override_false_when_no_local_file(self, tmp_path: Path) -> None:
|
async def test_has_local_override_false_when_no_local_file(
|
||||||
|
self, tmp_path: Path
|
||||||
|
) -> None:
|
||||||
"""has_local_override is False when no jail.d .local file exists."""
|
"""has_local_override is False when no jail.d .local file exists."""
|
||||||
_write(tmp_path / "jail.conf", JAIL_CONF)
|
_write(tmp_path / "jail.conf", JAIL_CONF)
|
||||||
with patch(
|
with patch(
|
||||||
@@ -602,8 +608,7 @@ class TestActivateJail:
|
|||||||
patch(
|
patch(
|
||||||
"app.services.config_file_service._get_active_jail_names",
|
"app.services.config_file_service._get_active_jail_names",
|
||||||
new=AsyncMock(return_value=set()),
|
new=AsyncMock(return_value=set()),
|
||||||
),
|
),pytest.raises(JailNotFoundInConfigError)
|
||||||
pytest.raises(JailNotFoundInConfigError),
|
|
||||||
):
|
):
|
||||||
await activate_jail(str(tmp_path), "/fake.sock", "nonexistent", req)
|
await activate_jail(str(tmp_path), "/fake.sock", "nonexistent", req)
|
||||||
|
|
||||||
@@ -616,8 +621,7 @@ class TestActivateJail:
|
|||||||
patch(
|
patch(
|
||||||
"app.services.config_file_service._get_active_jail_names",
|
"app.services.config_file_service._get_active_jail_names",
|
||||||
new=AsyncMock(return_value={"sshd"}),
|
new=AsyncMock(return_value={"sshd"}),
|
||||||
),
|
),pytest.raises(JailAlreadyActiveError)
|
||||||
pytest.raises(JailAlreadyActiveError),
|
|
||||||
):
|
):
|
||||||
await activate_jail(str(tmp_path), "/fake.sock", "sshd", req)
|
await activate_jail(str(tmp_path), "/fake.sock", "sshd", req)
|
||||||
|
|
||||||
@@ -687,8 +691,7 @@ class TestDeactivateJail:
|
|||||||
patch(
|
patch(
|
||||||
"app.services.config_file_service._get_active_jail_names",
|
"app.services.config_file_service._get_active_jail_names",
|
||||||
new=AsyncMock(return_value={"sshd"}),
|
new=AsyncMock(return_value={"sshd"}),
|
||||||
),
|
),pytest.raises(JailNotFoundInConfigError)
|
||||||
pytest.raises(JailNotFoundInConfigError),
|
|
||||||
):
|
):
|
||||||
await deactivate_jail(str(tmp_path), "/fake.sock", "nonexistent")
|
await deactivate_jail(str(tmp_path), "/fake.sock", "nonexistent")
|
||||||
|
|
||||||
@@ -698,8 +701,7 @@ class TestDeactivateJail:
|
|||||||
patch(
|
patch(
|
||||||
"app.services.config_file_service._get_active_jail_names",
|
"app.services.config_file_service._get_active_jail_names",
|
||||||
new=AsyncMock(return_value=set()),
|
new=AsyncMock(return_value=set()),
|
||||||
),
|
),pytest.raises(JailAlreadyInactiveError)
|
||||||
pytest.raises(JailAlreadyInactiveError),
|
|
||||||
):
|
):
|
||||||
await deactivate_jail(str(tmp_path), "/fake.sock", "apache-auth")
|
await deactivate_jail(str(tmp_path), "/fake.sock", "apache-auth")
|
||||||
|
|
||||||
@@ -708,6 +710,38 @@ class TestDeactivateJail:
|
|||||||
await deactivate_jail(str(tmp_path), "/fake.sock", "a/b")
|
await deactivate_jail(str(tmp_path), "/fake.sock", "a/b")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _extract_filter_base_name
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtractFilterBaseName:
|
||||||
|
def test_simple_name(self) -> None:
|
||||||
|
from app.services.config_file_service import _extract_filter_base_name
|
||||||
|
|
||||||
|
assert _extract_filter_base_name("sshd") == "sshd"
|
||||||
|
|
||||||
|
def test_name_with_mode(self) -> None:
|
||||||
|
from app.services.config_file_service import _extract_filter_base_name
|
||||||
|
|
||||||
|
assert _extract_filter_base_name("sshd[mode=aggressive]") == "sshd"
|
||||||
|
|
||||||
|
def test_name_with_variable_mode(self) -> None:
|
||||||
|
from app.services.config_file_service import _extract_filter_base_name
|
||||||
|
|
||||||
|
assert _extract_filter_base_name("sshd[mode=%(mode)s]") == "sshd"
|
||||||
|
|
||||||
|
def test_whitespace_stripped(self) -> None:
|
||||||
|
from app.services.config_file_service import _extract_filter_base_name
|
||||||
|
|
||||||
|
assert _extract_filter_base_name(" nginx ") == "nginx"
|
||||||
|
|
||||||
|
def test_empty_string(self) -> None:
|
||||||
|
from app.services.config_file_service import _extract_filter_base_name
|
||||||
|
|
||||||
|
assert _extract_filter_base_name("") == ""
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# _build_filter_to_jails_map
|
# _build_filter_to_jails_map
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -723,7 +757,9 @@ class TestBuildFilterToJailsMap:
|
|||||||
def test_inactive_jail_not_included(self) -> None:
|
def test_inactive_jail_not_included(self) -> None:
|
||||||
from app.services.config_file_service import _build_filter_to_jails_map
|
from app.services.config_file_service import _build_filter_to_jails_map
|
||||||
|
|
||||||
result = _build_filter_to_jails_map({"apache-auth": {"filter": "apache-auth"}}, set())
|
result = _build_filter_to_jails_map(
|
||||||
|
{"apache-auth": {"filter": "apache-auth"}}, set()
|
||||||
|
)
|
||||||
assert result == {}
|
assert result == {}
|
||||||
|
|
||||||
def test_multiple_jails_sharing_filter(self) -> None:
|
def test_multiple_jails_sharing_filter(self) -> None:
|
||||||
@@ -739,7 +775,9 @@ class TestBuildFilterToJailsMap:
|
|||||||
def test_mode_suffix_stripped(self) -> None:
|
def test_mode_suffix_stripped(self) -> None:
|
||||||
from app.services.config_file_service import _build_filter_to_jails_map
|
from app.services.config_file_service import _build_filter_to_jails_map
|
||||||
|
|
||||||
result = _build_filter_to_jails_map({"sshd": {"filter": "sshd[mode=aggressive]"}}, {"sshd"})
|
result = _build_filter_to_jails_map(
|
||||||
|
{"sshd": {"filter": "sshd[mode=aggressive]"}}, {"sshd"}
|
||||||
|
)
|
||||||
assert "sshd" in result
|
assert "sshd" in result
|
||||||
|
|
||||||
def test_missing_filter_key_falls_back_to_jail_name(self) -> None:
|
def test_missing_filter_key_falls_back_to_jail_name(self) -> None:
|
||||||
@@ -950,13 +988,10 @@ class TestGetFilter:
|
|||||||
async def test_raises_filter_not_found(self, tmp_path: Path) -> None:
|
async def test_raises_filter_not_found(self, tmp_path: Path) -> None:
|
||||||
from app.services.config_file_service import FilterNotFoundError, get_filter
|
from app.services.config_file_service import FilterNotFoundError, get_filter
|
||||||
|
|
||||||
with (
|
with patch(
|
||||||
patch(
|
"app.services.config_file_service._get_active_jail_names",
|
||||||
"app.services.config_file_service._get_active_jail_names",
|
new=AsyncMock(return_value=set()),
|
||||||
new=AsyncMock(return_value=set()),
|
), pytest.raises(FilterNotFoundError):
|
||||||
),
|
|
||||||
pytest.raises(FilterNotFoundError),
|
|
||||||
):
|
|
||||||
await get_filter(str(tmp_path), "/fake.sock", "nonexistent")
|
await get_filter(str(tmp_path), "/fake.sock", "nonexistent")
|
||||||
|
|
||||||
async def test_has_local_override_detected(self, tmp_path: Path) -> None:
|
async def test_has_local_override_detected(self, tmp_path: Path) -> None:
|
||||||
@@ -1058,13 +1093,10 @@ class TestGetFilterLocalOnly:
|
|||||||
async def test_raises_when_neither_conf_nor_local(self, tmp_path: Path) -> None:
|
async def test_raises_when_neither_conf_nor_local(self, tmp_path: Path) -> None:
|
||||||
from app.services.config_file_service import FilterNotFoundError, get_filter
|
from app.services.config_file_service import FilterNotFoundError, get_filter
|
||||||
|
|
||||||
with (
|
with patch(
|
||||||
patch(
|
"app.services.config_file_service._get_active_jail_names",
|
||||||
"app.services.config_file_service._get_active_jail_names",
|
new=AsyncMock(return_value=set()),
|
||||||
new=AsyncMock(return_value=set()),
|
), pytest.raises(FilterNotFoundError):
|
||||||
),
|
|
||||||
pytest.raises(FilterNotFoundError),
|
|
||||||
):
|
|
||||||
await get_filter(str(tmp_path), "/fake.sock", "nonexistent")
|
await get_filter(str(tmp_path), "/fake.sock", "nonexistent")
|
||||||
|
|
||||||
async def test_accepts_local_extension(self, tmp_path: Path) -> None:
|
async def test_accepts_local_extension(self, tmp_path: Path) -> None:
|
||||||
@@ -1180,7 +1212,9 @@ class TestSetJailLocalKeySync:
|
|||||||
|
|
||||||
jail_d = tmp_path / "jail.d"
|
jail_d = tmp_path / "jail.d"
|
||||||
jail_d.mkdir()
|
jail_d.mkdir()
|
||||||
(jail_d / "sshd.local").write_text("[sshd]\nenabled = true\n")
|
(jail_d / "sshd.local").write_text(
|
||||||
|
"[sshd]\nenabled = true\n"
|
||||||
|
)
|
||||||
|
|
||||||
_set_jail_local_key_sync(tmp_path, "sshd", "filter", "newfilter")
|
_set_jail_local_key_sync(tmp_path, "sshd", "filter", "newfilter")
|
||||||
|
|
||||||
@@ -1266,13 +1300,10 @@ class TestUpdateFilter:
|
|||||||
from app.models.config import FilterUpdateRequest
|
from app.models.config import FilterUpdateRequest
|
||||||
from app.services.config_file_service import FilterNotFoundError, update_filter
|
from app.services.config_file_service import FilterNotFoundError, update_filter
|
||||||
|
|
||||||
with (
|
with patch(
|
||||||
patch(
|
"app.services.config_file_service._get_active_jail_names",
|
||||||
"app.services.config_file_service._get_active_jail_names",
|
new=AsyncMock(return_value=set()),
|
||||||
new=AsyncMock(return_value=set()),
|
), pytest.raises(FilterNotFoundError):
|
||||||
),
|
|
||||||
pytest.raises(FilterNotFoundError),
|
|
||||||
):
|
|
||||||
await update_filter(
|
await update_filter(
|
||||||
str(tmp_path),
|
str(tmp_path),
|
||||||
"/fake.sock",
|
"/fake.sock",
|
||||||
@@ -1290,13 +1321,10 @@ class TestUpdateFilter:
|
|||||||
filter_d = tmp_path / "filter.d"
|
filter_d = tmp_path / "filter.d"
|
||||||
_write(filter_d / "sshd.conf", _FILTER_CONF_WITH_REGEX)
|
_write(filter_d / "sshd.conf", _FILTER_CONF_WITH_REGEX)
|
||||||
|
|
||||||
with (
|
with patch(
|
||||||
patch(
|
"app.services.config_file_service._get_active_jail_names",
|
||||||
"app.services.config_file_service._get_active_jail_names",
|
new=AsyncMock(return_value=set()),
|
||||||
new=AsyncMock(return_value=set()),
|
), pytest.raises(FilterInvalidRegexError):
|
||||||
),
|
|
||||||
pytest.raises(FilterInvalidRegexError),
|
|
||||||
):
|
|
||||||
await update_filter(
|
await update_filter(
|
||||||
str(tmp_path),
|
str(tmp_path),
|
||||||
"/fake.sock",
|
"/fake.sock",
|
||||||
@@ -1323,16 +1351,13 @@ class TestUpdateFilter:
|
|||||||
filter_d = tmp_path / "filter.d"
|
filter_d = tmp_path / "filter.d"
|
||||||
_write(filter_d / "sshd.conf", _FILTER_CONF)
|
_write(filter_d / "sshd.conf", _FILTER_CONF)
|
||||||
|
|
||||||
with (
|
with patch(
|
||||||
patch(
|
"app.services.config_file_service._get_active_jail_names",
|
||||||
"app.services.config_file_service._get_active_jail_names",
|
new=AsyncMock(return_value=set()),
|
||||||
new=AsyncMock(return_value=set()),
|
), patch(
|
||||||
),
|
"app.services.config_file_service.jail_service.reload_all",
|
||||||
patch(
|
new=AsyncMock(),
|
||||||
"app.services.config_file_service.jail_service.reload_all",
|
) as mock_reload:
|
||||||
new=AsyncMock(),
|
|
||||||
) as mock_reload,
|
|
||||||
):
|
|
||||||
await update_filter(
|
await update_filter(
|
||||||
str(tmp_path),
|
str(tmp_path),
|
||||||
"/fake.sock",
|
"/fake.sock",
|
||||||
@@ -1380,13 +1405,10 @@ class TestCreateFilter:
|
|||||||
filter_d = tmp_path / "filter.d"
|
filter_d = tmp_path / "filter.d"
|
||||||
_write(filter_d / "sshd.conf", _FILTER_CONF)
|
_write(filter_d / "sshd.conf", _FILTER_CONF)
|
||||||
|
|
||||||
with (
|
with patch(
|
||||||
patch(
|
"app.services.config_file_service._get_active_jail_names",
|
||||||
"app.services.config_file_service._get_active_jail_names",
|
new=AsyncMock(return_value=set()),
|
||||||
new=AsyncMock(return_value=set()),
|
), pytest.raises(FilterAlreadyExistsError):
|
||||||
),
|
|
||||||
pytest.raises(FilterAlreadyExistsError),
|
|
||||||
):
|
|
||||||
await create_filter(
|
await create_filter(
|
||||||
str(tmp_path),
|
str(tmp_path),
|
||||||
"/fake.sock",
|
"/fake.sock",
|
||||||
@@ -1400,13 +1422,10 @@ class TestCreateFilter:
|
|||||||
filter_d = tmp_path / "filter.d"
|
filter_d = tmp_path / "filter.d"
|
||||||
_write(filter_d / "custom.local", "[Definition]\n")
|
_write(filter_d / "custom.local", "[Definition]\n")
|
||||||
|
|
||||||
with (
|
with patch(
|
||||||
patch(
|
"app.services.config_file_service._get_active_jail_names",
|
||||||
"app.services.config_file_service._get_active_jail_names",
|
new=AsyncMock(return_value=set()),
|
||||||
new=AsyncMock(return_value=set()),
|
), pytest.raises(FilterAlreadyExistsError):
|
||||||
),
|
|
||||||
pytest.raises(FilterAlreadyExistsError),
|
|
||||||
):
|
|
||||||
await create_filter(
|
await create_filter(
|
||||||
str(tmp_path),
|
str(tmp_path),
|
||||||
"/fake.sock",
|
"/fake.sock",
|
||||||
@@ -1417,13 +1436,10 @@ class TestCreateFilter:
|
|||||||
from app.models.config import FilterCreateRequest
|
from app.models.config import FilterCreateRequest
|
||||||
from app.services.config_file_service import FilterInvalidRegexError, create_filter
|
from app.services.config_file_service import FilterInvalidRegexError, create_filter
|
||||||
|
|
||||||
with (
|
with patch(
|
||||||
patch(
|
"app.services.config_file_service._get_active_jail_names",
|
||||||
"app.services.config_file_service._get_active_jail_names",
|
new=AsyncMock(return_value=set()),
|
||||||
new=AsyncMock(return_value=set()),
|
), pytest.raises(FilterInvalidRegexError):
|
||||||
),
|
|
||||||
pytest.raises(FilterInvalidRegexError),
|
|
||||||
):
|
|
||||||
await create_filter(
|
await create_filter(
|
||||||
str(tmp_path),
|
str(tmp_path),
|
||||||
"/fake.sock",
|
"/fake.sock",
|
||||||
@@ -1445,16 +1461,13 @@ class TestCreateFilter:
|
|||||||
from app.models.config import FilterCreateRequest
|
from app.models.config import FilterCreateRequest
|
||||||
from app.services.config_file_service import create_filter
|
from app.services.config_file_service import create_filter
|
||||||
|
|
||||||
with (
|
with patch(
|
||||||
patch(
|
"app.services.config_file_service._get_active_jail_names",
|
||||||
"app.services.config_file_service._get_active_jail_names",
|
new=AsyncMock(return_value=set()),
|
||||||
new=AsyncMock(return_value=set()),
|
), patch(
|
||||||
),
|
"app.services.config_file_service.jail_service.reload_all",
|
||||||
patch(
|
new=AsyncMock(),
|
||||||
"app.services.config_file_service.jail_service.reload_all",
|
) as mock_reload:
|
||||||
new=AsyncMock(),
|
|
||||||
) as mock_reload,
|
|
||||||
):
|
|
||||||
await create_filter(
|
await create_filter(
|
||||||
str(tmp_path),
|
str(tmp_path),
|
||||||
"/fake.sock",
|
"/fake.sock",
|
||||||
@@ -1472,7 +1485,9 @@ class TestCreateFilter:
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
class TestDeleteFilter:
|
class TestDeleteFilter:
|
||||||
async def test_deletes_local_file_when_conf_and_local_exist(self, tmp_path: Path) -> None:
|
async def test_deletes_local_file_when_conf_and_local_exist(
|
||||||
|
self, tmp_path: Path
|
||||||
|
) -> None:
|
||||||
from app.services.config_file_service import delete_filter
|
from app.services.config_file_service import delete_filter
|
||||||
|
|
||||||
filter_d = tmp_path / "filter.d"
|
filter_d = tmp_path / "filter.d"
|
||||||
@@ -1509,7 +1524,9 @@ class TestDeleteFilter:
|
|||||||
with pytest.raises(FilterNotFoundError):
|
with pytest.raises(FilterNotFoundError):
|
||||||
await delete_filter(str(tmp_path), "nonexistent")
|
await delete_filter(str(tmp_path), "nonexistent")
|
||||||
|
|
||||||
async def test_accepts_filter_name_error_for_invalid_name(self, tmp_path: Path) -> None:
|
async def test_accepts_filter_name_error_for_invalid_name(
|
||||||
|
self, tmp_path: Path
|
||||||
|
) -> None:
|
||||||
from app.services.config_file_service import FilterNameError, delete_filter
|
from app.services.config_file_service import FilterNameError, delete_filter
|
||||||
|
|
||||||
with pytest.raises(FilterNameError):
|
with pytest.raises(FilterNameError):
|
||||||
@@ -1590,7 +1607,9 @@ class TestAssignFilterToJail:
|
|||||||
AssignFilterRequest(filter_name="sshd"),
|
AssignFilterRequest(filter_name="sshd"),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def test_raises_filter_name_error_for_invalid_filter(self, tmp_path: Path) -> None:
|
async def test_raises_filter_name_error_for_invalid_filter(
|
||||||
|
self, tmp_path: Path
|
||||||
|
) -> None:
|
||||||
from app.models.config import AssignFilterRequest
|
from app.models.config import AssignFilterRequest
|
||||||
from app.services.config_file_service import FilterNameError, assign_filter_to_jail
|
from app.services.config_file_service import FilterNameError, assign_filter_to_jail
|
||||||
|
|
||||||
@@ -1700,26 +1719,34 @@ class TestBuildActionToJailsMap:
|
|||||||
def test_active_jail_maps_to_action(self) -> None:
|
def test_active_jail_maps_to_action(self) -> None:
|
||||||
from app.services.config_file_service import _build_action_to_jails_map
|
from app.services.config_file_service import _build_action_to_jails_map
|
||||||
|
|
||||||
result = _build_action_to_jails_map({"sshd": {"action": "iptables-multiport"}}, {"sshd"})
|
result = _build_action_to_jails_map(
|
||||||
|
{"sshd": {"action": "iptables-multiport"}}, {"sshd"}
|
||||||
|
)
|
||||||
assert result == {"iptables-multiport": ["sshd"]}
|
assert result == {"iptables-multiport": ["sshd"]}
|
||||||
|
|
||||||
def test_inactive_jail_not_included(self) -> None:
|
def test_inactive_jail_not_included(self) -> None:
|
||||||
from app.services.config_file_service import _build_action_to_jails_map
|
from app.services.config_file_service import _build_action_to_jails_map
|
||||||
|
|
||||||
result = _build_action_to_jails_map({"sshd": {"action": "iptables-multiport"}}, set())
|
result = _build_action_to_jails_map(
|
||||||
|
{"sshd": {"action": "iptables-multiport"}}, set()
|
||||||
|
)
|
||||||
assert result == {}
|
assert result == {}
|
||||||
|
|
||||||
def test_multiple_actions_per_jail(self) -> None:
|
def test_multiple_actions_per_jail(self) -> None:
|
||||||
from app.services.config_file_service import _build_action_to_jails_map
|
from app.services.config_file_service import _build_action_to_jails_map
|
||||||
|
|
||||||
result = _build_action_to_jails_map({"sshd": {"action": "iptables-multiport\niptables-ipset"}}, {"sshd"})
|
result = _build_action_to_jails_map(
|
||||||
|
{"sshd": {"action": "iptables-multiport\niptables-ipset"}}, {"sshd"}
|
||||||
|
)
|
||||||
assert "iptables-multiport" in result
|
assert "iptables-multiport" in result
|
||||||
assert "iptables-ipset" in result
|
assert "iptables-ipset" in result
|
||||||
|
|
||||||
def test_parameter_block_stripped(self) -> None:
|
def test_parameter_block_stripped(self) -> None:
|
||||||
from app.services.config_file_service import _build_action_to_jails_map
|
from app.services.config_file_service import _build_action_to_jails_map
|
||||||
|
|
||||||
result = _build_action_to_jails_map({"sshd": {"action": "iptables[port=ssh, protocol=tcp]"}}, {"sshd"})
|
result = _build_action_to_jails_map(
|
||||||
|
{"sshd": {"action": "iptables[port=ssh, protocol=tcp]"}}, {"sshd"}
|
||||||
|
)
|
||||||
assert "iptables" in result
|
assert "iptables" in result
|
||||||
|
|
||||||
def test_multiple_jails_sharing_action(self) -> None:
|
def test_multiple_jails_sharing_action(self) -> None:
|
||||||
@@ -1974,13 +2001,10 @@ class TestGetAction:
|
|||||||
async def test_raises_for_unknown_action(self, tmp_path: Path) -> None:
|
async def test_raises_for_unknown_action(self, tmp_path: Path) -> None:
|
||||||
from app.services.config_file_service import ActionNotFoundError, get_action
|
from app.services.config_file_service import ActionNotFoundError, get_action
|
||||||
|
|
||||||
with (
|
with patch(
|
||||||
patch(
|
"app.services.config_file_service._get_active_jail_names",
|
||||||
"app.services.config_file_service._get_active_jail_names",
|
new=AsyncMock(return_value=set()),
|
||||||
new=AsyncMock(return_value=set()),
|
), pytest.raises(ActionNotFoundError):
|
||||||
),
|
|
||||||
pytest.raises(ActionNotFoundError),
|
|
||||||
):
|
|
||||||
await get_action(str(tmp_path), "/fake.sock", "nonexistent")
|
await get_action(str(tmp_path), "/fake.sock", "nonexistent")
|
||||||
|
|
||||||
async def test_local_only_action_returned(self, tmp_path: Path) -> None:
|
async def test_local_only_action_returned(self, tmp_path: Path) -> None:
|
||||||
@@ -2094,13 +2118,10 @@ class TestUpdateAction:
|
|||||||
from app.models.config import ActionUpdateRequest
|
from app.models.config import ActionUpdateRequest
|
||||||
from app.services.config_file_service import ActionNotFoundError, update_action
|
from app.services.config_file_service import ActionNotFoundError, update_action
|
||||||
|
|
||||||
with (
|
with patch(
|
||||||
patch(
|
"app.services.config_file_service._get_active_jail_names",
|
||||||
"app.services.config_file_service._get_active_jail_names",
|
new=AsyncMock(return_value=set()),
|
||||||
new=AsyncMock(return_value=set()),
|
), pytest.raises(ActionNotFoundError):
|
||||||
),
|
|
||||||
pytest.raises(ActionNotFoundError),
|
|
||||||
):
|
|
||||||
await update_action(
|
await update_action(
|
||||||
str(tmp_path),
|
str(tmp_path),
|
||||||
"/fake.sock",
|
"/fake.sock",
|
||||||
@@ -2566,7 +2587,9 @@ class TestRemoveActionFromJail:
|
|||||||
"app.services.config_file_service._get_active_jail_names",
|
"app.services.config_file_service._get_active_jail_names",
|
||||||
new=AsyncMock(return_value=set()),
|
new=AsyncMock(return_value=set()),
|
||||||
):
|
):
|
||||||
await remove_action_from_jail(str(tmp_path), "/fake.sock", "sshd", "iptables-multiport")
|
await remove_action_from_jail(
|
||||||
|
str(tmp_path), "/fake.sock", "sshd", "iptables-multiport"
|
||||||
|
)
|
||||||
|
|
||||||
content = (jail_d / "sshd.local").read_text()
|
content = (jail_d / "sshd.local").read_text()
|
||||||
assert "iptables-multiport" not in content
|
assert "iptables-multiport" not in content
|
||||||
@@ -2578,13 +2601,17 @@ class TestRemoveActionFromJail:
|
|||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(JailNotFoundInConfigError):
|
with pytest.raises(JailNotFoundInConfigError):
|
||||||
await remove_action_from_jail(str(tmp_path), "/fake.sock", "nonexistent", "iptables")
|
await remove_action_from_jail(
|
||||||
|
str(tmp_path), "/fake.sock", "nonexistent", "iptables"
|
||||||
|
)
|
||||||
|
|
||||||
async def test_raises_jail_name_error(self, tmp_path: Path) -> None:
|
async def test_raises_jail_name_error(self, tmp_path: Path) -> None:
|
||||||
from app.services.config_file_service import JailNameError, remove_action_from_jail
|
from app.services.config_file_service import JailNameError, remove_action_from_jail
|
||||||
|
|
||||||
with pytest.raises(JailNameError):
|
with pytest.raises(JailNameError):
|
||||||
await remove_action_from_jail(str(tmp_path), "/fake.sock", "../evil", "iptables")
|
await remove_action_from_jail(
|
||||||
|
str(tmp_path), "/fake.sock", "../evil", "iptables"
|
||||||
|
)
|
||||||
|
|
||||||
async def test_raises_action_name_error(self, tmp_path: Path) -> None:
|
async def test_raises_action_name_error(self, tmp_path: Path) -> None:
|
||||||
from app.services.config_file_service import ActionNameError, remove_action_from_jail
|
from app.services.config_file_service import ActionNameError, remove_action_from_jail
|
||||||
@@ -2592,7 +2619,9 @@ class TestRemoveActionFromJail:
|
|||||||
_write(tmp_path / "jail.conf", JAIL_CONF)
|
_write(tmp_path / "jail.conf", JAIL_CONF)
|
||||||
|
|
||||||
with pytest.raises(ActionNameError):
|
with pytest.raises(ActionNameError):
|
||||||
await remove_action_from_jail(str(tmp_path), "/fake.sock", "sshd", "../evil")
|
await remove_action_from_jail(
|
||||||
|
str(tmp_path), "/fake.sock", "sshd", "../evil"
|
||||||
|
)
|
||||||
|
|
||||||
async def test_triggers_reload_when_requested(self, tmp_path: Path) -> None:
|
async def test_triggers_reload_when_requested(self, tmp_path: Path) -> None:
|
||||||
from app.services.config_file_service import remove_action_from_jail
|
from app.services.config_file_service import remove_action_from_jail
|
||||||
@@ -2611,7 +2640,9 @@ class TestRemoveActionFromJail:
|
|||||||
new=AsyncMock(),
|
new=AsyncMock(),
|
||||||
) as mock_reload,
|
) as mock_reload,
|
||||||
):
|
):
|
||||||
await remove_action_from_jail(str(tmp_path), "/fake.sock", "sshd", "iptables", do_reload=True)
|
await remove_action_from_jail(
|
||||||
|
str(tmp_path), "/fake.sock", "sshd", "iptables", do_reload=True
|
||||||
|
)
|
||||||
|
|
||||||
mock_reload.assert_awaited_once()
|
mock_reload.assert_awaited_once()
|
||||||
|
|
||||||
@@ -2649,9 +2680,13 @@ class TestActivateJailReloadArgs:
|
|||||||
mock_js.reload_all = AsyncMock()
|
mock_js.reload_all = AsyncMock()
|
||||||
await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req)
|
await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req)
|
||||||
|
|
||||||
mock_js.reload_all.assert_awaited_once_with("/fake.sock", include_jails=["apache-auth"])
|
mock_js.reload_all.assert_awaited_once_with(
|
||||||
|
"/fake.sock", include_jails=["apache-auth"]
|
||||||
|
)
|
||||||
|
|
||||||
async def test_activate_returns_active_true_when_jail_starts(self, tmp_path: Path) -> None:
|
async def test_activate_returns_active_true_when_jail_starts(
|
||||||
|
self, tmp_path: Path
|
||||||
|
) -> None:
|
||||||
"""activate_jail returns active=True when the jail appears in post-reload names."""
|
"""activate_jail returns active=True when the jail appears in post-reload names."""
|
||||||
_write(tmp_path / "jail.conf", JAIL_CONF)
|
_write(tmp_path / "jail.conf", JAIL_CONF)
|
||||||
from app.models.config import ActivateJailRequest, JailValidationResult
|
from app.models.config import ActivateJailRequest, JailValidationResult
|
||||||
@@ -2673,12 +2708,16 @@ class TestActivateJailReloadArgs:
|
|||||||
),
|
),
|
||||||
):
|
):
|
||||||
mock_js.reload_all = AsyncMock()
|
mock_js.reload_all = AsyncMock()
|
||||||
result = await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req)
|
result = await activate_jail(
|
||||||
|
str(tmp_path), "/fake.sock", "apache-auth", req
|
||||||
|
)
|
||||||
|
|
||||||
assert result.active is True
|
assert result.active is True
|
||||||
assert "activated" in result.message.lower()
|
assert "activated" in result.message.lower()
|
||||||
|
|
||||||
async def test_activate_returns_active_false_when_jail_does_not_start(self, tmp_path: Path) -> None:
|
async def test_activate_returns_active_false_when_jail_does_not_start(
|
||||||
|
self, tmp_path: Path
|
||||||
|
) -> None:
|
||||||
"""activate_jail returns active=False when the jail is absent after reload.
|
"""activate_jail returns active=False when the jail is absent after reload.
|
||||||
|
|
||||||
This covers the Stage 3.1 requirement: if the jail config is invalid
|
This covers the Stage 3.1 requirement: if the jail config is invalid
|
||||||
@@ -2707,7 +2746,9 @@ class TestActivateJailReloadArgs:
|
|||||||
),
|
),
|
||||||
):
|
):
|
||||||
mock_js.reload_all = AsyncMock()
|
mock_js.reload_all = AsyncMock()
|
||||||
result = await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req)
|
result = await activate_jail(
|
||||||
|
str(tmp_path), "/fake.sock", "apache-auth", req
|
||||||
|
)
|
||||||
|
|
||||||
assert result.active is False
|
assert result.active is False
|
||||||
assert "apache-auth" in result.name
|
assert "apache-auth" in result.name
|
||||||
@@ -2735,13 +2776,23 @@ class TestDeactivateJailReloadArgs:
|
|||||||
mock_js.reload_all = AsyncMock()
|
mock_js.reload_all = AsyncMock()
|
||||||
await deactivate_jail(str(tmp_path), "/fake.sock", "sshd")
|
await deactivate_jail(str(tmp_path), "/fake.sock", "sshd")
|
||||||
|
|
||||||
mock_js.reload_all.assert_awaited_once_with("/fake.sock", exclude_jails=["sshd"])
|
mock_js.reload_all.assert_awaited_once_with(
|
||||||
|
"/fake.sock", exclude_jails=["sshd"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# _validate_jail_config_sync (Task 3)
|
# _validate_jail_config_sync (Task 3)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
from app.services.config_file_service import ( # noqa: E402 (added after block)
|
||||||
|
_validate_jail_config_sync,
|
||||||
|
_extract_filter_base_name,
|
||||||
|
_extract_action_base_name,
|
||||||
|
validate_jail_config,
|
||||||
|
rollback_jail,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestExtractFilterBaseName:
|
class TestExtractFilterBaseName:
|
||||||
def test_plain_name(self) -> None:
|
def test_plain_name(self) -> None:
|
||||||
@@ -2887,11 +2938,11 @@ class TestRollbackJail:
|
|||||||
|
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
"app.services.config_file_service.start_daemon",
|
"app.services.config_file_service._start_daemon",
|
||||||
new=AsyncMock(return_value=True),
|
new=AsyncMock(return_value=True),
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"app.services.config_file_service.wait_for_fail2ban",
|
"app.services.config_file_service._wait_for_fail2ban",
|
||||||
new=AsyncMock(return_value=True),
|
new=AsyncMock(return_value=True),
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
@@ -2899,7 +2950,9 @@ class TestRollbackJail:
|
|||||||
new=AsyncMock(return_value=set()),
|
new=AsyncMock(return_value=set()),
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
result = await rollback_jail(str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"])
|
result = await rollback_jail(
|
||||||
|
str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"]
|
||||||
|
)
|
||||||
|
|
||||||
assert result.disabled is True
|
assert result.disabled is True
|
||||||
assert result.fail2ban_running is True
|
assert result.fail2ban_running is True
|
||||||
@@ -2915,22 +2968,26 @@ class TestRollbackJail:
|
|||||||
|
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
"app.services.config_file_service.start_daemon",
|
"app.services.config_file_service._start_daemon",
|
||||||
new=AsyncMock(return_value=False),
|
new=AsyncMock(return_value=False),
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"app.services.config_file_service.wait_for_fail2ban",
|
"app.services.config_file_service._wait_for_fail2ban",
|
||||||
new=AsyncMock(return_value=False),
|
new=AsyncMock(return_value=False),
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
result = await rollback_jail(str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"])
|
result = await rollback_jail(
|
||||||
|
str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"]
|
||||||
|
)
|
||||||
|
|
||||||
assert result.fail2ban_running is False
|
assert result.fail2ban_running is False
|
||||||
assert result.disabled is True
|
assert result.disabled is True
|
||||||
|
|
||||||
async def test_rollback_rejects_unsafe_name(self, tmp_path: Path) -> None:
|
async def test_rollback_rejects_unsafe_name(self, tmp_path: Path) -> None:
|
||||||
with pytest.raises(JailNameError):
|
with pytest.raises(JailNameError):
|
||||||
await rollback_jail(str(tmp_path), "/fake.sock", "../evil", ["fail2ban-client", "start"])
|
await rollback_jail(
|
||||||
|
str(tmp_path), "/fake.sock", "../evil", ["fail2ban-client", "start"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -3039,7 +3096,9 @@ class TestActivateJailBlocking:
|
|||||||
class TestActivateJailRollback:
|
class TestActivateJailRollback:
|
||||||
"""Rollback logic in activate_jail restores the .local file and recovers."""
|
"""Rollback logic in activate_jail restores the .local file and recovers."""
|
||||||
|
|
||||||
async def test_activate_jail_rollback_on_reload_failure(self, tmp_path: Path) -> None:
|
async def test_activate_jail_rollback_on_reload_failure(
|
||||||
|
self, tmp_path: Path
|
||||||
|
) -> None:
|
||||||
"""Rollback when reload_all raises on the activation reload.
|
"""Rollback when reload_all raises on the activation reload.
|
||||||
|
|
||||||
Expects:
|
Expects:
|
||||||
@@ -3076,17 +3135,23 @@ class TestActivateJailRollback:
|
|||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"app.services.config_file_service._validate_jail_config_sync",
|
"app.services.config_file_service._validate_jail_config_sync",
|
||||||
return_value=JailValidationResult(jail_name="apache-auth", valid=True),
|
return_value=JailValidationResult(
|
||||||
|
jail_name="apache-auth", valid=True
|
||||||
|
),
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
mock_js.reload_all = AsyncMock(side_effect=reload_side_effect)
|
mock_js.reload_all = AsyncMock(side_effect=reload_side_effect)
|
||||||
result = await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req)
|
result = await activate_jail(
|
||||||
|
str(tmp_path), "/fake.sock", "apache-auth", req
|
||||||
|
)
|
||||||
|
|
||||||
assert result.active is False
|
assert result.active is False
|
||||||
assert result.recovered is True
|
assert result.recovered is True
|
||||||
assert local_path.read_text() == original_local
|
assert local_path.read_text() == original_local
|
||||||
|
|
||||||
async def test_activate_jail_rollback_on_health_check_failure(self, tmp_path: Path) -> None:
|
async def test_activate_jail_rollback_on_health_check_failure(
|
||||||
|
self, tmp_path: Path
|
||||||
|
) -> None:
|
||||||
"""Rollback when fail2ban is unreachable after the activation reload.
|
"""Rollback when fail2ban is unreachable after the activation reload.
|
||||||
|
|
||||||
Expects:
|
Expects:
|
||||||
@@ -3125,11 +3190,15 @@ class TestActivateJailRollback:
|
|||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"app.services.config_file_service._validate_jail_config_sync",
|
"app.services.config_file_service._validate_jail_config_sync",
|
||||||
return_value=JailValidationResult(jail_name="apache-auth", valid=True),
|
return_value=JailValidationResult(
|
||||||
|
jail_name="apache-auth", valid=True
|
||||||
|
),
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
mock_js.reload_all = AsyncMock()
|
mock_js.reload_all = AsyncMock()
|
||||||
result = await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req)
|
result = await activate_jail(
|
||||||
|
str(tmp_path), "/fake.sock", "apache-auth", req
|
||||||
|
)
|
||||||
|
|
||||||
assert result.active is False
|
assert result.active is False
|
||||||
assert result.recovered is True
|
assert result.recovered is True
|
||||||
@@ -3163,17 +3232,25 @@ class TestActivateJailRollback:
|
|||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"app.services.config_file_service._validate_jail_config_sync",
|
"app.services.config_file_service._validate_jail_config_sync",
|
||||||
return_value=JailValidationResult(jail_name="apache-auth", valid=True),
|
return_value=JailValidationResult(
|
||||||
|
jail_name="apache-auth", valid=True
|
||||||
|
),
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
# Both the activation reload and the recovery reload fail.
|
# Both the activation reload and the recovery reload fail.
|
||||||
mock_js.reload_all = AsyncMock(side_effect=RuntimeError("fail2ban unavailable"))
|
mock_js.reload_all = AsyncMock(
|
||||||
result = await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req)
|
side_effect=RuntimeError("fail2ban unavailable")
|
||||||
|
)
|
||||||
|
result = await activate_jail(
|
||||||
|
str(tmp_path), "/fake.sock", "apache-auth", req
|
||||||
|
)
|
||||||
|
|
||||||
assert result.active is False
|
assert result.active is False
|
||||||
assert result.recovered is False
|
assert result.recovered is False
|
||||||
|
|
||||||
async def test_activate_jail_rollback_on_jail_not_found_error(self, tmp_path: Path) -> None:
|
async def test_activate_jail_rollback_on_jail_not_found_error(
|
||||||
|
self, tmp_path: Path
|
||||||
|
) -> None:
|
||||||
"""Rollback when reload_all raises JailNotFoundError (invalid config).
|
"""Rollback when reload_all raises JailNotFoundError (invalid config).
|
||||||
|
|
||||||
When fail2ban cannot create a jail due to invalid configuration
|
When fail2ban cannot create a jail due to invalid configuration
|
||||||
@@ -3217,12 +3294,16 @@ class TestActivateJailRollback:
|
|||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"app.services.config_file_service._validate_jail_config_sync",
|
"app.services.config_file_service._validate_jail_config_sync",
|
||||||
return_value=JailValidationResult(jail_name="apache-auth", valid=True),
|
return_value=JailValidationResult(
|
||||||
|
jail_name="apache-auth", valid=True
|
||||||
|
),
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
mock_js.reload_all = AsyncMock(side_effect=reload_side_effect)
|
mock_js.reload_all = AsyncMock(side_effect=reload_side_effect)
|
||||||
mock_js.JailNotFoundError = JailNotFoundError
|
mock_js.JailNotFoundError = JailNotFoundError
|
||||||
result = await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req)
|
result = await activate_jail(
|
||||||
|
str(tmp_path), "/fake.sock", "apache-auth", req
|
||||||
|
)
|
||||||
|
|
||||||
assert result.active is False
|
assert result.active is False
|
||||||
assert result.recovered is True
|
assert result.recovered is True
|
||||||
@@ -3230,7 +3311,9 @@ class TestActivateJailRollback:
|
|||||||
# Verify the error message mentions logpath issues.
|
# Verify the error message mentions logpath issues.
|
||||||
assert "logpath" in result.message.lower() or "check that all logpath" in result.message.lower()
|
assert "logpath" in result.message.lower() or "check that all logpath" in result.message.lower()
|
||||||
|
|
||||||
async def test_activate_jail_rollback_deletes_file_when_no_prior_local(self, tmp_path: Path) -> None:
|
async def test_activate_jail_rollback_deletes_file_when_no_prior_local(
|
||||||
|
self, tmp_path: Path
|
||||||
|
) -> None:
|
||||||
"""Rollback deletes the .local file when none existed before activation.
|
"""Rollback deletes the .local file when none existed before activation.
|
||||||
|
|
||||||
When a jail had no .local override before activation, activate_jail
|
When a jail had no .local override before activation, activate_jail
|
||||||
@@ -3272,11 +3355,15 @@ class TestActivateJailRollback:
|
|||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"app.services.config_file_service._validate_jail_config_sync",
|
"app.services.config_file_service._validate_jail_config_sync",
|
||||||
return_value=JailValidationResult(jail_name="apache-auth", valid=True),
|
return_value=JailValidationResult(
|
||||||
|
jail_name="apache-auth", valid=True
|
||||||
|
),
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
mock_js.reload_all = AsyncMock(side_effect=reload_side_effect)
|
mock_js.reload_all = AsyncMock(side_effect=reload_side_effect)
|
||||||
result = await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req)
|
result = await activate_jail(
|
||||||
|
str(tmp_path), "/fake.sock", "apache-auth", req
|
||||||
|
)
|
||||||
|
|
||||||
assert result.active is False
|
assert result.active is False
|
||||||
assert result.recovered is True
|
assert result.recovered is True
|
||||||
@@ -3289,7 +3376,7 @@ class TestActivateJailRollback:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
class TestRollbackJailIntegration:
|
class TestRollbackJail:
|
||||||
"""Integration tests for :func:`~app.services.config_file_service.rollback_jail`."""
|
"""Integration tests for :func:`~app.services.config_file_service.rollback_jail`."""
|
||||||
|
|
||||||
async def test_local_file_written_enabled_false(self, tmp_path: Path) -> None:
|
async def test_local_file_written_enabled_false(self, tmp_path: Path) -> None:
|
||||||
@@ -3332,11 +3419,15 @@ class TestRollbackJailIntegration:
|
|||||||
AsyncMock(return_value={"other"}),
|
AsyncMock(return_value={"other"}),
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
await rollback_jail(str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"])
|
await rollback_jail(
|
||||||
|
str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"]
|
||||||
|
)
|
||||||
|
|
||||||
mock_start.assert_awaited_once_with(["fail2ban-client", "start"])
|
mock_start.assert_awaited_once_with(["fail2ban-client", "start"])
|
||||||
|
|
||||||
async def test_fail2ban_running_reflects_socket_probe_not_subprocess_exit(self, tmp_path: Path) -> None:
|
async def test_fail2ban_running_reflects_socket_probe_not_subprocess_exit(
|
||||||
|
self, tmp_path: Path
|
||||||
|
) -> None:
|
||||||
"""fail2ban_running in the response reflects the socket probe result.
|
"""fail2ban_running in the response reflects the socket probe result.
|
||||||
|
|
||||||
Even when start_daemon returns True (subprocess exit 0), if the socket
|
Even when start_daemon returns True (subprocess exit 0), if the socket
|
||||||
@@ -3352,11 +3443,15 @@ class TestRollbackJailIntegration:
|
|||||||
AsyncMock(return_value=False), # socket still unresponsive
|
AsyncMock(return_value=False), # socket still unresponsive
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
result = await rollback_jail(str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"])
|
result = await rollback_jail(
|
||||||
|
str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"]
|
||||||
|
)
|
||||||
|
|
||||||
assert result.fail2ban_running is False
|
assert result.fail2ban_running is False
|
||||||
|
|
||||||
async def test_active_jails_zero_when_fail2ban_not_running(self, tmp_path: Path) -> None:
|
async def test_active_jails_zero_when_fail2ban_not_running(
|
||||||
|
self, tmp_path: Path
|
||||||
|
) -> None:
|
||||||
"""active_jails is 0 in the response when fail2ban_running is False."""
|
"""active_jails is 0 in the response when fail2ban_running is False."""
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
@@ -3368,11 +3463,15 @@ class TestRollbackJailIntegration:
|
|||||||
AsyncMock(return_value=False),
|
AsyncMock(return_value=False),
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
result = await rollback_jail(str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"])
|
result = await rollback_jail(
|
||||||
|
str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"]
|
||||||
|
)
|
||||||
|
|
||||||
assert result.active_jails == 0
|
assert result.active_jails == 0
|
||||||
|
|
||||||
async def test_active_jails_count_from_socket_when_running(self, tmp_path: Path) -> None:
|
async def test_active_jails_count_from_socket_when_running(
|
||||||
|
self, tmp_path: Path
|
||||||
|
) -> None:
|
||||||
"""active_jails reflects the actual jail count from the socket when fail2ban is up."""
|
"""active_jails reflects the actual jail count from the socket when fail2ban is up."""
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
@@ -3388,11 +3487,15 @@ class TestRollbackJailIntegration:
|
|||||||
AsyncMock(return_value={"sshd", "nginx", "apache-auth"}),
|
AsyncMock(return_value={"sshd", "nginx", "apache-auth"}),
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
result = await rollback_jail(str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"])
|
result = await rollback_jail(
|
||||||
|
str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"]
|
||||||
|
)
|
||||||
|
|
||||||
assert result.active_jails == 3
|
assert result.active_jails == 3
|
||||||
|
|
||||||
async def test_fail2ban_down_at_start_still_succeeds_file_write(self, tmp_path: Path) -> None:
|
async def test_fail2ban_down_at_start_still_succeeds_file_write(
|
||||||
|
self, tmp_path: Path
|
||||||
|
) -> None:
|
||||||
"""rollback_jail writes the local file even when fail2ban is down at call time."""
|
"""rollback_jail writes the local file even when fail2ban is down at call time."""
|
||||||
# fail2ban is down: start_daemon fails and wait_for_fail2ban returns False.
|
# fail2ban is down: start_daemon fails and wait_for_fail2ban returns False.
|
||||||
with (
|
with (
|
||||||
@@ -3405,9 +3508,12 @@ class TestRollbackJailIntegration:
|
|||||||
AsyncMock(return_value=False),
|
AsyncMock(return_value=False),
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
result = await rollback_jail(str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"])
|
result = await rollback_jail(
|
||||||
|
str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"]
|
||||||
|
)
|
||||||
|
|
||||||
local = tmp_path / "jail.d" / "sshd.local"
|
local = tmp_path / "jail.d" / "sshd.local"
|
||||||
assert local.is_file(), "local file must be written even when fail2ban is down"
|
assert local.is_file(), "local file must be written even when fail2ban is down"
|
||||||
assert result.disabled is True
|
assert result.disabled is True
|
||||||
assert result.fail2ban_running is False
|
assert result.fail2ban_running is False
|
||||||
|
|
||||||
|
|||||||
@@ -256,6 +256,27 @@ class TestUpdateJailConfig:
|
|||||||
assert "bantime" in keys
|
assert "bantime" in keys
|
||||||
assert "maxretry" in keys
|
assert "maxretry" in keys
|
||||||
|
|
||||||
|
async def test_ignores_backend_field(self) -> None:
|
||||||
|
"""update_jail_config does not send a set command for backend."""
|
||||||
|
sent_commands: list[list[Any]] = []
|
||||||
|
|
||||||
|
async def _send(command: list[Any]) -> Any:
|
||||||
|
sent_commands.append(command)
|
||||||
|
return (0, "OK")
|
||||||
|
|
||||||
|
class _FakeClient:
|
||||||
|
def __init__(self, **_kw: Any) -> None:
|
||||||
|
self.send = AsyncMock(side_effect=_send)
|
||||||
|
|
||||||
|
from app.models.config import JailConfigUpdate
|
||||||
|
|
||||||
|
update = JailConfigUpdate(backend="polling")
|
||||||
|
with patch("app.services.config_service.Fail2BanClient", _FakeClient):
|
||||||
|
await config_service.update_jail_config(_SOCKET, "sshd", update)
|
||||||
|
|
||||||
|
keys = [cmd[2] for cmd in sent_commands if len(cmd) >= 3 and cmd[0] == "set"]
|
||||||
|
assert "backend" not in keys
|
||||||
|
|
||||||
async def test_raises_validation_error_on_bad_regex(self) -> None:
|
async def test_raises_validation_error_on_bad_regex(self) -> None:
|
||||||
"""update_jail_config raises ConfigValidationError for invalid regex."""
|
"""update_jail_config raises ConfigValidationError for invalid regex."""
|
||||||
from app.models.config import JailConfigUpdate
|
from app.models.config import JailConfigUpdate
|
||||||
@@ -721,11 +742,9 @@ class TestGetServiceStatus:
|
|||||||
def __init__(self, **_kw: Any) -> None:
|
def __init__(self, **_kw: Any) -> None:
|
||||||
self.send = AsyncMock(side_effect=_send)
|
self.send = AsyncMock(side_effect=_send)
|
||||||
|
|
||||||
with patch("app.services.config_service.Fail2BanClient", _FakeClient):
|
with patch("app.services.config_service.Fail2BanClient", _FakeClient), \
|
||||||
result = await config_service.get_service_status(
|
patch("app.services.health_service.probe", AsyncMock(return_value=online_status)):
|
||||||
_SOCKET,
|
result = await config_service.get_service_status(_SOCKET)
|
||||||
probe_fn=AsyncMock(return_value=online_status),
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.online is True
|
assert result.online is True
|
||||||
assert result.version == "1.0.0"
|
assert result.version == "1.0.0"
|
||||||
@@ -741,10 +760,8 @@ class TestGetServiceStatus:
|
|||||||
|
|
||||||
offline_status = ServerStatus(online=False)
|
offline_status = ServerStatus(online=False)
|
||||||
|
|
||||||
result = await config_service.get_service_status(
|
with patch("app.services.health_service.probe", AsyncMock(return_value=offline_status)):
|
||||||
_SOCKET,
|
result = await config_service.get_service_status(_SOCKET)
|
||||||
probe_fn=AsyncMock(return_value=offline_status),
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.online is False
|
assert result.online is False
|
||||||
assert result.jail_count == 0
|
assert result.jail_count == 0
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import pytest
|
|||||||
|
|
||||||
from app.models.config import ActionConfigUpdate, FilterConfigUpdate, JailFileConfigUpdate
|
from app.models.config import ActionConfigUpdate, FilterConfigUpdate, JailFileConfigUpdate
|
||||||
from app.models.file_config import ConfFileCreateRequest, ConfFileUpdateRequest
|
from app.models.file_config import ConfFileCreateRequest, ConfFileUpdateRequest
|
||||||
from app.services.raw_config_io_service import (
|
from app.services.file_config_service import (
|
||||||
ConfigDirError,
|
ConfigDirError,
|
||||||
ConfigFileExistsError,
|
ConfigFileExistsError,
|
||||||
ConfigFileNameError,
|
ConfigFileNameError,
|
||||||
|
|||||||
@@ -2,13 +2,12 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Mapping, Sequence
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from app.models.geo import GeoInfo
|
|
||||||
from app.services import geo_service
|
from app.services import geo_service
|
||||||
|
from app.services.geo_service import GeoInfo
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Helpers
|
# Helpers
|
||||||
@@ -45,7 +44,7 @@ def _make_session(response_json: dict[str, object], status: int = 200) -> MagicM
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def clear_geo_cache() -> None:
|
def clear_geo_cache() -> None: # type: ignore[misc]
|
||||||
"""Flush the module-level geo cache before every test."""
|
"""Flush the module-level geo cache before every test."""
|
||||||
geo_service.clear_cache()
|
geo_service.clear_cache()
|
||||||
|
|
||||||
@@ -69,7 +68,7 @@ class TestLookupSuccess:
|
|||||||
"org": "AS3320 Deutsche Telekom AG",
|
"org": "AS3320 Deutsche Telekom AG",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
result = await geo_service.lookup("1.2.3.4", session)
|
result = await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result.country_code == "DE"
|
assert result.country_code == "DE"
|
||||||
@@ -85,7 +84,7 @@ class TestLookupSuccess:
|
|||||||
"org": "Google LLC",
|
"org": "Google LLC",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
result = await geo_service.lookup("8.8.8.8", session)
|
result = await geo_service.lookup("8.8.8.8", session) # type: ignore[arg-type]
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result.country_name == "United States"
|
assert result.country_name == "United States"
|
||||||
@@ -101,7 +100,7 @@ class TestLookupSuccess:
|
|||||||
"org": "Deutsche Telekom",
|
"org": "Deutsche Telekom",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
result = await geo_service.lookup("1.2.3.4", session)
|
result = await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result.asn == "AS3320"
|
assert result.asn == "AS3320"
|
||||||
@@ -117,7 +116,7 @@ class TestLookupSuccess:
|
|||||||
"org": "Google LLC",
|
"org": "Google LLC",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
result = await geo_service.lookup("8.8.8.8", session)
|
result = await geo_service.lookup("8.8.8.8", session) # type: ignore[arg-type]
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result.org == "Google LLC"
|
assert result.org == "Google LLC"
|
||||||
@@ -143,8 +142,8 @@ class TestLookupCaching:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
await geo_service.lookup("1.2.3.4", session)
|
await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
|
||||||
await geo_service.lookup("1.2.3.4", session)
|
await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
|
||||||
|
|
||||||
# The session.get() should only have been called once.
|
# The session.get() should only have been called once.
|
||||||
assert session.get.call_count == 1
|
assert session.get.call_count == 1
|
||||||
@@ -161,9 +160,9 @@ class TestLookupCaching:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
await geo_service.lookup("2.3.4.5", session)
|
await geo_service.lookup("2.3.4.5", session) # type: ignore[arg-type]
|
||||||
geo_service.clear_cache()
|
geo_service.clear_cache()
|
||||||
await geo_service.lookup("2.3.4.5", session)
|
await geo_service.lookup("2.3.4.5", session) # type: ignore[arg-type]
|
||||||
|
|
||||||
assert session.get.call_count == 2
|
assert session.get.call_count == 2
|
||||||
|
|
||||||
@@ -173,8 +172,8 @@ class TestLookupCaching:
|
|||||||
{"status": "fail", "message": "reserved range"}
|
{"status": "fail", "message": "reserved range"}
|
||||||
)
|
)
|
||||||
|
|
||||||
await geo_service.lookup("192.168.1.1", session)
|
await geo_service.lookup("192.168.1.1", session) # type: ignore[arg-type]
|
||||||
await geo_service.lookup("192.168.1.1", session)
|
await geo_service.lookup("192.168.1.1", session) # type: ignore[arg-type]
|
||||||
|
|
||||||
# Second call is blocked by the negative cache — only one API hit.
|
# Second call is blocked by the negative cache — only one API hit.
|
||||||
assert session.get.call_count == 1
|
assert session.get.call_count == 1
|
||||||
@@ -191,7 +190,7 @@ class TestLookupFailures:
|
|||||||
async def test_non_200_response_returns_null_geo_info(self) -> None:
|
async def test_non_200_response_returns_null_geo_info(self) -> None:
|
||||||
"""A 429 or 500 status returns GeoInfo with null fields (not None)."""
|
"""A 429 or 500 status returns GeoInfo with null fields (not None)."""
|
||||||
session = _make_session({}, status=429)
|
session = _make_session({}, status=429)
|
||||||
result = await geo_service.lookup("1.2.3.4", session)
|
result = await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert isinstance(result, GeoInfo)
|
assert isinstance(result, GeoInfo)
|
||||||
assert result.country_code is None
|
assert result.country_code is None
|
||||||
@@ -204,7 +203,7 @@ class TestLookupFailures:
|
|||||||
mock_ctx.__aexit__ = AsyncMock(return_value=False)
|
mock_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||||
session.get = MagicMock(return_value=mock_ctx)
|
session.get = MagicMock(return_value=mock_ctx)
|
||||||
|
|
||||||
result = await geo_service.lookup("10.0.0.1", session)
|
result = await geo_service.lookup("10.0.0.1", session) # type: ignore[arg-type]
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert isinstance(result, GeoInfo)
|
assert isinstance(result, GeoInfo)
|
||||||
assert result.country_code is None
|
assert result.country_code is None
|
||||||
@@ -212,7 +211,7 @@ class TestLookupFailures:
|
|||||||
async def test_failed_status_returns_geo_info_with_nulls(self) -> None:
|
async def test_failed_status_returns_geo_info_with_nulls(self) -> None:
|
||||||
"""When ip-api returns ``status=fail`` a GeoInfo with null fields is returned (but not cached)."""
|
"""When ip-api returns ``status=fail`` a GeoInfo with null fields is returned (but not cached)."""
|
||||||
session = _make_session({"status": "fail", "message": "private range"})
|
session = _make_session({"status": "fail", "message": "private range"})
|
||||||
result = await geo_service.lookup("10.0.0.1", session)
|
result = await geo_service.lookup("10.0.0.1", session) # type: ignore[arg-type]
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert isinstance(result, GeoInfo)
|
assert isinstance(result, GeoInfo)
|
||||||
@@ -232,8 +231,8 @@ class TestNegativeCache:
|
|||||||
"""After a failed lookup the second call is served from the neg cache."""
|
"""After a failed lookup the second call is served from the neg cache."""
|
||||||
session = _make_session({"status": "fail", "message": "private range"})
|
session = _make_session({"status": "fail", "message": "private range"})
|
||||||
|
|
||||||
r1 = await geo_service.lookup("192.0.2.1", session)
|
r1 = await geo_service.lookup("192.0.2.1", session) # type: ignore[arg-type]
|
||||||
r2 = await geo_service.lookup("192.0.2.1", session)
|
r2 = await geo_service.lookup("192.0.2.1", session) # type: ignore[arg-type]
|
||||||
|
|
||||||
# Only one HTTP call should have been made; second served from neg cache.
|
# Only one HTTP call should have been made; second served from neg cache.
|
||||||
assert session.get.call_count == 1
|
assert session.get.call_count == 1
|
||||||
@@ -244,12 +243,12 @@ class TestNegativeCache:
|
|||||||
"""When the neg-cache entry is older than the TTL a new API call is made."""
|
"""When the neg-cache entry is older than the TTL a new API call is made."""
|
||||||
session = _make_session({"status": "fail", "message": "private range"})
|
session = _make_session({"status": "fail", "message": "private range"})
|
||||||
|
|
||||||
await geo_service.lookup("192.0.2.2", session)
|
await geo_service.lookup("192.0.2.2", session) # type: ignore[arg-type]
|
||||||
|
|
||||||
# Manually expire the neg-cache entry.
|
# Manually expire the neg-cache entry.
|
||||||
geo_service._neg_cache["192.0.2.2"] -= geo_service._NEG_CACHE_TTL + 1
|
geo_service._neg_cache["192.0.2.2"] -= geo_service._NEG_CACHE_TTL + 1 # type: ignore[attr-defined]
|
||||||
|
|
||||||
await geo_service.lookup("192.0.2.2", session)
|
await geo_service.lookup("192.0.2.2", session) # type: ignore[arg-type]
|
||||||
|
|
||||||
# Both calls should have hit the API.
|
# Both calls should have hit the API.
|
||||||
assert session.get.call_count == 2
|
assert session.get.call_count == 2
|
||||||
@@ -258,9 +257,9 @@ class TestNegativeCache:
|
|||||||
"""After clearing the neg cache the IP is eligible for a new API call."""
|
"""After clearing the neg cache the IP is eligible for a new API call."""
|
||||||
session = _make_session({"status": "fail", "message": "private range"})
|
session = _make_session({"status": "fail", "message": "private range"})
|
||||||
|
|
||||||
await geo_service.lookup("192.0.2.3", session)
|
await geo_service.lookup("192.0.2.3", session) # type: ignore[arg-type]
|
||||||
geo_service.clear_neg_cache()
|
geo_service.clear_neg_cache()
|
||||||
await geo_service.lookup("192.0.2.3", session)
|
await geo_service.lookup("192.0.2.3", session) # type: ignore[arg-type]
|
||||||
|
|
||||||
assert session.get.call_count == 2
|
assert session.get.call_count == 2
|
||||||
|
|
||||||
@@ -276,9 +275,9 @@ class TestNegativeCache:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
await geo_service.lookup("1.2.3.4", session)
|
await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
|
||||||
|
|
||||||
assert "1.2.3.4" not in geo_service._neg_cache
|
assert "1.2.3.4" not in geo_service._neg_cache # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -308,7 +307,7 @@ class TestGeoipFallback:
|
|||||||
mock_reader = self._make_geoip_reader("DE", "Germany")
|
mock_reader = self._make_geoip_reader("DE", "Germany")
|
||||||
|
|
||||||
with patch.object(geo_service, "_geoip_reader", mock_reader):
|
with patch.object(geo_service, "_geoip_reader", mock_reader):
|
||||||
result = await geo_service.lookup("1.2.3.4", session)
|
result = await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
|
||||||
|
|
||||||
mock_reader.country.assert_called_once_with("1.2.3.4")
|
mock_reader.country.assert_called_once_with("1.2.3.4")
|
||||||
assert result is not None
|
assert result is not None
|
||||||
@@ -321,12 +320,12 @@ class TestGeoipFallback:
|
|||||||
mock_reader = self._make_geoip_reader("US", "United States")
|
mock_reader = self._make_geoip_reader("US", "United States")
|
||||||
|
|
||||||
with patch.object(geo_service, "_geoip_reader", mock_reader):
|
with patch.object(geo_service, "_geoip_reader", mock_reader):
|
||||||
await geo_service.lookup("8.8.8.8", session)
|
await geo_service.lookup("8.8.8.8", session) # type: ignore[arg-type]
|
||||||
# Second call must be served from positive cache without hitting API.
|
# Second call must be served from positive cache without hitting API.
|
||||||
await geo_service.lookup("8.8.8.8", session)
|
await geo_service.lookup("8.8.8.8", session) # type: ignore[arg-type]
|
||||||
|
|
||||||
assert session.get.call_count == 1
|
assert session.get.call_count == 1
|
||||||
assert "8.8.8.8" in geo_service._cache
|
assert "8.8.8.8" in geo_service._cache # type: ignore[attr-defined]
|
||||||
|
|
||||||
async def test_geoip_fallback_not_called_on_api_success(self) -> None:
|
async def test_geoip_fallback_not_called_on_api_success(self) -> None:
|
||||||
"""When ip-api succeeds, the geoip2 reader must not be consulted."""
|
"""When ip-api succeeds, the geoip2 reader must not be consulted."""
|
||||||
@@ -342,7 +341,7 @@ class TestGeoipFallback:
|
|||||||
mock_reader = self._make_geoip_reader("XX", "Nowhere")
|
mock_reader = self._make_geoip_reader("XX", "Nowhere")
|
||||||
|
|
||||||
with patch.object(geo_service, "_geoip_reader", mock_reader):
|
with patch.object(geo_service, "_geoip_reader", mock_reader):
|
||||||
result = await geo_service.lookup("1.2.3.4", session)
|
result = await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
|
||||||
|
|
||||||
mock_reader.country.assert_not_called()
|
mock_reader.country.assert_not_called()
|
||||||
assert result is not None
|
assert result is not None
|
||||||
@@ -353,7 +352,7 @@ class TestGeoipFallback:
|
|||||||
session = _make_session({"status": "fail", "message": "private range"})
|
session = _make_session({"status": "fail", "message": "private range"})
|
||||||
|
|
||||||
with patch.object(geo_service, "_geoip_reader", None):
|
with patch.object(geo_service, "_geoip_reader", None):
|
||||||
result = await geo_service.lookup("10.0.0.1", session)
|
result = await geo_service.lookup("10.0.0.1", session) # type: ignore[arg-type]
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result.country_code is None
|
assert result.country_code is None
|
||||||
@@ -364,7 +363,7 @@ class TestGeoipFallback:
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def _make_batch_session(batch_response: Sequence[Mapping[str, object]]) -> MagicMock:
|
def _make_batch_session(batch_response: list[dict[str, object]]) -> MagicMock:
|
||||||
"""Build a mock aiohttp.ClientSession for batch POST calls.
|
"""Build a mock aiohttp.ClientSession for batch POST calls.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -413,7 +412,7 @@ class TestLookupBatchSingleCommit:
|
|||||||
session = _make_batch_session(batch_response)
|
session = _make_batch_session(batch_response)
|
||||||
db = _make_async_db()
|
db = _make_async_db()
|
||||||
|
|
||||||
await geo_service.lookup_batch(ips, session, db=db)
|
await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type]
|
||||||
|
|
||||||
db.commit.assert_awaited_once()
|
db.commit.assert_awaited_once()
|
||||||
|
|
||||||
@@ -427,7 +426,7 @@ class TestLookupBatchSingleCommit:
|
|||||||
session = _make_batch_session(batch_response)
|
session = _make_batch_session(batch_response)
|
||||||
db = _make_async_db()
|
db = _make_async_db()
|
||||||
|
|
||||||
await geo_service.lookup_batch(ips, session, db=db)
|
await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type]
|
||||||
|
|
||||||
db.commit.assert_awaited_once()
|
db.commit.assert_awaited_once()
|
||||||
|
|
||||||
@@ -453,13 +452,13 @@ class TestLookupBatchSingleCommit:
|
|||||||
|
|
||||||
async def test_no_commit_for_all_cached_ips(self) -> None:
|
async def test_no_commit_for_all_cached_ips(self) -> None:
|
||||||
"""When all IPs are already cached, no HTTP call and no commit occur."""
|
"""When all IPs are already cached, no HTTP call and no commit occur."""
|
||||||
geo_service._cache["5.5.5.5"] = GeoInfo(
|
geo_service._cache["5.5.5.5"] = GeoInfo( # type: ignore[attr-defined]
|
||||||
country_code="FR", country_name="France", asn="AS1", org="ISP"
|
country_code="FR", country_name="France", asn="AS1", org="ISP"
|
||||||
)
|
)
|
||||||
db = _make_async_db()
|
db = _make_async_db()
|
||||||
session = _make_batch_session([])
|
session = _make_batch_session([])
|
||||||
|
|
||||||
result = await geo_service.lookup_batch(["5.5.5.5"], session, db=db)
|
result = await geo_service.lookup_batch(["5.5.5.5"], session, db=db) # type: ignore[arg-type]
|
||||||
|
|
||||||
assert result["5.5.5.5"].country_code == "FR"
|
assert result["5.5.5.5"].country_code == "FR"
|
||||||
db.commit.assert_not_awaited()
|
db.commit.assert_not_awaited()
|
||||||
@@ -477,26 +476,26 @@ class TestDirtySetTracking:
|
|||||||
def test_successful_resolution_adds_to_dirty(self) -> None:
|
def test_successful_resolution_adds_to_dirty(self) -> None:
|
||||||
"""Storing a GeoInfo with a country_code adds the IP to _dirty."""
|
"""Storing a GeoInfo with a country_code adds the IP to _dirty."""
|
||||||
info = GeoInfo(country_code="DE", country_name="Germany", asn="AS1", org="ISP")
|
info = GeoInfo(country_code="DE", country_name="Germany", asn="AS1", org="ISP")
|
||||||
geo_service._store("1.2.3.4", info)
|
geo_service._store("1.2.3.4", info) # type: ignore[attr-defined]
|
||||||
|
|
||||||
assert "1.2.3.4" in geo_service._dirty
|
assert "1.2.3.4" in geo_service._dirty # type: ignore[attr-defined]
|
||||||
|
|
||||||
def test_null_country_does_not_add_to_dirty(self) -> None:
|
def test_null_country_does_not_add_to_dirty(self) -> None:
|
||||||
"""Storing a GeoInfo with country_code=None must not pollute _dirty."""
|
"""Storing a GeoInfo with country_code=None must not pollute _dirty."""
|
||||||
info = GeoInfo(country_code=None, country_name=None, asn=None, org=None)
|
info = GeoInfo(country_code=None, country_name=None, asn=None, org=None)
|
||||||
geo_service._store("10.0.0.1", info)
|
geo_service._store("10.0.0.1", info) # type: ignore[attr-defined]
|
||||||
|
|
||||||
assert "10.0.0.1" not in geo_service._dirty
|
assert "10.0.0.1" not in geo_service._dirty # type: ignore[attr-defined]
|
||||||
|
|
||||||
def test_clear_cache_also_clears_dirty(self) -> None:
|
def test_clear_cache_also_clears_dirty(self) -> None:
|
||||||
"""clear_cache() must discard any pending dirty entries."""
|
"""clear_cache() must discard any pending dirty entries."""
|
||||||
info = GeoInfo(country_code="US", country_name="United States", asn="AS1", org="ISP")
|
info = GeoInfo(country_code="US", country_name="United States", asn="AS1", org="ISP")
|
||||||
geo_service._store("8.8.8.8", info)
|
geo_service._store("8.8.8.8", info) # type: ignore[attr-defined]
|
||||||
assert geo_service._dirty
|
assert geo_service._dirty # type: ignore[attr-defined]
|
||||||
|
|
||||||
geo_service.clear_cache()
|
geo_service.clear_cache()
|
||||||
|
|
||||||
assert not geo_service._dirty
|
assert not geo_service._dirty # type: ignore[attr-defined]
|
||||||
|
|
||||||
async def test_lookup_batch_populates_dirty(self) -> None:
|
async def test_lookup_batch_populates_dirty(self) -> None:
|
||||||
"""After lookup_batch() with db=None, resolved IPs appear in _dirty."""
|
"""After lookup_batch() with db=None, resolved IPs appear in _dirty."""
|
||||||
@@ -510,7 +509,7 @@ class TestDirtySetTracking:
|
|||||||
await geo_service.lookup_batch(ips, session, db=None)
|
await geo_service.lookup_batch(ips, session, db=None)
|
||||||
|
|
||||||
for ip in ips:
|
for ip in ips:
|
||||||
assert ip in geo_service._dirty
|
assert ip in geo_service._dirty # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
||||||
class TestFlushDirty:
|
class TestFlushDirty:
|
||||||
@@ -519,8 +518,8 @@ class TestFlushDirty:
|
|||||||
async def test_flush_writes_and_clears_dirty(self) -> None:
|
async def test_flush_writes_and_clears_dirty(self) -> None:
|
||||||
"""flush_dirty() inserts all dirty IPs and clears _dirty afterwards."""
|
"""flush_dirty() inserts all dirty IPs and clears _dirty afterwards."""
|
||||||
info = GeoInfo(country_code="GB", country_name="United Kingdom", asn="AS2856", org="BT")
|
info = GeoInfo(country_code="GB", country_name="United Kingdom", asn="AS2856", org="BT")
|
||||||
geo_service._store("100.0.0.1", info)
|
geo_service._store("100.0.0.1", info) # type: ignore[attr-defined]
|
||||||
assert "100.0.0.1" in geo_service._dirty
|
assert "100.0.0.1" in geo_service._dirty # type: ignore[attr-defined]
|
||||||
|
|
||||||
db = _make_async_db()
|
db = _make_async_db()
|
||||||
count = await geo_service.flush_dirty(db)
|
count = await geo_service.flush_dirty(db)
|
||||||
@@ -528,7 +527,7 @@ class TestFlushDirty:
|
|||||||
assert count == 1
|
assert count == 1
|
||||||
db.executemany.assert_awaited_once()
|
db.executemany.assert_awaited_once()
|
||||||
db.commit.assert_awaited_once()
|
db.commit.assert_awaited_once()
|
||||||
assert "100.0.0.1" not in geo_service._dirty
|
assert "100.0.0.1" not in geo_service._dirty # type: ignore[attr-defined]
|
||||||
|
|
||||||
async def test_flush_returns_zero_when_nothing_dirty(self) -> None:
|
async def test_flush_returns_zero_when_nothing_dirty(self) -> None:
|
||||||
"""flush_dirty() returns 0 and makes no DB calls when _dirty is empty."""
|
"""flush_dirty() returns 0 and makes no DB calls when _dirty is empty."""
|
||||||
@@ -542,7 +541,7 @@ class TestFlushDirty:
|
|||||||
async def test_flush_re_adds_to_dirty_on_db_error(self) -> None:
|
async def test_flush_re_adds_to_dirty_on_db_error(self) -> None:
|
||||||
"""When the DB write fails, entries are re-added to _dirty for retry."""
|
"""When the DB write fails, entries are re-added to _dirty for retry."""
|
||||||
info = GeoInfo(country_code="AU", country_name="Australia", asn="AS1", org="ISP")
|
info = GeoInfo(country_code="AU", country_name="Australia", asn="AS1", org="ISP")
|
||||||
geo_service._store("200.0.0.1", info)
|
geo_service._store("200.0.0.1", info) # type: ignore[attr-defined]
|
||||||
|
|
||||||
db = _make_async_db()
|
db = _make_async_db()
|
||||||
db.executemany = AsyncMock(side_effect=OSError("disk full"))
|
db.executemany = AsyncMock(side_effect=OSError("disk full"))
|
||||||
@@ -550,7 +549,7 @@ class TestFlushDirty:
|
|||||||
count = await geo_service.flush_dirty(db)
|
count = await geo_service.flush_dirty(db)
|
||||||
|
|
||||||
assert count == 0
|
assert count == 0
|
||||||
assert "200.0.0.1" in geo_service._dirty
|
assert "200.0.0.1" in geo_service._dirty # type: ignore[attr-defined]
|
||||||
|
|
||||||
async def test_flush_batch_and_lookup_batch_integration(self) -> None:
|
async def test_flush_batch_and_lookup_batch_integration(self) -> None:
|
||||||
"""lookup_batch() populates _dirty; flush_dirty() then persists them."""
|
"""lookup_batch() populates _dirty; flush_dirty() then persists them."""
|
||||||
@@ -563,14 +562,14 @@ class TestFlushDirty:
|
|||||||
|
|
||||||
# Resolve without DB to populate only in-memory cache and _dirty.
|
# Resolve without DB to populate only in-memory cache and _dirty.
|
||||||
await geo_service.lookup_batch(ips, session, db=None)
|
await geo_service.lookup_batch(ips, session, db=None)
|
||||||
assert geo_service._dirty == set(ips)
|
assert geo_service._dirty == set(ips) # type: ignore[attr-defined]
|
||||||
|
|
||||||
# Now flush to the DB.
|
# Now flush to the DB.
|
||||||
db = _make_async_db()
|
db = _make_async_db()
|
||||||
count = await geo_service.flush_dirty(db)
|
count = await geo_service.flush_dirty(db)
|
||||||
|
|
||||||
assert count == 2
|
assert count == 2
|
||||||
assert not geo_service._dirty
|
assert not geo_service._dirty # type: ignore[attr-defined]
|
||||||
db.commit.assert_awaited_once()
|
db.commit.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
@@ -586,7 +585,7 @@ class TestLookupBatchThrottling:
|
|||||||
"""When more than _BATCH_SIZE IPs are sent, asyncio.sleep is called
|
"""When more than _BATCH_SIZE IPs are sent, asyncio.sleep is called
|
||||||
between consecutive batch HTTP calls with at least _BATCH_DELAY."""
|
between consecutive batch HTTP calls with at least _BATCH_DELAY."""
|
||||||
# Generate _BATCH_SIZE + 1 IPs so we get exactly 2 batch calls.
|
# Generate _BATCH_SIZE + 1 IPs so we get exactly 2 batch calls.
|
||||||
batch_size: int = geo_service._BATCH_SIZE
|
batch_size: int = geo_service._BATCH_SIZE # type: ignore[attr-defined]
|
||||||
ips = [f"10.0.{i // 256}.{i % 256}" for i in range(batch_size + 1)]
|
ips = [f"10.0.{i // 256}.{i % 256}" for i in range(batch_size + 1)]
|
||||||
|
|
||||||
def _make_result(chunk: list[str], _session: object) -> dict[str, GeoInfo]:
|
def _make_result(chunk: list[str], _session: object) -> dict[str, GeoInfo]:
|
||||||
@@ -609,7 +608,7 @@ class TestLookupBatchThrottling:
|
|||||||
assert mock_batch.call_count == 2
|
assert mock_batch.call_count == 2
|
||||||
mock_sleep.assert_awaited_once()
|
mock_sleep.assert_awaited_once()
|
||||||
delay_arg: float = mock_sleep.call_args[0][0]
|
delay_arg: float = mock_sleep.call_args[0][0]
|
||||||
assert delay_arg >= geo_service._BATCH_DELAY
|
assert delay_arg >= geo_service._BATCH_DELAY # type: ignore[attr-defined]
|
||||||
|
|
||||||
async def test_lookup_batch_retries_on_full_chunk_failure(self) -> None:
|
async def test_lookup_batch_retries_on_full_chunk_failure(self) -> None:
|
||||||
"""When a chunk returns all-None on first try, it retries and succeeds."""
|
"""When a chunk returns all-None on first try, it retries and succeeds."""
|
||||||
@@ -651,7 +650,7 @@ class TestLookupBatchThrottling:
|
|||||||
_empty = GeoInfo(country_code=None, country_name=None, asn=None, org=None)
|
_empty = GeoInfo(country_code=None, country_name=None, asn=None, org=None)
|
||||||
_failure: dict[str, GeoInfo] = dict.fromkeys(ips, _empty)
|
_failure: dict[str, GeoInfo] = dict.fromkeys(ips, _empty)
|
||||||
|
|
||||||
max_retries: int = geo_service._BATCH_MAX_RETRIES
|
max_retries: int = geo_service._BATCH_MAX_RETRIES # type: ignore[attr-defined]
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
@@ -668,11 +667,11 @@ class TestLookupBatchThrottling:
|
|||||||
# IP should have no country.
|
# IP should have no country.
|
||||||
assert result["9.9.9.9"].country_code is None
|
assert result["9.9.9.9"].country_code is None
|
||||||
# Negative cache should contain the IP.
|
# Negative cache should contain the IP.
|
||||||
assert "9.9.9.9" in geo_service._neg_cache
|
assert "9.9.9.9" in geo_service._neg_cache # type: ignore[attr-defined]
|
||||||
# Sleep called for each retry with exponential backoff.
|
# Sleep called for each retry with exponential backoff.
|
||||||
assert mock_sleep.call_count == max_retries
|
assert mock_sleep.call_count == max_retries
|
||||||
backoff_values = [call.args[0] for call in mock_sleep.call_args_list]
|
backoff_values = [call.args[0] for call in mock_sleep.call_args_list]
|
||||||
batch_delay: float = geo_service._BATCH_DELAY
|
batch_delay: float = geo_service._BATCH_DELAY # type: ignore[attr-defined]
|
||||||
for i, val in enumerate(backoff_values):
|
for i, val in enumerate(backoff_values):
|
||||||
expected = batch_delay * (2 ** (i + 1))
|
expected = batch_delay * (2 ** (i + 1))
|
||||||
assert val == pytest.approx(expected)
|
assert val == pytest.approx(expected)
|
||||||
@@ -710,7 +709,7 @@ class TestErrorLogging:
|
|||||||
import structlog.testing
|
import structlog.testing
|
||||||
|
|
||||||
with structlog.testing.capture_logs() as captured:
|
with structlog.testing.capture_logs() as captured:
|
||||||
result = await geo_service.lookup("197.221.98.153", session)
|
result = await geo_service.lookup("197.221.98.153", session) # type: ignore[arg-type]
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result.country_code is None
|
assert result.country_code is None
|
||||||
@@ -734,7 +733,7 @@ class TestErrorLogging:
|
|||||||
import structlog.testing
|
import structlog.testing
|
||||||
|
|
||||||
with structlog.testing.capture_logs() as captured:
|
with structlog.testing.capture_logs() as captured:
|
||||||
await geo_service.lookup("10.0.0.1", session)
|
await geo_service.lookup("10.0.0.1", session) # type: ignore[arg-type]
|
||||||
|
|
||||||
request_failed = [e for e in captured if e.get("event") == "geo_lookup_request_failed"]
|
request_failed = [e for e in captured if e.get("event") == "geo_lookup_request_failed"]
|
||||||
assert len(request_failed) == 1
|
assert len(request_failed) == 1
|
||||||
@@ -758,7 +757,7 @@ class TestErrorLogging:
|
|||||||
import structlog.testing
|
import structlog.testing
|
||||||
|
|
||||||
with structlog.testing.capture_logs() as captured:
|
with structlog.testing.capture_logs() as captured:
|
||||||
result = await geo_service._batch_api_call(["1.2.3.4"], session)
|
result = await geo_service._batch_api_call(["1.2.3.4"], session) # type: ignore[attr-defined]
|
||||||
|
|
||||||
assert result["1.2.3.4"].country_code is None
|
assert result["1.2.3.4"].country_code is None
|
||||||
|
|
||||||
@@ -779,7 +778,7 @@ class TestLookupCachedOnly:
|
|||||||
|
|
||||||
def test_returns_cached_ips(self) -> None:
|
def test_returns_cached_ips(self) -> None:
|
||||||
"""IPs already in the cache are returned in the geo_map."""
|
"""IPs already in the cache are returned in the geo_map."""
|
||||||
geo_service._cache["1.1.1.1"] = GeoInfo(
|
geo_service._cache["1.1.1.1"] = GeoInfo( # type: ignore[attr-defined]
|
||||||
country_code="AU", country_name="Australia", asn="AS13335", org="Cloudflare"
|
country_code="AU", country_name="Australia", asn="AS13335", org="Cloudflare"
|
||||||
)
|
)
|
||||||
geo_map, uncached = geo_service.lookup_cached_only(["1.1.1.1"])
|
geo_map, uncached = geo_service.lookup_cached_only(["1.1.1.1"])
|
||||||
@@ -799,7 +798,7 @@ class TestLookupCachedOnly:
|
|||||||
"""IPs in the negative cache within TTL are not re-queued as uncached."""
|
"""IPs in the negative cache within TTL are not re-queued as uncached."""
|
||||||
import time
|
import time
|
||||||
|
|
||||||
geo_service._neg_cache["10.0.0.1"] = time.monotonic()
|
geo_service._neg_cache["10.0.0.1"] = time.monotonic() # type: ignore[attr-defined]
|
||||||
|
|
||||||
geo_map, uncached = geo_service.lookup_cached_only(["10.0.0.1"])
|
geo_map, uncached = geo_service.lookup_cached_only(["10.0.0.1"])
|
||||||
|
|
||||||
@@ -808,7 +807,7 @@ class TestLookupCachedOnly:
|
|||||||
|
|
||||||
def test_expired_neg_cache_requeued(self) -> None:
|
def test_expired_neg_cache_requeued(self) -> None:
|
||||||
"""IPs whose neg-cache entry has expired are listed as uncached."""
|
"""IPs whose neg-cache entry has expired are listed as uncached."""
|
||||||
geo_service._neg_cache["10.0.0.2"] = 0.0 # epoch 0 → expired
|
geo_service._neg_cache["10.0.0.2"] = 0.0 # epoch 0 → expired # type: ignore[attr-defined]
|
||||||
|
|
||||||
_geo_map, uncached = geo_service.lookup_cached_only(["10.0.0.2"])
|
_geo_map, uncached = geo_service.lookup_cached_only(["10.0.0.2"])
|
||||||
|
|
||||||
@@ -816,12 +815,12 @@ class TestLookupCachedOnly:
|
|||||||
|
|
||||||
def test_mixed_ips(self) -> None:
|
def test_mixed_ips(self) -> None:
|
||||||
"""A mix of cached, neg-cached, and unknown IPs is split correctly."""
|
"""A mix of cached, neg-cached, and unknown IPs is split correctly."""
|
||||||
geo_service._cache["1.2.3.4"] = GeoInfo(
|
geo_service._cache["1.2.3.4"] = GeoInfo( # type: ignore[attr-defined]
|
||||||
country_code="DE", country_name="Germany", asn=None, org=None
|
country_code="DE", country_name="Germany", asn=None, org=None
|
||||||
)
|
)
|
||||||
import time
|
import time
|
||||||
|
|
||||||
geo_service._neg_cache["5.5.5.5"] = time.monotonic()
|
geo_service._neg_cache["5.5.5.5"] = time.monotonic() # type: ignore[attr-defined]
|
||||||
|
|
||||||
geo_map, uncached = geo_service.lookup_cached_only(["1.2.3.4", "5.5.5.5", "9.9.9.9"])
|
geo_map, uncached = geo_service.lookup_cached_only(["1.2.3.4", "5.5.5.5", "9.9.9.9"])
|
||||||
|
|
||||||
@@ -830,7 +829,7 @@ class TestLookupCachedOnly:
|
|||||||
|
|
||||||
def test_deduplication(self) -> None:
|
def test_deduplication(self) -> None:
|
||||||
"""Duplicate IPs in the input appear at most once in the output."""
|
"""Duplicate IPs in the input appear at most once in the output."""
|
||||||
geo_service._cache["1.2.3.4"] = GeoInfo(
|
geo_service._cache["1.2.3.4"] = GeoInfo( # type: ignore[attr-defined]
|
||||||
country_code="US", country_name="United States", asn=None, org=None
|
country_code="US", country_name="United States", asn=None, org=None
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -867,7 +866,7 @@ class TestLookupBatchBulkWrites:
|
|||||||
session = _make_batch_session(batch_response)
|
session = _make_batch_session(batch_response)
|
||||||
db = _make_async_db()
|
db = _make_async_db()
|
||||||
|
|
||||||
await geo_service.lookup_batch(ips, session, db=db)
|
await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type]
|
||||||
|
|
||||||
# One executemany for the positive rows.
|
# One executemany for the positive rows.
|
||||||
assert db.executemany.await_count >= 1
|
assert db.executemany.await_count >= 1
|
||||||
@@ -884,7 +883,7 @@ class TestLookupBatchBulkWrites:
|
|||||||
session = _make_batch_session(batch_response)
|
session = _make_batch_session(batch_response)
|
||||||
db = _make_async_db()
|
db = _make_async_db()
|
||||||
|
|
||||||
await geo_service.lookup_batch(ips, session, db=db)
|
await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type]
|
||||||
|
|
||||||
assert db.executemany.await_count >= 1
|
assert db.executemany.await_count >= 1
|
||||||
db.execute.assert_not_awaited()
|
db.execute.assert_not_awaited()
|
||||||
@@ -906,7 +905,7 @@ class TestLookupBatchBulkWrites:
|
|||||||
session = _make_batch_session(batch_response)
|
session = _make_batch_session(batch_response)
|
||||||
db = _make_async_db()
|
db = _make_async_db()
|
||||||
|
|
||||||
await geo_service.lookup_batch(ips, session, db=db)
|
await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type]
|
||||||
|
|
||||||
# One executemany for positives, one for negatives.
|
# One executemany for positives, one for negatives.
|
||||||
assert db.executemany.await_count == 2
|
assert db.executemany.await_count == 2
|
||||||
|
|||||||
@@ -64,7 +64,7 @@ async def _create_f2b_db(path: str, rows: list[dict[str, Any]]) -> None:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def f2b_db_path(tmp_path: Path) -> str:
|
async def f2b_db_path(tmp_path: Path) -> str: # type: ignore[misc]
|
||||||
"""Return the path to a test fail2ban SQLite database."""
|
"""Return the path to a test fail2ban SQLite database."""
|
||||||
path = str(tmp_path / "fail2ban_test.sqlite3")
|
path = str(tmp_path / "fail2ban_test.sqlite3")
|
||||||
await _create_f2b_db(
|
await _create_f2b_db(
|
||||||
@@ -123,7 +123,7 @@ class TestListHistory:
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""No filter returns every record in the database."""
|
"""No filter returns every record in the database."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.history_service.get_fail2ban_db_path",
|
"app.services.history_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=f2b_db_path),
|
new=AsyncMock(return_value=f2b_db_path),
|
||||||
):
|
):
|
||||||
result = await history_service.list_history("fake_socket")
|
result = await history_service.list_history("fake_socket")
|
||||||
@@ -135,7 +135,7 @@ class TestListHistory:
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""The ``range_`` filter excludes bans older than the window."""
|
"""The ``range_`` filter excludes bans older than the window."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.history_service.get_fail2ban_db_path",
|
"app.services.history_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=f2b_db_path),
|
new=AsyncMock(return_value=f2b_db_path),
|
||||||
):
|
):
|
||||||
# "24h" window should include only the two recent bans
|
# "24h" window should include only the two recent bans
|
||||||
@@ -147,7 +147,7 @@ class TestListHistory:
|
|||||||
async def test_jail_filter(self, f2b_db_path: str) -> None:
|
async def test_jail_filter(self, f2b_db_path: str) -> None:
|
||||||
"""Jail filter restricts results to bans from that jail."""
|
"""Jail filter restricts results to bans from that jail."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.history_service.get_fail2ban_db_path",
|
"app.services.history_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=f2b_db_path),
|
new=AsyncMock(return_value=f2b_db_path),
|
||||||
):
|
):
|
||||||
result = await history_service.list_history("fake_socket", jail="nginx")
|
result = await history_service.list_history("fake_socket", jail="nginx")
|
||||||
@@ -157,7 +157,7 @@ class TestListHistory:
|
|||||||
async def test_ip_prefix_filter(self, f2b_db_path: str) -> None:
|
async def test_ip_prefix_filter(self, f2b_db_path: str) -> None:
|
||||||
"""IP prefix filter restricts results to matching IPs."""
|
"""IP prefix filter restricts results to matching IPs."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.history_service.get_fail2ban_db_path",
|
"app.services.history_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=f2b_db_path),
|
new=AsyncMock(return_value=f2b_db_path),
|
||||||
):
|
):
|
||||||
result = await history_service.list_history(
|
result = await history_service.list_history(
|
||||||
@@ -170,7 +170,7 @@ class TestListHistory:
|
|||||||
async def test_combined_filters(self, f2b_db_path: str) -> None:
|
async def test_combined_filters(self, f2b_db_path: str) -> None:
|
||||||
"""Jail + IP prefix filters applied together narrow the result set."""
|
"""Jail + IP prefix filters applied together narrow the result set."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.history_service.get_fail2ban_db_path",
|
"app.services.history_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=f2b_db_path),
|
new=AsyncMock(return_value=f2b_db_path),
|
||||||
):
|
):
|
||||||
result = await history_service.list_history(
|
result = await history_service.list_history(
|
||||||
@@ -182,7 +182,7 @@ class TestListHistory:
|
|||||||
async def test_unknown_ip_returns_empty(self, f2b_db_path: str) -> None:
|
async def test_unknown_ip_returns_empty(self, f2b_db_path: str) -> None:
|
||||||
"""Filtering by a non-existent IP returns an empty result set."""
|
"""Filtering by a non-existent IP returns an empty result set."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.history_service.get_fail2ban_db_path",
|
"app.services.history_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=f2b_db_path),
|
new=AsyncMock(return_value=f2b_db_path),
|
||||||
):
|
):
|
||||||
result = await history_service.list_history(
|
result = await history_service.list_history(
|
||||||
@@ -196,7 +196,7 @@ class TestListHistory:
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""``failures`` field is parsed from the JSON ``data`` column."""
|
"""``failures`` field is parsed from the JSON ``data`` column."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.history_service.get_fail2ban_db_path",
|
"app.services.history_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=f2b_db_path),
|
new=AsyncMock(return_value=f2b_db_path),
|
||||||
):
|
):
|
||||||
result = await history_service.list_history(
|
result = await history_service.list_history(
|
||||||
@@ -210,7 +210,7 @@ class TestListHistory:
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""``matches`` list is parsed from the JSON ``data`` column."""
|
"""``matches`` list is parsed from the JSON ``data`` column."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.history_service.get_fail2ban_db_path",
|
"app.services.history_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=f2b_db_path),
|
new=AsyncMock(return_value=f2b_db_path),
|
||||||
):
|
):
|
||||||
result = await history_service.list_history(
|
result = await history_service.list_history(
|
||||||
@@ -226,7 +226,7 @@ class TestListHistory:
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Records with ``data=NULL`` produce failures=0 and matches=[]."""
|
"""Records with ``data=NULL`` produce failures=0 and matches=[]."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.history_service.get_fail2ban_db_path",
|
"app.services.history_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=f2b_db_path),
|
new=AsyncMock(return_value=f2b_db_path),
|
||||||
):
|
):
|
||||||
result = await history_service.list_history(
|
result = await history_service.list_history(
|
||||||
@@ -240,7 +240,7 @@ class TestListHistory:
|
|||||||
async def test_pagination(self, f2b_db_path: str) -> None:
|
async def test_pagination(self, f2b_db_path: str) -> None:
|
||||||
"""Pagination returns the correct slice."""
|
"""Pagination returns the correct slice."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.history_service.get_fail2ban_db_path",
|
"app.services.history_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=f2b_db_path),
|
new=AsyncMock(return_value=f2b_db_path),
|
||||||
):
|
):
|
||||||
result = await history_service.list_history(
|
result = await history_service.list_history(
|
||||||
@@ -265,7 +265,7 @@ class TestGetIpDetail:
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Returns ``None`` when the IP has no records in the database."""
|
"""Returns ``None`` when the IP has no records in the database."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.history_service.get_fail2ban_db_path",
|
"app.services.history_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=f2b_db_path),
|
new=AsyncMock(return_value=f2b_db_path),
|
||||||
):
|
):
|
||||||
result = await history_service.get_ip_detail("fake_socket", "99.99.99.99")
|
result = await history_service.get_ip_detail("fake_socket", "99.99.99.99")
|
||||||
@@ -276,7 +276,7 @@ class TestGetIpDetail:
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Returns an IpDetailResponse with correct totals for a known IP."""
|
"""Returns an IpDetailResponse with correct totals for a known IP."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.history_service.get_fail2ban_db_path",
|
"app.services.history_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=f2b_db_path),
|
new=AsyncMock(return_value=f2b_db_path),
|
||||||
):
|
):
|
||||||
result = await history_service.get_ip_detail("fake_socket", "1.2.3.4")
|
result = await history_service.get_ip_detail("fake_socket", "1.2.3.4")
|
||||||
@@ -291,7 +291,7 @@ class TestGetIpDetail:
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Timeline events are ordered newest-first."""
|
"""Timeline events are ordered newest-first."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.history_service.get_fail2ban_db_path",
|
"app.services.history_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=f2b_db_path),
|
new=AsyncMock(return_value=f2b_db_path),
|
||||||
):
|
):
|
||||||
result = await history_service.get_ip_detail("fake_socket", "1.2.3.4")
|
result = await history_service.get_ip_detail("fake_socket", "1.2.3.4")
|
||||||
@@ -304,7 +304,7 @@ class TestGetIpDetail:
|
|||||||
async def test_last_ban_at_is_most_recent(self, f2b_db_path: str) -> None:
|
async def test_last_ban_at_is_most_recent(self, f2b_db_path: str) -> None:
|
||||||
"""``last_ban_at`` matches the banned_at of the first timeline event."""
|
"""``last_ban_at`` matches the banned_at of the first timeline event."""
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.history_service.get_fail2ban_db_path",
|
"app.services.history_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=f2b_db_path),
|
new=AsyncMock(return_value=f2b_db_path),
|
||||||
):
|
):
|
||||||
result = await history_service.get_ip_detail("fake_socket", "1.2.3.4")
|
result = await history_service.get_ip_detail("fake_socket", "1.2.3.4")
|
||||||
@@ -316,7 +316,7 @@ class TestGetIpDetail:
|
|||||||
self, f2b_db_path: str
|
self, f2b_db_path: str
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Geolocation is applied when a geo_enricher is provided."""
|
"""Geolocation is applied when a geo_enricher is provided."""
|
||||||
from app.models.geo import GeoInfo
|
from app.services.geo_service import GeoInfo
|
||||||
|
|
||||||
mock_geo = GeoInfo(
|
mock_geo = GeoInfo(
|
||||||
country_code="US",
|
country_code="US",
|
||||||
@@ -327,7 +327,7 @@ class TestGetIpDetail:
|
|||||||
fake_enricher = AsyncMock(return_value=mock_geo)
|
fake_enricher = AsyncMock(return_value=mock_geo)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.history_service.get_fail2ban_db_path",
|
"app.services.history_service._get_fail2ban_db_path",
|
||||||
new=AsyncMock(return_value=f2b_db_path),
|
new=AsyncMock(return_value=f2b_db_path),
|
||||||
):
|
):
|
||||||
result = await history_service.get_ip_detail(
|
result = await history_service.get_ip_detail(
|
||||||
|
|||||||
@@ -635,7 +635,7 @@ class TestGetActiveBans:
|
|||||||
|
|
||||||
async def test_http_session_triggers_lookup_batch(self) -> None:
|
async def test_http_session_triggers_lookup_batch(self) -> None:
|
||||||
"""When http_session is provided, geo_service.lookup_batch is used."""
|
"""When http_session is provided, geo_service.lookup_batch is used."""
|
||||||
from app.models.geo import GeoInfo
|
from app.services.geo_service import GeoInfo
|
||||||
|
|
||||||
responses = {
|
responses = {
|
||||||
"status": _make_global_status("sshd"),
|
"status": _make_global_status("sshd"),
|
||||||
@@ -645,14 +645,17 @@ class TestGetActiveBans:
|
|||||||
),
|
),
|
||||||
}
|
}
|
||||||
mock_geo = {"1.2.3.4": GeoInfo(country_code="DE", country_name="Germany", asn="AS1", org="ISP")}
|
mock_geo = {"1.2.3.4": GeoInfo(country_code="DE", country_name="Germany", asn="AS1", org="ISP")}
|
||||||
mock_batch = AsyncMock(return_value=mock_geo)
|
|
||||||
|
|
||||||
with _patch_client(responses):
|
with (
|
||||||
|
_patch_client(responses),
|
||||||
|
patch(
|
||||||
|
"app.services.geo_service.lookup_batch",
|
||||||
|
new=AsyncMock(return_value=mock_geo),
|
||||||
|
) as mock_batch,
|
||||||
|
):
|
||||||
mock_session = AsyncMock()
|
mock_session = AsyncMock()
|
||||||
result = await jail_service.get_active_bans(
|
result = await jail_service.get_active_bans(
|
||||||
_SOCKET,
|
_SOCKET, http_session=mock_session
|
||||||
http_session=mock_session,
|
|
||||||
geo_batch_lookup=mock_batch,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_batch.assert_awaited_once()
|
mock_batch.assert_awaited_once()
|
||||||
@@ -669,14 +672,16 @@ class TestGetActiveBans:
|
|||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
failing_batch = AsyncMock(side_effect=RuntimeError("geo down"))
|
with (
|
||||||
|
_patch_client(responses),
|
||||||
with _patch_client(responses):
|
patch(
|
||||||
|
"app.services.geo_service.lookup_batch",
|
||||||
|
new=AsyncMock(side_effect=RuntimeError("geo down")),
|
||||||
|
),
|
||||||
|
):
|
||||||
mock_session = AsyncMock()
|
mock_session = AsyncMock()
|
||||||
result = await jail_service.get_active_bans(
|
result = await jail_service.get_active_bans(
|
||||||
_SOCKET,
|
_SOCKET, http_session=mock_session
|
||||||
http_session=mock_session,
|
|
||||||
geo_batch_lookup=failing_batch,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result.total == 1
|
assert result.total == 1
|
||||||
@@ -684,7 +689,7 @@ class TestGetActiveBans:
|
|||||||
|
|
||||||
async def test_geo_enricher_still_used_without_http_session(self) -> None:
|
async def test_geo_enricher_still_used_without_http_session(self) -> None:
|
||||||
"""Legacy geo_enricher is still called when http_session is not provided."""
|
"""Legacy geo_enricher is still called when http_session is not provided."""
|
||||||
from app.models.geo import GeoInfo
|
from app.services.geo_service import GeoInfo
|
||||||
|
|
||||||
responses = {
|
responses = {
|
||||||
"status": _make_global_status("sshd"),
|
"status": _make_global_status("sshd"),
|
||||||
@@ -982,7 +987,6 @@ class TestGetJailBannedIps:
|
|||||||
page=1,
|
page=1,
|
||||||
page_size=2,
|
page_size=2,
|
||||||
http_session=http_session,
|
http_session=http_session,
|
||||||
geo_batch_lookup=geo_service.lookup_batch,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Only the 2-IP page slice should be passed to geo enrichment.
|
# Only the 2-IP page slice should be passed to geo enrichment.
|
||||||
@@ -992,6 +996,9 @@ class TestGetJailBannedIps:
|
|||||||
|
|
||||||
async def test_unknown_jail_raises_jail_not_found_error(self) -> None:
|
async def test_unknown_jail_raises_jail_not_found_error(self) -> None:
|
||||||
"""get_jail_banned_ips raises JailNotFoundError for unknown jail."""
|
"""get_jail_banned_ips raises JailNotFoundError for unknown jail."""
|
||||||
|
responses = {
|
||||||
|
"status|ghost|short": (0, pytest.raises), # will be overridden
|
||||||
|
}
|
||||||
# Simulate fail2ban returning an "unknown jail" error.
|
# Simulate fail2ban returning an "unknown jail" error.
|
||||||
class _FakeClient:
|
class _FakeClient:
|
||||||
def __init__(self, **_kw: Any) -> None:
|
def __init__(self, **_kw: Any) -> None:
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from app.models.geo import GeoInfo
|
from app.services.geo_service import GeoInfo
|
||||||
from app.tasks.geo_re_resolve import _run_re_resolve
|
from app.tasks.geo_re_resolve import _run_re_resolve
|
||||||
|
|
||||||
|
|
||||||
@@ -79,8 +79,6 @@ async def test_run_re_resolve_no_unresolved_ips_skips() -> None:
|
|||||||
app = _make_app(unresolved_ips=[])
|
app = _make_app(unresolved_ips=[])
|
||||||
|
|
||||||
with patch("app.tasks.geo_re_resolve.geo_service") as mock_geo:
|
with patch("app.tasks.geo_re_resolve.geo_service") as mock_geo:
|
||||||
mock_geo.get_unresolved_ips = AsyncMock(return_value=[])
|
|
||||||
|
|
||||||
await _run_re_resolve(app)
|
await _run_re_resolve(app)
|
||||||
|
|
||||||
mock_geo.clear_neg_cache.assert_not_called()
|
mock_geo.clear_neg_cache.assert_not_called()
|
||||||
@@ -98,7 +96,6 @@ async def test_run_re_resolve_clears_neg_cache() -> None:
|
|||||||
app = _make_app(unresolved_ips=ips, lookup_result=result)
|
app = _make_app(unresolved_ips=ips, lookup_result=result)
|
||||||
|
|
||||||
with patch("app.tasks.geo_re_resolve.geo_service") as mock_geo:
|
with patch("app.tasks.geo_re_resolve.geo_service") as mock_geo:
|
||||||
mock_geo.get_unresolved_ips = AsyncMock(return_value=ips)
|
|
||||||
mock_geo.lookup_batch = AsyncMock(return_value=result)
|
mock_geo.lookup_batch = AsyncMock(return_value=result)
|
||||||
|
|
||||||
await _run_re_resolve(app)
|
await _run_re_resolve(app)
|
||||||
@@ -117,7 +114,6 @@ async def test_run_re_resolve_calls_lookup_batch_with_db() -> None:
|
|||||||
app = _make_app(unresolved_ips=ips, lookup_result=result)
|
app = _make_app(unresolved_ips=ips, lookup_result=result)
|
||||||
|
|
||||||
with patch("app.tasks.geo_re_resolve.geo_service") as mock_geo:
|
with patch("app.tasks.geo_re_resolve.geo_service") as mock_geo:
|
||||||
mock_geo.get_unresolved_ips = AsyncMock(return_value=ips)
|
|
||||||
mock_geo.lookup_batch = AsyncMock(return_value=result)
|
mock_geo.lookup_batch = AsyncMock(return_value=result)
|
||||||
|
|
||||||
await _run_re_resolve(app)
|
await _run_re_resolve(app)
|
||||||
@@ -141,7 +137,6 @@ async def test_run_re_resolve_logs_correct_counts(caplog: Any) -> None:
|
|||||||
app = _make_app(unresolved_ips=ips, lookup_result=result)
|
app = _make_app(unresolved_ips=ips, lookup_result=result)
|
||||||
|
|
||||||
with patch("app.tasks.geo_re_resolve.geo_service") as mock_geo:
|
with patch("app.tasks.geo_re_resolve.geo_service") as mock_geo:
|
||||||
mock_geo.get_unresolved_ips = AsyncMock(return_value=ips)
|
|
||||||
mock_geo.lookup_batch = AsyncMock(return_value=result)
|
mock_geo.lookup_batch = AsyncMock(return_value=result)
|
||||||
|
|
||||||
await _run_re_resolve(app)
|
await _run_re_resolve(app)
|
||||||
@@ -164,7 +159,6 @@ async def test_run_re_resolve_handles_all_resolved() -> None:
|
|||||||
app = _make_app(unresolved_ips=ips, lookup_result=result)
|
app = _make_app(unresolved_ips=ips, lookup_result=result)
|
||||||
|
|
||||||
with patch("app.tasks.geo_re_resolve.geo_service") as mock_geo:
|
with patch("app.tasks.geo_re_resolve.geo_service") as mock_geo:
|
||||||
mock_geo.get_unresolved_ips = AsyncMock(return_value=ips)
|
|
||||||
mock_geo.lookup_batch = AsyncMock(return_value=result)
|
mock_geo.lookup_batch = AsyncMock(return_value=result)
|
||||||
|
|
||||||
await _run_re_resolve(app)
|
await _run_re_resolve(app)
|
||||||
|
|||||||
@@ -270,7 +270,7 @@ class TestCrashDetection:
|
|||||||
async def test_crash_within_window_creates_pending_recovery(self) -> None:
|
async def test_crash_within_window_creates_pending_recovery(self) -> None:
|
||||||
"""An online→offline transition within 60 s of activation must set pending_recovery."""
|
"""An online→offline transition within 60 s of activation must set pending_recovery."""
|
||||||
app = _make_app(prev_online=True)
|
app = _make_app(prev_online=True)
|
||||||
now = datetime.datetime.now(tz=datetime.UTC)
|
now = datetime.datetime.now(tz=datetime.timezone.utc)
|
||||||
app.state.last_activation = {
|
app.state.last_activation = {
|
||||||
"jail_name": "sshd",
|
"jail_name": "sshd",
|
||||||
"at": now - datetime.timedelta(seconds=10),
|
"at": now - datetime.timedelta(seconds=10),
|
||||||
@@ -297,7 +297,7 @@ class TestCrashDetection:
|
|||||||
app = _make_app(prev_online=True)
|
app = _make_app(prev_online=True)
|
||||||
app.state.last_activation = {
|
app.state.last_activation = {
|
||||||
"jail_name": "sshd",
|
"jail_name": "sshd",
|
||||||
"at": datetime.datetime.now(tz=datetime.UTC)
|
"at": datetime.datetime.now(tz=datetime.timezone.utc)
|
||||||
- datetime.timedelta(seconds=120),
|
- datetime.timedelta(seconds=120),
|
||||||
}
|
}
|
||||||
app.state.pending_recovery = None
|
app.state.pending_recovery = None
|
||||||
@@ -315,8 +315,8 @@ class TestCrashDetection:
|
|||||||
async def test_came_online_marks_pending_recovery_resolved(self) -> None:
|
async def test_came_online_marks_pending_recovery_resolved(self) -> None:
|
||||||
"""An offline→online transition must mark an existing pending_recovery as recovered."""
|
"""An offline→online transition must mark an existing pending_recovery as recovered."""
|
||||||
app = _make_app(prev_online=False)
|
app = _make_app(prev_online=False)
|
||||||
activated_at = datetime.datetime.now(tz=datetime.UTC) - datetime.timedelta(seconds=30)
|
activated_at = datetime.datetime.now(tz=datetime.timezone.utc) - datetime.timedelta(seconds=30)
|
||||||
detected_at = datetime.datetime.now(tz=datetime.UTC)
|
detected_at = datetime.datetime.now(tz=datetime.timezone.utc)
|
||||||
app.state.pending_recovery = PendingRecovery(
|
app.state.pending_recovery = PendingRecovery(
|
||||||
jail_name="sshd",
|
jail_name="sshd",
|
||||||
activated_at=activated_at,
|
activated_at=activated_at,
|
||||||
|
|||||||
@@ -65,6 +65,10 @@ class TestEnsureJailConfigs:
|
|||||||
content = _read(jail_d, conf_file)
|
content = _read(jail_d, conf_file)
|
||||||
assert "enabled = false" in content
|
assert "enabled = false" in content
|
||||||
|
|
||||||
|
# Blocklist-import jail must have a 24-hour ban time
|
||||||
|
blocklist_conf = _read(jail_d, _BLOCKLIST_CONF)
|
||||||
|
assert "bantime = 86400" in blocklist_conf
|
||||||
|
|
||||||
# .local files must set enabled = true and nothing else
|
# .local files must set enabled = true and nothing else
|
||||||
for local_file in (_MANUAL_LOCAL, _BLOCKLIST_LOCAL):
|
for local_file in (_MANUAL_LOCAL, _BLOCKLIST_LOCAL):
|
||||||
content = _read(jail_d, local_file)
|
content = _read(jail_d, local_file)
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
{
|
{
|
||||||
"name": "bangui-frontend",
|
"name": "bangui-frontend",
|
||||||
"private": true,
|
"private": true,
|
||||||
"version": "0.9.4",
|
"version": "0.9.5",
|
||||||
"description": "BanGUI frontend — fail2ban web management interface",
|
"description": "BanGUI frontend — fail2ban web management interface",
|
||||||
"type": "module",
|
"type": "module",
|
||||||
"scripts": {
|
"scripts": {
|
||||||
|
|||||||
@@ -26,7 +26,6 @@ import { AuthProvider } from "./providers/AuthProvider";
|
|||||||
import { TimezoneProvider } from "./providers/TimezoneProvider";
|
import { TimezoneProvider } from "./providers/TimezoneProvider";
|
||||||
import { RequireAuth } from "./components/RequireAuth";
|
import { RequireAuth } from "./components/RequireAuth";
|
||||||
import { SetupGuard } from "./components/SetupGuard";
|
import { SetupGuard } from "./components/SetupGuard";
|
||||||
import { ErrorBoundary } from "./components/ErrorBoundary";
|
|
||||||
import { MainLayout } from "./layouts/MainLayout";
|
import { MainLayout } from "./layouts/MainLayout";
|
||||||
import { SetupPage } from "./pages/SetupPage";
|
import { SetupPage } from "./pages/SetupPage";
|
||||||
import { LoginPage } from "./pages/LoginPage";
|
import { LoginPage } from "./pages/LoginPage";
|
||||||
@@ -44,10 +43,9 @@ import { BlocklistsPage } from "./pages/BlocklistsPage";
|
|||||||
function App(): React.JSX.Element {
|
function App(): React.JSX.Element {
|
||||||
return (
|
return (
|
||||||
<FluentProvider theme={lightTheme}>
|
<FluentProvider theme={lightTheme}>
|
||||||
<ErrorBoundary>
|
<BrowserRouter future={{ v7_startTransition: true, v7_relativeSplatPath: true }}>
|
||||||
<BrowserRouter future={{ v7_startTransition: true, v7_relativeSplatPath: true }}>
|
<AuthProvider>
|
||||||
<AuthProvider>
|
<Routes>
|
||||||
<Routes>
|
|
||||||
{/* Setup wizard — always accessible; redirects to /login if already done */}
|
{/* Setup wizard — always accessible; redirects to /login if already done */}
|
||||||
<Route path="/setup" element={<SetupPage />} />
|
<Route path="/setup" element={<SetupPage />} />
|
||||||
|
|
||||||
@@ -87,7 +85,6 @@ function App(): React.JSX.Element {
|
|||||||
</Routes>
|
</Routes>
|
||||||
</AuthProvider>
|
</AuthProvider>
|
||||||
</BrowserRouter>
|
</BrowserRouter>
|
||||||
</ErrorBoundary>
|
|
||||||
</FluentProvider>
|
</FluentProvider>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ export async function fetchHistory(
|
|||||||
): Promise<HistoryListResponse> {
|
): Promise<HistoryListResponse> {
|
||||||
const params = new URLSearchParams();
|
const params = new URLSearchParams();
|
||||||
if (query.range) params.set("range", query.range);
|
if (query.range) params.set("range", query.range);
|
||||||
|
if (query.origin) params.set("origin", query.origin);
|
||||||
if (query.jail) params.set("jail", query.jail);
|
if (query.jail) params.set("jail", query.jail);
|
||||||
if (query.ip) params.set("ip", query.ip);
|
if (query.ip) params.set("ip", query.ip);
|
||||||
if (query.page !== undefined) params.set("page", String(query.page));
|
if (query.page !== undefined) params.set("page", String(query.page));
|
||||||
|
|||||||
@@ -27,7 +27,6 @@ import {
|
|||||||
import { PageEmpty, PageError, PageLoading } from "./PageFeedback";
|
import { PageEmpty, PageError, PageLoading } from "./PageFeedback";
|
||||||
import { ChevronLeftRegular, ChevronRightRegular } from "@fluentui/react-icons";
|
import { ChevronLeftRegular, ChevronRightRegular } from "@fluentui/react-icons";
|
||||||
import { useBans } from "../hooks/useBans";
|
import { useBans } from "../hooks/useBans";
|
||||||
import { formatTimestamp } from "../utils/formatDate";
|
|
||||||
import type { DashboardBanItem, TimeRange, BanOriginFilter } from "../types/ban";
|
import type { DashboardBanItem, TimeRange, BanOriginFilter } from "../types/ban";
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
@@ -91,6 +90,31 @@ const useStyles = makeStyles({
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Helpers
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Format an ISO 8601 timestamp for display.
|
||||||
|
*
|
||||||
|
* @param iso - ISO 8601 UTC string.
|
||||||
|
* @returns Localised date+time string.
|
||||||
|
*/
|
||||||
|
function formatTimestamp(iso: string): string {
|
||||||
|
try {
|
||||||
|
return new Date(iso).toLocaleString(undefined, {
|
||||||
|
year: "numeric",
|
||||||
|
month: "2-digit",
|
||||||
|
day: "2-digit",
|
||||||
|
hour: "2-digit",
|
||||||
|
minute: "2-digit",
|
||||||
|
second: "2-digit",
|
||||||
|
});
|
||||||
|
} catch {
|
||||||
|
return iso;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
// Column definitions
|
// Column definitions
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ import {
|
|||||||
makeStyles,
|
makeStyles,
|
||||||
tokens,
|
tokens,
|
||||||
} from "@fluentui/react-components";
|
} from "@fluentui/react-components";
|
||||||
import { useCardStyles } from "../theme/commonStyles";
|
|
||||||
import type { BanOriginFilter, TimeRange } from "../types/ban";
|
import type { BanOriginFilter, TimeRange } from "../types/ban";
|
||||||
import {
|
import {
|
||||||
BAN_ORIGIN_FILTER_LABELS,
|
BAN_ORIGIN_FILTER_LABELS,
|
||||||
@@ -58,6 +57,20 @@ const useStyles = makeStyles({
|
|||||||
alignItems: "center",
|
alignItems: "center",
|
||||||
flexWrap: "wrap",
|
flexWrap: "wrap",
|
||||||
gap: tokens.spacingVerticalS,
|
gap: tokens.spacingVerticalS,
|
||||||
|
backgroundColor: tokens.colorNeutralBackground1,
|
||||||
|
borderRadius: tokens.borderRadiusMedium,
|
||||||
|
borderTopWidth: "1px",
|
||||||
|
borderTopStyle: "solid",
|
||||||
|
borderTopColor: tokens.colorNeutralStroke2,
|
||||||
|
borderRightWidth: "1px",
|
||||||
|
borderRightStyle: "solid",
|
||||||
|
borderRightColor: tokens.colorNeutralStroke2,
|
||||||
|
borderBottomWidth: "1px",
|
||||||
|
borderBottomStyle: "solid",
|
||||||
|
borderBottomColor: tokens.colorNeutralStroke2,
|
||||||
|
borderLeftWidth: "1px",
|
||||||
|
borderLeftStyle: "solid",
|
||||||
|
borderLeftColor: tokens.colorNeutralStroke2,
|
||||||
paddingTop: tokens.spacingVerticalS,
|
paddingTop: tokens.spacingVerticalS,
|
||||||
paddingBottom: tokens.spacingVerticalS,
|
paddingBottom: tokens.spacingVerticalS,
|
||||||
paddingLeft: tokens.spacingHorizontalM,
|
paddingLeft: tokens.spacingHorizontalM,
|
||||||
@@ -94,10 +107,9 @@ export function DashboardFilterBar({
|
|||||||
onOriginFilterChange,
|
onOriginFilterChange,
|
||||||
}: DashboardFilterBarProps): React.JSX.Element {
|
}: DashboardFilterBarProps): React.JSX.Element {
|
||||||
const styles = useStyles();
|
const styles = useStyles();
|
||||||
const cardStyles = useCardStyles();
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className={`${styles.container} ${cardStyles.card}`}>
|
<div className={styles.container}>
|
||||||
{/* Time-range group */}
|
{/* Time-range group */}
|
||||||
<div className={styles.group}>
|
<div className={styles.group}>
|
||||||
<Text weight="semibold" size={300}>
|
<Text weight="semibold" size={300}>
|
||||||
|
|||||||
@@ -1,62 +0,0 @@
|
|||||||
/**
|
|
||||||
* React error boundary component.
|
|
||||||
*
|
|
||||||
* Catches render-time exceptions in child components and shows a fallback UI.
|
|
||||||
*/
|
|
||||||
import React from "react";
|
|
||||||
|
|
||||||
interface ErrorBoundaryState {
|
|
||||||
hasError: boolean;
|
|
||||||
errorMessage: string | null;
|
|
||||||
}
|
|
||||||
|
|
||||||
interface ErrorBoundaryProps {
|
|
||||||
children: React.ReactNode;
|
|
||||||
}
|
|
||||||
|
|
||||||
export class ErrorBoundary extends React.Component<ErrorBoundaryProps, ErrorBoundaryState> {
|
|
||||||
constructor(props: ErrorBoundaryProps) {
|
|
||||||
super(props);
|
|
||||||
this.state = { hasError: false, errorMessage: null };
|
|
||||||
this.handleReload = this.handleReload.bind(this);
|
|
||||||
}
|
|
||||||
|
|
||||||
static getDerivedStateFromError(error: Error): ErrorBoundaryState {
|
|
||||||
return { hasError: true, errorMessage: error.message || "Unknown error" };
|
|
||||||
}
|
|
||||||
|
|
||||||
componentDidCatch(error: Error, errorInfo: React.ErrorInfo): void {
|
|
||||||
console.error("ErrorBoundary caught an error", { error, errorInfo });
|
|
||||||
}
|
|
||||||
|
|
||||||
handleReload(): void {
|
|
||||||
window.location.reload();
|
|
||||||
}
|
|
||||||
|
|
||||||
render(): React.ReactNode {
|
|
||||||
if (this.state.hasError) {
|
|
||||||
return (
|
|
||||||
<div
|
|
||||||
style={{
|
|
||||||
display: "flex",
|
|
||||||
flexDirection: "column",
|
|
||||||
alignItems: "center",
|
|
||||||
justifyContent: "center",
|
|
||||||
minHeight: "100vh",
|
|
||||||
padding: "24px",
|
|
||||||
textAlign: "center",
|
|
||||||
}}
|
|
||||||
role="alert"
|
|
||||||
>
|
|
||||||
<h1>Something went wrong</h1>
|
|
||||||
<p>{this.state.errorMessage ?? "Please try reloading the page."}</p>
|
|
||||||
<button type="button" onClick={this.handleReload} style={{ marginTop: "16px" }}>
|
|
||||||
Reload
|
|
||||||
</button>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
return this.props.children;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -18,7 +18,6 @@ import {
|
|||||||
tokens,
|
tokens,
|
||||||
Tooltip,
|
Tooltip,
|
||||||
} from "@fluentui/react-components";
|
} from "@fluentui/react-components";
|
||||||
import { useCardStyles } from "../theme/commonStyles";
|
|
||||||
import { ArrowClockwiseRegular, ShieldRegular } from "@fluentui/react-icons";
|
import { ArrowClockwiseRegular, ShieldRegular } from "@fluentui/react-icons";
|
||||||
import { useServerStatus } from "../hooks/useServerStatus";
|
import { useServerStatus } from "../hooks/useServerStatus";
|
||||||
|
|
||||||
@@ -32,6 +31,20 @@ const useStyles = makeStyles({
|
|||||||
alignItems: "center",
|
alignItems: "center",
|
||||||
gap: tokens.spacingHorizontalL,
|
gap: tokens.spacingHorizontalL,
|
||||||
padding: `${tokens.spacingVerticalS} ${tokens.spacingHorizontalL}`,
|
padding: `${tokens.spacingVerticalS} ${tokens.spacingHorizontalL}`,
|
||||||
|
backgroundColor: tokens.colorNeutralBackground1,
|
||||||
|
borderRadius: tokens.borderRadiusMedium,
|
||||||
|
borderTopWidth: "1px",
|
||||||
|
borderTopStyle: "solid",
|
||||||
|
borderTopColor: tokens.colorNeutralStroke2,
|
||||||
|
borderRightWidth: "1px",
|
||||||
|
borderRightStyle: "solid",
|
||||||
|
borderRightColor: tokens.colorNeutralStroke2,
|
||||||
|
borderBottomWidth: "1px",
|
||||||
|
borderBottomStyle: "solid",
|
||||||
|
borderBottomColor: tokens.colorNeutralStroke2,
|
||||||
|
borderLeftWidth: "1px",
|
||||||
|
borderLeftStyle: "solid",
|
||||||
|
borderLeftColor: tokens.colorNeutralStroke2,
|
||||||
marginBottom: tokens.spacingVerticalL,
|
marginBottom: tokens.spacingVerticalL,
|
||||||
flexWrap: "wrap",
|
flexWrap: "wrap",
|
||||||
},
|
},
|
||||||
@@ -72,10 +85,8 @@ export function ServerStatusBar(): React.JSX.Element {
|
|||||||
const styles = useStyles();
|
const styles = useStyles();
|
||||||
const { status, loading, error, refresh } = useServerStatus();
|
const { status, loading, error, refresh } = useServerStatus();
|
||||||
|
|
||||||
const cardStyles = useCardStyles();
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className={`${cardStyles.card} ${styles.bar}`} role="status" aria-label="fail2ban server status">
|
<div className={styles.bar} role="status" aria-label="fail2ban server status">
|
||||||
{/* ---------------------------------------------------------------- */}
|
{/* ---------------------------------------------------------------- */}
|
||||||
{/* Online / Offline badge */}
|
{/* Online / Offline badge */}
|
||||||
{/* ---------------------------------------------------------------- */}
|
{/* ---------------------------------------------------------------- */}
|
||||||
|
|||||||
@@ -6,13 +6,12 @@
|
|||||||
* While the status is loading a full-screen spinner is shown.
|
* While the status is loading a full-screen spinner is shown.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
import { useEffect, useState } from "react";
|
||||||
import { Navigate } from "react-router-dom";
|
import { Navigate } from "react-router-dom";
|
||||||
import { Spinner } from "@fluentui/react-components";
|
import { Spinner } from "@fluentui/react-components";
|
||||||
import { useSetup } from "../hooks/useSetup";
|
import { getSetupStatus } from "../api/setup";
|
||||||
|
|
||||||
/**
|
type Status = "loading" | "done" | "pending";
|
||||||
* Component is intentionally simple; status load is handled by the hook.
|
|
||||||
*/
|
|
||||||
|
|
||||||
interface SetupGuardProps {
|
interface SetupGuardProps {
|
||||||
/** The protected content to render when setup is complete. */
|
/** The protected content to render when setup is complete. */
|
||||||
@@ -25,9 +24,25 @@ interface SetupGuardProps {
|
|||||||
* Redirects to `/setup` if setup is still pending.
|
* Redirects to `/setup` if setup is still pending.
|
||||||
*/
|
*/
|
||||||
export function SetupGuard({ children }: SetupGuardProps): React.JSX.Element {
|
export function SetupGuard({ children }: SetupGuardProps): React.JSX.Element {
|
||||||
const { status, loading } = useSetup();
|
const [status, setStatus] = useState<Status>("loading");
|
||||||
|
|
||||||
if (loading) {
|
useEffect(() => {
|
||||||
|
let cancelled = false;
|
||||||
|
getSetupStatus()
|
||||||
|
.then((res): void => {
|
||||||
|
if (!cancelled) setStatus(res.completed ? "done" : "pending");
|
||||||
|
})
|
||||||
|
.catch((): void => {
|
||||||
|
// A failed check conservatively redirects to /setup — a crashed
|
||||||
|
// backend cannot serve protected routes anyway.
|
||||||
|
if (!cancelled) setStatus("pending");
|
||||||
|
});
|
||||||
|
return (): void => {
|
||||||
|
cancelled = true;
|
||||||
|
};
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
if (status === "loading") {
|
||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
style={{
|
style={{
|
||||||
@@ -42,7 +57,7 @@ export function SetupGuard({ children }: SetupGuardProps): React.JSX.Element {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!status?.completed) {
|
if (status === "pending") {
|
||||||
return <Navigate to="/setup" replace />;
|
return <Navigate to="/setup" replace />;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,6 @@
|
|||||||
import { useCallback, useState } from "react";
|
import { useCallback, useState } from "react";
|
||||||
import { ComposableMap, ZoomableGroup, Geography, useGeographies } from "react-simple-maps";
|
import { ComposableMap, ZoomableGroup, Geography, useGeographies } from "react-simple-maps";
|
||||||
import { Button, makeStyles, tokens } from "@fluentui/react-components";
|
import { Button, makeStyles, tokens } from "@fluentui/react-components";
|
||||||
import { useCardStyles } from "../theme/commonStyles";
|
|
||||||
import type { GeoPermissibleObjects } from "d3-geo";
|
import type { GeoPermissibleObjects } from "d3-geo";
|
||||||
import { ISO_NUMERIC_TO_ALPHA2 } from "../data/isoNumericToAlpha2";
|
import { ISO_NUMERIC_TO_ALPHA2 } from "../data/isoNumericToAlpha2";
|
||||||
import { getBanCountColor } from "../utils/mapColors";
|
import { getBanCountColor } from "../utils/mapColors";
|
||||||
@@ -30,6 +29,9 @@ const useStyles = makeStyles({
|
|||||||
mapWrapper: {
|
mapWrapper: {
|
||||||
width: "100%",
|
width: "100%",
|
||||||
position: "relative",
|
position: "relative",
|
||||||
|
backgroundColor: tokens.colorNeutralBackground2,
|
||||||
|
borderRadius: tokens.borderRadiusMedium,
|
||||||
|
border: `1px solid ${tokens.colorNeutralStroke1}`,
|
||||||
overflow: "hidden",
|
overflow: "hidden",
|
||||||
},
|
},
|
||||||
countLabel: {
|
countLabel: {
|
||||||
@@ -209,7 +211,6 @@ export function WorldMap({
|
|||||||
thresholdHigh = 100,
|
thresholdHigh = 100,
|
||||||
}: WorldMapProps): React.JSX.Element {
|
}: WorldMapProps): React.JSX.Element {
|
||||||
const styles = useStyles();
|
const styles = useStyles();
|
||||||
const cardStyles = useCardStyles();
|
|
||||||
const [zoom, setZoom] = useState<number>(1);
|
const [zoom, setZoom] = useState<number>(1);
|
||||||
const [center, setCenter] = useState<[number, number]>([0, 0]);
|
const [center, setCenter] = useState<[number, number]>([0, 0]);
|
||||||
|
|
||||||
@@ -228,7 +229,7 @@ export function WorldMap({
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
className={`${cardStyles.card} ${styles.mapWrapper}`}
|
className={styles.mapWrapper}
|
||||||
role="img"
|
role="img"
|
||||||
aria-label="World map showing banned IP counts by country. Click a country to filter the table below."
|
aria-label="World map showing banned IP counts by country. Click a country to filter the table below."
|
||||||
>
|
>
|
||||||
|
|||||||
@@ -1,33 +0,0 @@
|
|||||||
import { describe, it, expect } from "vitest";
|
|
||||||
import { render, screen } from "@testing-library/react";
|
|
||||||
import { ErrorBoundary } from "../ErrorBoundary";
|
|
||||||
|
|
||||||
function ExplodingChild(): React.ReactElement {
|
|
||||||
throw new Error("boom");
|
|
||||||
}
|
|
||||||
|
|
||||||
describe("ErrorBoundary", () => {
|
|
||||||
it("renders the fallback UI when a child throws", () => {
|
|
||||||
render(
|
|
||||||
<ErrorBoundary>
|
|
||||||
<ExplodingChild />
|
|
||||||
</ErrorBoundary>,
|
|
||||||
);
|
|
||||||
|
|
||||||
expect(screen.getByRole("alert")).toBeInTheDocument();
|
|
||||||
expect(screen.getByText("Something went wrong")).toBeInTheDocument();
|
|
||||||
expect(screen.getByText(/boom/i)).toBeInTheDocument();
|
|
||||||
expect(screen.getByRole("button", { name: /reload/i })).toBeInTheDocument();
|
|
||||||
});
|
|
||||||
|
|
||||||
it("renders children normally when no error occurs", () => {
|
|
||||||
render(
|
|
||||||
<ErrorBoundary>
|
|
||||||
<div data-testid="safe-child">safe</div>
|
|
||||||
</ErrorBoundary>,
|
|
||||||
);
|
|
||||||
|
|
||||||
expect(screen.getByTestId("safe-child")).toBeInTheDocument();
|
|
||||||
expect(screen.queryByRole("alert")).not.toBeInTheDocument();
|
|
||||||
});
|
|
||||||
});
|
|
||||||
@@ -1,105 +0,0 @@
|
|||||||
import { Button, Badge, Table, TableBody, TableCell, TableCellLayout, TableHeader, TableHeaderCell, TableRow, Text, MessageBar, MessageBarBody, Spinner } from "@fluentui/react-components";
|
|
||||||
import { ArrowClockwiseRegular } from "@fluentui/react-icons";
|
|
||||||
import { useCommonSectionStyles } from "../../theme/commonStyles";
|
|
||||||
import { useImportLog } from "../../hooks/useBlocklist";
|
|
||||||
import { useBlocklistStyles } from "./blocklistStyles";
|
|
||||||
|
|
||||||
export function BlocklistImportLogSection(): React.JSX.Element {
|
|
||||||
const styles = useBlocklistStyles();
|
|
||||||
const sectionStyles = useCommonSectionStyles();
|
|
||||||
const { data, loading, error, page, setPage, refresh } = useImportLog(undefined, 20);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<div className={sectionStyles.section}>
|
|
||||||
<div className={sectionStyles.sectionHeader}>
|
|
||||||
<Text size={500} weight="semibold">
|
|
||||||
Import Log
|
|
||||||
</Text>
|
|
||||||
<Button icon={<ArrowClockwiseRegular />} appearance="secondary" onClick={refresh}>
|
|
||||||
Refresh
|
|
||||||
</Button>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
{error && (
|
|
||||||
<MessageBar intent="error">
|
|
||||||
<MessageBarBody>{error}</MessageBarBody>
|
|
||||||
</MessageBar>
|
|
||||||
)}
|
|
||||||
|
|
||||||
{loading ? (
|
|
||||||
<div className={styles.centred}>
|
|
||||||
<Spinner label="Loading log…" />
|
|
||||||
</div>
|
|
||||||
) : !data || data.items.length === 0 ? (
|
|
||||||
<div className={styles.centred}>
|
|
||||||
<Text>No import runs recorded yet.</Text>
|
|
||||||
</div>
|
|
||||||
) : (
|
|
||||||
<>
|
|
||||||
<div className={styles.tableWrapper}>
|
|
||||||
<Table>
|
|
||||||
<TableHeader>
|
|
||||||
<TableRow>
|
|
||||||
<TableHeaderCell>Timestamp</TableHeaderCell>
|
|
||||||
<TableHeaderCell>Source URL</TableHeaderCell>
|
|
||||||
<TableHeaderCell>Imported</TableHeaderCell>
|
|
||||||
<TableHeaderCell>Skipped</TableHeaderCell>
|
|
||||||
<TableHeaderCell>Status</TableHeaderCell>
|
|
||||||
</TableRow>
|
|
||||||
</TableHeader>
|
|
||||||
<TableBody>
|
|
||||||
{data.items.map((entry) => (
|
|
||||||
<TableRow key={entry.id} className={entry.errors ? styles.errorRow : undefined}>
|
|
||||||
<TableCell>
|
|
||||||
<TableCellLayout>
|
|
||||||
<span className={styles.mono}>{entry.timestamp}</span>
|
|
||||||
</TableCellLayout>
|
|
||||||
</TableCell>
|
|
||||||
<TableCell>
|
|
||||||
<TableCellLayout>
|
|
||||||
<span className={styles.mono}>{entry.source_url}</span>
|
|
||||||
</TableCellLayout>
|
|
||||||
</TableCell>
|
|
||||||
<TableCell>
|
|
||||||
<TableCellLayout>{entry.ips_imported}</TableCellLayout>
|
|
||||||
</TableCell>
|
|
||||||
<TableCell>
|
|
||||||
<TableCellLayout>{entry.ips_skipped}</TableCellLayout>
|
|
||||||
</TableCell>
|
|
||||||
<TableCell>
|
|
||||||
<TableCellLayout>
|
|
||||||
{entry.errors ? (
|
|
||||||
<Badge appearance="filled" color="danger">
|
|
||||||
Error
|
|
||||||
</Badge>
|
|
||||||
) : (
|
|
||||||
<Badge appearance="filled" color="success">
|
|
||||||
OK
|
|
||||||
</Badge>
|
|
||||||
)}
|
|
||||||
</TableCellLayout>
|
|
||||||
</TableCell>
|
|
||||||
</TableRow>
|
|
||||||
))}
|
|
||||||
</TableBody>
|
|
||||||
</Table>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
{data.total_pages > 1 && (
|
|
||||||
<div className={styles.pagination}>
|
|
||||||
<Button size="small" appearance="secondary" disabled={page <= 1} onClick={() => { setPage(page - 1); }}>
|
|
||||||
Previous
|
|
||||||
</Button>
|
|
||||||
<Text size={200}>
|
|
||||||
Page {page} of {data.total_pages}
|
|
||||||
</Text>
|
|
||||||
<Button size="small" appearance="secondary" disabled={page >= data.total_pages} onClick={() => { setPage(page + 1); }}>
|
|
||||||
Next
|
|
||||||
</Button>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
</>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,175 +0,0 @@
|
|||||||
import { useCallback, useState } from "react";
|
|
||||||
import { Button, Field, Input, MessageBar, MessageBarBody, Select, Spinner, Text } from "@fluentui/react-components";
|
|
||||||
import { PlayRegular } from "@fluentui/react-icons";
|
|
||||||
import { useCommonSectionStyles } from "../../theme/commonStyles";
|
|
||||||
import { useSchedule } from "../../hooks/useBlocklist";
|
|
||||||
import { useBlocklistStyles } from "./blocklistStyles";
|
|
||||||
import type { ScheduleConfig, ScheduleFrequency } from "../../types/blocklist";
|
|
||||||
|
|
||||||
const FREQUENCY_LABELS: Record<ScheduleFrequency, string> = {
|
|
||||||
hourly: "Every N hours",
|
|
||||||
daily: "Daily",
|
|
||||||
weekly: "Weekly",
|
|
||||||
};
|
|
||||||
|
|
||||||
const DAYS = ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"];
|
|
||||||
|
|
||||||
interface ScheduleSectionProps {
|
|
||||||
onRunImport: () => void;
|
|
||||||
runImportRunning: boolean;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function BlocklistScheduleSection({ onRunImport, runImportRunning }: ScheduleSectionProps): React.JSX.Element {
|
|
||||||
const styles = useBlocklistStyles();
|
|
||||||
const sectionStyles = useCommonSectionStyles();
|
|
||||||
const { info, loading, error, saveSchedule } = useSchedule();
|
|
||||||
const [saving, setSaving] = useState(false);
|
|
||||||
const [saveMsg, setSaveMsg] = useState<string | null>(null);
|
|
||||||
|
|
||||||
const config = info?.config ?? {
|
|
||||||
frequency: "daily" as ScheduleFrequency,
|
|
||||||
interval_hours: 24,
|
|
||||||
hour: 3,
|
|
||||||
minute: 0,
|
|
||||||
day_of_week: 0,
|
|
||||||
};
|
|
||||||
|
|
||||||
const [draft, setDraft] = useState<ScheduleConfig>(config);
|
|
||||||
|
|
||||||
const handleSave = useCallback((): void => {
|
|
||||||
setSaving(true);
|
|
||||||
saveSchedule(draft)
|
|
||||||
.then(() => {
|
|
||||||
setSaveMsg("Schedule saved.");
|
|
||||||
setSaving(false);
|
|
||||||
setTimeout(() => { setSaveMsg(null); }, 3000);
|
|
||||||
})
|
|
||||||
.catch((err: unknown) => {
|
|
||||||
setSaveMsg(err instanceof Error ? err.message : "Failed to save schedule");
|
|
||||||
setSaving(false);
|
|
||||||
});
|
|
||||||
}, [draft, saveSchedule]);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<div className={sectionStyles.section}>
|
|
||||||
<div className={sectionStyles.sectionHeader}>
|
|
||||||
<Text size={500} weight="semibold">
|
|
||||||
Import Schedule
|
|
||||||
</Text>
|
|
||||||
<Button icon={<PlayRegular />} appearance="secondary" onClick={onRunImport} disabled={runImportRunning}>
|
|
||||||
{runImportRunning ? <Spinner size="tiny" /> : "Run Now"}
|
|
||||||
</Button>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
{error && (
|
|
||||||
<MessageBar intent="error">
|
|
||||||
<MessageBarBody>{error}</MessageBarBody>
|
|
||||||
</MessageBar>
|
|
||||||
)}
|
|
||||||
{saveMsg && (
|
|
||||||
<MessageBar intent={saveMsg === "Schedule saved." ? "success" : "error"}>
|
|
||||||
<MessageBarBody>{saveMsg}</MessageBarBody>
|
|
||||||
</MessageBar>
|
|
||||||
)}
|
|
||||||
|
|
||||||
{loading ? (
|
|
||||||
<div className={styles.centred}>
|
|
||||||
<Spinner label="Loading schedule…" />
|
|
||||||
</div>
|
|
||||||
) : (
|
|
||||||
<>
|
|
||||||
<div className={styles.scheduleForm}>
|
|
||||||
<Field label="Frequency" className={styles.scheduleField}>
|
|
||||||
<Select
|
|
||||||
value={draft.frequency}
|
|
||||||
onChange={(_ev, d) => { setDraft((p) => ({ ...p, frequency: d.value as ScheduleFrequency })); }}
|
|
||||||
>
|
|
||||||
{(["hourly", "daily", "weekly"] as ScheduleFrequency[]).map((f) => (
|
|
||||||
<option key={f} value={f}>
|
|
||||||
{FREQUENCY_LABELS[f]}
|
|
||||||
</option>
|
|
||||||
))}
|
|
||||||
</Select>
|
|
||||||
</Field>
|
|
||||||
|
|
||||||
{draft.frequency === "hourly" && (
|
|
||||||
<Field label="Every (hours)" className={styles.scheduleField}>
|
|
||||||
<Input
|
|
||||||
type="number"
|
|
||||||
value={String(draft.interval_hours)}
|
|
||||||
onChange={(_ev, d) => { setDraft((p) => ({ ...p, interval_hours: Math.max(1, parseInt(d.value, 10) || 1) })); }}
|
|
||||||
min={1}
|
|
||||||
max={168}
|
|
||||||
/>
|
|
||||||
</Field>
|
|
||||||
)}
|
|
||||||
|
|
||||||
{draft.frequency !== "hourly" && (
|
|
||||||
<>
|
|
||||||
{draft.frequency === "weekly" && (
|
|
||||||
<Field label="Day of week" className={styles.scheduleField}>
|
|
||||||
<Select
|
|
||||||
value={String(draft.day_of_week)}
|
|
||||||
onChange={(_ev, d) => { setDraft((p) => ({ ...p, day_of_week: parseInt(d.value, 10) })); }}
|
|
||||||
>
|
|
||||||
{DAYS.map((day, i) => (
|
|
||||||
<option key={day} value={i}>
|
|
||||||
{day}
|
|
||||||
</option>
|
|
||||||
))}
|
|
||||||
</Select>
|
|
||||||
</Field>
|
|
||||||
)}
|
|
||||||
|
|
||||||
<Field label="Hour (UTC)" className={styles.scheduleField}>
|
|
||||||
<Select
|
|
||||||
value={String(draft.hour)}
|
|
||||||
onChange={(_ev, d) => { setDraft((p) => ({ ...p, hour: parseInt(d.value, 10) })); }}
|
|
||||||
>
|
|
||||||
{Array.from({ length: 24 }, (_, i) => (
|
|
||||||
<option key={i} value={i}>
|
|
||||||
{String(i).padStart(2, "0")}:00
|
|
||||||
</option>
|
|
||||||
))}
|
|
||||||
</Select>
|
|
||||||
</Field>
|
|
||||||
|
|
||||||
<Field label="Minute" className={styles.scheduleField}>
|
|
||||||
<Select
|
|
||||||
value={String(draft.minute)}
|
|
||||||
onChange={(_ev, d) => { setDraft((p) => ({ ...p, minute: parseInt(d.value, 10) })); }}
|
|
||||||
>
|
|
||||||
{[0, 15, 30, 45].map((m) => (
|
|
||||||
<option key={m} value={m}>
|
|
||||||
{String(m).padStart(2, "0")}
|
|
||||||
</option>
|
|
||||||
))}
|
|
||||||
</Select>
|
|
||||||
</Field>
|
|
||||||
</>
|
|
||||||
)}
|
|
||||||
|
|
||||||
<Button appearance="primary" onClick={handleSave} disabled={saving} style={{ alignSelf: "flex-end" }}>
|
|
||||||
{saving ? <Spinner size="tiny" /> : "Save Schedule"}
|
|
||||||
</Button>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div className={styles.metaRow}>
|
|
||||||
<div className={styles.metaItem}>
|
|
||||||
<Text size={200} weight="semibold">
|
|
||||||
Last run
|
|
||||||
</Text>
|
|
||||||
<Text size={200}>{info?.last_run_at ?? "Never"}</Text>
|
|
||||||
</div>
|
|
||||||
<div className={styles.metaItem}>
|
|
||||||
<Text size={200} weight="semibold">
|
|
||||||
Next run
|
|
||||||
</Text>
|
|
||||||
<Text size={200}>{info?.next_run_at ?? "Not scheduled"}</Text>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,392 +0,0 @@
|
|||||||
import { useCallback, useState } from "react";
|
|
||||||
import {
|
|
||||||
Button,
|
|
||||||
Dialog,
|
|
||||||
DialogActions,
|
|
||||||
DialogBody,
|
|
||||||
DialogContent,
|
|
||||||
DialogSurface,
|
|
||||||
DialogTitle,
|
|
||||||
Field,
|
|
||||||
Input,
|
|
||||||
MessageBar,
|
|
||||||
MessageBarBody,
|
|
||||||
Spinner,
|
|
||||||
Switch,
|
|
||||||
Table,
|
|
||||||
TableBody,
|
|
||||||
TableCell,
|
|
||||||
TableCellLayout,
|
|
||||||
TableHeader,
|
|
||||||
TableHeaderCell,
|
|
||||||
TableRow,
|
|
||||||
Text,
|
|
||||||
} from "@fluentui/react-components";
|
|
||||||
import { useCommonSectionStyles } from "../../theme/commonStyles";
|
|
||||||
import {
|
|
||||||
AddRegular,
|
|
||||||
ArrowClockwiseRegular,
|
|
||||||
DeleteRegular,
|
|
||||||
EditRegular,
|
|
||||||
EyeRegular,
|
|
||||||
PlayRegular,
|
|
||||||
} from "@fluentui/react-icons";
|
|
||||||
import { useBlocklists } from "../../hooks/useBlocklist";
|
|
||||||
import type { BlocklistSource, PreviewResponse } from "../../types/blocklist";
|
|
||||||
import { useBlocklistStyles } from "./blocklistStyles";
|
|
||||||
|
|
||||||
interface SourceFormValues {
|
|
||||||
name: string;
|
|
||||||
url: string;
|
|
||||||
enabled: boolean;
|
|
||||||
}
|
|
||||||
|
|
||||||
interface SourceFormDialogProps {
|
|
||||||
open: boolean;
|
|
||||||
mode: "add" | "edit";
|
|
||||||
initial: SourceFormValues;
|
|
||||||
saving: boolean;
|
|
||||||
error: string | null;
|
|
||||||
onClose: () => void;
|
|
||||||
onSubmit: (values: SourceFormValues) => void;
|
|
||||||
}
|
|
||||||
|
|
||||||
function SourceFormDialog({
|
|
||||||
open,
|
|
||||||
mode,
|
|
||||||
initial,
|
|
||||||
saving,
|
|
||||||
error,
|
|
||||||
onClose,
|
|
||||||
onSubmit,
|
|
||||||
}: SourceFormDialogProps): React.JSX.Element {
|
|
||||||
const styles = useBlocklistStyles();
|
|
||||||
const [values, setValues] = useState<SourceFormValues>(initial);
|
|
||||||
|
|
||||||
const handleOpen = useCallback((): void => {
|
|
||||||
setValues(initial);
|
|
||||||
}, [initial]);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Dialog
|
|
||||||
open={open}
|
|
||||||
onOpenChange={(_ev, data) => {
|
|
||||||
if (!data.open) onClose();
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<DialogSurface onAnimationEnd={open ? handleOpen : undefined}>
|
|
||||||
<DialogBody>
|
|
||||||
<DialogTitle>{mode === "add" ? "Add Blocklist Source" : "Edit Blocklist Source"}</DialogTitle>
|
|
||||||
<DialogContent>
|
|
||||||
<div className={styles.dialogForm}>
|
|
||||||
{error && (
|
|
||||||
<MessageBar intent="error">
|
|
||||||
<MessageBarBody>{error}</MessageBarBody>
|
|
||||||
</MessageBar>
|
|
||||||
)}
|
|
||||||
<Field label="Name" required>
|
|
||||||
<Input
|
|
||||||
value={values.name}
|
|
||||||
onChange={(_ev, d) => { setValues((p) => ({ ...p, name: d.value })); }}
|
|
||||||
placeholder="e.g. Blocklist.de — All"
|
|
||||||
/>
|
|
||||||
</Field>
|
|
||||||
<Field label="URL" required>
|
|
||||||
<Input
|
|
||||||
value={values.url}
|
|
||||||
onChange={(_ev, d) => { setValues((p) => ({ ...p, url: d.value })); }}
|
|
||||||
placeholder="https://lists.blocklist.de/lists/all.txt"
|
|
||||||
/>
|
|
||||||
</Field>
|
|
||||||
<Switch
|
|
||||||
label="Enabled"
|
|
||||||
checked={values.enabled}
|
|
||||||
onChange={(_ev, d) => { setValues((p) => ({ ...p, enabled: d.checked })); }}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
</DialogContent>
|
|
||||||
<DialogActions>
|
|
||||||
<Button appearance="secondary" onClick={onClose} disabled={saving}>
|
|
||||||
Cancel
|
|
||||||
</Button>
|
|
||||||
<Button
|
|
||||||
appearance="primary"
|
|
||||||
disabled={saving || !values.name.trim() || !values.url.trim()}
|
|
||||||
onClick={() => { onSubmit(values); }}
|
|
||||||
>
|
|
||||||
{saving ? <Spinner size="tiny" /> : mode === "add" ? "Add" : "Save"}
|
|
||||||
</Button>
|
|
||||||
</DialogActions>
|
|
||||||
</DialogBody>
|
|
||||||
</DialogSurface>
|
|
||||||
</Dialog>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
interface PreviewDialogProps {
|
|
||||||
open: boolean;
|
|
||||||
source: BlocklistSource | null;
|
|
||||||
onClose: () => void;
|
|
||||||
fetchPreview: (id: number) => Promise<PreviewResponse>;
|
|
||||||
}
|
|
||||||
|
|
||||||
function PreviewDialog({ open, source, onClose, fetchPreview }: PreviewDialogProps): React.JSX.Element {
|
|
||||||
const styles = useBlocklistStyles();
|
|
||||||
const [data, setData] = useState<PreviewResponse | null>(null);
|
|
||||||
const [loading, setLoading] = useState(false);
|
|
||||||
const [error, setError] = useState<string | null>(null);
|
|
||||||
|
|
||||||
const handleOpen = useCallback((): void => {
|
|
||||||
if (!source) return;
|
|
||||||
setData(null);
|
|
||||||
setError(null);
|
|
||||||
setLoading(true);
|
|
||||||
fetchPreview(source.id)
|
|
||||||
.then((result) => {
|
|
||||||
setData(result);
|
|
||||||
setLoading(false);
|
|
||||||
})
|
|
||||||
.catch((err: unknown) => {
|
|
||||||
setError(err instanceof Error ? err.message : "Failed to fetch preview");
|
|
||||||
setLoading(false);
|
|
||||||
});
|
|
||||||
}, [source, fetchPreview]);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Dialog
|
|
||||||
open={open}
|
|
||||||
onOpenChange={(_ev, d) => {
|
|
||||||
if (!d.open) onClose();
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<DialogSurface onAnimationEnd={open ? handleOpen : undefined}>
|
|
||||||
<DialogBody>
|
|
||||||
<DialogTitle>Preview — {source?.name ?? ""}</DialogTitle>
|
|
||||||
<DialogContent>
|
|
||||||
{loading && (
|
|
||||||
<div style={{ textAlign: "center", padding: "16px" }}>
|
|
||||||
<Spinner label="Downloading…" />
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
{error && (
|
|
||||||
<MessageBar intent="error">
|
|
||||||
<MessageBarBody>{error}</MessageBarBody>
|
|
||||||
</MessageBar>
|
|
||||||
)}
|
|
||||||
{data && (
|
|
||||||
<div style={{ display: "flex", flexDirection: "column", gap: "8px" }}>
|
|
||||||
<Text size={300}>
|
|
||||||
{data.valid_count} valid IPs / {data.skipped_count} skipped of {data.total_lines} total lines. Showing first {data.entries.length}:
|
|
||||||
</Text>
|
|
||||||
<div className={styles.previewList}>
|
|
||||||
{data.entries.map((entry) => (
|
|
||||||
<div key={entry}>{entry}</div>
|
|
||||||
))}
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
</DialogContent>
|
|
||||||
<DialogActions>
|
|
||||||
<Button appearance="secondary" onClick={onClose}>
|
|
||||||
Close
|
|
||||||
</Button>
|
|
||||||
</DialogActions>
|
|
||||||
</DialogBody>
|
|
||||||
</DialogSurface>
|
|
||||||
</Dialog>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
interface SourcesSectionProps {
|
|
||||||
onRunImport: () => void;
|
|
||||||
runImportRunning: boolean;
|
|
||||||
}
|
|
||||||
|
|
||||||
const EMPTY_SOURCE: SourceFormValues = { name: "", url: "", enabled: true };
|
|
||||||
|
|
||||||
export function BlocklistSourcesSection({ onRunImport, runImportRunning }: SourcesSectionProps): React.JSX.Element {
|
|
||||||
const styles = useBlocklistStyles();
|
|
||||||
const sectionStyles = useCommonSectionStyles();
|
|
||||||
const { sources, loading, error, refresh, createSource, updateSource, removeSource, previewSource } = useBlocklists();
|
|
||||||
|
|
||||||
const [dialogOpen, setDialogOpen] = useState(false);
|
|
||||||
const [dialogMode, setDialogMode] = useState<"add" | "edit">("add");
|
|
||||||
const [dialogInitial, setDialogInitial] = useState<SourceFormValues>(EMPTY_SOURCE);
|
|
||||||
const [editingId, setEditingId] = useState<number | null>(null);
|
|
||||||
const [saving, setSaving] = useState(false);
|
|
||||||
const [saveError, setSaveError] = useState<string | null>(null);
|
|
||||||
const [previewOpen, setPreviewOpen] = useState(false);
|
|
||||||
const [previewSourceItem, setPreviewSourceItem] = useState<BlocklistSource | null>(null);
|
|
||||||
|
|
||||||
const openAdd = useCallback((): void => {
|
|
||||||
setDialogMode("add");
|
|
||||||
setDialogInitial(EMPTY_SOURCE);
|
|
||||||
setEditingId(null);
|
|
||||||
setSaveError(null);
|
|
||||||
setDialogOpen(true);
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
const openEdit = useCallback((source: BlocklistSource): void => {
|
|
||||||
setDialogMode("edit");
|
|
||||||
setDialogInitial({ name: source.name, url: source.url, enabled: source.enabled });
|
|
||||||
setEditingId(source.id);
|
|
||||||
setSaveError(null);
|
|
||||||
setDialogOpen(true);
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
const handleSubmit = useCallback(
|
|
||||||
(values: SourceFormValues): void => {
|
|
||||||
setSaving(true);
|
|
||||||
setSaveError(null);
|
|
||||||
const op =
|
|
||||||
dialogMode === "add"
|
|
||||||
? createSource({ name: values.name, url: values.url, enabled: values.enabled })
|
|
||||||
: updateSource(editingId ?? -1, { name: values.name, url: values.url, enabled: values.enabled });
|
|
||||||
op
|
|
||||||
.then(() => {
|
|
||||||
setSaving(false);
|
|
||||||
setDialogOpen(false);
|
|
||||||
})
|
|
||||||
.catch((err: unknown) => {
|
|
||||||
setSaving(false);
|
|
||||||
setSaveError(err instanceof Error ? err.message : "Failed to save source");
|
|
||||||
});
|
|
||||||
},
|
|
||||||
[dialogMode, editingId, createSource, updateSource],
|
|
||||||
);
|
|
||||||
|
|
||||||
const handleToggleEnabled = useCallback(
|
|
||||||
(source: BlocklistSource): void => {
|
|
||||||
void updateSource(source.id, { enabled: !source.enabled });
|
|
||||||
},
|
|
||||||
[updateSource],
|
|
||||||
);
|
|
||||||
|
|
||||||
const handleDelete = useCallback(
|
|
||||||
(source: BlocklistSource): void => {
|
|
||||||
void removeSource(source.id);
|
|
||||||
},
|
|
||||||
[removeSource],
|
|
||||||
);
|
|
||||||
|
|
||||||
const handlePreview = useCallback((source: BlocklistSource): void => {
|
|
||||||
setPreviewSourceItem(source);
|
|
||||||
setPreviewOpen(true);
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<div className={sectionStyles.section}>
|
|
||||||
<div className={sectionStyles.sectionHeader}>
|
|
||||||
<Text size={500} weight="semibold">
|
|
||||||
Blocklist Sources
|
|
||||||
</Text>
|
|
||||||
<div style={{ display: "flex", gap: "8px" }}>
|
|
||||||
<Button icon={<PlayRegular />} appearance="secondary" onClick={onRunImport} disabled={runImportRunning}>
|
|
||||||
{runImportRunning ? <Spinner size="tiny" /> : "Run Now"}
|
|
||||||
</Button>
|
|
||||||
<Button icon={<ArrowClockwiseRegular />} appearance="secondary" onClick={refresh}>
|
|
||||||
Refresh
|
|
||||||
</Button>
|
|
||||||
<Button icon={<AddRegular />} appearance="primary" onClick={openAdd}>
|
|
||||||
Add Source
|
|
||||||
</Button>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
{error && (
|
|
||||||
<MessageBar intent="error">
|
|
||||||
<MessageBarBody>{error}</MessageBarBody>
|
|
||||||
</MessageBar>
|
|
||||||
)}
|
|
||||||
|
|
||||||
{loading ? (
|
|
||||||
<div className={styles.centred}>
|
|
||||||
<Spinner label="Loading sources…" />
|
|
||||||
</div>
|
|
||||||
) : sources.length === 0 ? (
|
|
||||||
<div className={styles.centred}>
|
|
||||||
<Text>No blocklist sources configured. Click "Add Source" to get started.</Text>
|
|
||||||
</div>
|
|
||||||
) : (
|
|
||||||
<div className={styles.tableWrapper}>
|
|
||||||
<Table>
|
|
||||||
<TableHeader>
|
|
||||||
<TableRow>
|
|
||||||
<TableHeaderCell>Name</TableHeaderCell>
|
|
||||||
<TableHeaderCell>URL</TableHeaderCell>
|
|
||||||
<TableHeaderCell>Enabled</TableHeaderCell>
|
|
||||||
<TableHeaderCell>Actions</TableHeaderCell>
|
|
||||||
</TableRow>
|
|
||||||
</TableHeader>
|
|
||||||
<TableBody>
|
|
||||||
{sources.map((source) => (
|
|
||||||
<TableRow key={source.id}>
|
|
||||||
<TableCell>
|
|
||||||
<TableCellLayout>{source.name}</TableCellLayout>
|
|
||||||
</TableCell>
|
|
||||||
<TableCell>
|
|
||||||
<TableCellLayout>
|
|
||||||
<span className={styles.mono}>{source.url}</span>
|
|
||||||
</TableCellLayout>
|
|
||||||
</TableCell>
|
|
||||||
<TableCell>
|
|
||||||
<Switch
|
|
||||||
checked={source.enabled}
|
|
||||||
onChange={() => { handleToggleEnabled(source); }}
|
|
||||||
label={source.enabled ? "On" : "Off"}
|
|
||||||
/>
|
|
||||||
</TableCell>
|
|
||||||
<TableCell>
|
|
||||||
<div className={styles.actionsCell}>
|
|
||||||
<Button
|
|
||||||
icon={<EyeRegular />}
|
|
||||||
size="small"
|
|
||||||
appearance="secondary"
|
|
||||||
onClick={() => { handlePreview(source); }}
|
|
||||||
>
|
|
||||||
Preview
|
|
||||||
</Button>
|
|
||||||
<Button
|
|
||||||
icon={<EditRegular />}
|
|
||||||
size="small"
|
|
||||||
appearance="secondary"
|
|
||||||
onClick={() => { openEdit(source); }}
|
|
||||||
>
|
|
||||||
Edit
|
|
||||||
</Button>
|
|
||||||
<Button
|
|
||||||
icon={<DeleteRegular />}
|
|
||||||
size="small"
|
|
||||||
appearance="secondary"
|
|
||||||
onClick={() => { handleDelete(source); }}
|
|
||||||
>
|
|
||||||
Delete
|
|
||||||
</Button>
|
|
||||||
</div>
|
|
||||||
</TableCell>
|
|
||||||
</TableRow>
|
|
||||||
))}
|
|
||||||
</TableBody>
|
|
||||||
</Table>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
|
|
||||||
<SourceFormDialog
|
|
||||||
open={dialogOpen}
|
|
||||||
mode={dialogMode}
|
|
||||||
initial={dialogInitial}
|
|
||||||
saving={saving}
|
|
||||||
error={saveError}
|
|
||||||
onClose={() => { setDialogOpen(false); }}
|
|
||||||
onSubmit={handleSubmit}
|
|
||||||
/>
|
|
||||||
|
|
||||||
<PreviewDialog
|
|
||||||
open={previewOpen}
|
|
||||||
source={previewSourceItem}
|
|
||||||
onClose={() => { setPreviewOpen(false); }}
|
|
||||||
fetchPreview={previewSource}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,62 +0,0 @@
|
|||||||
import { makeStyles, tokens } from "@fluentui/react-components";
|
|
||||||
|
|
||||||
export const useBlocklistStyles = makeStyles({
|
|
||||||
root: {
|
|
||||||
display: "flex",
|
|
||||||
flexDirection: "column",
|
|
||||||
gap: tokens.spacingVerticalXL,
|
|
||||||
},
|
|
||||||
|
|
||||||
tableWrapper: { overflowX: "auto" },
|
|
||||||
actionsCell: { display: "flex", gap: tokens.spacingHorizontalS, flexWrap: "wrap" },
|
|
||||||
mono: { fontFamily: "Consolas, 'Courier New', monospace", fontSize: "12px" },
|
|
||||||
centred: {
|
|
||||||
display: "flex",
|
|
||||||
justifyContent: "center",
|
|
||||||
padding: tokens.spacingVerticalL,
|
|
||||||
},
|
|
||||||
scheduleForm: {
|
|
||||||
display: "flex",
|
|
||||||
flexWrap: "wrap",
|
|
||||||
gap: tokens.spacingHorizontalM,
|
|
||||||
alignItems: "flex-end",
|
|
||||||
},
|
|
||||||
scheduleField: { minWidth: "140px" },
|
|
||||||
metaRow: {
|
|
||||||
display: "flex",
|
|
||||||
gap: tokens.spacingHorizontalL,
|
|
||||||
flexWrap: "wrap",
|
|
||||||
paddingTop: tokens.spacingVerticalS,
|
|
||||||
},
|
|
||||||
metaItem: { display: "flex", flexDirection: "column", gap: "2px" },
|
|
||||||
runResult: {
|
|
||||||
display: "flex",
|
|
||||||
flexDirection: "column",
|
|
||||||
gap: tokens.spacingVerticalXS,
|
|
||||||
maxHeight: "320px",
|
|
||||||
overflowY: "auto",
|
|
||||||
},
|
|
||||||
pagination: {
|
|
||||||
display: "flex",
|
|
||||||
justifyContent: "flex-end",
|
|
||||||
gap: tokens.spacingHorizontalS,
|
|
||||||
alignItems: "center",
|
|
||||||
paddingTop: tokens.spacingVerticalS,
|
|
||||||
},
|
|
||||||
dialogForm: {
|
|
||||||
display: "flex",
|
|
||||||
flexDirection: "column",
|
|
||||||
gap: tokens.spacingVerticalM,
|
|
||||||
minWidth: "380px",
|
|
||||||
},
|
|
||||||
previewList: {
|
|
||||||
fontFamily: "Consolas, 'Courier New', monospace",
|
|
||||||
fontSize: "12px",
|
|
||||||
maxHeight: "280px",
|
|
||||||
overflowY: "auto",
|
|
||||||
backgroundColor: tokens.colorNeutralBackground3,
|
|
||||||
padding: tokens.spacingVerticalS,
|
|
||||||
borderRadius: tokens.borderRadiusMedium,
|
|
||||||
},
|
|
||||||
errorRow: { backgroundColor: tokens.colorStatusDangerBackground1 },
|
|
||||||
});
|
|
||||||
@@ -216,7 +216,6 @@ function JailConfigDetail({
|
|||||||
ignore_regex: ignoreRegex,
|
ignore_regex: ignoreRegex,
|
||||||
date_pattern: datePattern !== "" ? datePattern : null,
|
date_pattern: datePattern !== "" ? datePattern : null,
|
||||||
dns_mode: dnsMode,
|
dns_mode: dnsMode,
|
||||||
backend,
|
|
||||||
log_encoding: logEncoding,
|
log_encoding: logEncoding,
|
||||||
prefregex: prefRegex !== "" ? prefRegex : null,
|
prefregex: prefRegex !== "" ? prefRegex : null,
|
||||||
bantime_escalation: {
|
bantime_escalation: {
|
||||||
@@ -231,7 +230,7 @@ function JailConfigDetail({
|
|||||||
}),
|
}),
|
||||||
[
|
[
|
||||||
banTime, findTime, maxRetry, failRegex, ignoreRegex, datePattern,
|
banTime, findTime, maxRetry, failRegex, ignoreRegex, datePattern,
|
||||||
dnsMode, backend, logEncoding, prefRegex, escEnabled, escFactor,
|
dnsMode, logEncoding, prefRegex, escEnabled, escFactor,
|
||||||
escFormula, escMultipliers, escMaxTime, escRndTime, escOverallJails,
|
escFormula, escMultipliers, escMaxTime, escRndTime, escOverallJails,
|
||||||
jail.ban_time, jail.find_time, jail.max_retry,
|
jail.ban_time, jail.find_time, jail.max_retry,
|
||||||
],
|
],
|
||||||
@@ -758,7 +757,12 @@ function InactiveJailDetail({
|
|||||||
*
|
*
|
||||||
* @returns JSX element.
|
* @returns JSX element.
|
||||||
*/
|
*/
|
||||||
export function JailsTab(): React.JSX.Element {
|
interface JailsTabProps {
|
||||||
|
/** Jail name to pre-select when the component mounts. */
|
||||||
|
initialJail?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function JailsTab({ initialJail }: JailsTabProps): React.JSX.Element {
|
||||||
const styles = useConfigStyles();
|
const styles = useConfigStyles();
|
||||||
const { jails, loading, error, refresh, updateJail } =
|
const { jails, loading, error, refresh, updateJail } =
|
||||||
useJailConfigs();
|
useJailConfigs();
|
||||||
@@ -819,6 +823,13 @@ export function JailsTab(): React.JSX.Element {
|
|||||||
return [...activeItems, ...inactiveItems];
|
return [...activeItems, ...inactiveItems];
|
||||||
}, [jails, inactiveJails]);
|
}, [jails, inactiveJails]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (!initialJail || selectedName) return;
|
||||||
|
if (listItems.some((item) => item.name === initialJail)) {
|
||||||
|
setSelectedName(initialJail);
|
||||||
|
}
|
||||||
|
}, [initialJail, listItems, selectedName]);
|
||||||
|
|
||||||
const activeJailMap = useMemo(
|
const activeJailMap = useMemo(
|
||||||
() => new Map(jails.map((j) => [j.name, j])),
|
() => new Map(jails.map((j) => [j.name, j])),
|
||||||
[jails],
|
[jails],
|
||||||
|
|||||||
@@ -25,10 +25,15 @@ import {
|
|||||||
ArrowSync24Regular,
|
ArrowSync24Regular,
|
||||||
} from "@fluentui/react-icons";
|
} from "@fluentui/react-icons";
|
||||||
import { ApiError } from "../../api/client";
|
import { ApiError } from "../../api/client";
|
||||||
import type { ServerSettingsUpdate, MapColorThresholdsUpdate } from "../../types/config";
|
import type { ServerSettingsUpdate, MapColorThresholdsResponse, MapColorThresholdsUpdate } from "../../types/config";
|
||||||
import { useServerSettings } from "../../hooks/useConfig";
|
import { useServerSettings } from "../../hooks/useConfig";
|
||||||
import { useAutoSave } from "../../hooks/useAutoSave";
|
import { useAutoSave } from "../../hooks/useAutoSave";
|
||||||
import { useMapColorThresholds } from "../../hooks/useMapColorThresholds";
|
import {
|
||||||
|
fetchMapColorThresholds,
|
||||||
|
updateMapColorThresholds,
|
||||||
|
reloadConfig,
|
||||||
|
restartFail2Ban,
|
||||||
|
} from "../../api/config";
|
||||||
import { AutoSaveIndicator } from "./AutoSaveIndicator";
|
import { AutoSaveIndicator } from "./AutoSaveIndicator";
|
||||||
import { ServerHealthSection } from "./ServerHealthSection";
|
import { ServerHealthSection } from "./ServerHealthSection";
|
||||||
import { useConfigStyles } from "./configStyles";
|
import { useConfigStyles } from "./configStyles";
|
||||||
@@ -43,7 +48,7 @@ const LOG_LEVELS = ["CRITICAL", "ERROR", "WARNING", "NOTICE", "INFO", "DEBUG"];
|
|||||||
*/
|
*/
|
||||||
export function ServerTab(): React.JSX.Element {
|
export function ServerTab(): React.JSX.Element {
|
||||||
const styles = useConfigStyles();
|
const styles = useConfigStyles();
|
||||||
const { settings, loading, error, updateSettings, flush, reload, restart } =
|
const { settings, loading, error, updateSettings, flush } =
|
||||||
useServerSettings();
|
useServerSettings();
|
||||||
const [logLevel, setLogLevel] = useState("");
|
const [logLevel, setLogLevel] = useState("");
|
||||||
const [logTarget, setLogTarget] = useState("");
|
const [logTarget, setLogTarget] = useState("");
|
||||||
@@ -57,15 +62,11 @@ export function ServerTab(): React.JSX.Element {
|
|||||||
const [isRestarting, setIsRestarting] = useState(false);
|
const [isRestarting, setIsRestarting] = useState(false);
|
||||||
|
|
||||||
// Map color thresholds
|
// Map color thresholds
|
||||||
const {
|
const [mapThresholds, setMapThresholds] = useState<MapColorThresholdsResponse | null>(null);
|
||||||
thresholds: mapThresholds,
|
|
||||||
error: mapThresholdsError,
|
|
||||||
refresh: refreshMapThresholds,
|
|
||||||
updateThresholds: updateMapThresholds,
|
|
||||||
} = useMapColorThresholds();
|
|
||||||
const [mapThresholdHigh, setMapThresholdHigh] = useState("");
|
const [mapThresholdHigh, setMapThresholdHigh] = useState("");
|
||||||
const [mapThresholdMedium, setMapThresholdMedium] = useState("");
|
const [mapThresholdMedium, setMapThresholdMedium] = useState("");
|
||||||
const [mapThresholdLow, setMapThresholdLow] = useState("");
|
const [mapThresholdLow, setMapThresholdLow] = useState("");
|
||||||
|
const [mapLoadError, setMapLoadError] = useState<string | null>(null);
|
||||||
|
|
||||||
const effectiveLogLevel = logLevel || settings?.log_level || "";
|
const effectiveLogLevel = logLevel || settings?.log_level || "";
|
||||||
const effectiveLogTarget = logTarget || settings?.log_target || "";
|
const effectiveLogTarget = logTarget || settings?.log_target || "";
|
||||||
@@ -104,11 +105,11 @@ export function ServerTab(): React.JSX.Element {
|
|||||||
}
|
}
|
||||||
}, [flush]);
|
}, [flush]);
|
||||||
|
|
||||||
const handleReload = async (): Promise<void> => {
|
const handleReload = useCallback(async () => {
|
||||||
setIsReloading(true);
|
setIsReloading(true);
|
||||||
setMsg(null);
|
setMsg(null);
|
||||||
try {
|
try {
|
||||||
await reload();
|
await reloadConfig();
|
||||||
setMsg({ text: "fail2ban reloaded successfully", ok: true });
|
setMsg({ text: "fail2ban reloaded successfully", ok: true });
|
||||||
} catch (err: unknown) {
|
} catch (err: unknown) {
|
||||||
setMsg({
|
setMsg({
|
||||||
@@ -118,13 +119,13 @@ export function ServerTab(): React.JSX.Element {
|
|||||||
} finally {
|
} finally {
|
||||||
setIsReloading(false);
|
setIsReloading(false);
|
||||||
}
|
}
|
||||||
};
|
}, []);
|
||||||
|
|
||||||
const handleRestart = async (): Promise<void> => {
|
const handleRestart = useCallback(async () => {
|
||||||
setIsRestarting(true);
|
setIsRestarting(true);
|
||||||
setMsg(null);
|
setMsg(null);
|
||||||
try {
|
try {
|
||||||
await restart();
|
await restartFail2Ban();
|
||||||
setMsg({ text: "fail2ban restart initiated", ok: true });
|
setMsg({ text: "fail2ban restart initiated", ok: true });
|
||||||
} catch (err: unknown) {
|
} catch (err: unknown) {
|
||||||
setMsg({
|
setMsg({
|
||||||
@@ -134,15 +135,27 @@ export function ServerTab(): React.JSX.Element {
|
|||||||
} finally {
|
} finally {
|
||||||
setIsRestarting(false);
|
setIsRestarting(false);
|
||||||
}
|
}
|
||||||
};
|
}, []);
|
||||||
|
|
||||||
|
// Load map color thresholds on mount.
|
||||||
|
const loadMapThresholds = useCallback(async (): Promise<void> => {
|
||||||
|
try {
|
||||||
|
const data = await fetchMapColorThresholds();
|
||||||
|
setMapThresholds(data);
|
||||||
|
setMapThresholdHigh(String(data.threshold_high));
|
||||||
|
setMapThresholdMedium(String(data.threshold_medium));
|
||||||
|
setMapThresholdLow(String(data.threshold_low));
|
||||||
|
setMapLoadError(null);
|
||||||
|
} catch (err) {
|
||||||
|
setMapLoadError(
|
||||||
|
err instanceof ApiError ? err.message : "Failed to load map color thresholds",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}, []);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (!mapThresholds) return;
|
void loadMapThresholds();
|
||||||
|
}, [loadMapThresholds]);
|
||||||
setMapThresholdHigh(String(mapThresholds.threshold_high));
|
|
||||||
setMapThresholdMedium(String(mapThresholds.threshold_medium));
|
|
||||||
setMapThresholdLow(String(mapThresholds.threshold_low));
|
|
||||||
}, [mapThresholds]);
|
|
||||||
|
|
||||||
// Map threshold validation and auto-save.
|
// Map threshold validation and auto-save.
|
||||||
const mapHigh = Number(mapThresholdHigh);
|
const mapHigh = Number(mapThresholdHigh);
|
||||||
@@ -177,10 +190,9 @@ export function ServerTab(): React.JSX.Element {
|
|||||||
|
|
||||||
const saveMapThresholds = useCallback(
|
const saveMapThresholds = useCallback(
|
||||||
async (payload: MapColorThresholdsUpdate): Promise<void> => {
|
async (payload: MapColorThresholdsUpdate): Promise<void> => {
|
||||||
await updateMapThresholds(payload);
|
await updateMapColorThresholds(payload);
|
||||||
await refreshMapThresholds();
|
|
||||||
},
|
},
|
||||||
[refreshMapThresholds, updateMapThresholds],
|
[],
|
||||||
);
|
);
|
||||||
|
|
||||||
const { status: mapSaveStatus, errorText: mapSaveErrorText, retry: retryMapSave } =
|
const { status: mapSaveStatus, errorText: mapSaveErrorText, retry: retryMapSave } =
|
||||||
@@ -320,10 +332,10 @@ export function ServerTab(): React.JSX.Element {
|
|||||||
</div>
|
</div>
|
||||||
|
|
||||||
{/* Map Color Thresholds section */}
|
{/* Map Color Thresholds section */}
|
||||||
{mapThresholdsError ? (
|
{mapLoadError ? (
|
||||||
<div className={styles.sectionCard}>
|
<div className={styles.sectionCard}>
|
||||||
<MessageBar intent="error">
|
<MessageBar intent="error">
|
||||||
<MessageBarBody>{mapThresholdsError}</MessageBarBody>
|
<MessageBarBody>{mapLoadError}</MessageBarBody>
|
||||||
</MessageBar>
|
</MessageBar>
|
||||||
</div>
|
</div>
|
||||||
) : mapThresholds ? (
|
) : mapThresholds ? (
|
||||||
|
|||||||
77
frontend/src/components/config/__tests__/JailsTab.test.tsx
Normal file
77
frontend/src/components/config/__tests__/JailsTab.test.tsx
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
import { describe, expect, it, vi } from "vitest";
|
||||||
|
import { render } from "@testing-library/react";
|
||||||
|
import { FluentProvider, webLightTheme } from "@fluentui/react-components";
|
||||||
|
|
||||||
|
import { JailsTab } from "../JailsTab";
|
||||||
|
import type { JailConfig } from "../../../types/config";
|
||||||
|
import { useAutoSave } from "../../../hooks/useAutoSave";
|
||||||
|
import { useJailConfigs } from "../../../hooks/useConfig";
|
||||||
|
import { useConfigActiveStatus } from "../../../hooks/useConfigActiveStatus";
|
||||||
|
|
||||||
|
vi.mock("../../../hooks/useAutoSave");
|
||||||
|
vi.mock("../../../hooks/useConfig");
|
||||||
|
vi.mock("../../../hooks/useConfigActiveStatus");
|
||||||
|
vi.mock("../../../api/config", () => ({
|
||||||
|
fetchInactiveJails: vi.fn().mockResolvedValue({ jails: [] }),
|
||||||
|
deactivateJail: vi.fn(),
|
||||||
|
deleteJailLocalOverride: vi.fn(),
|
||||||
|
addLogPath: vi.fn(),
|
||||||
|
deleteLogPath: vi.fn(),
|
||||||
|
fetchJailConfigFileContent: vi.fn(),
|
||||||
|
updateJailConfigFile: vi.fn(),
|
||||||
|
validateJailConfig: vi.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
const mockUseAutoSave = vi.mocked(useAutoSave);
|
||||||
|
const mockUseJailConfigs = vi.mocked(useJailConfigs);
|
||||||
|
const mockUseConfigActiveStatus = vi.mocked(useConfigActiveStatus);
|
||||||
|
|
||||||
|
const basicJail: JailConfig = {
|
||||||
|
name: "sshd",
|
||||||
|
ban_time: 600,
|
||||||
|
max_retry: 5,
|
||||||
|
find_time: 600,
|
||||||
|
fail_regex: [],
|
||||||
|
ignore_regex: [],
|
||||||
|
log_paths: [],
|
||||||
|
date_pattern: null,
|
||||||
|
log_encoding: "auto",
|
||||||
|
backend: "polling",
|
||||||
|
use_dns: "warn",
|
||||||
|
prefregex: "",
|
||||||
|
actions: [],
|
||||||
|
bantime_escalation: null,
|
||||||
|
};
|
||||||
|
|
||||||
|
describe("JailsTab", () => {
|
||||||
|
it("does not include backend in auto-save payload", () => {
|
||||||
|
const autoSavePayloads: Array<Record<string, unknown>> = [];
|
||||||
|
mockUseAutoSave.mockImplementation((value) => {
|
||||||
|
autoSavePayloads.push(value as Record<string, unknown>);
|
||||||
|
return { status: "idle", errorText: null, retry: vi.fn() };
|
||||||
|
});
|
||||||
|
|
||||||
|
mockUseJailConfigs.mockReturnValue({
|
||||||
|
jails: [basicJail],
|
||||||
|
total: 1,
|
||||||
|
loading: false,
|
||||||
|
error: null,
|
||||||
|
refresh: vi.fn(),
|
||||||
|
updateJail: vi.fn(),
|
||||||
|
reloadAll: vi.fn(),
|
||||||
|
});
|
||||||
|
|
||||||
|
mockUseConfigActiveStatus.mockReturnValue({ activeJails: [] });
|
||||||
|
|
||||||
|
render(
|
||||||
|
<FluentProvider theme={webLightTheme}>
|
||||||
|
<JailsTab initialJail="sshd" />
|
||||||
|
</FluentProvider>,
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(autoSavePayloads.length).toBeGreaterThan(0);
|
||||||
|
const lastPayload = autoSavePayloads[autoSavePayloads.length - 1];
|
||||||
|
|
||||||
|
expect(lastPayload).not.toHaveProperty("backend");
|
||||||
|
});
|
||||||
|
});
|
||||||
@@ -9,6 +9,7 @@
|
|||||||
* remains fast even when a jail contains thousands of banned IPs.
|
* remains fast even when a jail contains thousands of banned IPs.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
import { useCallback, useEffect, useRef, useState } from "react";
|
||||||
import {
|
import {
|
||||||
Badge,
|
Badge,
|
||||||
Button,
|
Button,
|
||||||
@@ -32,8 +33,6 @@ import {
|
|||||||
type TableColumnDefinition,
|
type TableColumnDefinition,
|
||||||
createTableColumn,
|
createTableColumn,
|
||||||
} from "@fluentui/react-components";
|
} from "@fluentui/react-components";
|
||||||
import { useCommonSectionStyles } from "../../theme/commonStyles";
|
|
||||||
import { formatTimestamp } from "../../utils/formatDate";
|
|
||||||
import {
|
import {
|
||||||
ArrowClockwiseRegular,
|
ArrowClockwiseRegular,
|
||||||
ChevronLeftRegular,
|
ChevronLeftRegular,
|
||||||
@@ -41,12 +40,17 @@ import {
|
|||||||
DismissRegular,
|
DismissRegular,
|
||||||
SearchRegular,
|
SearchRegular,
|
||||||
} from "@fluentui/react-icons";
|
} from "@fluentui/react-icons";
|
||||||
|
import { fetchJailBannedIps, unbanIp } from "../../api/jails";
|
||||||
import type { ActiveBan } from "../../types/jail";
|
import type { ActiveBan } from "../../types/jail";
|
||||||
|
import { ApiError } from "../../api/client";
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
// Constants
|
// Constants
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
/** Debounce delay in milliseconds for the search input. */
|
||||||
|
const SEARCH_DEBOUNCE_MS = 300;
|
||||||
|
|
||||||
/** Available page-size options. */
|
/** Available page-size options. */
|
||||||
const PAGE_SIZE_OPTIONS = [10, 25, 50, 100] as const;
|
const PAGE_SIZE_OPTIONS = [10, 25, 50, 100] as const;
|
||||||
|
|
||||||
@@ -55,6 +59,26 @@ const PAGE_SIZE_OPTIONS = [10, 25, 50, 100] as const;
|
|||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
const useStyles = makeStyles({
|
const useStyles = makeStyles({
|
||||||
|
root: {
|
||||||
|
display: "flex",
|
||||||
|
flexDirection: "column",
|
||||||
|
gap: tokens.spacingVerticalS,
|
||||||
|
backgroundColor: tokens.colorNeutralBackground1,
|
||||||
|
borderRadius: tokens.borderRadiusMedium,
|
||||||
|
borderTopWidth: "1px",
|
||||||
|
borderTopStyle: "solid",
|
||||||
|
borderTopColor: tokens.colorNeutralStroke2,
|
||||||
|
borderRightWidth: "1px",
|
||||||
|
borderRightStyle: "solid",
|
||||||
|
borderRightColor: tokens.colorNeutralStroke2,
|
||||||
|
borderBottomWidth: "1px",
|
||||||
|
borderBottomStyle: "solid",
|
||||||
|
borderBottomColor: tokens.colorNeutralStroke2,
|
||||||
|
borderLeftWidth: "1px",
|
||||||
|
borderLeftStyle: "solid",
|
||||||
|
borderLeftColor: tokens.colorNeutralStroke2,
|
||||||
|
padding: tokens.spacingVerticalM,
|
||||||
|
},
|
||||||
header: {
|
header: {
|
||||||
display: "flex",
|
display: "flex",
|
||||||
alignItems: "center",
|
alignItems: "center",
|
||||||
@@ -108,6 +132,31 @@ const useStyles = makeStyles({
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Helpers
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Format an ISO 8601 timestamp for compact display.
|
||||||
|
*
|
||||||
|
* @param iso - ISO 8601 string or `null`.
|
||||||
|
* @returns A locale time string, or `"—"` when `null`.
|
||||||
|
*/
|
||||||
|
function fmtTime(iso: string | null): string {
|
||||||
|
if (!iso) return "—";
|
||||||
|
try {
|
||||||
|
return new Date(iso).toLocaleString(undefined, {
|
||||||
|
year: "numeric",
|
||||||
|
month: "2-digit",
|
||||||
|
day: "2-digit",
|
||||||
|
hour: "2-digit",
|
||||||
|
minute: "2-digit",
|
||||||
|
});
|
||||||
|
} catch {
|
||||||
|
return iso;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
// Column definitions
|
// Column definitions
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
@@ -115,7 +164,7 @@ const useStyles = makeStyles({
|
|||||||
/** A row item augmented with an `onUnban` callback for the row action. */
|
/** A row item augmented with an `onUnban` callback for the row action. */
|
||||||
interface BanRow {
|
interface BanRow {
|
||||||
ban: ActiveBan;
|
ban: ActiveBan;
|
||||||
onUnban: (ip: string) => Promise<void>;
|
onUnban: (ip: string) => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
const columns: TableColumnDefinition<BanRow>[] = [
|
const columns: TableColumnDefinition<BanRow>[] = [
|
||||||
@@ -148,16 +197,12 @@ const columns: TableColumnDefinition<BanRow>[] = [
|
|||||||
createTableColumn<BanRow>({
|
createTableColumn<BanRow>({
|
||||||
columnId: "banned_at",
|
columnId: "banned_at",
|
||||||
renderHeaderCell: () => "Banned At",
|
renderHeaderCell: () => "Banned At",
|
||||||
renderCell: ({ ban }) => (
|
renderCell: ({ ban }) => <Text size={200}>{fmtTime(ban.banned_at)}</Text>,
|
||||||
<Text size={200}>{ban.banned_at ? formatTimestamp(ban.banned_at) : "—"}</Text>
|
|
||||||
),
|
|
||||||
}),
|
}),
|
||||||
createTableColumn<BanRow>({
|
createTableColumn<BanRow>({
|
||||||
columnId: "expires_at",
|
columnId: "expires_at",
|
||||||
renderHeaderCell: () => "Expires At",
|
renderHeaderCell: () => "Expires At",
|
||||||
renderCell: ({ ban }) => (
|
renderCell: ({ ban }) => <Text size={200}>{fmtTime(ban.expires_at)}</Text>,
|
||||||
<Text size={200}>{ban.expires_at ? formatTimestamp(ban.expires_at) : "—"}</Text>
|
|
||||||
),
|
|
||||||
}),
|
}),
|
||||||
createTableColumn<BanRow>({
|
createTableColumn<BanRow>({
|
||||||
columnId: "actions",
|
columnId: "actions",
|
||||||
@@ -168,7 +213,9 @@ const columns: TableColumnDefinition<BanRow>[] = [
|
|||||||
size="small"
|
size="small"
|
||||||
appearance="subtle"
|
appearance="subtle"
|
||||||
icon={<DismissRegular />}
|
icon={<DismissRegular />}
|
||||||
onClick={() => { void onUnban(ban.ip); }}
|
onClick={() => {
|
||||||
|
onUnban(ban.ip);
|
||||||
|
}}
|
||||||
aria-label={`Unban ${ban.ip}`}
|
aria-label={`Unban ${ban.ip}`}
|
||||||
/>
|
/>
|
||||||
</Tooltip>
|
</Tooltip>
|
||||||
@@ -182,19 +229,8 @@ const columns: TableColumnDefinition<BanRow>[] = [
|
|||||||
|
|
||||||
/** Props for {@link BannedIpsSection}. */
|
/** Props for {@link BannedIpsSection}. */
|
||||||
export interface BannedIpsSectionProps {
|
export interface BannedIpsSectionProps {
|
||||||
items: ActiveBan[];
|
/** The jail name whose banned IPs are displayed. */
|
||||||
total: number;
|
jailName: string;
|
||||||
page: number;
|
|
||||||
pageSize: number;
|
|
||||||
search: string;
|
|
||||||
loading: boolean;
|
|
||||||
error: string | null;
|
|
||||||
opError: string | null;
|
|
||||||
onSearch: (term: string) => void;
|
|
||||||
onPageChange: (page: number) => void;
|
|
||||||
onPageSizeChange: (size: number) => void;
|
|
||||||
onRefresh: () => Promise<void>;
|
|
||||||
onUnban: (ip: string) => Promise<void>;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
@@ -206,33 +242,87 @@ export interface BannedIpsSectionProps {
|
|||||||
*
|
*
|
||||||
* @param props - {@link BannedIpsSectionProps}
|
* @param props - {@link BannedIpsSectionProps}
|
||||||
*/
|
*/
|
||||||
export function BannedIpsSection({
|
export function BannedIpsSection({ jailName }: BannedIpsSectionProps): React.JSX.Element {
|
||||||
items,
|
|
||||||
total,
|
|
||||||
page,
|
|
||||||
pageSize,
|
|
||||||
search,
|
|
||||||
loading,
|
|
||||||
error,
|
|
||||||
opError,
|
|
||||||
onSearch,
|
|
||||||
onPageChange,
|
|
||||||
onPageSizeChange,
|
|
||||||
onRefresh,
|
|
||||||
onUnban,
|
|
||||||
}: BannedIpsSectionProps): React.JSX.Element {
|
|
||||||
const styles = useStyles();
|
const styles = useStyles();
|
||||||
const sectionStyles = useCommonSectionStyles();
|
|
||||||
|
const [items, setItems] = useState<ActiveBan[]>([]);
|
||||||
|
const [total, setTotal] = useState(0);
|
||||||
|
const [page, setPage] = useState(1);
|
||||||
|
const [pageSize, setPageSize] = useState<number>(25);
|
||||||
|
const [search, setSearch] = useState("");
|
||||||
|
const [debouncedSearch, setDebouncedSearch] = useState("");
|
||||||
|
const [loading, setLoading] = useState(false);
|
||||||
|
const [error, setError] = useState<string | null>(null);
|
||||||
|
const [opError, setOpError] = useState<string | null>(null);
|
||||||
|
|
||||||
|
const debounceRef = useRef<ReturnType<typeof setTimeout> | null>(null);
|
||||||
|
|
||||||
|
// Debounce the search input so we don't spam the backend on every keystroke.
|
||||||
|
useEffect(() => {
|
||||||
|
if (debounceRef.current !== null) {
|
||||||
|
clearTimeout(debounceRef.current);
|
||||||
|
}
|
||||||
|
debounceRef.current = setTimeout((): void => {
|
||||||
|
setDebouncedSearch(search);
|
||||||
|
setPage(1);
|
||||||
|
}, SEARCH_DEBOUNCE_MS);
|
||||||
|
return (): void => {
|
||||||
|
if (debounceRef.current !== null) clearTimeout(debounceRef.current);
|
||||||
|
};
|
||||||
|
}, [search]);
|
||||||
|
|
||||||
|
const load = useCallback(() => {
|
||||||
|
setLoading(true);
|
||||||
|
setError(null);
|
||||||
|
fetchJailBannedIps(jailName, page, pageSize, debouncedSearch || undefined)
|
||||||
|
.then((resp) => {
|
||||||
|
setItems(resp.items);
|
||||||
|
setTotal(resp.total);
|
||||||
|
})
|
||||||
|
.catch((err: unknown) => {
|
||||||
|
const msg =
|
||||||
|
err instanceof ApiError
|
||||||
|
? `${String(err.status)}: ${err.body}`
|
||||||
|
: err instanceof Error
|
||||||
|
? err.message
|
||||||
|
: String(err);
|
||||||
|
setError(msg);
|
||||||
|
})
|
||||||
|
.finally(() => {
|
||||||
|
setLoading(false);
|
||||||
|
});
|
||||||
|
}, [jailName, page, pageSize, debouncedSearch]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
load();
|
||||||
|
}, [load]);
|
||||||
|
|
||||||
|
const handleUnban = (ip: string): void => {
|
||||||
|
setOpError(null);
|
||||||
|
unbanIp(ip, jailName)
|
||||||
|
.then(() => {
|
||||||
|
load();
|
||||||
|
})
|
||||||
|
.catch((err: unknown) => {
|
||||||
|
const msg =
|
||||||
|
err instanceof ApiError
|
||||||
|
? `${String(err.status)}: ${err.body}`
|
||||||
|
: err instanceof Error
|
||||||
|
? err.message
|
||||||
|
: String(err);
|
||||||
|
setOpError(msg);
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
const rows: BanRow[] = items.map((ban) => ({
|
const rows: BanRow[] = items.map((ban) => ({
|
||||||
ban,
|
ban,
|
||||||
onUnban,
|
onUnban: handleUnban,
|
||||||
}));
|
}));
|
||||||
|
|
||||||
const totalPages = pageSize > 0 ? Math.ceil(total / pageSize) : 1;
|
const totalPages = pageSize > 0 ? Math.ceil(total / pageSize) : 1;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className={sectionStyles.section}>
|
<div className={styles.root}>
|
||||||
{/* Section header */}
|
{/* Section header */}
|
||||||
<div className={styles.header}>
|
<div className={styles.header}>
|
||||||
<div className={styles.headerLeft}>
|
<div className={styles.headerLeft}>
|
||||||
@@ -245,7 +335,7 @@ export function BannedIpsSection({
|
|||||||
size="small"
|
size="small"
|
||||||
appearance="subtle"
|
appearance="subtle"
|
||||||
icon={<ArrowClockwiseRegular />}
|
icon={<ArrowClockwiseRegular />}
|
||||||
onClick={() => { void onRefresh(); }}
|
onClick={load}
|
||||||
aria-label="Refresh banned IPs"
|
aria-label="Refresh banned IPs"
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
@@ -260,7 +350,7 @@ export function BannedIpsSection({
|
|||||||
placeholder="e.g. 192.168"
|
placeholder="e.g. 192.168"
|
||||||
value={search}
|
value={search}
|
||||||
onChange={(_, d) => {
|
onChange={(_, d) => {
|
||||||
onSearch(d.value);
|
setSearch(d.value);
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
</Field>
|
</Field>
|
||||||
@@ -330,8 +420,8 @@ export function BannedIpsSection({
|
|||||||
onOptionSelect={(_, d) => {
|
onOptionSelect={(_, d) => {
|
||||||
const newSize = Number(d.optionValue);
|
const newSize = Number(d.optionValue);
|
||||||
if (!Number.isNaN(newSize)) {
|
if (!Number.isNaN(newSize)) {
|
||||||
onPageSizeChange(newSize);
|
setPageSize(newSize);
|
||||||
onPageChange(1);
|
setPage(1);
|
||||||
}
|
}
|
||||||
}}
|
}}
|
||||||
style={{ minWidth: "80px" }}
|
style={{ minWidth: "80px" }}
|
||||||
@@ -355,7 +445,7 @@ export function BannedIpsSection({
|
|||||||
icon={<ChevronLeftRegular />}
|
icon={<ChevronLeftRegular />}
|
||||||
disabled={page <= 1}
|
disabled={page <= 1}
|
||||||
onClick={() => {
|
onClick={() => {
|
||||||
onPageChange(Math.max(1, page - 1));
|
setPage((p) => Math.max(1, p - 1));
|
||||||
}}
|
}}
|
||||||
aria-label="Previous page"
|
aria-label="Previous page"
|
||||||
/>
|
/>
|
||||||
@@ -365,7 +455,7 @@ export function BannedIpsSection({
|
|||||||
icon={<ChevronRightRegular />}
|
icon={<ChevronRightRegular />}
|
||||||
disabled={page >= totalPages}
|
disabled={page >= totalPages}
|
||||||
onClick={() => {
|
onClick={() => {
|
||||||
onPageChange(page + 1);
|
setPage((p) => p + 1);
|
||||||
}}
|
}}
|
||||||
aria-label="Next page"
|
aria-label="Next page"
|
||||||
/>
|
/>
|
||||||
|
|||||||
@@ -1,11 +1,52 @@
|
|||||||
import { describe, it, expect, vi } from "vitest";
|
/**
|
||||||
import { render, screen } from "@testing-library/react";
|
* Tests for the `BannedIpsSection` component.
|
||||||
|
*
|
||||||
|
* Verifies:
|
||||||
|
* - Renders the section header and total count badge.
|
||||||
|
* - Shows a spinner while loading.
|
||||||
|
* - Renders a table with IP rows on success.
|
||||||
|
* - Shows an empty-state message when there are no banned IPs.
|
||||||
|
* - Displays an error message bar when the API call fails.
|
||||||
|
* - Search input re-fetches with the search parameter after debounce.
|
||||||
|
* - Unban button calls `unbanIp` and refreshes the list.
|
||||||
|
* - Pagination buttons are shown and change the page.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { describe, it, expect, vi, beforeEach } from "vitest";
|
||||||
|
import { render, screen, waitFor, act, fireEvent } from "@testing-library/react";
|
||||||
import userEvent from "@testing-library/user-event";
|
import userEvent from "@testing-library/user-event";
|
||||||
import { FluentProvider, webLightTheme } from "@fluentui/react-components";
|
import { FluentProvider, webLightTheme } from "@fluentui/react-components";
|
||||||
import { BannedIpsSection, type BannedIpsSectionProps } from "../BannedIpsSection";
|
import { BannedIpsSection } from "../BannedIpsSection";
|
||||||
import type { ActiveBan } from "../../../types/jail";
|
import type { JailBannedIpsResponse } from "../../../types/jail";
|
||||||
|
|
||||||
function makeBan(ip: string): ActiveBan {
|
// ---------------------------------------------------------------------------
|
||||||
|
// Module mocks
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
const { mockFetchJailBannedIps, mockUnbanIp } = vi.hoisted(() => ({
|
||||||
|
mockFetchJailBannedIps: vi.fn<
|
||||||
|
(
|
||||||
|
jailName: string,
|
||||||
|
page?: number,
|
||||||
|
pageSize?: number,
|
||||||
|
search?: string,
|
||||||
|
) => Promise<JailBannedIpsResponse>
|
||||||
|
>(),
|
||||||
|
mockUnbanIp: vi.fn<
|
||||||
|
(ip: string, jail?: string) => Promise<{ message: string; jail: string }>
|
||||||
|
>(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
vi.mock("../../../api/jails", () => ({
|
||||||
|
fetchJailBannedIps: mockFetchJailBannedIps,
|
||||||
|
unbanIp: mockUnbanIp,
|
||||||
|
}));
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Fixtures
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
function makeBan(ip: string) {
|
||||||
return {
|
return {
|
||||||
ip,
|
ip,
|
||||||
jail: "sshd",
|
jail: "sshd",
|
||||||
@@ -16,65 +57,195 @@ function makeBan(ip: string): ActiveBan {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
function renderWithProps(props: Partial<BannedIpsSectionProps> = {}) {
|
function makeResponse(
|
||||||
const defaults: BannedIpsSectionProps = {
|
ips: string[] = ["1.2.3.4", "5.6.7.8"],
|
||||||
items: [makeBan("1.2.3.4"), makeBan("5.6.7.8")],
|
total = 2,
|
||||||
total: 2,
|
): JailBannedIpsResponse {
|
||||||
|
return {
|
||||||
|
items: ips.map(makeBan),
|
||||||
|
total,
|
||||||
page: 1,
|
page: 1,
|
||||||
pageSize: 25,
|
page_size: 25,
|
||||||
search: "",
|
|
||||||
loading: false,
|
|
||||||
error: null,
|
|
||||||
opError: null,
|
|
||||||
onSearch: vi.fn(),
|
|
||||||
onPageChange: vi.fn(),
|
|
||||||
onPageSizeChange: vi.fn(),
|
|
||||||
onRefresh: vi.fn(),
|
|
||||||
onUnban: vi.fn(),
|
|
||||||
};
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
const EMPTY_RESPONSE: JailBannedIpsResponse = {
|
||||||
|
items: [],
|
||||||
|
total: 0,
|
||||||
|
page: 1,
|
||||||
|
page_size: 25,
|
||||||
|
};
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Helpers
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
function renderSection(jailName = "sshd") {
|
||||||
return render(
|
return render(
|
||||||
<FluentProvider theme={webLightTheme}>
|
<FluentProvider theme={webLightTheme}>
|
||||||
<BannedIpsSection {...defaults} {...props} />
|
<BannedIpsSection jailName={jailName} />
|
||||||
</FluentProvider>,
|
</FluentProvider>,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Tests
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
describe("BannedIpsSection", () => {
|
describe("BannedIpsSection", () => {
|
||||||
it("shows the table rows and total count", () => {
|
beforeEach(() => {
|
||||||
renderWithProps();
|
vi.clearAllMocks();
|
||||||
expect(screen.getByText("Currently Banned IPs")).toBeTruthy();
|
vi.useRealTimers();
|
||||||
expect(screen.getByText("1.2.3.4")).toBeTruthy();
|
mockUnbanIp.mockResolvedValue({ message: "ok", jail: "sshd" });
|
||||||
expect(screen.getByText("5.6.7.8")).toBeTruthy();
|
});
|
||||||
|
|
||||||
|
it("renders section header with 'Currently Banned IPs' title", async () => {
|
||||||
|
mockFetchJailBannedIps.mockResolvedValue(makeResponse());
|
||||||
|
renderSection();
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText("Currently Banned IPs")).toBeTruthy();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it("shows the total count badge", async () => {
|
||||||
|
mockFetchJailBannedIps.mockResolvedValue(makeResponse(["1.2.3.4", "5.6.7.8"], 2));
|
||||||
|
renderSection();
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText("2")).toBeTruthy();
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
it("shows a spinner while loading", () => {
|
it("shows a spinner while loading", () => {
|
||||||
renderWithProps({ loading: true, items: [] });
|
// Never resolves during this test so we see the spinner.
|
||||||
|
mockFetchJailBannedIps.mockReturnValue(new Promise(() => void 0));
|
||||||
|
renderSection();
|
||||||
expect(screen.getByText("Loading banned IPs…")).toBeTruthy();
|
expect(screen.getByText("Loading banned IPs…")).toBeTruthy();
|
||||||
});
|
});
|
||||||
|
|
||||||
it("shows error message when error is present", () => {
|
it("renders IP rows when banned IPs exist", async () => {
|
||||||
renderWithProps({ error: "Failed to load" });
|
mockFetchJailBannedIps.mockResolvedValue(makeResponse(["1.2.3.4", "5.6.7.8"]));
|
||||||
expect(screen.getByText(/Failed to load/i)).toBeTruthy();
|
renderSection();
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText("1.2.3.4")).toBeTruthy();
|
||||||
|
expect(screen.getByText("5.6.7.8")).toBeTruthy();
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
it("triggers onUnban for IP row button", async () => {
|
it("shows empty-state message when no IPs are banned", async () => {
|
||||||
const onUnban = vi.fn();
|
mockFetchJailBannedIps.mockResolvedValue(EMPTY_RESPONSE);
|
||||||
renderWithProps({ onUnban });
|
renderSection();
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(
|
||||||
|
screen.getByText("No IPs currently banned in this jail."),
|
||||||
|
).toBeTruthy();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it("shows an error message bar on API failure", async () => {
|
||||||
|
mockFetchJailBannedIps.mockRejectedValue(new Error("socket dead"));
|
||||||
|
renderSection();
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText(/socket dead/i)).toBeTruthy();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it("calls fetchJailBannedIps with the jail name", async () => {
|
||||||
|
mockFetchJailBannedIps.mockResolvedValue(makeResponse());
|
||||||
|
renderSection("nginx");
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(mockFetchJailBannedIps).toHaveBeenCalledWith(
|
||||||
|
"nginx",
|
||||||
|
expect.any(Number),
|
||||||
|
expect.any(Number),
|
||||||
|
undefined,
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it("search input re-fetches after debounce with the search term", async () => {
|
||||||
|
vi.useFakeTimers();
|
||||||
|
mockFetchJailBannedIps.mockResolvedValue(makeResponse());
|
||||||
|
renderSection();
|
||||||
|
// Flush pending async work from the initial render (no timer advancement needed).
|
||||||
|
await act(async () => {});
|
||||||
|
|
||||||
|
mockFetchJailBannedIps.mockClear();
|
||||||
|
mockFetchJailBannedIps.mockResolvedValue(
|
||||||
|
makeResponse(["1.2.3.4"], 1),
|
||||||
|
);
|
||||||
|
|
||||||
|
// fireEvent is synchronous — avoids hanging with fake timers.
|
||||||
|
const input = screen.getByPlaceholderText("e.g. 192.168");
|
||||||
|
act(() => {
|
||||||
|
fireEvent.change(input, { target: { value: "1.2.3" } });
|
||||||
|
});
|
||||||
|
|
||||||
|
// Advance just past the 300ms debounce delay and flush promises.
|
||||||
|
await act(async () => {
|
||||||
|
await vi.advanceTimersByTimeAsync(350);
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(mockFetchJailBannedIps).toHaveBeenLastCalledWith(
|
||||||
|
"sshd",
|
||||||
|
expect.any(Number),
|
||||||
|
expect.any(Number),
|
||||||
|
"1.2.3",
|
||||||
|
);
|
||||||
|
|
||||||
|
vi.useRealTimers();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("calls unbanIp when the unban button is clicked", async () => {
|
||||||
|
mockFetchJailBannedIps.mockResolvedValue(makeResponse(["1.2.3.4"]));
|
||||||
|
renderSection();
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText("1.2.3.4")).toBeTruthy();
|
||||||
|
});
|
||||||
|
|
||||||
const unbanBtn = screen.getByLabelText("Unban 1.2.3.4");
|
const unbanBtn = screen.getByLabelText("Unban 1.2.3.4");
|
||||||
await userEvent.click(unbanBtn);
|
await userEvent.click(unbanBtn);
|
||||||
|
|
||||||
expect(onUnban).toHaveBeenCalledWith("1.2.3.4");
|
expect(mockUnbanIp).toHaveBeenCalledWith("1.2.3.4", "sshd");
|
||||||
});
|
});
|
||||||
|
|
||||||
it("calls onSearch when the search input changes", async () => {
|
it("refreshes list after successful unban", async () => {
|
||||||
const onSearch = vi.fn();
|
mockFetchJailBannedIps
|
||||||
renderWithProps({ onSearch });
|
.mockResolvedValueOnce(makeResponse(["1.2.3.4"]))
|
||||||
|
.mockResolvedValue(EMPTY_RESPONSE);
|
||||||
|
mockUnbanIp.mockResolvedValue({ message: "ok", jail: "sshd" });
|
||||||
|
|
||||||
const input = screen.getByPlaceholderText("e.g. 192.168");
|
renderSection();
|
||||||
await userEvent.type(input, "1.2.3");
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText("1.2.3.4")).toBeTruthy();
|
||||||
|
});
|
||||||
|
|
||||||
expect(onSearch).toHaveBeenCalled();
|
const unbanBtn = screen.getByLabelText("Unban 1.2.3.4");
|
||||||
|
await userEvent.click(unbanBtn);
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(mockFetchJailBannedIps).toHaveBeenCalledTimes(2);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it("shows pagination controls when total > 0", async () => {
|
||||||
|
mockFetchJailBannedIps.mockResolvedValue(
|
||||||
|
makeResponse(["1.2.3.4", "5.6.7.8"], 50),
|
||||||
|
);
|
||||||
|
renderSection();
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByLabelText("Next page")).toBeTruthy();
|
||||||
|
expect(screen.getByLabelText("Previous page")).toBeTruthy();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it("previous page button is disabled on page 1", async () => {
|
||||||
|
mockFetchJailBannedIps.mockResolvedValue(
|
||||||
|
makeResponse(["1.2.3.4"], 50),
|
||||||
|
);
|
||||||
|
renderSection();
|
||||||
|
await waitFor(() => {
|
||||||
|
const prevBtn = screen.getByLabelText("Previous page");
|
||||||
|
expect(prevBtn).toHaveAttribute("disabled");
|
||||||
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -1,88 +0,0 @@
|
|||||||
import { describe, it, expect, vi, beforeEach, afterEach } from "vitest";
|
|
||||||
import { renderHook, act } from "@testing-library/react";
|
|
||||||
import { useConfigItem } from "../useConfigItem";
|
|
||||||
|
|
||||||
describe("useConfigItem", () => {
|
|
||||||
beforeEach(() => {
|
|
||||||
vi.useFakeTimers();
|
|
||||||
});
|
|
||||||
|
|
||||||
afterEach(() => {
|
|
||||||
vi.useRealTimers();
|
|
||||||
vi.clearAllMocks();
|
|
||||||
});
|
|
||||||
|
|
||||||
it("loads data and sets loading state", async () => {
|
|
||||||
const fetchFn = vi.fn().mockResolvedValue("hello");
|
|
||||||
const saveFn = vi.fn().mockResolvedValue(undefined);
|
|
||||||
|
|
||||||
const { result } = renderHook(() => useConfigItem<string, string>({ fetchFn, saveFn }));
|
|
||||||
|
|
||||||
expect(result.current.loading).toBe(true);
|
|
||||||
await act(async () => {
|
|
||||||
await Promise.resolve();
|
|
||||||
});
|
|
||||||
|
|
||||||
expect(fetchFn).toHaveBeenCalled();
|
|
||||||
expect(result.current.data).toBe("hello");
|
|
||||||
expect(result.current.loading).toBe(false);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("sets error if fetch rejects", async () => {
|
|
||||||
const fetchFn = vi.fn().mockRejectedValue(new Error("nope"));
|
|
||||||
const saveFn = vi.fn().mockResolvedValue(undefined);
|
|
||||||
|
|
||||||
const { result } = renderHook(() => useConfigItem<string, string>({ fetchFn, saveFn }));
|
|
||||||
|
|
||||||
await act(async () => {
|
|
||||||
await Promise.resolve();
|
|
||||||
});
|
|
||||||
|
|
||||||
expect(result.current.error).toBe("nope");
|
|
||||||
expect(result.current.loading).toBe(false);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("save updates data when mergeOnSave is provided", async () => {
|
|
||||||
const fetchFn = vi.fn().mockResolvedValue({ value: 1 });
|
|
||||||
const saveFn = vi.fn().mockResolvedValue(undefined);
|
|
||||||
|
|
||||||
const { result } = renderHook(() =>
|
|
||||||
useConfigItem<{ value: number }, { delta: number }>({
|
|
||||||
fetchFn,
|
|
||||||
saveFn,
|
|
||||||
mergeOnSave: (prev, update) =>
|
|
||||||
prev ? { ...prev, value: prev.value + update.delta } : prev,
|
|
||||||
})
|
|
||||||
);
|
|
||||||
|
|
||||||
await act(async () => {
|
|
||||||
await Promise.resolve();
|
|
||||||
});
|
|
||||||
|
|
||||||
expect(result.current.data).toEqual({ value: 1 });
|
|
||||||
|
|
||||||
await act(async () => {
|
|
||||||
await result.current.save({ delta: 2 });
|
|
||||||
});
|
|
||||||
|
|
||||||
expect(saveFn).toHaveBeenCalledWith({ delta: 2 });
|
|
||||||
expect(result.current.data).toEqual({ value: 3 });
|
|
||||||
});
|
|
||||||
|
|
||||||
it("saveError is set when save fails", async () => {
|
|
||||||
const fetchFn = vi.fn().mockResolvedValue("ok");
|
|
||||||
const saveFn = vi.fn().mockRejectedValue(new Error("save failed"));
|
|
||||||
|
|
||||||
const { result } = renderHook(() => useConfigItem<string, string>({ fetchFn, saveFn }));
|
|
||||||
|
|
||||||
await act(async () => {
|
|
||||||
await Promise.resolve();
|
|
||||||
});
|
|
||||||
|
|
||||||
await act(async () => {
|
|
||||||
await expect(result.current.save("test")).rejects.toThrow("save failed");
|
|
||||||
});
|
|
||||||
|
|
||||||
expect(result.current.saveError).toBe("save failed");
|
|
||||||
});
|
|
||||||
});
|
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
import { describe, it, expect, vi } from "vitest";
|
|
||||||
import { renderHook, act, waitFor } from "@testing-library/react";
|
|
||||||
import { useJailBannedIps } from "../useJails";
|
|
||||||
import * as api from "../../api/jails";
|
|
||||||
|
|
||||||
vi.mock("../../api/jails");
|
|
||||||
|
|
||||||
describe("useJailBannedIps", () => {
|
|
||||||
it("loads bans and allows unban", async () => {
|
|
||||||
const fetchMock = vi.mocked(api.fetchJailBannedIps);
|
|
||||||
const unbanMock = vi.mocked(api.unbanIp);
|
|
||||||
|
|
||||||
fetchMock.mockResolvedValue({ items: [{ ip: "1.2.3.4", jail: "sshd", banned_at: "2025-01-01T10:00:00+00:00", expires_at: "2025-01-01T10:10:00+00:00", ban_count: 1, country: "US" }], total: 1, page: 1, page_size: 25 });
|
|
||||||
unbanMock.mockResolvedValue({ message: "ok", jail: "sshd" });
|
|
||||||
|
|
||||||
const { result } = renderHook(() => useJailBannedIps("sshd"));
|
|
||||||
await waitFor(() => {
|
|
||||||
expect(result.current.loading).toBe(false);
|
|
||||||
});
|
|
||||||
expect(result.current.items.length).toBe(1);
|
|
||||||
|
|
||||||
await act(async () => {
|
|
||||||
await result.current.unban("1.2.3.4");
|
|
||||||
});
|
|
||||||
|
|
||||||
expect(unbanMock).toHaveBeenCalledWith("1.2.3.4", "sshd");
|
|
||||||
expect(fetchMock).toHaveBeenCalledTimes(2);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
@@ -1,207 +0,0 @@
|
|||||||
import { describe, it, expect, vi, beforeEach } from "vitest";
|
|
||||||
import { renderHook, act } from "@testing-library/react";
|
|
||||||
import * as jailsApi from "../../api/jails";
|
|
||||||
import { useJailDetail } from "../useJails";
|
|
||||||
import type { Jail } from "../../types/jail";
|
|
||||||
|
|
||||||
// Mock the API module
|
|
||||||
vi.mock("../../api/jails");
|
|
||||||
|
|
||||||
const mockJail: Jail = {
|
|
||||||
name: "sshd",
|
|
||||||
running: true,
|
|
||||||
idle: false,
|
|
||||||
backend: "pyinotify",
|
|
||||||
log_paths: ["/var/log/auth.log"],
|
|
||||||
fail_regex: ["^\\[.*\\]\\s.*Failed password"],
|
|
||||||
ignore_regex: [],
|
|
||||||
date_pattern: "%b %d %H:%M:%S",
|
|
||||||
log_encoding: "UTF-8",
|
|
||||||
actions: [],
|
|
||||||
find_time: 600,
|
|
||||||
ban_time: 600,
|
|
||||||
max_retry: 5,
|
|
||||||
status: null,
|
|
||||||
bantime_escalation: null,
|
|
||||||
};
|
|
||||||
|
|
||||||
describe("useJailDetail control methods", () => {
|
|
||||||
beforeEach(() => {
|
|
||||||
vi.clearAllMocks();
|
|
||||||
vi.mocked(jailsApi.fetchJail).mockResolvedValue({
|
|
||||||
jail: mockJail,
|
|
||||||
ignore_list: [],
|
|
||||||
ignore_self: false,
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
it("calls start() and refetches jail data", async () => {
|
|
||||||
vi.mocked(jailsApi.startJail).mockResolvedValue(undefined);
|
|
||||||
|
|
||||||
const { result } = renderHook(() => useJailDetail("sshd"));
|
|
||||||
|
|
||||||
// Wait for initial fetch
|
|
||||||
await act(async () => {
|
|
||||||
await new Promise((r) => setTimeout(r, 0));
|
|
||||||
});
|
|
||||||
|
|
||||||
expect(result.current.jail?.name).toBe("sshd");
|
|
||||||
expect(jailsApi.startJail).not.toHaveBeenCalled();
|
|
||||||
|
|
||||||
// Call start()
|
|
||||||
await act(async () => {
|
|
||||||
await result.current.start();
|
|
||||||
});
|
|
||||||
|
|
||||||
expect(jailsApi.startJail).toHaveBeenCalledWith("sshd");
|
|
||||||
expect(jailsApi.fetchJail).toHaveBeenCalledTimes(2); // Initial fetch + refetch after start
|
|
||||||
});
|
|
||||||
|
|
||||||
it("calls stop() and refetches jail data", async () => {
|
|
||||||
vi.mocked(jailsApi.stopJail).mockResolvedValue(undefined);
|
|
||||||
|
|
||||||
const { result } = renderHook(() => useJailDetail("sshd"));
|
|
||||||
|
|
||||||
// Wait for initial fetch
|
|
||||||
await act(async () => {
|
|
||||||
await new Promise((r) => setTimeout(r, 0));
|
|
||||||
});
|
|
||||||
|
|
||||||
// Call stop()
|
|
||||||
await act(async () => {
|
|
||||||
await result.current.stop();
|
|
||||||
});
|
|
||||||
|
|
||||||
expect(jailsApi.stopJail).toHaveBeenCalledWith("sshd");
|
|
||||||
expect(jailsApi.fetchJail).toHaveBeenCalledTimes(2); // Initial fetch + refetch after stop
|
|
||||||
});
|
|
||||||
|
|
||||||
it("calls reload() and refetches jail data", async () => {
|
|
||||||
vi.mocked(jailsApi.reloadJail).mockResolvedValue(undefined);
|
|
||||||
|
|
||||||
const { result } = renderHook(() => useJailDetail("sshd"));
|
|
||||||
|
|
||||||
// Wait for initial fetch
|
|
||||||
await act(async () => {
|
|
||||||
await new Promise((r) => setTimeout(r, 0));
|
|
||||||
});
|
|
||||||
|
|
||||||
// Call reload()
|
|
||||||
await act(async () => {
|
|
||||||
await result.current.reload();
|
|
||||||
});
|
|
||||||
|
|
||||||
expect(jailsApi.reloadJail).toHaveBeenCalledWith("sshd");
|
|
||||||
expect(jailsApi.fetchJail).toHaveBeenCalledTimes(2); // Initial fetch + refetch after reload
|
|
||||||
});
|
|
||||||
|
|
||||||
it("calls setIdle() with correct parameter and refetches jail data", async () => {
|
|
||||||
vi.mocked(jailsApi.setJailIdle).mockResolvedValue(undefined);
|
|
||||||
|
|
||||||
const { result } = renderHook(() => useJailDetail("sshd"));
|
|
||||||
|
|
||||||
// Wait for initial fetch
|
|
||||||
await act(async () => {
|
|
||||||
await new Promise((r) => setTimeout(r, 0));
|
|
||||||
});
|
|
||||||
|
|
||||||
// Call setIdle(true)
|
|
||||||
await act(async () => {
|
|
||||||
await result.current.setIdle(true);
|
|
||||||
});
|
|
||||||
|
|
||||||
expect(jailsApi.setJailIdle).toHaveBeenCalledWith("sshd", true);
|
|
||||||
expect(jailsApi.fetchJail).toHaveBeenCalledTimes(2);
|
|
||||||
|
|
||||||
// Reset mock to verify second call
|
|
||||||
vi.mocked(jailsApi.setJailIdle).mockClear();
|
|
||||||
vi.mocked(jailsApi.fetchJail).mockResolvedValue({
|
|
||||||
jail: { ...mockJail, idle: true },
|
|
||||||
ignore_list: [],
|
|
||||||
ignore_self: false,
|
|
||||||
});
|
|
||||||
|
|
||||||
// Call setIdle(false)
|
|
||||||
await act(async () => {
|
|
||||||
await result.current.setIdle(false);
|
|
||||||
});
|
|
||||||
|
|
||||||
expect(jailsApi.setJailIdle).toHaveBeenCalledWith("sshd", false);
|
|
||||||
});
|
|
||||||
|
|
||||||
it("propagates errors from start()", async () => {
|
|
||||||
const error = new Error("Failed to start jail");
|
|
||||||
vi.mocked(jailsApi.startJail).mockRejectedValue(error);
|
|
||||||
|
|
||||||
const { result } = renderHook(() => useJailDetail("sshd"));
|
|
||||||
|
|
||||||
// Wait for initial fetch
|
|
||||||
await act(async () => {
|
|
||||||
await new Promise((r) => setTimeout(r, 0));
|
|
||||||
});
|
|
||||||
|
|
||||||
// Call start() and expect it to throw
|
|
||||||
await expect(
|
|
||||||
act(async () => {
|
|
||||||
await result.current.start();
|
|
||||||
}),
|
|
||||||
).rejects.toThrow("Failed to start jail");
|
|
||||||
});
|
|
||||||
|
|
||||||
it("propagates errors from stop()", async () => {
|
|
||||||
const error = new Error("Failed to stop jail");
|
|
||||||
vi.mocked(jailsApi.stopJail).mockRejectedValue(error);
|
|
||||||
|
|
||||||
const { result } = renderHook(() => useJailDetail("sshd"));
|
|
||||||
|
|
||||||
// Wait for initial fetch
|
|
||||||
await act(async () => {
|
|
||||||
await new Promise((r) => setTimeout(r, 0));
|
|
||||||
});
|
|
||||||
|
|
||||||
// Call stop() and expect it to throw
|
|
||||||
await expect(
|
|
||||||
act(async () => {
|
|
||||||
await result.current.stop();
|
|
||||||
}),
|
|
||||||
).rejects.toThrow("Failed to stop jail");
|
|
||||||
});
|
|
||||||
|
|
||||||
it("propagates errors from reload()", async () => {
|
|
||||||
const error = new Error("Failed to reload jail");
|
|
||||||
vi.mocked(jailsApi.reloadJail).mockRejectedValue(error);
|
|
||||||
|
|
||||||
const { result } = renderHook(() => useJailDetail("sshd"));
|
|
||||||
|
|
||||||
// Wait for initial fetch
|
|
||||||
await act(async () => {
|
|
||||||
await new Promise((r) => setTimeout(r, 0));
|
|
||||||
});
|
|
||||||
|
|
||||||
// Call reload() and expect it to throw
|
|
||||||
await expect(
|
|
||||||
act(async () => {
|
|
||||||
await result.current.reload();
|
|
||||||
}),
|
|
||||||
).rejects.toThrow("Failed to reload jail");
|
|
||||||
});
|
|
||||||
|
|
||||||
it("propagates errors from setIdle()", async () => {
|
|
||||||
const error = new Error("Failed to set idle mode");
|
|
||||||
vi.mocked(jailsApi.setJailIdle).mockRejectedValue(error);
|
|
||||||
|
|
||||||
const { result } = renderHook(() => useJailDetail("sshd"));
|
|
||||||
|
|
||||||
// Wait for initial fetch
|
|
||||||
await act(async () => {
|
|
||||||
await new Promise((r) => setTimeout(r, 0));
|
|
||||||
});
|
|
||||||
|
|
||||||
// Call setIdle() and expect it to throw
|
|
||||||
await expect(
|
|
||||||
act(async () => {
|
|
||||||
await result.current.setIdle(true);
|
|
||||||
}),
|
|
||||||
).rejects.toThrow("Failed to set idle mode");
|
|
||||||
});
|
|
||||||
});
|
|
||||||
@@ -1,41 +0,0 @@
|
|||||||
import { describe, it, expect, vi } from "vitest";
|
|
||||||
import { renderHook, act, waitFor } from "@testing-library/react";
|
|
||||||
import { useMapColorThresholds } from "../useMapColorThresholds";
|
|
||||||
import * as api from "../../api/config";
|
|
||||||
|
|
||||||
vi.mock("../../api/config");
|
|
||||||
|
|
||||||
describe("useMapColorThresholds", () => {
|
|
||||||
it("loads thresholds and exposes values", async () => {
|
|
||||||
const mocked = vi.mocked(api.fetchMapColorThresholds);
|
|
||||||
mocked.mockResolvedValue({ threshold_low: 10, threshold_medium: 20, threshold_high: 50 });
|
|
||||||
|
|
||||||
const { result } = renderHook(() => useMapColorThresholds());
|
|
||||||
|
|
||||||
expect(result.current.loading).toBe(true);
|
|
||||||
await waitFor(() => {
|
|
||||||
expect(result.current.loading).toBe(false);
|
|
||||||
});
|
|
||||||
expect(result.current.thresholds).toEqual({ threshold_low: 10, threshold_medium: 20, threshold_high: 50 });
|
|
||||||
expect(result.current.error).toBeNull();
|
|
||||||
});
|
|
||||||
|
|
||||||
it("updates thresholds via callback", async () => {
|
|
||||||
const fetchMock = vi.mocked(api.fetchMapColorThresholds);
|
|
||||||
const updateMock = vi.mocked(api.updateMapColorThresholds);
|
|
||||||
|
|
||||||
fetchMock.mockResolvedValue({ threshold_low: 10, threshold_medium: 20, threshold_high: 50 });
|
|
||||||
updateMock.mockResolvedValue({ threshold_low: 15, threshold_medium: 25, threshold_high: 75 });
|
|
||||||
|
|
||||||
const { result } = renderHook(() => useMapColorThresholds());
|
|
||||||
await waitFor(() => {
|
|
||||||
expect(result.current.loading).toBe(false);
|
|
||||||
});
|
|
||||||
|
|
||||||
await act(async () => {
|
|
||||||
await result.current.updateThresholds({ threshold_low: 15, threshold_medium: 25, threshold_high: 75 });
|
|
||||||
});
|
|
||||||
|
|
||||||
expect(result.current.thresholds).toEqual({ threshold_low: 15, threshold_medium: 25, threshold_high: 75 });
|
|
||||||
});
|
|
||||||
});
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user