refactor: complete Task 2/3 geo decouple + exceptions centralization; mark as done

This commit is contained in:
2026-03-21 17:15:02 +01:00
parent 452901913f
commit 5a49106f4d
28 changed files with 803 additions and 571 deletions

View File

@@ -1,238 +1,195 @@
# BanGUI — Refactoring Instructions for AI Agents # BanGUI — Architecture Issues & Refactoring Plan
This document is the single source of truth for any AI agent performing a refactoring task on the BanGUI codebase. This document catalogues architecture violations, code smells, and structural issues found during a full project review. Issues are grouped by category and prioritised.
Read it in full before writing a single line of code.
The authoritative description of every module, its responsibilities, and the allowed dependency direction is in [Architekture.md](Architekture.md). Always cross-reference it.
--- ---
## 0. Golden Rules ## 1. Service-to-Service Coupling (Backend)
1. **Architecture first.** Every change must comply with the layered architecture defined in [Architekture.md §2](Architekture.md). Dependencies flow inward: `routers → services → repositories`. Never add an import that reverses this direction. The architecture mandates that dependencies flow **routers → services → repositories**, yet **15 service-to-service imports** exist, with 7 using lazy imports to work around circular dependencies.
2. **One concern per file.** Each module has an explicitly stated purpose in [Architekture.md](Architekture.md). Do not add responsibilities to a module that do not belong there.
3. **No behaviour change.** Refactoring must preserve all existing behaviour. If a function's public signature, return value, or side-effects must change, that is a feature — create a separate task for it. | Source Service | Imports From | Line | Mechanism |
4. **Tests stay green.** Run the full test suite (`pytest backend/`) before and after every change. Do not submit work that introduces new failures. |---|---|---|---|
5. **Smallest diff wins.** Prefer targeted edits. Do not rewrite a file when a few lines suffice. | `history_service` | `ban_service` | L29 | Direct import of 3 **private** functions: `_get_fail2ban_db_path`, `_parse_data_json`, `_ts_to_iso` |
| `auth_service` | `setup_service` | L23 | Top-level import |
| `config_service` | `setup_service` | L47 | Top-level import |
| `config_service` | `health_service` | L891 | Lazy import inside function |
| `config_file_service` | `jail_service` | L5758 | Top-level import + re-export of `JailNotFoundError` |
| `blocklist_service` | `jail_service` | L299 | Lazy import |
| `blocklist_service` | `geo_service` | L343 | Lazy import |
| `jail_service` | `geo_service` | L860, L1047 | Lazy import (2 sites) |
| `ban_service` | `geo_service` | L251, L392 | Lazy import (2 sites) |
| `history_service` | `geo_service` | L19 | TYPE_CHECKING import |
**Impact**: Circular dependency risk; lazy imports hide coupling; private function imports create fragile links between services.
**Recommendation**:
- Extract `_get_fail2ban_db_path()`, `_parse_data_json()`, `_ts_to_iso()` from `ban_service` to `app/utils/fail2ban_db_utils.py` (shared utility).
- Pass geo-enrichment as a callback parameter instead of each service importing `geo_service` directly. The router or dependency layer should wire this.
- Where services depend on another service's domain exceptions (e.g., `JailNotFoundError`), move exceptions to `app/models/` or a shared `app/exceptions.py`.
--- ---
## 1. Before You Start ## 2. God Modules (Backend)
### 1.1 Understand the project Several service files far exceed a reasonable size for a single-domain module:
Read the following documents in order: | File | Lines | Functions | Issue |
|---|---|---|---|
| `config_file_service.py` | **3105** | **73** | Handles jails, filters, and actions — three distinct domains crammed into one file |
| `jail_service.py` | **1382** | **34** | Manages jail listing, status, controls, banned-IP queries, and geo enrichment |
| `config_service.py` | **921** | ~20 | Socket-based config, log preview, regex testing, and service status |
| `file_config_service.py` | **1011** | ~20 | Raw file I/O for jails, filters, and actions |
1. [Architekture.md](Architekture.md) — full system overview, component map, module purposes, dependency rules. **Recommendation**:
2. [Docs/Backend-Development.md](Backend-Development.md) — coding conventions, testing strategy, environment setup. - Split `config_file_service.py` into `filter_config_service.py`, `action_config_service.py`, and a slimmed-down `jail_config_service.py`.
3. [Docs/Tasks.md](Tasks.md) — open issues and planned work; avoid touching areas that have pending conflicting changes. - Extract log-preview / regex-test functionality from `config_service.py` into a dedicated `log_service.py`.
### 1.2 Map the code to the architecture ---
Before editing, locate every file that is in scope: ## 3. Confusing Config Service Naming (Backend)
``` Three services with overlapping names handle different aspects of configuration, causing developer confusion:
backend/app/
routers/ HTTP layer — zero business logic
services/ Business logic — orchestrates repositories + clients
repositories/ Data access — raw SQL only
models/ Pydantic schemas
tasks/ APScheduler jobs
utils/ Pure helpers, no framework deps
main.py App factory, lifespan, middleware
config.py Pydantic settings
dependencies.py FastAPI Depends() wiring
frontend/src/ | Current Name | Purpose |
api/ Typed fetch wrappers + endpoint constants |---|---|
components/ Presentational UI, no API calls | `config_service` | Read/write via the fail2ban **socket** |
hooks/ All state, side-effects, API calls | `config_file_service` | Parse/activate/deactivate jails from **files on disk** |
pages/ Route components — orchestration only | `file_config_service` | **Raw file I/O** for jail/filter/action `.conf` files |
providers/ React context
types/ TypeScript interfaces `config_file_service` vs `file_config_service` differ only by word order, making it easy to import the wrong one.
utils/ Pure helpers
**Recommendation**: Rename for clarity:
- `config_service` → keep (socket-based)
- `config_file_service``jail_activation_service` (its main job is activating/deactivating jails)
- `file_config_service``raw_config_io_service` or merge into `config_file_service`
---
## 4. Architecture Doc Drift
The architecture doc does not fully reflect the current codebase:
| Category | In Architecture Doc | Actually Exists | Notes |
|---|---|---|---|
| Repositories | 4 listed | **6 files** | `fail2ban_db_repo.py` and `geo_cache_repo.py` are missing from the doc |
| Utils | 4 listed | **8 files** | `conffile_parser.py`, `config_parser.py`, `config_writer.py`, `jail_config.py` are undocumented |
| Tasks | 3 listed | **4 files** | `geo_re_resolve.py` is missing from the doc |
| Services | `conffile_parser` listed as a service | Actually in `app/utils/` | Doc says "Services" but the file is in `utils/` |
| Routers | `file_config.py` not listed | Exists | Missing from router table |
**Recommendation**: Update the Architecture doc to reflect the actual file inventory.
---
## 5. Shared Private Functions Cross Service Boundary (Backend)
`history_service.py` imports three **underscore-prefixed** ("private") functions from `ban_service.py`:
```python
from app.services.ban_service import _get_fail2ban_db_path, _parse_data_json, _ts_to_iso
``` ```
Confirm which layer every file you intend to touch belongs to. If unsure, consult [Architekture.md §2.2](Architekture.md) (backend) or [Architekture.md §3.2](Architekture.md) (frontend). These are implementation details of `ban_service` that should not be consumed externally. Their `_` prefix signals they are not part of the public API.
### 1.3 Run the baseline **Recommendation**: Move these to `app/utils/fail2ban_db_utils.py` as public functions and import from there in both services.
```bash
# Backend
pytest backend/ -x --tb=short
# Frontend
cd frontend && npm run test
```
Record the number of passing tests. After refactoring, that number must be equal or higher.
--- ---
## 2. Backend Refactoring ## 6. Missing Error Boundaries (Frontend)
### 2.1 Routers (`app/routers/`) No React Error Boundary component exists anywhere in the frontend. A single unhandled exception in any component will crash the entire application with a white screen.
**Allowed content:** request parsing, response serialisation, dependency injection via `Depends()`, delegation to a service, HTTP error mapping. **Recommendation**: Add an `<ErrorBoundary>` wrapper in `MainLayout.tsx` or `App.tsx` with a fallback UI that shows the error and offers a retry/reload.
**Forbidden content:** SQL queries, business logic, direct use of `fail2ban_client`, any logic that would also make sense in a unit test without an HTTP request.
Checklist:
- [ ] Every handler calls exactly one service method per logical operation.
- [ ] No `if`/`elif` chains that implement business rules — move these to the service.
- [ ] No raw SQL or repository imports.
- [ ] All response models are Pydantic schemas from `app/models/`.
- [ ] HTTP status codes are consistent with API conventions (200 OK, 201 Created, 204 No Content, 400/422 for client errors, 404 for missing resources, 500 only for unexpected failures).
### 2.2 Services (`app/services/`)
**Allowed content:** business rules, coordination between repositories and external clients, validation that goes beyond Pydantic, fail2ban command orchestration.
**Forbidden content:** raw SQL, direct aiosqlite calls, FastAPI `HTTPException` (raise domain exceptions instead and let the router or exception handler convert them).
Checklist:
- [ ] Service classes / functions accept plain Python types or domain models — not `Request` or `Response` objects.
- [ ] No direct `aiosqlite` usage — go through a repository.
- [ ] No `HTTPException` — raise a custom domain exception or a plain `ValueError`/`RuntimeError` with a clear message.
- [ ] No circular imports between services — if two services need each other's logic, extract the shared logic to a utility or a third service.
### 2.3 Repositories (`app/repositories/`)
**Allowed content:** SQL queries, result mapping to domain models, transaction management.
**Forbidden content:** business logic, fail2ban calls, HTTP concerns, logging beyond debug-level traces.
Checklist:
- [ ] Every public method accepts a `db: aiosqlite.Connection` parameter — sessions are not managed internally.
- [ ] Methods return typed domain models or plain Python primitives, never raw `aiosqlite.Row` objects exposed to callers.
- [ ] No business rules (e.g., no "if this setting is missing, create a default" logic — that belongs in the service).
### 2.4 Models (`app/models/`)
- Keep **Request**, **Response**, and **Domain** model types clearly separated (see [Architekture.md §2.2](Architekture.md)).
- Do not use response models as function arguments inside service or repository code.
- Validators (`@field_validator`, `@model_validator`) belong in models only when they concern data shape, not business rules.
### 2.5 Tasks (`app/tasks/`)
- Tasks must be thin: fetch inputs → call one service method → log result.
- Error handling must be inside the task (APScheduler swallows unhandled exceptions — log them explicitly).
- No direct repository or `fail2ban_client` use; go through a service.
### 2.6 Utils (`app/utils/`)
- Must have zero framework dependencies (no FastAPI, no aiosqlite imports).
- Must be pure or near-pure functions.
- `fail2ban_client.py` is the single exception — it wraps the socket protocol but still has no service-layer logic.
### 2.7 Dependencies (`app/dependencies.py`)
- This file is the **only** place where service constructors are called and injected.
- Do not construct services inside router handlers; always receive them via `Depends()`.
--- ---
## 3. Frontend Refactoring ## 7. Duplicated Formatting Functions (Frontend)
### 3.1 Pages (`src/pages/`) Several formatting functions are independently defined in multiple files instead of being shared:
**Allowed content:** composing components and hooks, layout decisions, routing. | Function | Defined In | Also In |
**Forbidden content:** direct `fetch`/`axios` calls, inline business logic, state management beyond what is needed to coordinate child components.
Checklist:
- [ ] All data fetching goes through a hook from `src/hooks/`.
- [ ] No API function from `src/api/` is called directly inside a page component.
### 3.2 Components (`src/components/`)
**Allowed content:** rendering, styling, event handlers that call prop callbacks.
**Forbidden content:** API calls, hook-level state (prefer lifting state to the page or a dedicated hook), direct use of `src/api/`.
Checklist:
- [ ] Components receive all data via props.
- [ ] Components emit changes via callback props (`onXxx`).
- [ ] No `useEffect` that calls an API function — that belongs in a hook.
### 3.3 Hooks (`src/hooks/`)
**Allowed content:** `useState`, `useEffect`, `useCallback`, `useRef`; calls to `src/api/`; local state derivation.
**Forbidden content:** JSX rendering, Fluent UI components.
Checklist:
- [ ] Each hook has a single, focused concern matching its name (e.g., `useBans` only manages ban data).
- [ ] Hooks return a stable interface: `{ data, loading, error, refetch }` or equivalent.
- [ ] Shared logic between hooks is extracted to `src/utils/` (pure) or a parent hook (stateful).
### 3.4 API layer (`src/api/`)
- `client.ts` is the only place that calls `fetch`. All other api files call `client.ts`.
- `endpoints.ts` is the single source of truth for URL strings.
- API functions must be typed: explicit request and response TypeScript interfaces from `src/types/`.
### 3.5 Types (`src/types/`)
- Interfaces must match the backend Pydantic response schemas exactly (field names, optionality).
- Do not use `any`. Use `unknown` and narrow with type guards when the shape is genuinely unknown.
---
## 4. General Code Quality Rules
### Naming
- Python: `snake_case` for variables/functions, `PascalCase` for classes.
- TypeScript: `camelCase` for variables/functions, `PascalCase` for components and types.
- File names must match the primary export they contain.
### Error handling
- Backend: raise typed exceptions; map them to HTTP status codes in `main.py` exception handlers or in the router — nowhere else.
- Frontend: all API call error states are represented in hook return values; never swallow errors silently.
### Logging (backend)
- Use `structlog` with bound context loggers — never bare `print()`.
- Log at `debug` in repositories, `info` in services for meaningful events, `warning`/`error` in tasks and exception handlers.
- Never log sensitive data (passwords, session tokens, raw IP lists larger than a handful of entries).
### Async correctness (backend)
- Every function that touches I/O (database, fail2ban socket, HTTP) must be `async def`.
- Never call `asyncio.run()` inside a running event loop.
- Do not use `time.sleep()` — use `await asyncio.sleep()`.
---
## 5. Refactoring Workflow
Follow this sequence for every refactoring task:
1. **Read** the relevant section of [Architekture.md](Architekture.md) for the files you will touch.
2. **Run** the full test suite to confirm the baseline.
3. **Identify** the violation or smell: which rule from this document does it break?
4. **Plan** the minimal change: what is the smallest edit that fixes the violation?
5. **Edit** the code. One logical change per commit.
6. **Verify** imports: nothing new violates the dependency direction.
7. **Run** the full test suite. All previously passing tests must still pass.
8. **Update** any affected docstrings or inline comments to reflect the new structure.
9. **Do not** update `Architekture.md` unless the refactor changes the documented structure — that requires a separate review.
---
## 6. Common Violations to Look For
| Violation | Where it typically appears | Fix |
|---|---|---| |---|---|---|
| Business logic in a router handler | `app/routers/*.py` | Extract logic to the corresponding service | | `formatTimestamp()` | `BanTable.tsx` (L103) | — (but `fmtTime()` in `BannedIpsSection.tsx` does the same thing) |
| Direct `aiosqlite` calls in a service | `app/services/*.py` | Move the query into the matching repository | | `fmtSeconds()` | `JailDetailPage.tsx` (L152) | `JailsPage.tsx` (L147) — identical |
| `HTTPException` raised inside a service | `app/services/*.py` | Raise a domain exception; catch and convert it in the router or exception handler | | `fmtTime()` | `BannedIpsSection.tsx` (L139) | — |
| API call inside a React component | `src/components/*.tsx` | Move to a hook; pass data via props |
| Hardcoded URL string in a hook or component | `src/hooks/*.ts`, `src/components/*.tsx` | Use the constant from `src/api/endpoints.ts` | **Recommendation**: Consolidate into `src/utils/formatDate.ts` and import from there.
| `any` type in TypeScript | anywhere in `src/` | Replace with a concrete interface from `src/types/` |
| `print()` statements in production code | `backend/app/**/*.py` | Replace with `structlog` logger |
| Synchronous I/O in an async function | `backend/app/**/*.py` | Use the async equivalent |
| A repository method that contains an `if` with a business rule | `app/repositories/*.py` | Move the rule to the service layer |
--- ---
## 7. Out of Scope ## 8. Duplicated Hook Logic (Frontend)
Do not make the following changes unless explicitly instructed in a separate task: Three hooks follow an identical fetch-then-save pattern with near-identical code:
- Adding new API endpoints or pages. | Hook | Lines | Pattern |
- Changing database schema or migration files. |---|---|---|
- Upgrading dependencies. | `useFilterConfig.ts` | 91 | Load item → expose save → handle abort |
- Altering Docker or CI configuration. | `useActionConfig.ts` | 89 | Load item → expose save → handle abort |
- Modifying `Architekture.md` or `Tasks.md`. | `useJailFileConfig.ts` | 76 | Load item → expose save → handle abort |
**Recommendation**: Create a generic `useConfigItem<T>()` hook that takes `fetchFn` and `saveFn` parameters and eliminates the triplication.
---
## 9. Inconsistent Error Handling in Hooks (Frontend)
Hooks handle errors differently:
- Some filter out `AbortError` (e.g., `useHistory`, `useMapData`)
- Others catch all errors indiscriminately (e.g., `useBans`, `useBlocklist`)
This means some hooks surface spurious "request aborted" errors to the UI while others don't.
**Recommendation**: Standardise a shared error-catching pattern, e.g. a `handleFetchError(err, setError)` utility that always filters `AbortError`.
---
## 10. No Global Request State / Caching (Frontend)
Each hook manages its own loading/error/data state independently. There is:
- No request deduplication (two components fetching the same data trigger two requests)
- No stale-while-revalidate caching
- No automatic background refetching
**Recommendation**: Consider adopting React Query (TanStack Query) or SWR for data-fetching hooks. This would eliminate boilerplate in every hook (abort handling, loading state, error state, caching) and provide automatic deduplication.
---
## 11. Large Frontend Components
| Component | Lines | Issue |
|---|---|---|
| `BlocklistsPage.tsx` | 968 | Page does a lot: source list, add/edit dialogs, import log, schedule config |
| `JailsTab.tsx` | 939 | Combines jail list, config editing, raw config, validation, activate/deactivate |
| `JailsPage.tsx` | 691 | Mixes jail table, detail drawer, ban/unban forms |
| `JailDetailPage.tsx` | 663 | Full detail view with multiple sections |
**Recommendation**: Extract sub-sections into focused child components. For example, `JailsTab.tsx` could delegate to `<JailConfigEditor>`, `<JailValidation>`, and `<JailActivateDialog>`.
---
## 12. Duplicated Section Styles (Frontend)
The same card/section styling pattern (`backgroundColor`, `borderRadius`, `border`, `padding` using Fluent UI tokens) is repeated across 13+ files. Each page recreates it in its own `makeStyles` block.
**Recommendation**: Define a shared `useCardStyles()` or export a `sectionStyle` in `src/theme/commonStyles.ts` and import it.
---
## Summary by Priority
| Priority | Issue | Section |
|---|---|---|
| **High** | Service-to-service coupling / circular deps | §1 |
| **High** | God module `config_file_service.py` (3105 lines, 73 functions) | §2 |
| **High** | Shared private function imports across services | §5 |
| **Medium** | Confusing config service naming | §3 |
| **Medium** | Architecture doc drift | §4 |
| **Medium** | Missing error boundaries (frontend) | §6 |
| **Medium** | No global request state / caching (frontend) | §10 |
| **Low** | Duplicated formatting functions (frontend) | §7 |
| **Low** | Duplicated hook logic (frontend) | §8 |
| **Low** | Inconsistent error handling in hooks (frontend) | §9 |
| **Low** | Large frontend components | §11 |
| **Low** | Duplicated section styles (frontend) | §12 |

