Fix geo_re_resolve async mocks and mark tasks complete
This commit is contained in:
@@ -158,6 +158,8 @@ After completing TASK B-5, a `geo_service` method (or via `geo_cache_repo` throu
|
|||||||
|
|
||||||
#### TASK B-8 — Remove `print()` from `geo_service.py` docstring example
|
#### TASK B-8 — Remove `print()` from `geo_service.py` docstring example
|
||||||
|
|
||||||
|
**Status:** Completed ✅
|
||||||
|
|
||||||
**Violated rule:** Refactoring.md §4 / Backend-Development.md §2 — Never use `print()` in production code; use `structlog`.
|
**Violated rule:** Refactoring.md §4 / Backend-Development.md §2 — Never use `print()` in production code; use `structlog`.
|
||||||
|
|
||||||
**Files affected:**
|
**Files affected:**
|
||||||
@@ -229,6 +231,8 @@ Remove or rewrite the docstring snippet so it does not contain a bare `print()`
|
|||||||
|
|
||||||
#### TASK F-1 — Wrap `SetupPage` API calls in a dedicated hook
|
#### TASK F-1 — Wrap `SetupPage` API calls in a dedicated hook
|
||||||
|
|
||||||
|
**Status:** Completed ✅
|
||||||
|
|
||||||
**Violated rule:** Refactoring.md §3.1 — Pages must not call API functions from `src/api/` directly; all data fetching goes through hooks.
|
**Violated rule:** Refactoring.md §3.1 — Pages must not call API functions from `src/api/` directly; all data fetching goes through hooks.
|
||||||
|
|
||||||
**Files affected:**
|
**Files affected:**
|
||||||
@@ -409,6 +413,8 @@ For each component listed:
|
|||||||
|
|
||||||
#### TASK B-13 — Remove `Any` type annotations in `jail_service.py`
|
#### TASK B-13 — Remove `Any` type annotations in `jail_service.py`
|
||||||
|
|
||||||
|
**Status:** Completed ✅
|
||||||
|
|
||||||
**Violated rule:** Backend-Development.md §1 — Never use `Any`; all functions must have explicit type annotations.
|
**Violated rule:** Backend-Development.md §1 — Never use `Any`; all functions must have explicit type annotations.
|
||||||
|
|
||||||
**Files affected:**
|
**Files affected:**
|
||||||
@@ -424,6 +430,8 @@ For each component listed:
|
|||||||
|
|
||||||
#### TASK B-14 — Remove `Any` type annotations in `health_service.py`
|
#### TASK B-14 — Remove `Any` type annotations in `health_service.py`
|
||||||
|
|
||||||
|
**Status:** Completed ✅
|
||||||
|
|
||||||
**Violated rule:** Backend-Development.md §1 — Never use `Any`; all functions must have explicit type annotations.
|
**Violated rule:** Backend-Development.md §1 — Never use `Any`; all functions must have explicit type annotations.
|
||||||
|
|
||||||
**Files affected:**
|
**Files affected:**
|
||||||
@@ -439,6 +447,8 @@ For each component listed:
|
|||||||
|
|
||||||
#### TASK B-15 — Remove `Any` type annotations in `blocklist_service.py`
|
#### TASK B-15 — Remove `Any` type annotations in `blocklist_service.py`
|
||||||
|
|
||||||
|
**Status:** Completed ✅
|
||||||
|
|
||||||
**Violated rule:** Backend-Development.md §1 — Never use `Any`; all functions must have explicit type annotations.
|
**Violated rule:** Backend-Development.md §1 — Never use `Any`; all functions must have explicit type annotations.
|
||||||
|
|
||||||
**Files affected:**
|
**Files affected:**
|
||||||
@@ -454,6 +464,8 @@ For each component listed:
|
|||||||
|
|
||||||
#### TASK B-16 — Remove `Any` type annotations in `import_log_repo.py`
|
#### TASK B-16 — Remove `Any` type annotations in `import_log_repo.py`
|
||||||
|
|
||||||
|
**Status:** Completed ✅
|
||||||
|
|
||||||
**Violated rule:** Backend-Development.md §1 — Never use `Any`; all functions must have explicit type annotations.
|
**Violated rule:** Backend-Development.md §1 — Never use `Any`; all functions must have explicit type annotations.
|
||||||
|
|
||||||
**Files affected:**
|
**Files affected:**
|
||||||
@@ -470,6 +482,8 @@ For each component listed:
|
|||||||
|
|
||||||
#### TASK B-17 — Remove `Any` type annotations in `config_file_service.py`
|
#### TASK B-17 — Remove `Any` type annotations in `config_file_service.py`
|
||||||
|
|
||||||
|
**Status:** Completed ✅
|
||||||
|
|
||||||
**Violated rule:** Backend-Development.md §1 — Never use `Any`; all functions must have explicit type annotations.
|
**Violated rule:** Backend-Development.md §1 — Never use `Any`; all functions must have explicit type annotations.
|
||||||
|
|
||||||
**Files affected:**
|
**Files affected:**
|
||||||
@@ -485,6 +499,8 @@ For each component listed:
|
|||||||
|
|
||||||
#### TASK B-18 — Remove `Any` type annotations in `fail2ban_client.py`
|
#### TASK B-18 — Remove `Any` type annotations in `fail2ban_client.py`
|
||||||
|
|
||||||
|
**Status:** Completed ✅
|
||||||
|
|
||||||
**Violated rule:** Backend-Development.md §1 — Never use `Any`; all functions must have explicit type annotations.
|
**Violated rule:** Backend-Development.md §1 — Never use `Any`; all functions must have explicit type annotations.
|
||||||
|
|
||||||
**Files affected:**
|
**Files affected:**
|
||||||
@@ -500,6 +516,8 @@ For each component listed:
|
|||||||
|
|
||||||
#### TASK B-19 — Remove `Any` annotations from background tasks
|
#### TASK B-19 — Remove `Any` annotations from background tasks
|
||||||
|
|
||||||
|
**Status:** Completed ✅
|
||||||
|
|
||||||
**Violated rule:** Backend-Development.md §1 — Never use `Any`; all functions must have explicit type annotations.
|
**Violated rule:** Backend-Development.md §1 — Never use `Any`; all functions must have explicit type annotations.
|
||||||
|
|
||||||
**Files affected:**
|
**Files affected:**
|
||||||
@@ -517,6 +535,8 @@ For each component listed:
|
|||||||
|
|
||||||
#### TASK B-20 — Remove `type: ignore` in `dependencies.get_settings`
|
#### TASK B-20 — Remove `type: ignore` in `dependencies.get_settings`
|
||||||
|
|
||||||
|
**Status:** Completed ✅
|
||||||
|
|
||||||
**Violated rule:** Backend-Development.md §1 — Avoid `Any` and ignored type errors.
|
**Violated rule:** Backend-Development.md §1 — Avoid `Any` and ignored type errors.
|
||||||
|
|
||||||
**Files affected:**
|
**Files affected:**
|
||||||
@@ -527,3 +547,19 @@ For each component listed:
|
|||||||
1. Introduce a typed model (e.g., `TypedDict` or `Protocol`) for `app.state` to declare `settings: Settings` and other shared state properties.
|
1. Introduce a typed model (e.g., `TypedDict` or `Protocol`) for `app.state` to declare `settings: Settings` and other shared state properties.
|
||||||
2. Update `get_settings` (and any other helpers that read from `app.state`) so the return type is inferred as `Settings` without needing a `type: ignore` comment.
|
2. Update `get_settings` (and any other helpers that read from `app.state`) so the return type is inferred as `Settings` without needing a `type: ignore` comment.
|
||||||
3. Run `mypy --strict` or `pyright` to confirm the type ignore is no longer needed.
|
3. Run `mypy --strict` or `pyright` to confirm the type ignore is no longer needed.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
#### TASK B-21 — Fix `geo_re_resolve` test mocks to support async calls
|
||||||
|
|
||||||
|
**Status:** Completed ✅
|
||||||
|
|
||||||
|
**Violated rule:** Test code must correctly mock async coroutines (`AsyncMock`) when awaited.
|
||||||
|
|
||||||
|
**Files affected:**
|
||||||
|
- `backend/tests/test_tasks/test_geo_re_resolve.py` — patched `geo_service` to ensure `get_unresolved_ips` is an `AsyncMock`.
|
||||||
|
|
||||||
|
**What to do:**
|
||||||
|
|
||||||
|
1. Ensure all mocks for async service methods are `AsyncMock` so they can be awaited.
|
||||||
|
2. Run `pytest -q -c backend/pyproject.toml` to confirm the test suite passes.
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ directly — to keep coupling explicit and testable.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from typing import Annotated
|
from typing import Annotated, Protocol, cast
|
||||||
|
|
||||||
import aiosqlite
|
import aiosqlite
|
||||||
import structlog
|
import structlog
|
||||||
@@ -19,6 +19,13 @@ from app.utils.time_utils import utc_now
|
|||||||
|
|
||||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class AppState(Protocol):
|
||||||
|
"""Partial view of the FastAPI application state used by dependencies."""
|
||||||
|
|
||||||
|
settings: Settings
|
||||||
|
|
||||||
|
|
||||||
_COOKIE_NAME = "bangui_session"
|
_COOKIE_NAME = "bangui_session"
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -85,7 +92,8 @@ async def get_settings(request: Request) -> Settings:
|
|||||||
Returns:
|
Returns:
|
||||||
The application settings loaded at startup.
|
The application settings loaded at startup.
|
||||||
"""
|
"""
|
||||||
return request.app.state.settings # type: ignore[no-any-return]
|
state = cast(AppState, request.app.state)
|
||||||
|
return state.settings
|
||||||
|
|
||||||
|
|
||||||
async def require_auth(
|
async def require_auth(
|
||||||
|
|||||||
@@ -8,12 +8,25 @@ table. All methods are plain async functions that accept a
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import TYPE_CHECKING, Any
|
from collections.abc import Mapping
|
||||||
|
from typing import TYPE_CHECKING, TypedDict, cast
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import aiosqlite
|
import aiosqlite
|
||||||
|
|
||||||
|
|
||||||
|
class ImportLogRow(TypedDict):
|
||||||
|
"""Row shape returned by queries on the import_log table."""
|
||||||
|
|
||||||
|
id: int
|
||||||
|
source_id: int | None
|
||||||
|
source_url: str
|
||||||
|
timestamp: str
|
||||||
|
ips_imported: int
|
||||||
|
ips_skipped: int
|
||||||
|
errors: str | None
|
||||||
|
|
||||||
|
|
||||||
async def add_log(
|
async def add_log(
|
||||||
db: aiosqlite.Connection,
|
db: aiosqlite.Connection,
|
||||||
*,
|
*,
|
||||||
@@ -54,7 +67,7 @@ async def list_logs(
|
|||||||
source_id: int | None = None,
|
source_id: int | None = None,
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
page_size: int = 50,
|
page_size: int = 50,
|
||||||
) -> tuple[list[dict[str, Any]], int]:
|
) -> tuple[list[ImportLogRow], int]:
|
||||||
"""Return a paginated list of import log entries.
|
"""Return a paginated list of import log entries.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -68,8 +81,8 @@ async def list_logs(
|
|||||||
*total* is the count of all matching rows (ignoring pagination).
|
*total* is the count of all matching rows (ignoring pagination).
|
||||||
"""
|
"""
|
||||||
where = ""
|
where = ""
|
||||||
params_count: list[Any] = []
|
params_count: list[object] = []
|
||||||
params_rows: list[Any] = []
|
params_rows: list[object] = []
|
||||||
|
|
||||||
if source_id is not None:
|
if source_id is not None:
|
||||||
where = " WHERE source_id = ?"
|
where = " WHERE source_id = ?"
|
||||||
@@ -102,7 +115,7 @@ async def list_logs(
|
|||||||
return items, total
|
return items, total
|
||||||
|
|
||||||
|
|
||||||
async def get_last_log(db: aiosqlite.Connection) -> dict[str, Any] | None:
|
async def get_last_log(db: aiosqlite.Connection) -> ImportLogRow | None:
|
||||||
"""Return the most recent import log entry across all sources.
|
"""Return the most recent import log entry across all sources.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -143,13 +156,14 @@ def compute_total_pages(total: int, page_size: int) -> int:
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def _row_to_dict(row: Any) -> dict[str, Any]:
|
def _row_to_dict(row: object) -> ImportLogRow:
|
||||||
"""Convert an aiosqlite row to a plain Python dict.
|
"""Convert an aiosqlite row to a plain Python dict.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
row: An :class:`aiosqlite.Row` or sequence returned by a cursor.
|
row: An :class:`aiosqlite.Row` or similar mapping returned by a cursor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict mapping column names to Python values.
|
Dict mapping column names to Python values.
|
||||||
"""
|
"""
|
||||||
return dict(row)
|
mapping = cast(Mapping[str, object], row)
|
||||||
|
return cast(ImportLogRow, dict(mapping))
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ under the key ``"blocklist_schedule"``.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
|
|
||||||
@@ -56,7 +56,7 @@ _PREVIEW_MAX_BYTES: int = 65536
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def _row_to_source(row: dict[str, Any]) -> BlocklistSource:
|
def _row_to_source(row: dict[str, object]) -> BlocklistSource:
|
||||||
"""Convert a repository row dict to a :class:`BlocklistSource`.
|
"""Convert a repository row dict to a :class:`BlocklistSource`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -542,7 +542,7 @@ async def list_import_logs(
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def _aiohttp_timeout(seconds: float) -> Any:
|
def _aiohttp_timeout(seconds: float) -> "aiohttp.ClientTimeout":
|
||||||
"""Return an :class:`aiohttp.ClientTimeout` with the given total timeout.
|
"""Return an :class:`aiohttp.ClientTimeout` with the given total timeout.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ import os
|
|||||||
import re
|
import re
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING, cast, TypeAlias
|
||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
|
|
||||||
@@ -57,7 +57,12 @@ from app.models.config import (
|
|||||||
from app.services import jail_service
|
from app.services import jail_service
|
||||||
from app.services.jail_service import JailNotFoundError as JailNotFoundError
|
from app.services.jail_service import JailNotFoundError as JailNotFoundError
|
||||||
from app.utils import conffile_parser
|
from app.utils import conffile_parser
|
||||||
from app.utils.fail2ban_client import Fail2BanClient, Fail2BanConnectionError
|
from app.utils.fail2ban_client import (
|
||||||
|
Fail2BanClient,
|
||||||
|
Fail2BanCommand,
|
||||||
|
Fail2BanConnectionError,
|
||||||
|
Fail2BanResponse,
|
||||||
|
)
|
||||||
|
|
||||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||||
|
|
||||||
@@ -539,10 +544,10 @@ async def _get_active_jail_names(socket_path: str) -> set[str]:
|
|||||||
try:
|
try:
|
||||||
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
||||||
|
|
||||||
def _to_dict_inner(pairs: Any) -> dict[str, Any]:
|
def _to_dict_inner(pairs: object) -> dict[str, object]:
|
||||||
if not isinstance(pairs, (list, tuple)):
|
if not isinstance(pairs, (list, tuple)):
|
||||||
return {}
|
return {}
|
||||||
result: dict[str, Any] = {}
|
result: dict[str, object] = {}
|
||||||
for item in pairs:
|
for item in pairs:
|
||||||
try:
|
try:
|
||||||
k, v = item
|
k, v = item
|
||||||
@@ -551,8 +556,8 @@ async def _get_active_jail_names(socket_path: str) -> set[str]:
|
|||||||
pass
|
pass
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _ok(response: Any) -> Any:
|
def _ok(response: object) -> object:
|
||||||
code, data = response
|
code, data = cast(Fail2BanResponse, response)
|
||||||
if code != 0:
|
if code != 0:
|
||||||
raise ValueError(f"fail2ban error {code}: {data!r}")
|
raise ValueError(f"fail2ban error {code}: {data!r}")
|
||||||
return data
|
return data
|
||||||
@@ -813,7 +818,7 @@ def _write_local_override_sync(
|
|||||||
config_dir: Path,
|
config_dir: Path,
|
||||||
jail_name: str,
|
jail_name: str,
|
||||||
enabled: bool,
|
enabled: bool,
|
||||||
overrides: dict[str, Any],
|
overrides: dict[str, object],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Write a ``jail.d/{name}.local`` file atomically.
|
"""Write a ``jail.d/{name}.local`` file atomically.
|
||||||
|
|
||||||
@@ -862,7 +867,7 @@ def _write_local_override_sync(
|
|||||||
if overrides.get("port") is not None:
|
if overrides.get("port") is not None:
|
||||||
lines.append(f"port = {overrides['port']}")
|
lines.append(f"port = {overrides['port']}")
|
||||||
if overrides.get("logpath"):
|
if overrides.get("logpath"):
|
||||||
paths: list[str] = overrides["logpath"]
|
paths: list[str] = cast(list[str], overrides["logpath"])
|
||||||
if paths:
|
if paths:
|
||||||
lines.append(f"logpath = {paths[0]}")
|
lines.append(f"logpath = {paths[0]}")
|
||||||
for p in paths[1:]:
|
for p in paths[1:]:
|
||||||
@@ -1209,7 +1214,7 @@ async def activate_jail(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
overrides: dict[str, Any] = {
|
overrides: dict[str, object] = {
|
||||||
"bantime": req.bantime,
|
"bantime": req.bantime,
|
||||||
"findtime": req.findtime,
|
"findtime": req.findtime,
|
||||||
"maxretry": req.maxretry,
|
"maxretry": req.maxretry,
|
||||||
|
|||||||
@@ -30,7 +30,8 @@ Usage::
|
|||||||
# single lookup
|
# single lookup
|
||||||
info = await geo_service.lookup("1.2.3.4", session)
|
info = await geo_service.lookup("1.2.3.4", session)
|
||||||
if info:
|
if info:
|
||||||
print(info.country_code) # "DE"
|
# info.country_code == "DE"
|
||||||
|
... # use the GeoInfo object in your application
|
||||||
|
|
||||||
# bulk lookup (more efficient for large sets)
|
# bulk lookup (more efficient for large sets)
|
||||||
geo_map = await geo_service.lookup_batch(["1.2.3.4", "5.6.7.8"], session)
|
geo_map = await geo_service.lookup_batch(["1.2.3.4", "5.6.7.8"], session)
|
||||||
@@ -42,7 +43,7 @@ import asyncio
|
|||||||
import time
|
import time
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, TypeAlias
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import structlog
|
import structlog
|
||||||
@@ -119,7 +120,7 @@ class GeoInfo:
|
|||||||
"""Organisation name associated with the IP, e.g. ``"Deutsche Telekom"``."""
|
"""Organisation name associated with the IP, e.g. ``"Deutsche Telekom"``."""
|
||||||
|
|
||||||
|
|
||||||
GeoEnricher: TypeAlias = Callable[[str], Awaitable[GeoInfo | None]]
|
type GeoEnricher = Callable[[str], Awaitable[GeoInfo | None]]
|
||||||
"""Async callable used to enrich IPs with :class:`~app.services.geo_service.GeoInfo`.
|
"""Async callable used to enrich IPs with :class:`~app.services.geo_service.GeoInfo`.
|
||||||
|
|
||||||
This is a shared type alias used by services that optionally accept a geo
|
This is a shared type alias used by services that optionally accept a geo
|
||||||
|
|||||||
@@ -9,12 +9,17 @@ seconds by the background health-check task, not on every HTTP request.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any
|
from typing import cast
|
||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
|
|
||||||
from app.models.server import ServerStatus
|
from app.models.server import ServerStatus
|
||||||
from app.utils.fail2ban_client import Fail2BanClient, Fail2BanConnectionError, Fail2BanProtocolError
|
from app.utils.fail2ban_client import (
|
||||||
|
Fail2BanClient,
|
||||||
|
Fail2BanConnectionError,
|
||||||
|
Fail2BanProtocolError,
|
||||||
|
Fail2BanResponse,
|
||||||
|
)
|
||||||
|
|
||||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||||
|
|
||||||
@@ -25,7 +30,7 @@ log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
|||||||
_SOCKET_TIMEOUT: float = 5.0
|
_SOCKET_TIMEOUT: float = 5.0
|
||||||
|
|
||||||
|
|
||||||
def _ok(response: Any) -> Any:
|
def _ok(response: object) -> object:
|
||||||
"""Extract the payload from a fail2ban ``(return_code, data)`` response.
|
"""Extract the payload from a fail2ban ``(return_code, data)`` response.
|
||||||
|
|
||||||
fail2ban wraps every response in a ``(0, data)`` success tuple or
|
fail2ban wraps every response in a ``(0, data)`` success tuple or
|
||||||
@@ -42,7 +47,7 @@ def _ok(response: Any) -> Any:
|
|||||||
ValueError: If the response indicates an error (return code ≠ 0).
|
ValueError: If the response indicates an error (return code ≠ 0).
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
code, data = response
|
code, data = cast(Fail2BanResponse, response)
|
||||||
except (TypeError, ValueError) as exc:
|
except (TypeError, ValueError) as exc:
|
||||||
raise ValueError(f"Unexpected fail2ban response shape: {response!r}") from exc
|
raise ValueError(f"Unexpected fail2ban response shape: {response!r}") from exc
|
||||||
|
|
||||||
@@ -52,7 +57,7 @@ def _ok(response: Any) -> Any:
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
def _to_dict(pairs: Any) -> dict[str, Any]:
|
def _to_dict(pairs: object) -> dict[str, object]:
|
||||||
"""Convert a list of ``(key, value)`` pairs to a plain dict.
|
"""Convert a list of ``(key, value)`` pairs to a plain dict.
|
||||||
|
|
||||||
fail2ban returns structured data as lists of 2-tuples rather than dicts.
|
fail2ban returns structured data as lists of 2-tuples rather than dicts.
|
||||||
@@ -66,7 +71,7 @@ def _to_dict(pairs: Any) -> dict[str, Any]:
|
|||||||
"""
|
"""
|
||||||
if not isinstance(pairs, (list, tuple)):
|
if not isinstance(pairs, (list, tuple)):
|
||||||
return {}
|
return {}
|
||||||
result: dict[str, Any] = {}
|
result: dict[str, object] = {}
|
||||||
for item in pairs:
|
for item in pairs:
|
||||||
try:
|
try:
|
||||||
k, v = item
|
k, v = item
|
||||||
@@ -119,7 +124,7 @@ async def probe(socket_path: str, timeout: float = _SOCKET_TIMEOUT) -> ServerSta
|
|||||||
# 3. Global status — jail count and names #
|
# 3. Global status — jail count and names #
|
||||||
# ------------------------------------------------------------------ #
|
# ------------------------------------------------------------------ #
|
||||||
status_data = _to_dict(_ok(await client.send(["status"])))
|
status_data = _to_dict(_ok(await client.send(["status"])))
|
||||||
active_jails: int = int(status_data.get("Number of jail", 0) or 0)
|
active_jails: int = int(str(status_data.get("Number of jail", 0) or 0))
|
||||||
jail_list_raw: str = str(status_data.get("Jail list", "") or "").strip()
|
jail_list_raw: str = str(status_data.get("Jail list", "") or "").strip()
|
||||||
jail_names: list[str] = (
|
jail_names: list[str] = (
|
||||||
[j.strip() for j in jail_list_raw.split(",") if j.strip()]
|
[j.strip() for j in jail_list_raw.split(",") if j.strip()]
|
||||||
@@ -138,8 +143,8 @@ async def probe(socket_path: str, timeout: float = _SOCKET_TIMEOUT) -> ServerSta
|
|||||||
jail_resp = _to_dict(_ok(await client.send(["status", jail_name])))
|
jail_resp = _to_dict(_ok(await client.send(["status", jail_name])))
|
||||||
filter_stats = _to_dict(jail_resp.get("Filter") or [])
|
filter_stats = _to_dict(jail_resp.get("Filter") or [])
|
||||||
action_stats = _to_dict(jail_resp.get("Actions") or [])
|
action_stats = _to_dict(jail_resp.get("Actions") or [])
|
||||||
total_failures += int(filter_stats.get("Currently failed", 0) or 0)
|
total_failures += int(str(filter_stats.get("Currently failed", 0) or 0))
|
||||||
total_bans += int(action_stats.get("Currently banned", 0) or 0)
|
total_bans += int(str(action_stats.get("Currently banned", 0) or 0))
|
||||||
except (ValueError, TypeError, KeyError) as exc:
|
except (ValueError, TypeError, KeyError) as exc:
|
||||||
log.warning(
|
log.warning(
|
||||||
"fail2ban_jail_status_parse_error",
|
"fail2ban_jail_status_parse_error",
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import contextlib
|
import contextlib
|
||||||
import ipaddress
|
import ipaddress
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING, Awaitable, Callable, cast, TypeAlias
|
||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
|
|
||||||
@@ -27,10 +27,24 @@ from app.models.jail import (
|
|||||||
JailStatus,
|
JailStatus,
|
||||||
JailSummary,
|
JailSummary,
|
||||||
)
|
)
|
||||||
from app.utils.fail2ban_client import Fail2BanClient, Fail2BanConnectionError
|
from app.utils.fail2ban_client import (
|
||||||
|
Fail2BanClient,
|
||||||
|
Fail2BanCommand,
|
||||||
|
Fail2BanConnectionError,
|
||||||
|
Fail2BanResponse,
|
||||||
|
Fail2BanToken,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import aiohttp
|
||||||
|
import aiosqlite
|
||||||
|
|
||||||
|
from app.services.geo_service import GeoInfo
|
||||||
|
|
||||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||||
|
|
||||||
|
GeoEnricher: TypeAlias = Callable[[str], Awaitable["GeoInfo | None"]]
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Constants
|
# Constants
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -77,7 +91,7 @@ class JailOperationError(Exception):
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def _ok(response: Any) -> Any:
|
def _ok(response: object) -> object:
|
||||||
"""Extract the payload from a fail2ban ``(return_code, data)`` response.
|
"""Extract the payload from a fail2ban ``(return_code, data)`` response.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -90,7 +104,7 @@ def _ok(response: Any) -> Any:
|
|||||||
ValueError: If the response indicates an error (return code ≠ 0).
|
ValueError: If the response indicates an error (return code ≠ 0).
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
code, data = response
|
code, data = cast(Fail2BanResponse, response)
|
||||||
except (TypeError, ValueError) as exc:
|
except (TypeError, ValueError) as exc:
|
||||||
raise ValueError(f"Unexpected fail2ban response shape: {response!r}") from exc
|
raise ValueError(f"Unexpected fail2ban response shape: {response!r}") from exc
|
||||||
|
|
||||||
@@ -100,7 +114,7 @@ def _ok(response: Any) -> Any:
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
def _to_dict(pairs: Any) -> dict[str, Any]:
|
def _to_dict(pairs: object) -> dict[str, object]:
|
||||||
"""Convert a list of ``(key, value)`` pairs to a plain dict.
|
"""Convert a list of ``(key, value)`` pairs to a plain dict.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -111,7 +125,7 @@ def _to_dict(pairs: Any) -> dict[str, Any]:
|
|||||||
"""
|
"""
|
||||||
if not isinstance(pairs, (list, tuple)):
|
if not isinstance(pairs, (list, tuple)):
|
||||||
return {}
|
return {}
|
||||||
result: dict[str, Any] = {}
|
result: dict[str, object] = {}
|
||||||
for item in pairs:
|
for item in pairs:
|
||||||
try:
|
try:
|
||||||
k, v = item
|
k, v = item
|
||||||
@@ -121,7 +135,7 @@ def _to_dict(pairs: Any) -> dict[str, Any]:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def _ensure_list(value: Any) -> list[str]:
|
def _ensure_list(value: object | None) -> list[str]:
|
||||||
"""Coerce a fail2ban response value to a list of strings.
|
"""Coerce a fail2ban response value to a list of strings.
|
||||||
|
|
||||||
Some fail2ban ``get`` responses return ``None`` or a single string
|
Some fail2ban ``get`` responses return ``None`` or a single string
|
||||||
@@ -170,9 +184,9 @@ def _is_not_found_error(exc: Exception) -> bool:
|
|||||||
|
|
||||||
async def _safe_get(
|
async def _safe_get(
|
||||||
client: Fail2BanClient,
|
client: Fail2BanClient,
|
||||||
command: list[Any],
|
command: Fail2BanCommand,
|
||||||
default: Any = None,
|
default: object | None = None,
|
||||||
) -> Any:
|
) -> object | None:
|
||||||
"""Send a ``get`` command and return ``default`` on error.
|
"""Send a ``get`` command and return ``default`` on error.
|
||||||
|
|
||||||
Errors during optional detail queries (logpath, regex, etc.) should
|
Errors during optional detail queries (logpath, regex, etc.) should
|
||||||
@@ -187,7 +201,8 @@ async def _safe_get(
|
|||||||
The response payload, or *default* on any error.
|
The response payload, or *default* on any error.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
return _ok(await client.send(command))
|
response = await client.send(command)
|
||||||
|
return _ok(cast(Fail2BanResponse, response))
|
||||||
except (ValueError, TypeError, Exception):
|
except (ValueError, TypeError, Exception):
|
||||||
return default
|
return default
|
||||||
|
|
||||||
@@ -309,7 +324,7 @@ async def _fetch_jail_summary(
|
|||||||
backend_cmd_is_supported = await _check_backend_cmd_supported(client, name)
|
backend_cmd_is_supported = await _check_backend_cmd_supported(client, name)
|
||||||
|
|
||||||
# Build the gather list based on command support.
|
# Build the gather list based on command support.
|
||||||
gather_list: list[Any] = [
|
gather_list: list[Awaitable[object]] = [
|
||||||
client.send(["status", name, "short"]),
|
client.send(["status", name, "short"]),
|
||||||
client.send(["get", name, "bantime"]),
|
client.send(["get", name, "bantime"]),
|
||||||
client.send(["get", name, "findtime"]),
|
client.send(["get", name, "findtime"]),
|
||||||
@@ -325,7 +340,7 @@ async def _fetch_jail_summary(
|
|||||||
uses_backend_backend_commands = True
|
uses_backend_backend_commands = True
|
||||||
else:
|
else:
|
||||||
# Commands not supported; return default values without sending.
|
# Commands not supported; return default values without sending.
|
||||||
async def _return_default(value: Any) -> tuple[int, Any]:
|
async def _return_default(value: object | None) -> Fail2BanResponse:
|
||||||
return (0, value)
|
return (0, value)
|
||||||
|
|
||||||
gather_list.extend([
|
gather_list.extend([
|
||||||
@@ -335,12 +350,12 @@ async def _fetch_jail_summary(
|
|||||||
uses_backend_backend_commands = False
|
uses_backend_backend_commands = False
|
||||||
|
|
||||||
_r = await asyncio.gather(*gather_list, return_exceptions=True)
|
_r = await asyncio.gather(*gather_list, return_exceptions=True)
|
||||||
status_raw: Any = _r[0]
|
status_raw: object | Exception = _r[0]
|
||||||
bantime_raw: Any = _r[1]
|
bantime_raw: object | Exception = _r[1]
|
||||||
findtime_raw: Any = _r[2]
|
findtime_raw: object | Exception = _r[2]
|
||||||
maxretry_raw: Any = _r[3]
|
maxretry_raw: object | Exception = _r[3]
|
||||||
backend_raw: Any = _r[4]
|
backend_raw: object | Exception = _r[4]
|
||||||
idle_raw: Any = _r[5]
|
idle_raw: object | Exception = _r[5]
|
||||||
|
|
||||||
# Parse jail status (filter + actions).
|
# Parse jail status (filter + actions).
|
||||||
jail_status: JailStatus | None = None
|
jail_status: JailStatus | None = None
|
||||||
@@ -350,35 +365,35 @@ async def _fetch_jail_summary(
|
|||||||
filter_stats = _to_dict(raw.get("Filter") or [])
|
filter_stats = _to_dict(raw.get("Filter") or [])
|
||||||
action_stats = _to_dict(raw.get("Actions") or [])
|
action_stats = _to_dict(raw.get("Actions") or [])
|
||||||
jail_status = JailStatus(
|
jail_status = JailStatus(
|
||||||
currently_banned=int(action_stats.get("Currently banned", 0) or 0),
|
currently_banned=int(str(action_stats.get("Currently banned", 0) or 0)),
|
||||||
total_banned=int(action_stats.get("Total banned", 0) or 0),
|
total_banned=int(str(action_stats.get("Total banned", 0) or 0)),
|
||||||
currently_failed=int(filter_stats.get("Currently failed", 0) or 0),
|
currently_failed=int(str(filter_stats.get("Currently failed", 0) or 0)),
|
||||||
total_failed=int(filter_stats.get("Total failed", 0) or 0),
|
total_failed=int(str(filter_stats.get("Total failed", 0) or 0)),
|
||||||
)
|
)
|
||||||
except (ValueError, TypeError) as exc:
|
except (ValueError, TypeError) as exc:
|
||||||
log.warning("jail_status_parse_error", jail=name, error=str(exc))
|
log.warning("jail_status_parse_error", jail=name, error=str(exc))
|
||||||
|
|
||||||
def _safe_int(raw: Any, fallback: int) -> int:
|
def _safe_int(raw: object | Exception, fallback: int) -> int:
|
||||||
if isinstance(raw, Exception):
|
if isinstance(raw, Exception):
|
||||||
return fallback
|
return fallback
|
||||||
try:
|
try:
|
||||||
return int(_ok(raw))
|
return int(str(_ok(cast(Fail2BanResponse, raw))))
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
return fallback
|
return fallback
|
||||||
|
|
||||||
def _safe_str(raw: Any, fallback: str) -> str:
|
def _safe_str(raw: object | Exception, fallback: str) -> str:
|
||||||
if isinstance(raw, Exception):
|
if isinstance(raw, Exception):
|
||||||
return fallback
|
return fallback
|
||||||
try:
|
try:
|
||||||
return str(_ok(raw))
|
return str(_ok(cast(Fail2BanResponse, raw)))
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
return fallback
|
return fallback
|
||||||
|
|
||||||
def _safe_bool(raw: Any, fallback: bool = False) -> bool:
|
def _safe_bool(raw: object | Exception, fallback: bool = False) -> bool:
|
||||||
if isinstance(raw, Exception):
|
if isinstance(raw, Exception):
|
||||||
return fallback
|
return fallback
|
||||||
try:
|
try:
|
||||||
return bool(_ok(raw))
|
return bool(_ok(cast(Fail2BanResponse, raw)))
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
return fallback
|
return fallback
|
||||||
|
|
||||||
@@ -428,10 +443,10 @@ async def get_jail(socket_path: str, name: str) -> JailDetailResponse:
|
|||||||
action_stats = _to_dict(raw.get("Actions") or [])
|
action_stats = _to_dict(raw.get("Actions") or [])
|
||||||
|
|
||||||
jail_status = JailStatus(
|
jail_status = JailStatus(
|
||||||
currently_banned=int(action_stats.get("Currently banned", 0) or 0),
|
currently_banned=int(str(action_stats.get("Currently banned", 0) or 0)),
|
||||||
total_banned=int(action_stats.get("Total banned", 0) or 0),
|
total_banned=int(str(action_stats.get("Total banned", 0) or 0)),
|
||||||
currently_failed=int(filter_stats.get("Currently failed", 0) or 0),
|
currently_failed=int(str(filter_stats.get("Currently failed", 0) or 0)),
|
||||||
total_failed=int(filter_stats.get("Total failed", 0) or 0),
|
total_failed=int(str(filter_stats.get("Total failed", 0) or 0)),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fetch all detail fields in parallel.
|
# Fetch all detail fields in parallel.
|
||||||
@@ -480,11 +495,11 @@ async def get_jail(socket_path: str, name: str) -> JailDetailResponse:
|
|||||||
bt_increment: bool = bool(bt_increment_raw)
|
bt_increment: bool = bool(bt_increment_raw)
|
||||||
bantime_escalation = BantimeEscalation(
|
bantime_escalation = BantimeEscalation(
|
||||||
increment=bt_increment,
|
increment=bt_increment,
|
||||||
factor=float(bt_factor_raw) if bt_factor_raw is not None else None,
|
factor=float(str(bt_factor_raw)) if bt_factor_raw is not None else None,
|
||||||
formula=str(bt_formula_raw) if bt_formula_raw else None,
|
formula=str(bt_formula_raw) if bt_formula_raw else None,
|
||||||
multipliers=str(bt_multipliers_raw) if bt_multipliers_raw else None,
|
multipliers=str(bt_multipliers_raw) if bt_multipliers_raw else None,
|
||||||
max_time=int(bt_maxtime_raw) if bt_maxtime_raw is not None else None,
|
max_time=int(str(bt_maxtime_raw)) if bt_maxtime_raw is not None else None,
|
||||||
rnd_time=int(bt_rndtime_raw) if bt_rndtime_raw is not None else None,
|
rnd_time=int(str(bt_rndtime_raw)) if bt_rndtime_raw is not None else None,
|
||||||
overall_jails=bool(bt_overalljails_raw),
|
overall_jails=bool(bt_overalljails_raw),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -500,9 +515,9 @@ async def get_jail(socket_path: str, name: str) -> JailDetailResponse:
|
|||||||
ignore_ips=_ensure_list(ignoreip_raw),
|
ignore_ips=_ensure_list(ignoreip_raw),
|
||||||
date_pattern=str(datepattern_raw) if datepattern_raw else None,
|
date_pattern=str(datepattern_raw) if datepattern_raw else None,
|
||||||
log_encoding=str(logencoding_raw or "UTF-8"),
|
log_encoding=str(logencoding_raw or "UTF-8"),
|
||||||
find_time=int(findtime_raw or 600),
|
find_time=int(str(findtime_raw or 600)),
|
||||||
ban_time=int(bantime_raw or 600),
|
ban_time=int(str(bantime_raw or 600)),
|
||||||
max_retry=int(maxretry_raw or 5),
|
max_retry=int(str(maxretry_raw or 5)),
|
||||||
bantime_escalation=bantime_escalation,
|
bantime_escalation=bantime_escalation,
|
||||||
status=jail_status,
|
status=jail_status,
|
||||||
actions=_ensure_list(actions_raw),
|
actions=_ensure_list(actions_raw),
|
||||||
@@ -671,8 +686,8 @@ async def reload_all(
|
|||||||
if exclude_jails:
|
if exclude_jails:
|
||||||
names_set -= set(exclude_jails)
|
names_set -= set(exclude_jails)
|
||||||
|
|
||||||
stream: list[list[str]] = [["start", n] for n in sorted(names_set)]
|
stream: list[list[object]] = [["start", n] for n in sorted(names_set)]
|
||||||
_ok(await client.send(["reload", "--all", [], stream]))
|
_ok(await client.send(["reload", "--all", [], cast(Fail2BanToken, stream)]))
|
||||||
log.info("all_jails_reloaded")
|
log.info("all_jails_reloaded")
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
# Detect UnknownJailException (missing or invalid jail configuration)
|
# Detect UnknownJailException (missing or invalid jail configuration)
|
||||||
@@ -795,9 +810,9 @@ async def unban_ip(
|
|||||||
|
|
||||||
async def get_active_bans(
|
async def get_active_bans(
|
||||||
socket_path: str,
|
socket_path: str,
|
||||||
geo_enricher: Any | None = None,
|
geo_enricher: GeoEnricher | None = None,
|
||||||
http_session: Any | None = None,
|
http_session: "aiohttp.ClientSession" | None = None,
|
||||||
app_db: Any | None = None,
|
app_db: "aiosqlite.Connection" | None = None,
|
||||||
) -> ActiveBanListResponse:
|
) -> ActiveBanListResponse:
|
||||||
"""Return all currently banned IPs across every jail.
|
"""Return all currently banned IPs across every jail.
|
||||||
|
|
||||||
@@ -849,7 +864,7 @@ async def get_active_bans(
|
|||||||
return ActiveBanListResponse(bans=[], total=0)
|
return ActiveBanListResponse(bans=[], total=0)
|
||||||
|
|
||||||
# For each jail, fetch the ban list with time info in parallel.
|
# For each jail, fetch the ban list with time info in parallel.
|
||||||
results: list[Any] = await asyncio.gather(
|
results: list[object | Exception] = await asyncio.gather(
|
||||||
*[client.send(["get", jn, "banip", "--with-time"]) for jn in jail_names],
|
*[client.send(["get", jn, "banip", "--with-time"]) for jn in jail_names],
|
||||||
return_exceptions=True,
|
return_exceptions=True,
|
||||||
)
|
)
|
||||||
@@ -865,7 +880,7 @@ async def get_active_bans(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ban_list: list[str] = _ok(raw_result) or []
|
ban_list: list[str] = cast(list[str], _ok(raw_result)) or []
|
||||||
except (TypeError, ValueError) as exc:
|
except (TypeError, ValueError) as exc:
|
||||||
log.warning(
|
log.warning(
|
||||||
"active_bans_parse_error",
|
"active_bans_parse_error",
|
||||||
@@ -992,8 +1007,8 @@ async def get_jail_banned_ips(
|
|||||||
page: int = 1,
|
page: int = 1,
|
||||||
page_size: int = 25,
|
page_size: int = 25,
|
||||||
search: str | None = None,
|
search: str | None = None,
|
||||||
http_session: Any | None = None,
|
http_session: "aiohttp.ClientSession" | None = None,
|
||||||
app_db: Any | None = None,
|
app_db: "aiosqlite.Connection" | None = None,
|
||||||
) -> JailBannedIpsResponse:
|
) -> JailBannedIpsResponse:
|
||||||
"""Return a paginated list of currently banned IPs for a single jail.
|
"""Return a paginated list of currently banned IPs for a single jail.
|
||||||
|
|
||||||
@@ -1040,7 +1055,7 @@ async def get_jail_banned_ips(
|
|||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
raw_result = []
|
raw_result = []
|
||||||
|
|
||||||
ban_list: list[str] = raw_result or []
|
ban_list: list[str] = cast(list[str], raw_result) or []
|
||||||
|
|
||||||
# Parse all entries.
|
# Parse all entries.
|
||||||
all_bans: list[ActiveBan] = []
|
all_bans: list[ActiveBan] = []
|
||||||
@@ -1094,7 +1109,7 @@ async def get_jail_banned_ips(
|
|||||||
|
|
||||||
async def _enrich_bans(
|
async def _enrich_bans(
|
||||||
bans: list[ActiveBan],
|
bans: list[ActiveBan],
|
||||||
geo_enricher: Any,
|
geo_enricher: GeoEnricher,
|
||||||
) -> list[ActiveBan]:
|
) -> list[ActiveBan]:
|
||||||
"""Enrich ban records with geo data asynchronously.
|
"""Enrich ban records with geo data asynchronously.
|
||||||
|
|
||||||
@@ -1105,14 +1120,15 @@ async def _enrich_bans(
|
|||||||
Returns:
|
Returns:
|
||||||
The same list with ``country`` fields populated where lookup succeeded.
|
The same list with ``country`` fields populated where lookup succeeded.
|
||||||
"""
|
"""
|
||||||
geo_results: list[Any] = await asyncio.gather(
|
geo_results: list[object | Exception] = await asyncio.gather(
|
||||||
*[geo_enricher(ban.ip) for ban in bans],
|
*[cast(Awaitable[object], geo_enricher(ban.ip)) for ban in bans],
|
||||||
return_exceptions=True,
|
return_exceptions=True,
|
||||||
)
|
)
|
||||||
enriched: list[ActiveBan] = []
|
enriched: list[ActiveBan] = []
|
||||||
for ban, geo in zip(bans, geo_results, strict=False):
|
for ban, geo in zip(bans, geo_results, strict=False):
|
||||||
if geo is not None and not isinstance(geo, Exception):
|
if geo is not None and not isinstance(geo, Exception):
|
||||||
enriched.append(ban.model_copy(update={"country": geo.country_code}))
|
geo_info = cast("GeoInfo", geo)
|
||||||
|
enriched.append(ban.model_copy(update={"country": geo_info.country_code}))
|
||||||
else:
|
else:
|
||||||
enriched.append(ban)
|
enriched.append(ban)
|
||||||
return enriched
|
return enriched
|
||||||
@@ -1260,8 +1276,8 @@ async def set_ignore_self(socket_path: str, name: str, *, on: bool) -> None:
|
|||||||
async def lookup_ip(
|
async def lookup_ip(
|
||||||
socket_path: str,
|
socket_path: str,
|
||||||
ip: str,
|
ip: str,
|
||||||
geo_enricher: Any | None = None,
|
geo_enricher: GeoEnricher | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, object | list[str] | None]:
|
||||||
"""Return ban status and history for a single IP address.
|
"""Return ban status and history for a single IP address.
|
||||||
|
|
||||||
Checks every running jail for whether the IP is currently banned.
|
Checks every running jail for whether the IP is currently banned.
|
||||||
@@ -1304,7 +1320,7 @@ async def lookup_ip(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Check ban status per jail in parallel.
|
# Check ban status per jail in parallel.
|
||||||
ban_results: list[Any] = await asyncio.gather(
|
ban_results: list[object | Exception] = await asyncio.gather(
|
||||||
*[client.send(["get", jn, "banip"]) for jn in jail_names],
|
*[client.send(["get", jn, "banip"]) for jn in jail_names],
|
||||||
return_exceptions=True,
|
return_exceptions=True,
|
||||||
)
|
)
|
||||||
@@ -1314,7 +1330,7 @@ async def lookup_ip(
|
|||||||
if isinstance(result, Exception):
|
if isinstance(result, Exception):
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
ban_list: list[str] = _ok(result) or []
|
ban_list: list[str] = cast(list[str], _ok(result)) or []
|
||||||
if ip in ban_list:
|
if ip in ban_list:
|
||||||
currently_banned_in.append(jail_name)
|
currently_banned_in.append(jail_name)
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
@@ -1351,6 +1367,6 @@ async def unban_all_ips(socket_path: str) -> int:
|
|||||||
cannot be reached.
|
cannot be reached.
|
||||||
"""
|
"""
|
||||||
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
||||||
count: int = int(_ok(await client.send(["unban", "--all"])))
|
count: int = int(str(_ok(await client.send(["unban", "--all"])) or 0))
|
||||||
log.info("all_ips_unbanned", count=count)
|
log.info("all_ips_unbanned", count=count)
|
||||||
return count
|
return count
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ The task runs every 10 minutes. On each invocation it:
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
|
|
||||||
@@ -35,7 +35,7 @@ GEO_RE_RESOLVE_INTERVAL: int = 600
|
|||||||
JOB_ID: str = "geo_re_resolve"
|
JOB_ID: str = "geo_re_resolve"
|
||||||
|
|
||||||
|
|
||||||
async def _run_re_resolve(app: Any) -> None:
|
async def _run_re_resolve(app: "FastAPI") -> None:
|
||||||
"""Query NULL-country IPs from the database and re-resolve them.
|
"""Query NULL-country IPs from the database and re-resolve them.
|
||||||
|
|
||||||
Reads shared resources from ``app.state`` and delegates to
|
Reads shared resources from ``app.state`` and delegates to
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ within 60 seconds of that activation, a
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, TypedDict
|
||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
|
|
||||||
@@ -31,6 +31,14 @@ if TYPE_CHECKING: # pragma: no cover
|
|||||||
|
|
||||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class ActivationRecord(TypedDict):
|
||||||
|
"""Stored timestamp data for a jail activation event."""
|
||||||
|
|
||||||
|
jail_name: str
|
||||||
|
at: datetime.datetime
|
||||||
|
|
||||||
|
|
||||||
#: How often the probe fires (seconds).
|
#: How often the probe fires (seconds).
|
||||||
HEALTH_CHECK_INTERVAL: int = 30
|
HEALTH_CHECK_INTERVAL: int = 30
|
||||||
|
|
||||||
@@ -39,7 +47,7 @@ HEALTH_CHECK_INTERVAL: int = 30
|
|||||||
_ACTIVATION_CRASH_WINDOW: int = 60
|
_ACTIVATION_CRASH_WINDOW: int = 60
|
||||||
|
|
||||||
|
|
||||||
async def _run_probe(app: Any) -> None:
|
async def _run_probe(app: "FastAPI") -> None:
|
||||||
"""Probe fail2ban and cache the result on *app.state*.
|
"""Probe fail2ban and cache the result on *app.state*.
|
||||||
|
|
||||||
Detects online/offline state transitions. When fail2ban goes offline
|
Detects online/offline state transitions. When fail2ban goes offline
|
||||||
@@ -86,7 +94,7 @@ async def _run_probe(app: Any) -> None:
|
|||||||
elif not status.online and prev_status.online:
|
elif not status.online and prev_status.online:
|
||||||
log.warning("fail2ban_went_offline")
|
log.warning("fail2ban_went_offline")
|
||||||
# Check whether this crash happened shortly after a jail activation.
|
# Check whether this crash happened shortly after a jail activation.
|
||||||
last_activation: dict[str, Any] | None = getattr(
|
last_activation: ActivationRecord | None = getattr(
|
||||||
app.state, "last_activation", None
|
app.state, "last_activation", None
|
||||||
)
|
)
|
||||||
if last_activation is not None:
|
if last_activation is not None:
|
||||||
|
|||||||
@@ -79,6 +79,8 @@ async def test_run_re_resolve_no_unresolved_ips_skips() -> None:
|
|||||||
app = _make_app(unresolved_ips=[])
|
app = _make_app(unresolved_ips=[])
|
||||||
|
|
||||||
with patch("app.tasks.geo_re_resolve.geo_service") as mock_geo:
|
with patch("app.tasks.geo_re_resolve.geo_service") as mock_geo:
|
||||||
|
mock_geo.get_unresolved_ips = AsyncMock(return_value=[])
|
||||||
|
|
||||||
await _run_re_resolve(app)
|
await _run_re_resolve(app)
|
||||||
|
|
||||||
mock_geo.clear_neg_cache.assert_not_called()
|
mock_geo.clear_neg_cache.assert_not_called()
|
||||||
@@ -96,6 +98,7 @@ async def test_run_re_resolve_clears_neg_cache() -> None:
|
|||||||
app = _make_app(unresolved_ips=ips, lookup_result=result)
|
app = _make_app(unresolved_ips=ips, lookup_result=result)
|
||||||
|
|
||||||
with patch("app.tasks.geo_re_resolve.geo_service") as mock_geo:
|
with patch("app.tasks.geo_re_resolve.geo_service") as mock_geo:
|
||||||
|
mock_geo.get_unresolved_ips = AsyncMock(return_value=ips)
|
||||||
mock_geo.lookup_batch = AsyncMock(return_value=result)
|
mock_geo.lookup_batch = AsyncMock(return_value=result)
|
||||||
|
|
||||||
await _run_re_resolve(app)
|
await _run_re_resolve(app)
|
||||||
@@ -114,6 +117,7 @@ async def test_run_re_resolve_calls_lookup_batch_with_db() -> None:
|
|||||||
app = _make_app(unresolved_ips=ips, lookup_result=result)
|
app = _make_app(unresolved_ips=ips, lookup_result=result)
|
||||||
|
|
||||||
with patch("app.tasks.geo_re_resolve.geo_service") as mock_geo:
|
with patch("app.tasks.geo_re_resolve.geo_service") as mock_geo:
|
||||||
|
mock_geo.get_unresolved_ips = AsyncMock(return_value=ips)
|
||||||
mock_geo.lookup_batch = AsyncMock(return_value=result)
|
mock_geo.lookup_batch = AsyncMock(return_value=result)
|
||||||
|
|
||||||
await _run_re_resolve(app)
|
await _run_re_resolve(app)
|
||||||
@@ -137,6 +141,7 @@ async def test_run_re_resolve_logs_correct_counts(caplog: Any) -> None:
|
|||||||
app = _make_app(unresolved_ips=ips, lookup_result=result)
|
app = _make_app(unresolved_ips=ips, lookup_result=result)
|
||||||
|
|
||||||
with patch("app.tasks.geo_re_resolve.geo_service") as mock_geo:
|
with patch("app.tasks.geo_re_resolve.geo_service") as mock_geo:
|
||||||
|
mock_geo.get_unresolved_ips = AsyncMock(return_value=ips)
|
||||||
mock_geo.lookup_batch = AsyncMock(return_value=result)
|
mock_geo.lookup_batch = AsyncMock(return_value=result)
|
||||||
|
|
||||||
await _run_re_resolve(app)
|
await _run_re_resolve(app)
|
||||||
@@ -159,6 +164,7 @@ async def test_run_re_resolve_handles_all_resolved() -> None:
|
|||||||
app = _make_app(unresolved_ips=ips, lookup_result=result)
|
app = _make_app(unresolved_ips=ips, lookup_result=result)
|
||||||
|
|
||||||
with patch("app.tasks.geo_re_resolve.geo_service") as mock_geo:
|
with patch("app.tasks.geo_re_resolve.geo_service") as mock_geo:
|
||||||
|
mock_geo.get_unresolved_ips = AsyncMock(return_value=ips)
|
||||||
mock_geo.lookup_batch = AsyncMock(return_value=result)
|
mock_geo.lookup_batch = AsyncMock(return_value=result)
|
||||||
|
|
||||||
await _run_re_resolve(app)
|
await _run_re_resolve(app)
|
||||||
|
|||||||
@@ -99,37 +99,18 @@ export function SetupPage(): React.JSX.Element {
|
|||||||
const styles = useStyles();
|
const styles = useStyles();
|
||||||
const navigate = useNavigate();
|
const navigate = useNavigate();
|
||||||
|
|
||||||
const [checking, setChecking] = useState(true);
|
const { status, loading, error, submit, submitting, submitError } = useSetup();
|
||||||
const [values, setValues] = useState<FormValues>(DEFAULT_VALUES);
|
const [values, setValues] = useState<FormValues>(DEFAULT_VALUES);
|
||||||
const [errors, setErrors] = useState<Partial<Record<keyof FormValues, string>>>({});
|
const [errors, setErrors] = useState<Partial<Record<keyof FormValues, string>>>({});
|
||||||
const [apiError, setApiError] = useState<string | null>(null);
|
const apiError = error ?? submitError;
|
||||||
const [submitting, setSubmitting] = useState(false);
|
|
||||||
|
|
||||||
// Redirect to /login if setup has already been completed.
|
// Redirect to /login if setup has already been completed.
|
||||||
// Show a full-screen spinner while the check is in flight to prevent
|
// Show a full-screen spinner while the initial status check is in flight.
|
||||||
// the form from flashing before the redirect fires.
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
let cancelled = false;
|
if (status?.completed) {
|
||||||
getSetupStatus()
|
navigate("/login", { replace: true });
|
||||||
.then((res) => {
|
}
|
||||||
if (!cancelled) {
|
}, [navigate, status]);
|
||||||
if (res.completed) {
|
|
||||||
navigate("/login", { replace: true });
|
|
||||||
} else {
|
|
||||||
setChecking(false);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.catch(() => {
|
|
||||||
// Failed check: the backend may still be starting up. Stay on this
|
|
||||||
// page so the user can attempt setup once the backend is ready.
|
|
||||||
console.warn("SetupPage: setup status check failed — rendering setup form");
|
|
||||||
if (!cancelled) setChecking(false);
|
|
||||||
});
|
|
||||||
return (): void => {
|
|
||||||
cancelled = true;
|
|
||||||
};
|
|
||||||
}, [navigate]);
|
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
// Handlers
|
// Handlers
|
||||||
@@ -169,13 +150,11 @@ export function SetupPage(): React.JSX.Element {
|
|||||||
|
|
||||||
async function handleSubmit(ev: FormEvent<HTMLFormElement>): Promise<void> {
|
async function handleSubmit(ev: FormEvent<HTMLFormElement>): Promise<void> {
|
||||||
ev.preventDefault();
|
ev.preventDefault();
|
||||||
setApiError(null);
|
|
||||||
|
|
||||||
if (!validate()) return;
|
if (!validate()) return;
|
||||||
|
|
||||||
setSubmitting(true);
|
|
||||||
try {
|
try {
|
||||||
await submitSetup({
|
await submit({
|
||||||
master_password: values.masterPassword,
|
master_password: values.masterPassword,
|
||||||
database_path: values.databasePath,
|
database_path: values.databasePath,
|
||||||
fail2ban_socket: values.fail2banSocket,
|
fail2ban_socket: values.fail2banSocket,
|
||||||
@@ -183,14 +162,8 @@ export function SetupPage(): React.JSX.Element {
|
|||||||
session_duration_minutes: parseInt(values.sessionDurationMinutes, 10),
|
session_duration_minutes: parseInt(values.sessionDurationMinutes, 10),
|
||||||
});
|
});
|
||||||
navigate("/login", { replace: true });
|
navigate("/login", { replace: true });
|
||||||
} catch (err) {
|
} catch {
|
||||||
if (err instanceof ApiError) {
|
// Errors are surfaced through the hook via `submitError`.
|
||||||
setApiError(err.message || `Error ${String(err.status)}`);
|
|
||||||
} else {
|
|
||||||
setApiError("An unexpected error occurred. Please try again.");
|
|
||||||
}
|
|
||||||
} finally {
|
|
||||||
setSubmitting(false);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -198,7 +171,7 @@ export function SetupPage(): React.JSX.Element {
|
|||||||
// Render
|
// Render
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
if (checking) {
|
if (loading) {
|
||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
style={{
|
style={{
|
||||||
@@ -224,7 +197,7 @@ export function SetupPage(): React.JSX.Element {
|
|||||||
is complete.
|
is complete.
|
||||||
</Text>
|
</Text>
|
||||||
|
|
||||||
{apiError !== null && (
|
{apiError && (
|
||||||
<MessageBar intent="error" className={styles.error}>
|
<MessageBar intent="error" className={styles.error}>
|
||||||
<MessageBarBody>{apiError}</MessageBarBody>
|
<MessageBarBody>{apiError}</MessageBarBody>
|
||||||
</MessageBar>
|
</MessageBar>
|
||||||
|
|||||||
Reference in New Issue
Block a user