View File

@@ -2,6 +2,314 @@
This document breaks the entire BanGUI project into development stages, ordered so that each stage builds on the previous one. Every task is described in prose with enough detail for a developer to begin work. References point to the relevant documentation. This document breaks the entire BanGUI project into development stages, ordered so that each stage builds on the previous one. Every task is described in prose with enough detail for a developer to begin work. References point to the relevant documentation.
Reference: `Docs/Refactoring.md` for full analysis of each issue.
--- ---
## Open Issues ## Open Issues
---
### Task 1 — Extract shared private functions to a utility module (✅ completed)
**Priority**: High
**Refactoring ref**: Refactoring.md §1, §5
**Affected files**:
- `backend/app/services/ban_service.py` (defines `_get_fail2ban_db_path` ~L117, `_parse_data_json` ~L152, `_ts_to_iso` ~L105)
- `backend/app/services/history_service.py` (imports these three private functions from `ban_service`)
**What to do**:
1. Create a new file `backend/app/utils/fail2ban_db_utils.py`.
2. Move the three functions `_get_fail2ban_db_path()`, `_parse_data_json()`, and `_ts_to_iso()` from `backend/app/services/ban_service.py` into the new utility file. Rename them to remove the leading underscore (they are now public utilities): `get_fail2ban_db_path()`, `parse_data_json()`, `ts_to_iso()`.
3. In `backend/app/services/ban_service.py`, replace the function bodies with imports from the new utility: `from app.utils.fail2ban_db_utils import get_fail2ban_db_path, parse_data_json, ts_to_iso`. Update all internal call sites within `ban_service.py` that reference the old `_`-prefixed names.
4. In `backend/app/services/history_service.py`, replace the import `from app.services.ban_service import _get_fail2ban_db_path, _parse_data_json, _ts_to_iso` with `from app.utils.fail2ban_db_utils import get_fail2ban_db_path, parse_data_json, ts_to_iso`. Update all call sites in `history_service.py`.
5. Search the entire `backend/` tree for any other references to the old `_`-prefixed names and update them.
6. Run existing tests: `cd backend && python -m pytest tests/` — all tests must pass.
**Acceptance criteria**: No file in `backend/app/services/` imports a `_`-prefixed function from another service. The three functions live in `backend/app/utils/fail2ban_db_utils.py` and are imported from there.
---
### Task 2 — Decouple geo-enrichment from services (✅ completed)
**Priority**: High
**Refactoring ref**: Refactoring.md §1
**Affected files**:
- `backend/app/services/jail_service.py` (lazy imports `geo_service` at ~L860, ~L1047)
- `backend/app/services/ban_service.py` (lazy imports `geo_service` at ~L251, ~L392)
- `backend/app/services/blocklist_service.py` (lazy imports `geo_service` at ~L343)
- `backend/app/services/history_service.py` (TYPE_CHECKING import of `geo_service` at ~L19)
- `backend/app/services/geo_service.py` (the service being imported)
- Router files that call these services: `backend/app/routers/jails.py`, `backend/app/routers/bans.py`, `backend/app/routers/dashboard.py`, `backend/app/routers/history.py`, `backend/app/routers/blocklist.py`
**What to do**:
1. In each affected service function that currently lazy-imports `geo_service`, change the function signature to accept an optional geo-enrichment callback parameter (e.g., `enrich_geo: Callable | None = None`). The callback signature should match what `geo_service` provides (typically `async def enrich(ip: str) -> GeoInfo | None`).
2. Remove all lazy imports of `geo_service` from `jail_service.py`, `ban_service.py`, `blocklist_service.py`, and `history_service.py`.
3. In the corresponding router files, import `geo_service` and pass its enrichment function as the callback when calling the service functions. The router layer is where wiring belongs.
4. Run existing tests: `cd backend && python -m pytest tests/` — all tests must pass. If tests mock `geo_service` inside a service, update mocks to inject the callback instead.
**Acceptance criteria**: No service file imports `geo_service` (directly or lazily). Geo-enrichment is injected from routers via callback parameters.
---
### Task 3 — Move shared domain exceptions to a central module (✅ completed)
**Priority**: High
**Refactoring ref**: Refactoring.md §1
**Affected files**:
- `backend/app/services/config_file_service.py` (defines `JailNotFoundError` and other domain exceptions)
- `backend/app/services/jail_service.py` (may define or re-export exceptions)
- Any service or router that imports exceptions cross-service (e.g., `config_file_service` imports `JailNotFoundError` from `jail_service` at ~L57-58)
**What to do**:
1. Create `backend/app/exceptions.py`.
2. Grep the entire `backend/app/services/` directory for all custom exception class definitions (classes inheriting from `Exception` or `HTTPException`). Collect every exception that is imported by more than one module.
3. Move those shared exception classes into `backend/app/exceptions.py`.
4. Update all import statements across `backend/app/services/`, `backend/app/routers/`, and `backend/app/` to import from `backend/app/exceptions.py`.
5. Exception classes used only within a single service may remain in that service file.
6. Run existing tests: `cd backend && python -m pytest tests/` — all tests must pass.
**Acceptance criteria**: `backend/app/exceptions.py` exists and contains all cross-service exceptions. No service imports an exception class from another service module.
---
### Task 4 — Split `config_file_service.py` (god module)
**Priority**: High
**Refactoring ref**: Refactoring.md §2
**Affected files**:
- `backend/app/services/config_file_service.py` (~2232 lines, ~73 functions)
- `backend/app/routers/` files that import from `config_file_service`
**What to do**:
1. Read `backend/app/services/config_file_service.py` and categorise every function into one of three domains:
- **Jail config** — functions dealing with jail activation, deactivation, listing jail configs
- **Filter config** — functions dealing with fail2ban filter files (reading, writing, listing filters)
- **Action config** — functions dealing with fail2ban action files (reading, writing, listing actions)
2. Create three new service files:
- `backend/app/services/jail_config_service.py` — jail-related functions
- `backend/app/services/filter_config_service.py` — filter-related functions
- `backend/app/services/action_config_service.py` — action-related functions
3. Move functions from `config_file_service.py` into the appropriate new file. Any truly shared helpers used across all three domains should remain in `config_file_service.py` (renamed to a shared helper) or move to `backend/app/utils/`.
4. Delete `config_file_service.py` once empty (or keep it as a thin re-export layer for backward compatibility during transition).
5. Update all imports in `backend/app/routers/` and `backend/app/services/` that referenced `config_file_service`.
6. Run existing tests: `cd backend && python -m pytest tests/` — all tests must pass.
**Acceptance criteria**: No single service file exceeds ~800 lines. The three new files each handle one domain. All routers import from the correct new module.
---
### Task 5 — Extract log-preview / regex-test from `config_service.py`
**Priority**: Medium
**Refactoring ref**: Refactoring.md §2
**Affected files**:
- `backend/app/services/config_service.py` (~1845 lines)
- `backend/app/routers/config.py` (routes that call log-preview / regex-test functions)
**What to do**:
1. Read `backend/app/services/config_service.py` and identify all functions related to log-preview and regex-testing (these are distinct from the core socket-based config reading/writing functions).
2. Create `backend/app/services/log_service.py`.
3. Move the log-preview and regex-test functions into `log_service.py`.
4. Update imports in `backend/app/routers/config.py` (or create a new `backend/app/routers/log.py` if the endpoints are logically separate).
5. Run existing tests: `cd backend && python -m pytest tests/` — all tests must pass.
**Acceptance criteria**: `config_service.py` no longer contains log-preview or regex-test logic. `log_service.py` exists and is used by the appropriate router.
---
### Task 6 — Rename confusing config service files
**Priority**: Medium
**Refactoring ref**: Refactoring.md §3
**Affected files**:
- `backend/app/services/config_file_service.py` → rename to `jail_activation_service.py` (or the split modules from Task 4)
- `backend/app/services/file_config_service.py` → rename to `raw_config_io_service.py`
- All files importing from the old names
**Note**: This task depends on Task 4 being completed first. If Task 4 splits `config_file_service.py`, this task only needs to rename `file_config_service.py`.
**What to do**:
1. Rename `backend/app/services/file_config_service.py` to `backend/app/services/raw_config_io_service.py`.
2. Update all import statements across the codebase (`backend/app/services/`, `backend/app/routers/`, `backend/app/tasks/`, tests) that reference `file_config_service` to reference `raw_config_io_service`.
3. Also rename the corresponding router if one exists: check `backend/app/routers/file_config.py` and rename accordingly.
4. Run existing tests: `cd backend && python -m pytest tests/` — all tests must pass.
**Acceptance criteria**: No file named `file_config_service.py` exists. The new name `raw_config_io_service.py` is used everywhere.
---
### Task 7 — Remove remaining service-to-service coupling
**Priority**: Medium
**Refactoring ref**: Refactoring.md §1
**Affected files**:
- `backend/app/services/auth_service.py` (imports `setup_service` at ~L23)
- `backend/app/services/config_service.py` (imports `setup_service` at ~L47, lazy-imports `health_service` at ~L891)
- `backend/app/services/blocklist_service.py` (lazy-imports `jail_service` at ~L299)
**What to do**:
1. For each remaining service-to-service import, determine why the dependency exists (read the calling code).
2. Refactor using one of these strategies:
- **Dependency injection**: The router passes the needed data or function from service A when calling service B.
- **Shared utility**: If the imported function is a pure utility, move it to `backend/app/utils/`.
- **Event / callback**: The service accepts a callback parameter instead of importing another service directly.
3. Remove all direct and lazy imports between service modules.
4. Run existing tests: `cd backend && python -m pytest tests/` — all tests must pass.
**Acceptance criteria**: Running `grep -r "from app.services" backend/app/services/` returns zero results (no service imports another service). All wiring happens in the router or dependency-injection layer.
---
### Task 8 — Update Architecture documentation
**Priority**: Medium
**Refactoring ref**: Refactoring.md §4
**Affected files**:
- `Docs/Architekture.md`
**What to do**:
1. Read `Docs/Architekture.md` and the actual file listings below.
2. Add the following missing items to the appropriate sections:
- **Repositories**: Add `fail2ban_db_repo.py` and `geo_cache_repo.py` (in `backend/app/repositories/`)
- **Utils**: Add `conffile_parser.py`, `config_parser.py`, `config_writer.py`, `jail_config.py` (in `backend/app/utils/`)
- **Tasks**: Add `geo_re_resolve.py` (in `backend/app/tasks/`)
- **Services**: Correct the entry that lists `conffile_parser` as a service — it is in `app/utils/`
- **Routers**: Add `file_config.py` (in `backend/app/routers/`)
3. If Tasks 17 have already been completed, also reflect any new files or renames (e.g., `fail2ban_db_utils.py`, `exceptions.py`, the split service files, renamed services).
4. Verify no other files exist that are missing from the doc by comparing the doc's file lists against `ls backend/app/*/`.
**Acceptance criteria**: Every `.py` file under `backend/app/` (excluding `__init__.py` and `__pycache__`) is mentioned in the Architecture doc.
---
### Task 9 — Add React Error Boundary to the frontend
**Priority**: Medium
**Refactoring ref**: Refactoring.md §6
**Affected files**:
- New file: `frontend/src/components/ErrorBoundary.tsx`
- `frontend/src/App.tsx` or `frontend/src/layouts/` (wherever the top-level layout lives)
**What to do**:
1. Create `frontend/src/components/ErrorBoundary.tsx` — a React class component implementing `componentDidCatch` and `getDerivedStateFromError`. It should:
- Catch any rendering error in its children.
- Display a user-friendly fallback UI (e.g., "Something went wrong" message with a "Reload" button that calls `window.location.reload()`).
- Log the error (console.error is sufficient for now).
2. Read `frontend/src/App.tsx` to find the main layout/route wrapper.
3. Wrap the main content (inside `<App>` or `<MainLayout>`) with `<ErrorBoundary>` so that any component crash shows the fallback instead of a white screen.
4. Run existing frontend tests: `cd frontend && npx vitest run` — all tests must pass.
**Acceptance criteria**: An `<ErrorBoundary>` component exists and wraps the application's main content. A component throwing during render shows a fallback UI instead of crashing the whole app.
---
### Task 10 — Consolidate duplicated formatting functions (frontend)
**Priority**: Low
**Refactoring ref**: Refactoring.md §7
**Affected files**:
- `frontend/src/components/BanTable.tsx` (has `formatTimestamp()` ~L103)
- `frontend/src/components/jail/BannedIpsSection.tsx` (has `fmtTime()` ~L139)
- `frontend/src/pages/JailDetailPage.tsx` (has `fmtSeconds()` ~L152)
- `frontend/src/pages/JailsPage.tsx` (has `fmtSeconds()` ~L147)
**What to do**:
1. Create `frontend/src/utils/formatDate.ts`.
2. Define three exported functions:
- `formatTimestamp(ts: string): string` — consolidation of `formatTimestamp` and `fmtTime`
- `formatSeconds(seconds: number): string` — consolidation of the two identical `fmtSeconds` functions
3. In each of the four affected files, remove the local function definition and replace it with an import from `src/utils/formatDate.ts`. Adjust call sites if the function name changed.
4. Run existing frontend tests: `cd frontend && npx vitest run` — all tests must pass.
**Acceptance criteria**: No formatting function for dates/times is defined locally in a component or page file. All import from `src/utils/formatDate.ts`.
---
### Task 11 — Create generic `useConfigItem<T>` hook (frontend)
**Priority**: Low
**Refactoring ref**: Refactoring.md §8
**Affected files**:
- `frontend/src/hooks/useFilterConfig.ts` (~91 lines)
- `frontend/src/hooks/useActionConfig.ts` (~88 lines)
- `frontend/src/hooks/useJailFileConfig.ts` (~76 lines)
**What to do**:
1. Read all three hook files. Identify the common pattern: load item via fetch → store in state → expose save function → handle abort controller cleanup.
2. Create `frontend/src/hooks/useConfigItem.ts` with a generic hook:
```ts
function useConfigItem<T>(
fetchFn: (signal: AbortSignal) => Promise<T>,
saveFn: (data: T) => Promise<void>
): { data: T | null; loading: boolean; error: string | null; save: (data: T) => Promise<void> }
```
3. Rewrite `useFilterConfig.ts`, `useActionConfig.ts`, and `useJailFileConfig.ts` to be thin wrappers around `useConfigItem<T>` — each file should be <20 lines, just providing the specific fetch/save functions.
4. Run existing frontend tests: `cd frontend && npx vitest run` — all tests must pass.
**Acceptance criteria**: `useConfigItem.ts` exists. The three original hooks use it and each reduced to <20 lines of domain-specific glue.
---
### Task 12 — Standardise error handling in frontend hooks
**Priority**: Low
**Refactoring ref**: Refactoring.md §9
**Affected files**:
- All hook files in `frontend/src/hooks/` that do fetch calls (at least: `useHistory.ts`, `useMapData.ts`, `useBans.ts`, `useBlocklist.ts`, and others)
**What to do**:
1. Create a utility function in `frontend/src/utils/fetchError.ts`:
```ts
export function handleFetchError(err: unknown, setError: (msg: string | null) => void): void {
if (err instanceof DOMException && err.name === "AbortError") return;
setError(err instanceof Error ? err.message : "Unknown error");
}
```
2. Grep all hook files for `catch` blocks. In every hook that catches fetch errors:
- Replace the catch body with a call to `handleFetchError(err, setError)`.
3. Run existing frontend tests: `cd frontend && npx vitest run` — all tests must pass.
**Acceptance criteria**: Every hook that fetches data uses `handleFetchError()` in its catch block. No hook surfaces `AbortError` to the UI.
---
### Task 13 — Extract sub-components from large frontend pages
**Priority**: Low
**Refactoring ref**: Refactoring.md §11
**Affected files**:
- `frontend/src/pages/BlocklistsPage.tsx` (~968 lines)
- `frontend/src/components/config/JailsTab.tsx` (~939 lines)
- `frontend/src/pages/JailsPage.tsx` (~691 lines)
- `frontend/src/pages/JailDetailPage.tsx` (~663 lines)
**What to do**:
1. For each large file, identify logical UI sections that can be extracted into their own component files.
2. Suggested splits (adjust after reading the actual code):
- **BlocklistsPage.tsx**: Extract `<BlocklistSourceList>`, `<BlocklistAddEditDialog>`, `<BlocklistImportLog>`, `<BlocklistScheduleConfig>`.
- **JailsTab.tsx**: Extract `<JailConfigEditor>`, `<JailRawConfig>`, `<JailValidation>`, `<JailActivateDialog>`.
- **JailsPage.tsx**: Extract `<JailDetailDrawer>`, `<BanUnbanForm>`.
- **JailDetailPage.tsx**: Extract logical sections (examine the JSX to identify).
3. Place extracted components in appropriate directories (e.g., `frontend/src/components/blocklist/`, `frontend/src/components/jail/`).
4. Each parent page should import and compose the new child components. Props should be passed down — avoid prop drilling deeper than 2 levels (use context if needed).
5. Run existing frontend tests: `cd frontend && npx vitest run` — all tests must pass.
**Acceptance criteria**: No single page or component file exceeds ~400 lines. Each extracted component is in its own file.
---
### Task 14 — Consolidate duplicated section/card styles (frontend)
**Priority**: Low
**Refactoring ref**: Refactoring.md §12
**Affected files**:
- 13+ files across `frontend/src/pages/` and `frontend/src/components/` that define identical card/section styles using `makeStyles` with `backgroundColor`, `borderRadius`, `border`, `padding` using Fluent UI tokens.
**What to do**:
1. Grep the frontend codebase for `makeStyles` calls that contain `backgroundColor` and `borderRadius` together. Identify the common pattern.
2. Create `frontend/src/theme/commonStyles.ts` with a shared `useCardStyles()` hook that exports the common section/card style class.
3. In each of the 13+ files, remove the local `makeStyles` definition for the card/section style and import `useCardStyles` from `commonStyles.ts` instead. Keep any file-specific style overrides local.
4. Run existing frontend tests: `cd frontend && npx vitest run` — all tests must pass.
**Acceptance criteria**: A shared `useCardStyles()` exists in `frontend/src/theme/commonStyles.ts`. At least 10 files import it instead of defining their own card styles.

23
backend/app/exceptions.py Normal file
View File

@@ -0,0 +1,23 @@
"""Shared domain exception classes used across routers and services."""
from __future__ import annotations
class JailNotFoundError(Exception):
"""Raised when a requested jail name does not exist."""
class JailOperationError(Exception):
"""Raised when a fail2ban jail operation fails."""
class ConfigValidationError(Exception):
"""Raised when config values fail validation before applying."""
class ConfigOperationError(Exception):
"""Raised when a config payload update or command fails."""
class ServerOperationError(Exception):
"""Raised when a server control command (e.g. refresh) fails."""

View File

@@ -3,8 +3,18 @@
Response models for the ``GET /api/geo/lookup/{ip}`` endpoint. Response models for the ``GET /api/geo/lookup/{ip}`` endpoint.
""" """
from __future__ import annotations
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from typing import TYPE_CHECKING
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
if TYPE_CHECKING:
import aiohttp
import aiosqlite
class GeoDetail(BaseModel): class GeoDetail(BaseModel):
"""Enriched geolocation data for an IP address. """Enriched geolocation data for an IP address.
@@ -64,3 +74,26 @@ class IpLookupResponse(BaseModel):
default=None, default=None,
description="Enriched geographical and network information.", description="Enriched geographical and network information.",
) )
# ---------------------------------------------------------------------------
# shared service types
# ---------------------------------------------------------------------------
@dataclass
class GeoInfo:
"""Geo resolution result used throughout backend services."""
country_code: str | None
country_name: str | None
asn: str | None
org: str | None
GeoEnricher = Callable[[str], Awaitable[GeoInfo | None]]
GeoBatchLookup = Callable[
[list[str], "aiohttp.ClientSession", "aiosqlite.Connection | None"],
Awaitable[dict[str, GeoInfo]],
]
GeoCacheLookup = Callable[[list[str]], tuple[dict[str, GeoInfo], list[str]]]

View File

@@ -20,8 +20,8 @@ from fastapi import APIRouter, HTTPException, Request, status
from app.dependencies import AuthDep from app.dependencies import AuthDep
from app.models.ban import ActiveBanListResponse, BanRequest, UnbanAllResponse, UnbanRequest from app.models.ban import ActiveBanListResponse, BanRequest, UnbanAllResponse, UnbanRequest
from app.models.jail import JailCommandResponse from app.models.jail import JailCommandResponse
from app.services import jail_service from app.services import geo_service, jail_service
from app.services.jail_service import JailNotFoundError, JailOperationError from app.exceptions import JailNotFoundError, JailOperationError
from app.utils.fail2ban_client import Fail2BanConnectionError from app.utils.fail2ban_client import Fail2BanConnectionError
router: APIRouter = APIRouter(prefix="/api/bans", tags=["Bans"]) router: APIRouter = APIRouter(prefix="/api/bans", tags=["Bans"])
@@ -73,6 +73,7 @@ async def get_active_bans(
try: try:
return await jail_service.get_active_bans( return await jail_service.get_active_bans(
socket_path, socket_path,
geo_batch_lookup=geo_service.lookup_batch,
http_session=http_session, http_session=http_session,
app_db=app_db, app_db=app_db,
) )

View File

@@ -42,7 +42,7 @@ from app.models.blocklist import (
ScheduleConfig, ScheduleConfig,
ScheduleInfo, ScheduleInfo,
) )
from app.services import blocklist_service from app.services import blocklist_service, geo_service
from app.tasks import blocklist_import as blocklist_import_task from app.tasks import blocklist_import as blocklist_import_task
router: APIRouter = APIRouter(prefix="/api/blocklists", tags=["Blocklists"]) router: APIRouter = APIRouter(prefix="/api/blocklists", tags=["Blocklists"])
@@ -131,7 +131,13 @@ async def run_import_now(
""" """
http_session: aiohttp.ClientSession = request.app.state.http_session http_session: aiohttp.ClientSession = request.app.state.http_session
socket_path: str = request.app.state.settings.fail2ban_socket socket_path: str = request.app.state.settings.fail2ban_socket
return await blocklist_service.import_all(db, http_session, socket_path) return await blocklist_service.import_all(
db,
http_session,
socket_path,
geo_is_cached=geo_service.is_cached,
geo_batch_lookup=geo_service.lookup_batch,
)
@router.get( @router.get(

View File

@@ -93,12 +93,7 @@ from app.services.config_file_service import (
JailNameError, JailNameError,
JailNotFoundInConfigError, JailNotFoundInConfigError,
) )
from app.services.config_service import ( from app.exceptions import ConfigOperationError, ConfigValidationError, JailNotFoundError, JailOperationError
ConfigOperationError,
ConfigValidationError,
JailNotFoundError,
)
from app.services.jail_service import JailOperationError
from app.tasks.health_check import _run_probe from app.tasks.health_check import _run_probe
from app.utils.fail2ban_client import Fail2BanConnectionError from app.utils.fail2ban_client import Fail2BanConnectionError

View File

@@ -29,7 +29,7 @@ from app.models.ban import (
TimeRange, TimeRange,
) )
from app.models.server import ServerStatus, ServerStatusResponse from app.models.server import ServerStatus, ServerStatusResponse
from app.services import ban_service from app.services import ban_service, geo_service
router: APIRouter = APIRouter(prefix="/api/dashboard", tags=["Dashboard"]) router: APIRouter = APIRouter(prefix="/api/dashboard", tags=["Dashboard"])
@@ -119,6 +119,7 @@ async def get_dashboard_bans(
page_size=page_size, page_size=page_size,
http_session=http_session, http_session=http_session,
app_db=None, app_db=None,
geo_batch_lookup=geo_service.lookup_batch,
origin=origin, origin=origin,
) )
@@ -162,6 +163,8 @@ async def get_bans_by_country(
socket_path, socket_path,
range, range,
http_session=http_session, http_session=http_session,
geo_cache_lookup=geo_service.lookup_cached_only,
geo_batch_lookup=geo_service.lookup_batch,
app_db=None, app_db=None,
origin=origin, origin=origin,
) )

View File

@@ -19,9 +19,8 @@ import aiosqlite
from fastapi import APIRouter, Depends, HTTPException, Path, Request, status from fastapi import APIRouter, Depends, HTTPException, Path, Request, status
from app.dependencies import AuthDep, get_db from app.dependencies import AuthDep, get_db
from app.models.geo import GeoCacheStatsResponse, GeoDetail, IpLookupResponse from app.models.geo import GeoCacheStatsResponse, GeoDetail, GeoInfo, IpLookupResponse
from app.services import geo_service, jail_service from app.services import geo_service, jail_service
from app.services.geo_service import GeoInfo
from app.utils.fail2ban_client import Fail2BanConnectionError from app.utils.fail2ban_client import Fail2BanConnectionError
router: APIRouter = APIRouter(prefix="/api/geo", tags=["Geo"]) router: APIRouter = APIRouter(prefix="/api/geo", tags=["Geo"])

View File

@@ -31,8 +31,8 @@ from app.models.jail import (
JailDetailResponse, JailDetailResponse,
JailListResponse, JailListResponse,
) )
from app.services import jail_service from app.services import geo_service, jail_service
from app.services.jail_service import JailNotFoundError, JailOperationError from app.exceptions import JailNotFoundError, JailOperationError
from app.utils.fail2ban_client import Fail2BanConnectionError from app.utils.fail2ban_client import Fail2BanConnectionError
router: APIRouter = APIRouter(prefix="/api/jails", tags=["Jails"]) router: APIRouter = APIRouter(prefix="/api/jails", tags=["Jails"])
@@ -606,6 +606,7 @@ async def get_jail_banned_ips(
page=page, page=page,
page_size=page_size, page_size=page_size,
search=search, search=search,
geo_batch_lookup=geo_service.lookup_batch,
http_session=http_session, http_session=http_session,
app_db=app_db, app_db=app_db,
) )

View File

@@ -15,7 +15,7 @@ from fastapi import APIRouter, HTTPException, Request, status
from app.dependencies import AuthDep from app.dependencies import AuthDep
from app.models.server import ServerSettingsResponse, ServerSettingsUpdate from app.models.server import ServerSettingsResponse, ServerSettingsUpdate
from app.services import server_service from app.services import server_service
from app.services.server_service import ServerOperationError from app.exceptions import ServerOperationError
from app.utils.fail2ban_client import Fail2BanConnectionError from app.utils.fail2ban_client import Fail2BanConnectionError
router: APIRouter = APIRouter(prefix="/api/server", tags=["Server"]) router: APIRouter = APIRouter(prefix="/api/server", tags=["Server"])

View File

@@ -11,11 +11,8 @@ so BanGUI never modifies or locks the fail2ban database.
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import json
import time import time
from collections.abc import Awaitable, Callable from typing import TYPE_CHECKING
from datetime import UTC, datetime
from typing import TYPE_CHECKING, cast
import structlog import structlog
@@ -39,18 +36,16 @@ from app.models.ban import (
JailBanCount as JailBanCountModel, JailBanCount as JailBanCountModel,
) )
from app.repositories import fail2ban_db_repo from app.repositories import fail2ban_db_repo
from app.utils.fail2ban_client import Fail2BanClient, Fail2BanResponse from app.utils.fail2ban_db_utils import get_fail2ban_db_path, parse_data_json, ts_to_iso
if TYPE_CHECKING: if TYPE_CHECKING:
import aiohttp import aiohttp
import aiosqlite import aiosqlite
from app.services.geo_service import GeoInfo from app.models.geo import GeoBatchLookup, GeoCacheLookup, GeoEnricher, GeoInfo
log: structlog.stdlib.BoundLogger = structlog.get_logger() log: structlog.stdlib.BoundLogger = structlog.get_logger()
type GeoEnricher = Callable[[str], Awaitable[GeoInfo | None]]
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Constants # Constants
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -102,98 +97,6 @@ def _since_unix(range_: TimeRange) -> int:
return int(time.time()) - seconds return int(time.time()) - seconds
def _ts_to_iso(unix_ts: int) -> str:
"""Convert a Unix timestamp to an ISO 8601 UTC string.
Args:
unix_ts: Seconds since the Unix epoch.
Returns:
ISO 8601 UTC timestamp, e.g. ``"2026-03-01T12:00:00+00:00"``.
"""
return datetime.fromtimestamp(unix_ts, tz=UTC).isoformat()
async def _get_fail2ban_db_path(socket_path: str) -> str:
"""Query fail2ban for the path to its SQLite database.
Sends the ``get dbfile`` command via the fail2ban socket and returns
the value of the ``dbfile`` setting.
Args:
socket_path: Path to the fail2ban Unix domain socket.
Returns:
Absolute path to the fail2ban SQLite database file.
Raises:
RuntimeError: If fail2ban reports that no database is configured
or if the socket response is unexpected.
~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket
cannot be reached.
"""
async with Fail2BanClient(socket_path, timeout=_SOCKET_TIMEOUT) as client:
response = await client.send(["get", "dbfile"])
try:
code, data = cast("Fail2BanResponse", response)
except (TypeError, ValueError) as exc:
raise RuntimeError(f"Unexpected response from fail2ban: {response!r}") from exc
if code != 0:
raise RuntimeError(f"fail2ban error code {code}: {data!r}")
if data is None:
raise RuntimeError("fail2ban has no database configured (dbfile is None)")
return str(data)
def _parse_data_json(raw: object) -> tuple[list[str], int]:
"""Extract matches and failure count from the ``bans.data`` column.
The ``data`` column stores a JSON blob with optional keys:
* ``matches`` — list of raw matched log lines.
* ``failures`` — total failure count that triggered the ban.
Args:
raw: The raw ``data`` column value (string, dict, or ``None``).
Returns:
A ``(matches, failures)`` tuple. Both default to empty/zero when
parsing fails or the column is absent.
"""
if raw is None:
return [], 0
obj: dict[str, object] = {}
if isinstance(raw, str):
try:
parsed: object = json.loads(raw)
if isinstance(parsed, dict):
obj = parsed
# json.loads("null") → None, or other non-dict — treat as empty
except json.JSONDecodeError:
return [], 0
elif isinstance(raw, dict):
obj = raw
raw_matches = obj.get("matches")
if isinstance(raw_matches, list):
matches: list[str] = [str(m) for m in raw_matches]
else:
matches = []
raw_failures = obj.get("failures")
failures: int = 0
if isinstance(raw_failures, (int, float, str)):
try:
failures = int(raw_failures)
except (ValueError, TypeError):
failures = 0
return matches, failures
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -209,6 +112,7 @@ async def list_bans(
page_size: int = _DEFAULT_PAGE_SIZE, page_size: int = _DEFAULT_PAGE_SIZE,
http_session: aiohttp.ClientSession | None = None, http_session: aiohttp.ClientSession | None = None,
app_db: aiosqlite.Connection | None = None, app_db: aiosqlite.Connection | None = None,
geo_batch_lookup: GeoBatchLookup | None = None,
geo_enricher: GeoEnricher | None = None, geo_enricher: GeoEnricher | None = None,
origin: BanOrigin | None = None, origin: BanOrigin | None = None,
) -> DashboardBanListResponse: ) -> DashboardBanListResponse:
@@ -248,14 +152,13 @@ async def list_bans(
:class:`~app.models.ban.DashboardBanListResponse` containing the :class:`~app.models.ban.DashboardBanListResponse` containing the
paginated items and total count. paginated items and total count.
""" """
from app.services import geo_service # noqa: PLC0415
since: int = _since_unix(range_) since: int = _since_unix(range_)
effective_page_size: int = min(page_size, _MAX_PAGE_SIZE) effective_page_size: int = min(page_size, _MAX_PAGE_SIZE)
offset: int = (page - 1) * effective_page_size offset: int = (page - 1) * effective_page_size
origin_clause, origin_params = _origin_sql_filter(origin) origin_clause, origin_params = _origin_sql_filter(origin)
db_path: str = await _get_fail2ban_db_path(socket_path) db_path: str = await get_fail2ban_db_path(socket_path)
log.info( log.info(
"ban_service_list_bans", "ban_service_list_bans",
db_path=db_path, db_path=db_path,
@@ -276,10 +179,10 @@ async def list_bans(
# This avoids hitting the 45 req/min single-IP rate limit when the # This avoids hitting the 45 req/min single-IP rate limit when the
# page contains many bans (e.g. after a large blocklist import). # page contains many bans (e.g. after a large blocklist import).
geo_map: dict[str, GeoInfo] = {} geo_map: dict[str, GeoInfo] = {}
if http_session is not None and rows: if http_session is not None and rows and geo_batch_lookup is not None:
page_ips: list[str] = [r.ip for r in rows] page_ips: list[str] = [r.ip for r in rows]
try: try:
geo_map = await geo_service.lookup_batch(page_ips, http_session, db=app_db) geo_map = await geo_batch_lookup(page_ips, http_session, db=app_db)
except Exception: # noqa: BLE001 except Exception: # noqa: BLE001
log.warning("ban_service_batch_geo_failed_list_bans") log.warning("ban_service_batch_geo_failed_list_bans")
@@ -287,9 +190,9 @@ async def list_bans(
for row in rows: for row in rows:
jail: str = row.jail jail: str = row.jail
ip: str = row.ip ip: str = row.ip
banned_at: str = _ts_to_iso(row.timeofban) banned_at: str = ts_to_iso(row.timeofban)
ban_count: int = row.bancount ban_count: int = row.bancount
matches, _ = _parse_data_json(row.data) matches, _ = parse_data_json(row.data)
service: str | None = matches[0] if matches else None service: str | None = matches[0] if matches else None
country_code: str | None = None country_code: str | None = None
@@ -350,6 +253,8 @@ async def bans_by_country(
socket_path: str, socket_path: str,
range_: TimeRange, range_: TimeRange,
http_session: aiohttp.ClientSession | None = None, http_session: aiohttp.ClientSession | None = None,
geo_cache_lookup: GeoCacheLookup | None = None,
geo_batch_lookup: GeoBatchLookup | None = None,
geo_enricher: GeoEnricher | None = None, geo_enricher: GeoEnricher | None = None,
app_db: aiosqlite.Connection | None = None, app_db: aiosqlite.Connection | None = None,
origin: BanOrigin | None = None, origin: BanOrigin | None = None,
@@ -389,11 +294,10 @@ async def bans_by_country(
:class:`~app.models.ban.BansByCountryResponse` with per-country :class:`~app.models.ban.BansByCountryResponse` with per-country
aggregation and the companion ban list. aggregation and the companion ban list.
""" """
from app.services import geo_service # noqa: PLC0415
since: int = _since_unix(range_) since: int = _since_unix(range_)
origin_clause, origin_params = _origin_sql_filter(origin) origin_clause, origin_params = _origin_sql_filter(origin)
db_path: str = await _get_fail2ban_db_path(socket_path) db_path: str = await get_fail2ban_db_path(socket_path)
log.info( log.info(
"ban_service_bans_by_country", "ban_service_bans_by_country",
db_path=db_path, db_path=db_path,
@@ -429,23 +333,24 @@ async def bans_by_country(
unique_ips: list[str] = [r.ip for r in agg_rows] unique_ips: list[str] = [r.ip for r in agg_rows]
geo_map: dict[str, GeoInfo] = {} geo_map: dict[str, GeoInfo] = {}
if http_session is not None and unique_ips: if http_session is not None and unique_ips and geo_cache_lookup is not None:
# Serve only what is already in the in-memory cache — no API calls on # Serve only what is already in the in-memory cache — no API calls on
# the hot path. Uncached IPs are resolved asynchronously in the # the hot path. Uncached IPs are resolved asynchronously in the
# background so subsequent requests benefit from a warmer cache. # background so subsequent requests benefit from a warmer cache.
geo_map, uncached = geo_service.lookup_cached_only(unique_ips) geo_map, uncached = geo_cache_lookup(unique_ips)
if uncached: if uncached:
log.info( log.info(
"ban_service_geo_background_scheduled", "ban_service_geo_background_scheduled",
uncached=len(uncached), uncached=len(uncached),
cached=len(geo_map), cached=len(geo_map),
) )
# Fire-and-forget: lookup_batch handles rate-limiting / retries. if geo_batch_lookup is not None:
# The dirty-set flush task persists results to the DB. # Fire-and-forget: lookup_batch handles rate-limiting / retries.
asyncio.create_task( # noqa: RUF006 # The dirty-set flush task persists results to the DB.
geo_service.lookup_batch(uncached, http_session, db=app_db), asyncio.create_task( # noqa: RUF006
name="geo_bans_by_country", geo_batch_lookup(uncached, http_session, db=app_db),
) name="geo_bans_by_country",
)
elif geo_enricher is not None and unique_ips: elif geo_enricher is not None and unique_ips:
# Fallback: legacy per-IP enricher (used in tests / older callers). # Fallback: legacy per-IP enricher (used in tests / older callers).
async def _safe_lookup(ip: str) -> tuple[str, GeoInfo | None]: async def _safe_lookup(ip: str) -> tuple[str, GeoInfo | None]:
@@ -483,13 +388,13 @@ async def bans_by_country(
cn = geo.country_name if geo else None cn = geo.country_name if geo else None
asn: str | None = geo.asn if geo else None asn: str | None = geo.asn if geo else None
org: str | None = geo.org if geo else None org: str | None = geo.org if geo else None
matches, _ = _parse_data_json(companion_row.data) matches, _ = parse_data_json(companion_row.data)
bans.append( bans.append(
DashboardBanItem( DashboardBanItem(
ip=ip, ip=ip,
jail=companion_row.jail, jail=companion_row.jail,
banned_at=_ts_to_iso(companion_row.timeofban), banned_at=ts_to_iso(companion_row.timeofban),
service=matches[0] if matches else None, service=matches[0] if matches else None,
country_code=cc, country_code=cc,
country_name=cn, country_name=cn,
@@ -550,7 +455,7 @@ async def ban_trend(
num_buckets: int = bucket_count(range_) num_buckets: int = bucket_count(range_)
origin_clause, origin_params = _origin_sql_filter(origin) origin_clause, origin_params = _origin_sql_filter(origin)
db_path: str = await _get_fail2ban_db_path(socket_path) db_path: str = await get_fail2ban_db_path(socket_path)
log.info( log.info(
"ban_service_ban_trend", "ban_service_ban_trend",
db_path=db_path, db_path=db_path,
@@ -571,7 +476,7 @@ async def ban_trend(
buckets: list[BanTrendBucket] = [ buckets: list[BanTrendBucket] = [
BanTrendBucket( BanTrendBucket(
timestamp=_ts_to_iso(since + i * bucket_secs), timestamp=ts_to_iso(since + i * bucket_secs),
count=counts[i], count=counts[i],
) )
for i in range(num_buckets) for i in range(num_buckets)
@@ -615,12 +520,12 @@ async def bans_by_jail(
since: int = _since_unix(range_) since: int = _since_unix(range_)
origin_clause, origin_params = _origin_sql_filter(origin) origin_clause, origin_params = _origin_sql_filter(origin)
db_path: str = await _get_fail2ban_db_path(socket_path) db_path: str = await get_fail2ban_db_path(socket_path)
log.debug( log.debug(
"ban_service_bans_by_jail", "ban_service_bans_by_jail",
db_path=db_path, db_path=db_path,
since=since, since=since,
since_iso=_ts_to_iso(since), since_iso=ts_to_iso(since),
range=range_, range=range_,
origin=origin, origin=origin,
) )

View File

@@ -33,9 +33,13 @@ from app.repositories import blocklist_repo, import_log_repo, settings_repo
from app.utils.ip_utils import is_valid_ip, is_valid_network from app.utils.ip_utils import is_valid_ip, is_valid_network
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Callable
import aiohttp import aiohttp
import aiosqlite import aiosqlite
from app.models.geo import GeoBatchLookup
log: structlog.stdlib.BoundLogger = structlog.get_logger() log: structlog.stdlib.BoundLogger = structlog.get_logger()
#: Settings key used to persist the schedule config. #: Settings key used to persist the schedule config.
@@ -238,6 +242,8 @@ async def import_source(
http_session: aiohttp.ClientSession, http_session: aiohttp.ClientSession,
socket_path: str, socket_path: str,
db: aiosqlite.Connection, db: aiosqlite.Connection,
geo_is_cached: Callable[[str], bool] | None = None,
geo_batch_lookup: GeoBatchLookup | None = None,
) -> ImportSourceResult: ) -> ImportSourceResult:
"""Download and apply bans from a single blocklist source. """Download and apply bans from a single blocklist source.
@@ -339,12 +345,8 @@ async def import_source(
) )
# --- Pre-warm geo cache for newly imported IPs --- # --- Pre-warm geo cache for newly imported IPs ---
if imported_ips: if imported_ips and geo_is_cached is not None:
from app.services import geo_service # noqa: PLC0415 uncached_ips: list[str] = [ip for ip in imported_ips if not geo_is_cached(ip)]
uncached_ips: list[str] = [
ip for ip in imported_ips if not geo_service.is_cached(ip)
]
skipped_geo: int = len(imported_ips) - len(uncached_ips) skipped_geo: int = len(imported_ips) - len(uncached_ips)
if skipped_geo > 0: if skipped_geo > 0:
@@ -355,9 +357,9 @@ async def import_source(
to_lookup=len(uncached_ips), to_lookup=len(uncached_ips),
) )
if uncached_ips: if uncached_ips and geo_batch_lookup is not None:
try: try:
await geo_service.lookup_batch(uncached_ips, http_session, db=db) await geo_batch_lookup(uncached_ips, http_session, db=db)
log.info( log.info(
"blocklist_geo_prewarm_complete", "blocklist_geo_prewarm_complete",
source_id=source.id, source_id=source.id,
@@ -383,6 +385,8 @@ async def import_all(
db: aiosqlite.Connection, db: aiosqlite.Connection,
http_session: aiohttp.ClientSession, http_session: aiohttp.ClientSession,
socket_path: str, socket_path: str,
geo_is_cached: Callable[[str], bool] | None = None,
geo_batch_lookup: GeoBatchLookup | None = None,
) -> ImportRunResult: ) -> ImportRunResult:
"""Import all enabled blocklist sources. """Import all enabled blocklist sources.
@@ -406,7 +410,14 @@ async def import_all(
for row in sources: for row in sources:
source = _row_to_source(row) source = _row_to_source(row)
result = await import_source(source, http_session, socket_path, db) result = await import_source(
source,
http_session,
socket_path,
db,
geo_is_cached=geo_is_cached,
geo_batch_lookup=geo_batch_lookup,
)
results.append(result) results.append(result)
total_imported += result.ips_imported total_imported += result.ips_imported
total_skipped += result.ips_skipped total_skipped += result.ips_skipped

View File

@@ -54,8 +54,8 @@ from app.models.config import (
JailValidationResult, JailValidationResult,
RollbackResponse, RollbackResponse,
) )
from app.exceptions import JailNotFoundError
from app.services import jail_service from app.services import jail_service
from app.services.jail_service import JailNotFoundError as JailNotFoundError
from app.utils import conffile_parser from app.utils import conffile_parser
from app.utils.fail2ban_client import ( from app.utils.fail2ban_client import (
Fail2BanClient, Fail2BanClient,

View File

@@ -44,6 +44,7 @@ from app.models.config import (
RegexTestResponse, RegexTestResponse,
ServiceStatusResponse, ServiceStatusResponse,
) )
from app.exceptions import ConfigOperationError, ConfigValidationError, JailNotFoundError
from app.services import setup_service from app.services import setup_service
from app.utils.fail2ban_client import Fail2BanClient from app.utils.fail2ban_client import Fail2BanClient
@@ -55,26 +56,7 @@ _SOCKET_TIMEOUT: float = 10.0
# Custom exceptions # Custom exceptions
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# (exceptions are now defined in app.exceptions and imported above)
class JailNotFoundError(Exception):
"""Raised when a requested jail name does not exist in fail2ban."""
def __init__(self, name: str) -> None:
"""Initialise with the jail name that was not found.
Args:
name: The jail name that could not be located.
"""
self.name: str = name
super().__init__(f"Jail not found: {name!r}")
class ConfigValidationError(Exception):
"""Raised when a configuration value fails validation before writing."""
class ConfigOperationError(Exception):
"""Raised when a configuration write command fails."""
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------

View File

@@ -41,13 +41,12 @@ from __future__ import annotations
import asyncio import asyncio
import time import time
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import aiohttp import aiohttp
import structlog import structlog
from app.models.geo import GeoInfo
from app.repositories import geo_cache_repo from app.repositories import geo_cache_repo
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -94,40 +93,6 @@ _BATCH_DELAY: float = 1.5
#: transient error (e.g. connection reset due to rate limiting). #: transient error (e.g. connection reset due to rate limiting).
_BATCH_MAX_RETRIES: int = 2 _BATCH_MAX_RETRIES: int = 2
# ---------------------------------------------------------------------------
# Domain model
# ---------------------------------------------------------------------------
@dataclass
class GeoInfo:
"""Geographical and network metadata for a single IP address.
All fields default to ``None`` when the information is unavailable or
the lookup fails gracefully.
"""
country_code: str | None
"""ISO 3166-1 alpha-2 country code, e.g. ``"DE"``."""
country_name: str | None
"""Human-readable country name, e.g. ``"Germany"``."""
asn: str | None
"""Autonomous System Number string, e.g. ``"AS3320"``."""
org: str | None
"""Organisation name associated with the IP, e.g. ``"Deutsche Telekom"``."""
type GeoEnricher = Callable[[str], Awaitable[GeoInfo | None]]
"""Async callable used to enrich IPs with :class:`~app.services.geo_service.GeoInfo`.
This is a shared type alias used by services that optionally accept a geo
lookup callable (for example, :mod:`app.services.history_service`).
"""
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Internal cache # Internal cache
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------

View File

@@ -16,7 +16,7 @@ from typing import TYPE_CHECKING
import structlog import structlog
if TYPE_CHECKING: if TYPE_CHECKING:
from app.services.geo_service import GeoEnricher from app.models.geo import GeoEnricher
from app.models.ban import TIME_RANGE_SECONDS, TimeRange from app.models.ban import TIME_RANGE_SECONDS, TimeRange
from app.models.history import ( from app.models.history import (
@@ -26,7 +26,7 @@ from app.models.history import (
IpTimelineEvent, IpTimelineEvent,
) )
from app.repositories import fail2ban_db_repo from app.repositories import fail2ban_db_repo
from app.services.ban_service import _get_fail2ban_db_path, _parse_data_json, _ts_to_iso from app.utils.fail2ban_db_utils import get_fail2ban_db_path, parse_data_json, ts_to_iso
log: structlog.stdlib.BoundLogger = structlog.get_logger() log: structlog.stdlib.BoundLogger = structlog.get_logger()
@@ -93,7 +93,7 @@ async def list_history(
if range_ is not None: if range_ is not None:
since = _since_unix(range_) since = _since_unix(range_)
db_path: str = await _get_fail2ban_db_path(socket_path) db_path: str = await get_fail2ban_db_path(socket_path)
log.info( log.info(
"history_service_list", "history_service_list",
db_path=db_path, db_path=db_path,
@@ -116,9 +116,9 @@ async def list_history(
for row in rows: for row in rows:
jail_name: str = row.jail jail_name: str = row.jail
ip: str = row.ip ip: str = row.ip
banned_at: str = _ts_to_iso(row.timeofban) banned_at: str = ts_to_iso(row.timeofban)
ban_count: int = row.bancount ban_count: int = row.bancount
matches, failures = _parse_data_json(row.data) matches, failures = parse_data_json(row.data)
country_code: str | None = None country_code: str | None = None
country_name: str | None = None country_name: str | None = None
@@ -180,7 +180,7 @@ async def get_ip_detail(
:class:`~app.models.history.IpDetailResponse` if any records exist :class:`~app.models.history.IpDetailResponse` if any records exist
for *ip*, or ``None`` if the IP has no history in the database. for *ip*, or ``None`` if the IP has no history in the database.
""" """
db_path: str = await _get_fail2ban_db_path(socket_path) db_path: str = await get_fail2ban_db_path(socket_path)
log.info("history_service_ip_detail", db_path=db_path, ip=ip) log.info("history_service_ip_detail", db_path=db_path, ip=ip)
rows = await fail2ban_db_repo.get_history_for_ip(db_path=db_path, ip=ip) rows = await fail2ban_db_repo.get_history_for_ip(db_path=db_path, ip=ip)
@@ -193,9 +193,9 @@ async def get_ip_detail(
for row in rows: for row in rows:
jail_name: str = row.jail jail_name: str = row.jail
banned_at: str = _ts_to_iso(row.timeofban) banned_at: str = ts_to_iso(row.timeofban)
ban_count: int = row.bancount ban_count: int = row.bancount
matches, failures = _parse_data_json(row.data) matches, failures = parse_data_json(row.data)
total_failures += failures total_failures += failures
timeline.append( timeline.append(
IpTimelineEvent( IpTimelineEvent(

View File

@@ -14,11 +14,11 @@ from __future__ import annotations
import asyncio import asyncio
import contextlib import contextlib
import ipaddress import ipaddress
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, TypedDict, cast from typing import TYPE_CHECKING, TypedDict, cast
import structlog import structlog
from app.exceptions import JailNotFoundError, JailOperationError
from app.models.ban import ActiveBan, ActiveBanListResponse, JailBannedIpsResponse from app.models.ban import ActiveBan, ActiveBanListResponse, JailBannedIpsResponse
from app.models.config import BantimeEscalation from app.models.config import BantimeEscalation
from app.models.jail import ( from app.models.jail import (
@@ -28,7 +28,6 @@ from app.models.jail import (
JailStatus, JailStatus,
JailSummary, JailSummary,
) )
from app.services.geo_service import GeoInfo
from app.utils.fail2ban_client import ( from app.utils.fail2ban_client import (
Fail2BanClient, Fail2BanClient,
Fail2BanCommand, Fail2BanCommand,
@@ -38,9 +37,13 @@ from app.utils.fail2ban_client import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Awaitable
import aiohttp import aiohttp
import aiosqlite import aiosqlite
from app.models.geo import GeoBatchLookup, GeoEnricher, GeoInfo
log: structlog.stdlib.BoundLogger = structlog.get_logger() log: structlog.stdlib.BoundLogger = structlog.get_logger()
class IpLookupResult(TypedDict): class IpLookupResult(TypedDict):
@@ -55,8 +58,6 @@ class IpLookupResult(TypedDict):
geo: GeoInfo | None geo: GeoInfo | None
GeoEnricher = Callable[[str], Awaitable[GeoInfo | None]]
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Constants # Constants
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -81,23 +82,6 @@ _backend_cmd_lock: asyncio.Lock = asyncio.Lock()
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class JailNotFoundError(Exception):
"""Raised when a requested jail name does not exist in fail2ban."""
def __init__(self, name: str) -> None:
"""Initialise with the jail name that was not found.
Args:
name: The jail name that could not be located.
"""
self.name: str = name
super().__init__(f"Jail not found: {name!r}")
class JailOperationError(Exception):
"""Raised when a jail control command fails for a non-auth reason."""
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Internal helpers # Internal helpers
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -820,6 +804,7 @@ async def unban_ip(
async def get_active_bans( async def get_active_bans(
socket_path: str, socket_path: str,
geo_batch_lookup: GeoBatchLookup | None = None,
geo_enricher: GeoEnricher | None = None, geo_enricher: GeoEnricher | None = None,
http_session: aiohttp.ClientSession | None = None, http_session: aiohttp.ClientSession | None = None,
app_db: aiosqlite.Connection | None = None, app_db: aiosqlite.Connection | None = None,
@@ -857,7 +842,6 @@ async def get_active_bans(
~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket ~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket
cannot be reached. cannot be reached.
""" """
from app.services import geo_service # noqa: PLC0415
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT) client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
@@ -905,10 +889,10 @@ async def get_active_bans(
bans.append(ban) bans.append(ban)
# Enrich with geo data — prefer batch lookup over per-IP enricher. # Enrich with geo data — prefer batch lookup over per-IP enricher.
if http_session is not None and bans: if http_session is not None and bans and geo_batch_lookup is not None:
all_ips: list[str] = [ban.ip for ban in bans] all_ips: list[str] = [ban.ip for ban in bans]
try: try:
geo_map = await geo_service.lookup_batch(all_ips, http_session, db=app_db) geo_map = await geo_batch_lookup(all_ips, http_session, db=app_db)
except Exception: # noqa: BLE001 except Exception: # noqa: BLE001
log.warning("active_bans_batch_geo_failed") log.warning("active_bans_batch_geo_failed")
geo_map = {} geo_map = {}
@@ -1017,6 +1001,7 @@ async def get_jail_banned_ips(
page: int = 1, page: int = 1,
page_size: int = 25, page_size: int = 25,
search: str | None = None, search: str | None = None,
geo_batch_lookup: GeoBatchLookup | None = None,
http_session: aiohttp.ClientSession | None = None, http_session: aiohttp.ClientSession | None = None,
app_db: aiosqlite.Connection | None = None, app_db: aiosqlite.Connection | None = None,
) -> JailBannedIpsResponse: ) -> JailBannedIpsResponse:
@@ -1044,8 +1029,6 @@ async def get_jail_banned_ips(
~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket is ~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket is
unreachable. unreachable.
""" """
from app.services import geo_service # noqa: PLC0415
# Clamp page_size to the allowed maximum. # Clamp page_size to the allowed maximum.
page_size = min(page_size, _MAX_PAGE_SIZE) page_size = min(page_size, _MAX_PAGE_SIZE)
@@ -1086,10 +1069,10 @@ async def get_jail_banned_ips(
page_bans = all_bans[start : start + page_size] page_bans = all_bans[start : start + page_size]
# Geo-enrich only the page slice. # Geo-enrich only the page slice.
if http_session is not None and page_bans: if http_session is not None and page_bans and geo_batch_lookup is not None:
page_ips = [b.ip for b in page_bans] page_ips = [b.ip for b in page_bans]
try: try:
geo_map = await geo_service.lookup_batch(page_ips, http_session, db=app_db) geo_map = await geo_batch_lookup(page_ips, http_session, db=app_db)
except Exception: # noqa: BLE001 except Exception: # noqa: BLE001
log.warning("jail_banned_ips_geo_failed", jail=jail_name) log.warning("jail_banned_ips_geo_failed", jail=jail_name)
geo_map = {} geo_map = {}

View File

@@ -14,6 +14,8 @@ from typing import cast
import structlog import structlog
from app.exceptions import ServerOperationError
from app.exceptions import ServerOperationError
from app.models.server import ServerSettings, ServerSettingsResponse, ServerSettingsUpdate from app.models.server import ServerSettings, ServerSettingsResponse, ServerSettingsUpdate
from app.utils.fail2ban_client import Fail2BanClient, Fail2BanCommand, Fail2BanResponse from app.utils.fail2ban_client import Fail2BanClient, Fail2BanCommand, Fail2BanResponse
@@ -54,15 +56,6 @@ def _to_str(value: object | None, default: str) -> str:
return str(value) return str(value)
# ---------------------------------------------------------------------------
# Custom exceptions
# ---------------------------------------------------------------------------
class ServerOperationError(Exception):
"""Raised when a server-level set command fails."""
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Internal helpers # Internal helpers
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------

View File

@@ -0,0 +1,63 @@
"""Utilities shared by fail2ban-related services."""
from __future__ import annotations
import json
from datetime import UTC, datetime
def ts_to_iso(unix_ts: int) -> str:
"""Convert a Unix timestamp to an ISO 8601 UTC string."""
return datetime.fromtimestamp(unix_ts, tz=UTC).isoformat()
async def get_fail2ban_db_path(socket_path: str) -> str:
"""Query fail2ban for the path to its SQLite database file."""
from app.utils.fail2ban_client import Fail2BanClient # pragma: no cover
socket_timeout: float = 5.0
async with Fail2BanClient(socket_path, timeout=socket_timeout) as client:
response = await client.send(["get", "dbfile"])
if not isinstance(response, tuple) or len(response) != 2:
raise RuntimeError(f"Unexpected response from fail2ban: {response!r}")
code, data = response
if code != 0:
raise RuntimeError(f"fail2ban error code {code}: {data!r}")
if data is None:
raise RuntimeError("fail2ban has no database configured (dbfile is None)")
return str(data)
def parse_data_json(raw: object) -> tuple[list[str], int]:
"""Extract matches and failure count from the fail2ban bans.data value."""
if raw is None:
return [], 0
obj: dict[str, object] = {}
if isinstance(raw, str):
try:
parsed = json.loads(raw)
if isinstance(parsed, dict):
obj = parsed
except json.JSONDecodeError:
return [], 0
elif isinstance(raw, dict):
obj = raw
raw_matches = obj.get("matches")
matches = [str(m) for m in raw_matches] if isinstance(raw_matches, list) else []
raw_failures = obj.get("failures")
failures = 0
if isinstance(raw_failures, (int, float, str)):
try:
failures = int(raw_failures)
except (ValueError, TypeError):
failures = 0
return matches, failures

View File

@@ -12,7 +12,7 @@ from httpx import ASGITransport, AsyncClient
from app.config import Settings from app.config import Settings
from app.db import init_db from app.db import init_db
from app.main import create_app from app.main import create_app
from app.services.geo_service import GeoInfo from app.models.geo import GeoInfo
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Fixtures # Fixtures

View File

@@ -154,7 +154,7 @@ class TestListBansHappyPath:
async def test_returns_bans_in_range(self, f2b_db_path: str) -> None: async def test_returns_bans_in_range(self, f2b_db_path: str) -> None:
"""Only bans within the selected range are returned.""" """Only bans within the selected range are returned."""
with patch( with patch(
"app.services.ban_service._get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path), new=AsyncMock(return_value=f2b_db_path),
): ):
result = await ban_service.list_bans("/fake/sock", "24h") result = await ban_service.list_bans("/fake/sock", "24h")
@@ -166,7 +166,7 @@ class TestListBansHappyPath:
async def test_results_sorted_newest_first(self, f2b_db_path: str) -> None: async def test_results_sorted_newest_first(self, f2b_db_path: str) -> None:
"""Items are ordered by ``banned_at`` descending (newest first).""" """Items are ordered by ``banned_at`` descending (newest first)."""
with patch( with patch(
"app.services.ban_service._get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path), new=AsyncMock(return_value=f2b_db_path),
): ):
result = await ban_service.list_bans("/fake/sock", "24h") result = await ban_service.list_bans("/fake/sock", "24h")
@@ -177,7 +177,7 @@ class TestListBansHappyPath:
async def test_ban_fields_present(self, f2b_db_path: str) -> None: async def test_ban_fields_present(self, f2b_db_path: str) -> None:
"""Each item contains ip, jail, banned_at, ban_count.""" """Each item contains ip, jail, banned_at, ban_count."""
with patch( with patch(
"app.services.ban_service._get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path), new=AsyncMock(return_value=f2b_db_path),
): ):
result = await ban_service.list_bans("/fake/sock", "24h") result = await ban_service.list_bans("/fake/sock", "24h")
@@ -191,7 +191,7 @@ class TestListBansHappyPath:
async def test_service_extracted_from_first_match(self, f2b_db_path: str) -> None: async def test_service_extracted_from_first_match(self, f2b_db_path: str) -> None:
"""``service`` field is the first element of ``data.matches``.""" """``service`` field is the first element of ``data.matches``."""
with patch( with patch(
"app.services.ban_service._get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path), new=AsyncMock(return_value=f2b_db_path),
): ):
result = await ban_service.list_bans("/fake/sock", "24h") result = await ban_service.list_bans("/fake/sock", "24h")
@@ -203,7 +203,7 @@ class TestListBansHappyPath:
async def test_service_is_none_when_no_matches(self, f2b_db_path: str) -> None: async def test_service_is_none_when_no_matches(self, f2b_db_path: str) -> None:
"""``service`` is ``None`` when the ban has no stored matches.""" """``service`` is ``None`` when the ban has no stored matches."""
with patch( with patch(
"app.services.ban_service._get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path), new=AsyncMock(return_value=f2b_db_path),
): ):
# Use 7d to include the older ban with no matches. # Use 7d to include the older ban with no matches.
@@ -215,7 +215,7 @@ class TestListBansHappyPath:
async def test_empty_db_returns_zero(self, empty_f2b_db_path: str) -> None: async def test_empty_db_returns_zero(self, empty_f2b_db_path: str) -> None:
"""When no bans exist the result has total=0 and no items.""" """When no bans exist the result has total=0 and no items."""
with patch( with patch(
"app.services.ban_service._get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=empty_f2b_db_path), new=AsyncMock(return_value=empty_f2b_db_path),
): ):
result = await ban_service.list_bans("/fake/sock", "24h") result = await ban_service.list_bans("/fake/sock", "24h")
@@ -226,7 +226,7 @@ class TestListBansHappyPath:
async def test_365d_range_includes_old_bans(self, f2b_db_path: str) -> None: async def test_365d_range_includes_old_bans(self, f2b_db_path: str) -> None:
"""The ``365d`` range includes bans that are 2 days old.""" """The ``365d`` range includes bans that are 2 days old."""
with patch( with patch(
"app.services.ban_service._get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path), new=AsyncMock(return_value=f2b_db_path),
): ):
result = await ban_service.list_bans("/fake/sock", "365d") result = await ban_service.list_bans("/fake/sock", "365d")
@@ -246,7 +246,7 @@ class TestListBansGeoEnrichment:
self, f2b_db_path: str self, f2b_db_path: str
) -> None: ) -> None:
"""Geo fields are populated when an enricher returns data.""" """Geo fields are populated when an enricher returns data."""
from app.services.geo_service import GeoInfo from app.models.geo import GeoInfo
async def fake_enricher(ip: str) -> GeoInfo: async def fake_enricher(ip: str) -> GeoInfo:
return GeoInfo( return GeoInfo(
@@ -257,7 +257,7 @@ class TestListBansGeoEnrichment:
) )
with patch( with patch(
"app.services.ban_service._get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path), new=AsyncMock(return_value=f2b_db_path),
): ):
result = await ban_service.list_bans( result = await ban_service.list_bans(
@@ -278,7 +278,7 @@ class TestListBansGeoEnrichment:
raise RuntimeError("geo service down") raise RuntimeError("geo service down")
with patch( with patch(
"app.services.ban_service._get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path), new=AsyncMock(return_value=f2b_db_path),
): ):
result = await ban_service.list_bans( result = await ban_service.list_bans(
@@ -304,25 +304,27 @@ class TestListBansBatchGeoEnrichment:
"""Geo fields are populated via lookup_batch when http_session is given.""" """Geo fields are populated via lookup_batch when http_session is given."""
from unittest.mock import MagicMock from unittest.mock import MagicMock
from app.services.geo_service import GeoInfo from app.models.geo import GeoInfo
fake_session = MagicMock() fake_session = MagicMock()
fake_geo_map = { fake_geo_map = {
"1.2.3.4": GeoInfo(country_code="DE", country_name="Germany", asn="AS3320", org="Deutsche Telekom"), "1.2.3.4": GeoInfo(country_code="DE", country_name="Germany", asn="AS3320", org="Deutsche Telekom"),
"5.6.7.8": GeoInfo(country_code="US", country_name="United States", asn="AS15169", org="Google"), "5.6.7.8": GeoInfo(country_code="US", country_name="United States", asn="AS15169", org="Google"),
} }
fake_geo_batch = AsyncMock(return_value=fake_geo_map)
with patch( with patch(
"app.services.ban_service._get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path), new=AsyncMock(return_value=f2b_db_path),
), patch(
"app.services.geo_service.lookup_batch",
new=AsyncMock(return_value=fake_geo_map),
): ):
result = await ban_service.list_bans( result = await ban_service.list_bans(
"/fake/sock", "24h", http_session=fake_session "/fake/sock",
"24h",
http_session=fake_session,
geo_batch_lookup=fake_geo_batch,
) )
fake_geo_batch.assert_awaited_once_with(["1.2.3.4", "5.6.7.8"], fake_session, db=None)
assert result.total == 2 assert result.total == 2
de_item = next(i for i in result.items if i.ip == "1.2.3.4") de_item = next(i for i in result.items if i.ip == "1.2.3.4")
us_item = next(i for i in result.items if i.ip == "5.6.7.8") us_item = next(i for i in result.items if i.ip == "5.6.7.8")
@@ -339,15 +341,17 @@ class TestListBansBatchGeoEnrichment:
fake_session = MagicMock() fake_session = MagicMock()
failing_geo_batch = AsyncMock(side_effect=RuntimeError("batch geo down"))
with patch( with patch(
"app.services.ban_service._get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path), new=AsyncMock(return_value=f2b_db_path),
), patch(
"app.services.geo_service.lookup_batch",
new=AsyncMock(side_effect=RuntimeError("batch geo down")),
): ):
result = await ban_service.list_bans( result = await ban_service.list_bans(
"/fake/sock", "24h", http_session=fake_session "/fake/sock",
"24h",
http_session=fake_session,
geo_batch_lookup=failing_geo_batch,
) )
assert result.total == 2 assert result.total == 2
@@ -360,28 +364,27 @@ class TestListBansBatchGeoEnrichment:
"""When both http_session and geo_enricher are provided, batch wins.""" """When both http_session and geo_enricher are provided, batch wins."""
from unittest.mock import MagicMock from unittest.mock import MagicMock
from app.services.geo_service import GeoInfo from app.models.geo import GeoInfo
fake_session = MagicMock() fake_session = MagicMock()
fake_geo_map = { fake_geo_map = {
"1.2.3.4": GeoInfo(country_code="DE", country_name="Germany", asn=None, org=None), "1.2.3.4": GeoInfo(country_code="DE", country_name="Germany", asn=None, org=None),
"5.6.7.8": GeoInfo(country_code="DE", country_name="Germany", asn=None, org=None), "5.6.7.8": GeoInfo(country_code="DE", country_name="Germany", asn=None, org=None),
} }
fake_geo_batch = AsyncMock(return_value=fake_geo_map)
async def enricher_should_not_be_called(ip: str) -> GeoInfo: async def enricher_should_not_be_called(ip: str) -> GeoInfo:
raise AssertionError(f"geo_enricher was called for {ip!r} — should not happen") raise AssertionError(f"geo_enricher was called for {ip!r} — should not happen")
with patch( with patch(
"app.services.ban_service._get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path), new=AsyncMock(return_value=f2b_db_path),
), patch(
"app.services.geo_service.lookup_batch",
new=AsyncMock(return_value=fake_geo_map),
): ):
result = await ban_service.list_bans( result = await ban_service.list_bans(
"/fake/sock", "/fake/sock",
"24h", "24h",
http_session=fake_session, http_session=fake_session,
geo_batch_lookup=fake_geo_batch,
geo_enricher=enricher_should_not_be_called, geo_enricher=enricher_should_not_be_called,
) )
@@ -401,7 +404,7 @@ class TestListBansPagination:
async def test_page_size_respected(self, f2b_db_path: str) -> None: async def test_page_size_respected(self, f2b_db_path: str) -> None:
"""``page_size=1`` returns at most one item.""" """``page_size=1`` returns at most one item."""
with patch( with patch(
"app.services.ban_service._get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path), new=AsyncMock(return_value=f2b_db_path),
): ):
result = await ban_service.list_bans("/fake/sock", "7d", page_size=1) result = await ban_service.list_bans("/fake/sock", "7d", page_size=1)
@@ -412,7 +415,7 @@ class TestListBansPagination:
async def test_page_2_returns_remaining_items(self, f2b_db_path: str) -> None: async def test_page_2_returns_remaining_items(self, f2b_db_path: str) -> None:
"""The second page returns items not on the first page.""" """The second page returns items not on the first page."""
with patch( with patch(
"app.services.ban_service._get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path), new=AsyncMock(return_value=f2b_db_path),
): ):
page1 = await ban_service.list_bans("/fake/sock", "7d", page=1, page_size=1) page1 = await ban_service.list_bans("/fake/sock", "7d", page=1, page_size=1)
@@ -426,7 +429,7 @@ class TestListBansPagination:
) -> None: ) -> None:
"""``total`` reports all matching records regardless of pagination.""" """``total`` reports all matching records regardless of pagination."""
with patch( with patch(
"app.services.ban_service._get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path), new=AsyncMock(return_value=f2b_db_path),
): ):
result = await ban_service.list_bans("/fake/sock", "7d", page_size=1) result = await ban_service.list_bans("/fake/sock", "7d", page_size=1)
@@ -447,7 +450,7 @@ class TestBanOriginDerivation:
) -> None: ) -> None:
"""Bans from ``blocklist-import`` jail carry ``origin == "blocklist"``.""" """Bans from ``blocklist-import`` jail carry ``origin == "blocklist"``."""
with patch( with patch(
"app.services.ban_service._get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path), new=AsyncMock(return_value=mixed_origin_db_path),
): ):
result = await ban_service.list_bans("/fake/sock", "24h") result = await ban_service.list_bans("/fake/sock", "24h")
@@ -461,7 +464,7 @@ class TestBanOriginDerivation:
) -> None: ) -> None:
"""Bans from organic jails (sshd, nginx, …) carry ``origin == "selfblock"``.""" """Bans from organic jails (sshd, nginx, …) carry ``origin == "selfblock"``."""
with patch( with patch(
"app.services.ban_service._get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path), new=AsyncMock(return_value=mixed_origin_db_path),
): ):
result = await ban_service.list_bans("/fake/sock", "24h") result = await ban_service.list_bans("/fake/sock", "24h")
@@ -476,7 +479,7 @@ class TestBanOriginDerivation:
) -> None: ) -> None:
"""Every returned item has an ``origin`` field with a valid value.""" """Every returned item has an ``origin`` field with a valid value."""
with patch( with patch(
"app.services.ban_service._get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path), new=AsyncMock(return_value=mixed_origin_db_path),
): ):
result = await ban_service.list_bans("/fake/sock", "24h") result = await ban_service.list_bans("/fake/sock", "24h")
@@ -489,7 +492,7 @@ class TestBanOriginDerivation:
) -> None: ) -> None:
"""``bans_by_country`` also derives origin correctly for blocklist bans.""" """``bans_by_country`` also derives origin correctly for blocklist bans."""
with patch( with patch(
"app.services.ban_service._get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path), new=AsyncMock(return_value=mixed_origin_db_path),
): ):
result = await ban_service.bans_by_country("/fake/sock", "24h") result = await ban_service.bans_by_country("/fake/sock", "24h")
@@ -503,7 +506,7 @@ class TestBanOriginDerivation:
) -> None: ) -> None:
"""``bans_by_country`` derives origin correctly for organic jails.""" """``bans_by_country`` derives origin correctly for organic jails."""
with patch( with patch(
"app.services.ban_service._get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path), new=AsyncMock(return_value=mixed_origin_db_path),
): ):
result = await ban_service.bans_by_country("/fake/sock", "24h") result = await ban_service.bans_by_country("/fake/sock", "24h")
@@ -527,7 +530,7 @@ class TestOriginFilter:
) -> None: ) -> None:
"""``origin='blocklist'`` returns only blocklist-import jail bans.""" """``origin='blocklist'`` returns only blocklist-import jail bans."""
with patch( with patch(
"app.services.ban_service._get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path), new=AsyncMock(return_value=mixed_origin_db_path),
): ):
result = await ban_service.list_bans( result = await ban_service.list_bans(
@@ -544,7 +547,7 @@ class TestOriginFilter:
) -> None: ) -> None:
"""``origin='selfblock'`` excludes the blocklist-import jail.""" """``origin='selfblock'`` excludes the blocklist-import jail."""
with patch( with patch(
"app.services.ban_service._get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path), new=AsyncMock(return_value=mixed_origin_db_path),
): ):
result = await ban_service.list_bans( result = await ban_service.list_bans(
@@ -562,7 +565,7 @@ class TestOriginFilter:
) -> None: ) -> None:
"""``origin=None`` applies no jail restriction — all bans returned.""" """``origin=None`` applies no jail restriction — all bans returned."""
with patch( with patch(
"app.services.ban_service._get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path), new=AsyncMock(return_value=mixed_origin_db_path),
): ):
result = await ban_service.list_bans("/fake/sock", "24h", origin=None) result = await ban_service.list_bans("/fake/sock", "24h", origin=None)
@@ -574,7 +577,7 @@ class TestOriginFilter:
) -> None: ) -> None:
"""``bans_by_country`` with ``origin='blocklist'`` counts only blocklist bans.""" """``bans_by_country`` with ``origin='blocklist'`` counts only blocklist bans."""
with patch( with patch(
"app.services.ban_service._get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path), new=AsyncMock(return_value=mixed_origin_db_path),
): ):
result = await ban_service.bans_by_country( result = await ban_service.bans_by_country(
@@ -589,7 +592,7 @@ class TestOriginFilter:
) -> None: ) -> None:
"""``bans_by_country`` with ``origin='selfblock'`` excludes blocklist jails.""" """``bans_by_country`` with ``origin='selfblock'`` excludes blocklist jails."""
with patch( with patch(
"app.services.ban_service._get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path), new=AsyncMock(return_value=mixed_origin_db_path),
): ):
result = await ban_service.bans_by_country( result = await ban_service.bans_by_country(
@@ -604,7 +607,7 @@ class TestOriginFilter:
) -> None: ) -> None:
"""``bans_by_country`` with ``origin=None`` returns all bans.""" """``bans_by_country`` with ``origin=None`` returns all bans."""
with patch( with patch(
"app.services.ban_service._get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path), new=AsyncMock(return_value=mixed_origin_db_path),
): ):
result = await ban_service.bans_by_country( result = await ban_service.bans_by_country(
@@ -644,7 +647,7 @@ class TestBansbyCountryBackground:
with ( with (
patch( patch(
"app.services.ban_service._get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path), new=AsyncMock(return_value=mixed_origin_db_path),
), ),
patch( patch(
@@ -652,8 +655,13 @@ class TestBansbyCountryBackground:
) as mock_create_task, ) as mock_create_task,
): ):
mock_session = AsyncMock() mock_session = AsyncMock()
mock_batch = AsyncMock(return_value={})
result = await ban_service.bans_by_country( result = await ban_service.bans_by_country(
"/fake/sock", "24h", http_session=mock_session "/fake/sock",
"24h",
http_session=mock_session,
geo_cache_lookup=geo_service.lookup_cached_only,
geo_batch_lookup=mock_batch,
) )
# All countries resolved from cache — no background task needed. # All countries resolved from cache — no background task needed.
@@ -674,7 +682,7 @@ class TestBansbyCountryBackground:
with ( with (
patch( patch(
"app.services.ban_service._get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path), new=AsyncMock(return_value=mixed_origin_db_path),
), ),
patch( patch(
@@ -682,8 +690,13 @@ class TestBansbyCountryBackground:
) as mock_create_task, ) as mock_create_task,
): ):
mock_session = AsyncMock() mock_session = AsyncMock()
mock_batch = AsyncMock(return_value={})
result = await ban_service.bans_by_country( result = await ban_service.bans_by_country(
"/fake/sock", "24h", http_session=mock_session "/fake/sock",
"24h",
http_session=mock_session,
geo_cache_lookup=geo_service.lookup_cached_only,
geo_batch_lookup=mock_batch,
) )
# Background task must have been scheduled for uncached IPs. # Background task must have been scheduled for uncached IPs.
@@ -701,7 +714,7 @@ class TestBansbyCountryBackground:
with ( with (
patch( patch(
"app.services.ban_service._get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path), new=AsyncMock(return_value=mixed_origin_db_path),
), ),
patch( patch(
@@ -727,7 +740,7 @@ class TestBanTrend:
async def test_24h_returns_24_buckets(self, empty_f2b_db_path: str) -> None: async def test_24h_returns_24_buckets(self, empty_f2b_db_path: str) -> None:
"""``range_='24h'`` always yields exactly 24 buckets.""" """``range_='24h'`` always yields exactly 24 buckets."""
with patch( with patch(
"app.services.ban_service._get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=empty_f2b_db_path), new=AsyncMock(return_value=empty_f2b_db_path),
): ):
result = await ban_service.ban_trend("/fake/sock", "24h") result = await ban_service.ban_trend("/fake/sock", "24h")
@@ -738,7 +751,7 @@ class TestBanTrend:
async def test_7d_returns_28_buckets(self, empty_f2b_db_path: str) -> None: async def test_7d_returns_28_buckets(self, empty_f2b_db_path: str) -> None:
"""``range_='7d'`` yields 28 six-hour buckets.""" """``range_='7d'`` yields 28 six-hour buckets."""
with patch( with patch(
"app.services.ban_service._get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=empty_f2b_db_path), new=AsyncMock(return_value=empty_f2b_db_path),
): ):
result = await ban_service.ban_trend("/fake/sock", "7d") result = await ban_service.ban_trend("/fake/sock", "7d")
@@ -749,7 +762,7 @@ class TestBanTrend:
async def test_30d_returns_30_buckets(self, empty_f2b_db_path: str) -> None: async def test_30d_returns_30_buckets(self, empty_f2b_db_path: str) -> None:
"""``range_='30d'`` yields 30 daily buckets.""" """``range_='30d'`` yields 30 daily buckets."""
with patch( with patch(
"app.services.ban_service._get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=empty_f2b_db_path), new=AsyncMock(return_value=empty_f2b_db_path),
): ):
result = await ban_service.ban_trend("/fake/sock", "30d") result = await ban_service.ban_trend("/fake/sock", "30d")
@@ -760,7 +773,7 @@ class TestBanTrend:
async def test_365d_bucket_size_label(self, empty_f2b_db_path: str) -> None: async def test_365d_bucket_size_label(self, empty_f2b_db_path: str) -> None:
"""``range_='365d'`` uses '7d' as the bucket size label.""" """``range_='365d'`` uses '7d' as the bucket size label."""
with patch( with patch(
"app.services.ban_service._get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=empty_f2b_db_path), new=AsyncMock(return_value=empty_f2b_db_path),
): ):
result = await ban_service.ban_trend("/fake/sock", "365d") result = await ban_service.ban_trend("/fake/sock", "365d")
@@ -771,7 +784,7 @@ class TestBanTrend:
async def test_empty_db_all_buckets_zero(self, empty_f2b_db_path: str) -> None: async def test_empty_db_all_buckets_zero(self, empty_f2b_db_path: str) -> None:
"""All bucket counts are zero when the database has no bans.""" """All bucket counts are zero when the database has no bans."""
with patch( with patch(
"app.services.ban_service._get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=empty_f2b_db_path), new=AsyncMock(return_value=empty_f2b_db_path),
): ):
result = await ban_service.ban_trend("/fake/sock", "24h") result = await ban_service.ban_trend("/fake/sock", "24h")
@@ -781,7 +794,7 @@ class TestBanTrend:
async def test_buckets_are_time_ordered(self, empty_f2b_db_path: str) -> None: async def test_buckets_are_time_ordered(self, empty_f2b_db_path: str) -> None:
"""Buckets are ordered chronologically (ascending timestamps).""" """Buckets are ordered chronologically (ascending timestamps)."""
with patch( with patch(
"app.services.ban_service._get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=empty_f2b_db_path), new=AsyncMock(return_value=empty_f2b_db_path),
): ):
result = await ban_service.ban_trend("/fake/sock", "7d") result = await ban_service.ban_trend("/fake/sock", "7d")
@@ -804,7 +817,7 @@ class TestBanTrend:
) )
with patch( with patch(
"app.services.ban_service._get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=path), new=AsyncMock(return_value=path),
): ):
result = await ban_service.ban_trend("/fake/sock", "24h") result = await ban_service.ban_trend("/fake/sock", "24h")
@@ -828,7 +841,7 @@ class TestBanTrend:
) )
with patch( with patch(
"app.services.ban_service._get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=path), new=AsyncMock(return_value=path),
): ):
result = await ban_service.ban_trend( result = await ban_service.ban_trend(
@@ -854,7 +867,7 @@ class TestBanTrend:
) )
with patch( with patch(
"app.services.ban_service._get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=path), new=AsyncMock(return_value=path),
): ):
result = await ban_service.ban_trend( result = await ban_service.ban_trend(
@@ -868,7 +881,7 @@ class TestBanTrend:
from datetime import datetime from datetime import datetime
with patch( with patch(
"app.services.ban_service._get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=empty_f2b_db_path), new=AsyncMock(return_value=empty_f2b_db_path),
): ):
result = await ban_service.ban_trend("/fake/sock", "24h") result = await ban_service.ban_trend("/fake/sock", "24h")
@@ -904,7 +917,7 @@ class TestBansByJail:
) )
with patch( with patch(
"app.services.ban_service._get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=path), new=AsyncMock(return_value=path),
): ):
result = await ban_service.bans_by_jail("/fake/sock", "24h") result = await ban_service.bans_by_jail("/fake/sock", "24h")
@@ -931,7 +944,7 @@ class TestBansByJail:
) )
with patch( with patch(
"app.services.ban_service._get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=path), new=AsyncMock(return_value=path),
): ):
result = await ban_service.bans_by_jail("/fake/sock", "24h") result = await ban_service.bans_by_jail("/fake/sock", "24h")
@@ -942,7 +955,7 @@ class TestBansByJail:
async def test_empty_db_returns_empty_list(self, empty_f2b_db_path: str) -> None: async def test_empty_db_returns_empty_list(self, empty_f2b_db_path: str) -> None:
"""An empty database returns an empty jails list with total zero.""" """An empty database returns an empty jails list with total zero."""
with patch( with patch(
"app.services.ban_service._get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=empty_f2b_db_path), new=AsyncMock(return_value=empty_f2b_db_path),
): ):
result = await ban_service.bans_by_jail("/fake/sock", "24h") result = await ban_service.bans_by_jail("/fake/sock", "24h")
@@ -954,7 +967,7 @@ class TestBansByJail:
"""Bans older than the time window are not counted.""" """Bans older than the time window are not counted."""
# f2b_db_path has one ban from _TWO_DAYS_AGO, which is outside "24h". # f2b_db_path has one ban from _TWO_DAYS_AGO, which is outside "24h".
with patch( with patch(
"app.services.ban_service._get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path), new=AsyncMock(return_value=f2b_db_path),
): ):
result = await ban_service.bans_by_jail("/fake/sock", "24h") result = await ban_service.bans_by_jail("/fake/sock", "24h")
@@ -965,7 +978,7 @@ class TestBansByJail:
async def test_origin_filter_blocklist(self, mixed_origin_db_path: str) -> None: async def test_origin_filter_blocklist(self, mixed_origin_db_path: str) -> None:
"""``origin='blocklist'`` returns only the blocklist-import jail.""" """``origin='blocklist'`` returns only the blocklist-import jail."""
with patch( with patch(
"app.services.ban_service._get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path), new=AsyncMock(return_value=mixed_origin_db_path),
): ):
result = await ban_service.bans_by_jail( result = await ban_service.bans_by_jail(
@@ -979,7 +992,7 @@ class TestBansByJail:
async def test_origin_filter_selfblock(self, mixed_origin_db_path: str) -> None: async def test_origin_filter_selfblock(self, mixed_origin_db_path: str) -> None:
"""``origin='selfblock'`` excludes the blocklist-import jail.""" """``origin='selfblock'`` excludes the blocklist-import jail."""
with patch( with patch(
"app.services.ban_service._get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path), new=AsyncMock(return_value=mixed_origin_db_path),
): ):
result = await ban_service.bans_by_jail( result = await ban_service.bans_by_jail(
@@ -995,7 +1008,7 @@ class TestBansByJail:
) -> None: ) -> None:
"""``origin=None`` returns bans from all jails.""" """``origin=None`` returns bans from all jails."""
with patch( with patch(
"app.services.ban_service._get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path), new=AsyncMock(return_value=mixed_origin_db_path),
): ):
result = await ban_service.bans_by_jail( result = await ban_service.bans_by_jail(
@@ -1023,7 +1036,7 @@ class TestBansByJail:
with ( with (
patch( patch(
"app.services.ban_service._get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=path), new=AsyncMock(return_value=path),
), ),
patch("app.services.ban_service.log") as mock_log, patch("app.services.ban_service.log") as mock_log,

View File

@@ -19,8 +19,8 @@ from unittest.mock import AsyncMock, patch
import aiosqlite import aiosqlite
import pytest import pytest
from app.models.geo import GeoInfo
from app.services import ban_service, geo_service from app.services import ban_service, geo_service
from app.services.geo_service import GeoInfo
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Constants # Constants
@@ -161,7 +161,7 @@ class TestBanServicePerformance:
return geo_service._cache.get(ip) # noqa: SLF001 return geo_service._cache.get(ip) # noqa: SLF001
with patch( with patch(
"app.services.ban_service._get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=perf_db_path), new=AsyncMock(return_value=perf_db_path),
): ):
start = time.perf_counter() start = time.perf_counter()
@@ -191,7 +191,7 @@ class TestBanServicePerformance:
return geo_service._cache.get(ip) # noqa: SLF001 return geo_service._cache.get(ip) # noqa: SLF001
with patch( with patch(
"app.services.ban_service._get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=perf_db_path), new=AsyncMock(return_value=perf_db_path),
): ):
start = time.perf_counter() start = time.perf_counter()
@@ -217,7 +217,7 @@ class TestBanServicePerformance:
return geo_service._cache.get(ip) # noqa: SLF001 return geo_service._cache.get(ip) # noqa: SLF001
with patch( with patch(
"app.services.ban_service._get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=perf_db_path), new=AsyncMock(return_value=perf_db_path),
): ):
result = await ban_service.list_bans( result = await ban_service.list_bans(
@@ -241,7 +241,7 @@ class TestBanServicePerformance:
return geo_service._cache.get(ip) # noqa: SLF001 return geo_service._cache.get(ip) # noqa: SLF001
with patch( with patch(
"app.services.ban_service._get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=perf_db_path), new=AsyncMock(return_value=perf_db_path),
): ):
result = await ban_service.bans_by_country( result = await ban_service.bans_by_country(

View File

@@ -315,20 +315,15 @@ class TestGeoPrewarmCacheFilter:
def _mock_is_cached(ip: str) -> bool: def _mock_is_cached(ip: str) -> bool:
return ip == "1.2.3.4" return ip == "1.2.3.4"
with ( mock_batch = AsyncMock(return_value={})
patch("app.services.jail_service.ban_ip", new_callable=AsyncMock), with patch("app.services.jail_service.ban_ip", new_callable=AsyncMock):
patch(
"app.services.geo_service.is_cached",
side_effect=_mock_is_cached,
),
patch(
"app.services.geo_service.lookup_batch",
new_callable=AsyncMock,
return_value={},
) as mock_batch,
):
result = await blocklist_service.import_source( result = await blocklist_service.import_source(
source, session, "/tmp/fake.sock", db source,
session,
"/tmp/fake.sock",
db,
geo_is_cached=_mock_is_cached,
geo_batch_lookup=mock_batch,
) )
assert result.ips_imported == 3 assert result.ips_imported == 3

View File

@@ -7,8 +7,8 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from app.models.geo import GeoInfo
from app.services import geo_service from app.services import geo_service
from app.services.geo_service import GeoInfo
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Helpers # Helpers

View File

@@ -123,7 +123,7 @@ class TestListHistory:
) -> None: ) -> None:
"""No filter returns every record in the database.""" """No filter returns every record in the database."""
with patch( with patch(
"app.services.history_service._get_fail2ban_db_path", "app.services.history_service.get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path), new=AsyncMock(return_value=f2b_db_path),
): ):
result = await history_service.list_history("fake_socket") result = await history_service.list_history("fake_socket")
@@ -135,7 +135,7 @@ class TestListHistory:
) -> None: ) -> None:
"""The ``range_`` filter excludes bans older than the window.""" """The ``range_`` filter excludes bans older than the window."""
with patch( with patch(
"app.services.history_service._get_fail2ban_db_path", "app.services.history_service.get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path), new=AsyncMock(return_value=f2b_db_path),
): ):
# "24h" window should include only the two recent bans # "24h" window should include only the two recent bans
@@ -147,7 +147,7 @@ class TestListHistory:
async def test_jail_filter(self, f2b_db_path: str) -> None: async def test_jail_filter(self, f2b_db_path: str) -> None:
"""Jail filter restricts results to bans from that jail.""" """Jail filter restricts results to bans from that jail."""
with patch( with patch(
"app.services.history_service._get_fail2ban_db_path", "app.services.history_service.get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path), new=AsyncMock(return_value=f2b_db_path),
): ):
result = await history_service.list_history("fake_socket", jail="nginx") result = await history_service.list_history("fake_socket", jail="nginx")
@@ -157,7 +157,7 @@ class TestListHistory:
async def test_ip_prefix_filter(self, f2b_db_path: str) -> None: async def test_ip_prefix_filter(self, f2b_db_path: str) -> None:
"""IP prefix filter restricts results to matching IPs.""" """IP prefix filter restricts results to matching IPs."""
with patch( with patch(
"app.services.history_service._get_fail2ban_db_path", "app.services.history_service.get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path), new=AsyncMock(return_value=f2b_db_path),
): ):
result = await history_service.list_history( result = await history_service.list_history(
@@ -170,7 +170,7 @@ class TestListHistory:
async def test_combined_filters(self, f2b_db_path: str) -> None: async def test_combined_filters(self, f2b_db_path: str) -> None:
"""Jail + IP prefix filters applied together narrow the result set.""" """Jail + IP prefix filters applied together narrow the result set."""
with patch( with patch(
"app.services.history_service._get_fail2ban_db_path", "app.services.history_service.get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path), new=AsyncMock(return_value=f2b_db_path),
): ):
result = await history_service.list_history( result = await history_service.list_history(
@@ -182,7 +182,7 @@ class TestListHistory:
async def test_unknown_ip_returns_empty(self, f2b_db_path: str) -> None: async def test_unknown_ip_returns_empty(self, f2b_db_path: str) -> None:
"""Filtering by a non-existent IP returns an empty result set.""" """Filtering by a non-existent IP returns an empty result set."""
with patch( with patch(
"app.services.history_service._get_fail2ban_db_path", "app.services.history_service.get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path), new=AsyncMock(return_value=f2b_db_path),
): ):
result = await history_service.list_history( result = await history_service.list_history(
@@ -196,7 +196,7 @@ class TestListHistory:
) -> None: ) -> None:
"""``failures`` field is parsed from the JSON ``data`` column.""" """``failures`` field is parsed from the JSON ``data`` column."""
with patch( with patch(
"app.services.history_service._get_fail2ban_db_path", "app.services.history_service.get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path), new=AsyncMock(return_value=f2b_db_path),
): ):
result = await history_service.list_history( result = await history_service.list_history(
@@ -210,7 +210,7 @@ class TestListHistory:
) -> None: ) -> None:
"""``matches`` list is parsed from the JSON ``data`` column.""" """``matches`` list is parsed from the JSON ``data`` column."""
with patch( with patch(
"app.services.history_service._get_fail2ban_db_path", "app.services.history_service.get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path), new=AsyncMock(return_value=f2b_db_path),
): ):
result = await history_service.list_history( result = await history_service.list_history(
@@ -226,7 +226,7 @@ class TestListHistory:
) -> None: ) -> None:
"""Records with ``data=NULL`` produce failures=0 and matches=[].""" """Records with ``data=NULL`` produce failures=0 and matches=[]."""
with patch( with patch(
"app.services.history_service._get_fail2ban_db_path", "app.services.history_service.get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path), new=AsyncMock(return_value=f2b_db_path),
): ):
result = await history_service.list_history( result = await history_service.list_history(
@@ -240,7 +240,7 @@ class TestListHistory:
async def test_pagination(self, f2b_db_path: str) -> None: async def test_pagination(self, f2b_db_path: str) -> None:
"""Pagination returns the correct slice.""" """Pagination returns the correct slice."""
with patch( with patch(
"app.services.history_service._get_fail2ban_db_path", "app.services.history_service.get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path), new=AsyncMock(return_value=f2b_db_path),
): ):
result = await history_service.list_history( result = await history_service.list_history(
@@ -265,7 +265,7 @@ class TestGetIpDetail:
) -> None: ) -> None:
"""Returns ``None`` when the IP has no records in the database.""" """Returns ``None`` when the IP has no records in the database."""
with patch( with patch(
"app.services.history_service._get_fail2ban_db_path", "app.services.history_service.get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path), new=AsyncMock(return_value=f2b_db_path),
): ):
result = await history_service.get_ip_detail("fake_socket", "99.99.99.99") result = await history_service.get_ip_detail("fake_socket", "99.99.99.99")
@@ -276,7 +276,7 @@ class TestGetIpDetail:
) -> None: ) -> None:
"""Returns an IpDetailResponse with correct totals for a known IP.""" """Returns an IpDetailResponse with correct totals for a known IP."""
with patch( with patch(
"app.services.history_service._get_fail2ban_db_path", "app.services.history_service.get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path), new=AsyncMock(return_value=f2b_db_path),
): ):
result = await history_service.get_ip_detail("fake_socket", "1.2.3.4") result = await history_service.get_ip_detail("fake_socket", "1.2.3.4")
@@ -291,7 +291,7 @@ class TestGetIpDetail:
) -> None: ) -> None:
"""Timeline events are ordered newest-first.""" """Timeline events are ordered newest-first."""
with patch( with patch(
"app.services.history_service._get_fail2ban_db_path", "app.services.history_service.get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path), new=AsyncMock(return_value=f2b_db_path),
): ):
result = await history_service.get_ip_detail("fake_socket", "1.2.3.4") result = await history_service.get_ip_detail("fake_socket", "1.2.3.4")
@@ -304,7 +304,7 @@ class TestGetIpDetail:
async def test_last_ban_at_is_most_recent(self, f2b_db_path: str) -> None: async def test_last_ban_at_is_most_recent(self, f2b_db_path: str) -> None:
"""``last_ban_at`` matches the banned_at of the first timeline event.""" """``last_ban_at`` matches the banned_at of the first timeline event."""
with patch( with patch(
"app.services.history_service._get_fail2ban_db_path", "app.services.history_service.get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path), new=AsyncMock(return_value=f2b_db_path),
): ):
result = await history_service.get_ip_detail("fake_socket", "1.2.3.4") result = await history_service.get_ip_detail("fake_socket", "1.2.3.4")
@@ -316,7 +316,7 @@ class TestGetIpDetail:
self, f2b_db_path: str self, f2b_db_path: str
) -> None: ) -> None:
"""Geolocation is applied when a geo_enricher is provided.""" """Geolocation is applied when a geo_enricher is provided."""
from app.services.geo_service import GeoInfo from app.models.geo import GeoInfo
mock_geo = GeoInfo( mock_geo = GeoInfo(
country_code="US", country_code="US",
@@ -327,7 +327,7 @@ class TestGetIpDetail:
fake_enricher = AsyncMock(return_value=mock_geo) fake_enricher = AsyncMock(return_value=mock_geo)
with patch( with patch(
"app.services.history_service._get_fail2ban_db_path", "app.services.history_service.get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path), new=AsyncMock(return_value=f2b_db_path),
): ):
result = await history_service.get_ip_detail( result = await history_service.get_ip_detail(

View File

@@ -635,7 +635,7 @@ class TestGetActiveBans:
async def test_http_session_triggers_lookup_batch(self) -> None: async def test_http_session_triggers_lookup_batch(self) -> None:
"""When http_session is provided, geo_service.lookup_batch is used.""" """When http_session is provided, geo_service.lookup_batch is used."""
from app.services.geo_service import GeoInfo from app.models.geo import GeoInfo
responses = { responses = {
"status": _make_global_status("sshd"), "status": _make_global_status("sshd"),
@@ -645,17 +645,14 @@ class TestGetActiveBans:
), ),
} }
mock_geo = {"1.2.3.4": GeoInfo(country_code="DE", country_name="Germany", asn="AS1", org="ISP")} mock_geo = {"1.2.3.4": GeoInfo(country_code="DE", country_name="Germany", asn="AS1", org="ISP")}
mock_batch = AsyncMock(return_value=mock_geo)
with ( with _patch_client(responses):
_patch_client(responses),
patch(
"app.services.geo_service.lookup_batch",
new=AsyncMock(return_value=mock_geo),
) as mock_batch,
):
mock_session = AsyncMock() mock_session = AsyncMock()
result = await jail_service.get_active_bans( result = await jail_service.get_active_bans(
_SOCKET, http_session=mock_session _SOCKET,
http_session=mock_session,
geo_batch_lookup=mock_batch,
) )
mock_batch.assert_awaited_once() mock_batch.assert_awaited_once()
@@ -672,16 +669,14 @@ class TestGetActiveBans:
), ),
} }
with ( failing_batch = AsyncMock(side_effect=RuntimeError("geo down"))
_patch_client(responses),
patch( with _patch_client(responses):
"app.services.geo_service.lookup_batch",
new=AsyncMock(side_effect=RuntimeError("geo down")),
),
):
mock_session = AsyncMock() mock_session = AsyncMock()
result = await jail_service.get_active_bans( result = await jail_service.get_active_bans(
_SOCKET, http_session=mock_session _SOCKET,
http_session=mock_session,
geo_batch_lookup=failing_batch,
) )
assert result.total == 1 assert result.total == 1
@@ -689,7 +684,7 @@ class TestGetActiveBans:
async def test_geo_enricher_still_used_without_http_session(self) -> None: async def test_geo_enricher_still_used_without_http_session(self) -> None:
"""Legacy geo_enricher is still called when http_session is not provided.""" """Legacy geo_enricher is still called when http_session is not provided."""
from app.services.geo_service import GeoInfo from app.models.geo import GeoInfo
responses = { responses = {
"status": _make_global_status("sshd"), "status": _make_global_status("sshd"),
@@ -987,6 +982,7 @@ class TestGetJailBannedIps:
page=1, page=1,
page_size=2, page_size=2,
http_session=http_session, http_session=http_session,
geo_batch_lookup=geo_service.lookup_batch,
) )
# Only the 2-IP page slice should be passed to geo enrichment. # Only the 2-IP page slice should be passed to geo enrichment.

View File

@@ -13,7 +13,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from app.services.geo_service import GeoInfo from app.models.geo import GeoInfo
from app.tasks.geo_re_resolve import _run_re_resolve from app.tasks.geo_re_resolve import _run_re_resolve