refactor: improve backend type safety and import organization
- Add TYPE_CHECKING guards for runtime-expensive imports (aiohttp, aiosqlite) - Reorganize imports to follow PEP 8 conventions - Convert TypeAlias to modern PEP 695 type syntax (where appropriate) - Use Sequence/Mapping from collections.abc for type hints (covariant) - Replace string literals with cast() for improved type inference - Fix casting of Fail2BanResponse and TypedDict patterns - Add IpLookupResult TypedDict for precise return type annotation - Reformat overlong lines for readability (120 char limit) - Add asyncio_mode and filterwarnings to pytest config - Update test fixtures with improved type hints This improves mypy type checking and makes type relationships explicit.
This commit is contained in:
@@ -85,4 +85,4 @@ def get_settings() -> Settings:
|
|||||||
A validated :class:`Settings` object. Raises :class:`pydantic.ValidationError`
|
A validated :class:`Settings` object. Raises :class:`pydantic.ValidationError`
|
||||||
if required keys are absent or values fail validation.
|
if required keys are absent or values fail validation.
|
||||||
"""
|
"""
|
||||||
return Settings() # pydantic-settings populates required fields from env vars
|
return Settings() # type: ignore[call-arg] # pydantic-settings populates required fields from env vars
|
||||||
|
|||||||
@@ -92,7 +92,7 @@ async def get_settings(request: Request) -> Settings:
|
|||||||
Returns:
|
Returns:
|
||||||
The application settings loaded at startup.
|
The application settings loaded at startup.
|
||||||
"""
|
"""
|
||||||
state = cast(AppState, request.app.state)
|
state = cast("AppState", request.app.state)
|
||||||
return state.settings
|
return state.settings
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,8 @@ from __future__ import annotations
|
|||||||
from typing import TYPE_CHECKING, TypedDict
|
from typing import TYPE_CHECKING, TypedDict
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
import aiosqlite
|
import aiosqlite
|
||||||
|
|
||||||
|
|
||||||
@@ -112,7 +114,7 @@ async def upsert_neg_entry(db: aiosqlite.Connection, ip: str) -> None:
|
|||||||
|
|
||||||
async def bulk_upsert_entries(
|
async def bulk_upsert_entries(
|
||||||
db: aiosqlite.Connection,
|
db: aiosqlite.Connection,
|
||||||
rows: list[tuple[str, str | None, str | None, str | None, str | None]],
|
rows: Sequence[tuple[str, str | None, str | None, str | None, str | None]],
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Bulk insert or update multiple geo cache entries."""
|
"""Bulk insert or update multiple geo cache entries."""
|
||||||
if not rows:
|
if not rows:
|
||||||
|
|||||||
@@ -8,10 +8,11 @@ table. All methods are plain async functions that accept a
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from collections.abc import Mapping
|
|
||||||
from typing import TYPE_CHECKING, TypedDict, cast
|
from typing import TYPE_CHECKING, TypedDict, cast
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Mapping
|
||||||
|
|
||||||
import aiosqlite
|
import aiosqlite
|
||||||
|
|
||||||
|
|
||||||
@@ -165,5 +166,5 @@ def _row_to_dict(row: object) -> ImportLogRow:
|
|||||||
Returns:
|
Returns:
|
||||||
Dict mapping column names to Python values.
|
Dict mapping column names to Python values.
|
||||||
"""
|
"""
|
||||||
mapping = cast(Mapping[str, object], row)
|
mapping = cast("Mapping[str, object]", row)
|
||||||
return cast(ImportLogRow, dict(mapping))
|
return cast("ImportLogRow", dict(mapping))
|
||||||
|
|||||||
@@ -44,8 +44,6 @@ import structlog
|
|||||||
from fastapi import APIRouter, HTTPException, Path, Query, Request, status
|
from fastapi import APIRouter, HTTPException, Path, Query, Request, status
|
||||||
|
|
||||||
from app.dependencies import AuthDep
|
from app.dependencies import AuthDep
|
||||||
|
|
||||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
|
||||||
from app.models.config import (
|
from app.models.config import (
|
||||||
ActionConfig,
|
ActionConfig,
|
||||||
ActionCreateRequest,
|
ActionCreateRequest,
|
||||||
@@ -104,6 +102,8 @@ from app.services.jail_service import JailOperationError
|
|||||||
from app.tasks.health_check import _run_probe
|
from app.tasks.health_check import _run_probe
|
||||||
from app.utils.fail2ban_client import Fail2BanConnectionError
|
from app.utils.fail2ban_client import Fail2BanConnectionError
|
||||||
|
|
||||||
|
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||||
|
|
||||||
router: APIRouter = APIRouter(prefix="/api/config", tags=["Config"])
|
router: APIRouter = APIRouter(prefix="/api/config", tags=["Config"])
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -428,9 +428,7 @@ async def restart_fail2ban(
|
|||||||
await config_file_service.start_daemon(start_cmd_parts)
|
await config_file_service.start_daemon(start_cmd_parts)
|
||||||
|
|
||||||
# Step 3: probe the socket until fail2ban is responsive or the budget expires.
|
# Step 3: probe the socket until fail2ban is responsive or the budget expires.
|
||||||
fail2ban_running: bool = await config_file_service.wait_for_fail2ban(
|
fail2ban_running: bool = await config_file_service.wait_for_fail2ban(socket_path, max_wait_seconds=10.0)
|
||||||
socket_path, max_wait_seconds=10.0
|
|
||||||
)
|
|
||||||
if not fail2ban_running:
|
if not fail2ban_running:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
@@ -604,9 +602,7 @@ async def get_map_color_thresholds(
|
|||||||
"""
|
"""
|
||||||
from app.services import setup_service
|
from app.services import setup_service
|
||||||
|
|
||||||
high, medium, low = await setup_service.get_map_color_thresholds(
|
high, medium, low = await setup_service.get_map_color_thresholds(request.app.state.db)
|
||||||
request.app.state.db
|
|
||||||
)
|
|
||||||
return MapColorThresholdsResponse(
|
return MapColorThresholdsResponse(
|
||||||
threshold_high=high,
|
threshold_high=high,
|
||||||
threshold_medium=medium,
|
threshold_medium=medium,
|
||||||
@@ -696,9 +692,7 @@ async def activate_jail(
|
|||||||
req = body if body is not None else ActivateJailRequest()
|
req = body if body is not None else ActivateJailRequest()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await config_file_service.activate_jail(
|
result = await config_file_service.activate_jail(config_dir, socket_path, name, req)
|
||||||
config_dir, socket_path, name, req
|
|
||||||
)
|
|
||||||
except JailNameError as exc:
|
except JailNameError as exc:
|
||||||
raise _bad_request(str(exc)) from exc
|
raise _bad_request(str(exc)) from exc
|
||||||
except JailNotFoundInConfigError:
|
except JailNotFoundInConfigError:
|
||||||
@@ -831,9 +825,7 @@ async def delete_jail_local_override(
|
|||||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await config_file_service.delete_jail_local_override(
|
await config_file_service.delete_jail_local_override(config_dir, socket_path, name)
|
||||||
config_dir, socket_path, name
|
|
||||||
)
|
|
||||||
except JailNameError as exc:
|
except JailNameError as exc:
|
||||||
raise _bad_request(str(exc)) from exc
|
raise _bad_request(str(exc)) from exc
|
||||||
except JailNotFoundInConfigError:
|
except JailNotFoundInConfigError:
|
||||||
@@ -952,9 +944,7 @@ async def rollback_jail(
|
|||||||
start_cmd_parts: list[str] = start_cmd.split()
|
start_cmd_parts: list[str] = start_cmd.split()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await config_file_service.rollback_jail(
|
result = await config_file_service.rollback_jail(config_dir, socket_path, name, start_cmd_parts)
|
||||||
config_dir, socket_path, name, start_cmd_parts
|
|
||||||
)
|
|
||||||
except JailNameError as exc:
|
except JailNameError as exc:
|
||||||
raise _bad_request(str(exc)) from exc
|
raise _bad_request(str(exc)) from exc
|
||||||
except ConfigWriteError as exc:
|
except ConfigWriteError as exc:
|
||||||
@@ -1107,9 +1097,7 @@ async def update_filter(
|
|||||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||||
try:
|
try:
|
||||||
return await config_file_service.update_filter(
|
return await config_file_service.update_filter(config_dir, socket_path, name, body, do_reload=reload)
|
||||||
config_dir, socket_path, name, body, do_reload=reload
|
|
||||||
)
|
|
||||||
except FilterNameError as exc:
|
except FilterNameError as exc:
|
||||||
raise _bad_request(str(exc)) from exc
|
raise _bad_request(str(exc)) from exc
|
||||||
except FilterNotFoundError:
|
except FilterNotFoundError:
|
||||||
@@ -1159,9 +1147,7 @@ async def create_filter(
|
|||||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||||
try:
|
try:
|
||||||
return await config_file_service.create_filter(
|
return await config_file_service.create_filter(config_dir, socket_path, body, do_reload=reload)
|
||||||
config_dir, socket_path, body, do_reload=reload
|
|
||||||
)
|
|
||||||
except FilterNameError as exc:
|
except FilterNameError as exc:
|
||||||
raise _bad_request(str(exc)) from exc
|
raise _bad_request(str(exc)) from exc
|
||||||
except FilterAlreadyExistsError as exc:
|
except FilterAlreadyExistsError as exc:
|
||||||
@@ -1257,9 +1243,7 @@ async def assign_filter_to_jail(
|
|||||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||||
try:
|
try:
|
||||||
await config_file_service.assign_filter_to_jail(
|
await config_file_service.assign_filter_to_jail(config_dir, socket_path, name, body, do_reload=reload)
|
||||||
config_dir, socket_path, name, body, do_reload=reload
|
|
||||||
)
|
|
||||||
except (JailNameError, FilterNameError) as exc:
|
except (JailNameError, FilterNameError) as exc:
|
||||||
raise _bad_request(str(exc)) from exc
|
raise _bad_request(str(exc)) from exc
|
||||||
except JailNotFoundInConfigError:
|
except JailNotFoundInConfigError:
|
||||||
@@ -1403,9 +1387,7 @@ async def update_action(
|
|||||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||||
try:
|
try:
|
||||||
return await config_file_service.update_action(
|
return await config_file_service.update_action(config_dir, socket_path, name, body, do_reload=reload)
|
||||||
config_dir, socket_path, name, body, do_reload=reload
|
|
||||||
)
|
|
||||||
except ActionNameError as exc:
|
except ActionNameError as exc:
|
||||||
raise _bad_request(str(exc)) from exc
|
raise _bad_request(str(exc)) from exc
|
||||||
except ActionNotFoundError:
|
except ActionNotFoundError:
|
||||||
@@ -1451,9 +1433,7 @@ async def create_action(
|
|||||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||||
try:
|
try:
|
||||||
return await config_file_service.create_action(
|
return await config_file_service.create_action(config_dir, socket_path, body, do_reload=reload)
|
||||||
config_dir, socket_path, body, do_reload=reload
|
|
||||||
)
|
|
||||||
except ActionNameError as exc:
|
except ActionNameError as exc:
|
||||||
raise _bad_request(str(exc)) from exc
|
raise _bad_request(str(exc)) from exc
|
||||||
except ActionAlreadyExistsError as exc:
|
except ActionAlreadyExistsError as exc:
|
||||||
@@ -1546,9 +1526,7 @@ async def assign_action_to_jail(
|
|||||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||||
try:
|
try:
|
||||||
await config_file_service.assign_action_to_jail(
|
await config_file_service.assign_action_to_jail(config_dir, socket_path, name, body, do_reload=reload)
|
||||||
config_dir, socket_path, name, body, do_reload=reload
|
|
||||||
)
|
|
||||||
except (JailNameError, ActionNameError) as exc:
|
except (JailNameError, ActionNameError) as exc:
|
||||||
raise _bad_request(str(exc)) from exc
|
raise _bad_request(str(exc)) from exc
|
||||||
except JailNotFoundInConfigError:
|
except JailNotFoundInConfigError:
|
||||||
@@ -1597,9 +1575,7 @@ async def remove_action_from_jail(
|
|||||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||||
try:
|
try:
|
||||||
await config_file_service.remove_action_from_jail(
|
await config_file_service.remove_action_from_jail(config_dir, socket_path, name, action_name, do_reload=reload)
|
||||||
config_dir, socket_path, name, action_name, do_reload=reload
|
|
||||||
)
|
|
||||||
except (JailNameError, ActionNameError) as exc:
|
except (JailNameError, ActionNameError) as exc:
|
||||||
raise _bad_request(str(exc)) from exc
|
raise _bad_request(str(exc)) from exc
|
||||||
except JailNotFoundInConfigError:
|
except JailNotFoundInConfigError:
|
||||||
@@ -1689,4 +1665,3 @@ async def get_service_status(
|
|||||||
return await config_service.get_service_status(socket_path)
|
return await config_service.get_service_status(socket_path)
|
||||||
except Fail2BanConnectionError as exc:
|
except Fail2BanConnectionError as exc:
|
||||||
raise _bad_gateway(exc) from exc
|
raise _bad_gateway(exc) from exc
|
||||||
|
|
||||||
|
|||||||
@@ -13,12 +13,15 @@ from typing import TYPE_CHECKING, Annotated
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
|
from app.services.jail_service import IpLookupResult
|
||||||
|
|
||||||
import aiosqlite
|
import aiosqlite
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Path, Request, status
|
from fastapi import APIRouter, Depends, HTTPException, Path, Request, status
|
||||||
|
|
||||||
from app.dependencies import AuthDep, get_db
|
from app.dependencies import AuthDep, get_db
|
||||||
from app.models.geo import GeoCacheStatsResponse, GeoDetail, IpLookupResponse
|
from app.models.geo import GeoCacheStatsResponse, GeoDetail, IpLookupResponse
|
||||||
from app.services import geo_service, jail_service
|
from app.services import geo_service, jail_service
|
||||||
|
from app.services.geo_service import GeoInfo
|
||||||
from app.utils.fail2ban_client import Fail2BanConnectionError
|
from app.utils.fail2ban_client import Fail2BanConnectionError
|
||||||
|
|
||||||
router: APIRouter = APIRouter(prefix="/api/geo", tags=["Geo"])
|
router: APIRouter = APIRouter(prefix="/api/geo", tags=["Geo"])
|
||||||
@@ -61,7 +64,7 @@ async def lookup_ip(
|
|||||||
return await geo_service.lookup(addr, http_session)
|
return await geo_service.lookup(addr, http_session)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await jail_service.lookup_ip(
|
result: IpLookupResult = await jail_service.lookup_ip(
|
||||||
socket_path,
|
socket_path,
|
||||||
ip,
|
ip,
|
||||||
geo_enricher=_enricher,
|
geo_enricher=_enricher,
|
||||||
@@ -77,9 +80,9 @@ async def lookup_ip(
|
|||||||
detail=f"Cannot reach fail2ban: {exc}",
|
detail=f"Cannot reach fail2ban: {exc}",
|
||||||
) from exc
|
) from exc
|
||||||
|
|
||||||
raw_geo = result.get("geo")
|
raw_geo = result["geo"]
|
||||||
geo_detail: GeoDetail | None = None
|
geo_detail: GeoDetail | None = None
|
||||||
if raw_geo is not None:
|
if isinstance(raw_geo, GeoInfo):
|
||||||
geo_detail = GeoDetail(
|
geo_detail = GeoDetail(
|
||||||
country_code=raw_geo.country_code,
|
country_code=raw_geo.country_code,
|
||||||
country_name=raw_geo.country_name,
|
country_name=raw_geo.country_name,
|
||||||
|
|||||||
@@ -14,17 +14,11 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from dataclasses import asdict
|
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from typing import TYPE_CHECKING, TypeAlias
|
from typing import TYPE_CHECKING, cast
|
||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
import aiosqlite
|
|
||||||
|
|
||||||
from app.services.geo_service import GeoInfo
|
|
||||||
|
|
||||||
from app.models.ban import (
|
from app.models.ban import (
|
||||||
BLOCKLIST_JAIL,
|
BLOCKLIST_JAIL,
|
||||||
BUCKET_SECONDS,
|
BUCKET_SECONDS,
|
||||||
@@ -37,20 +31,25 @@ from app.models.ban import (
|
|||||||
BanTrendResponse,
|
BanTrendResponse,
|
||||||
DashboardBanItem,
|
DashboardBanItem,
|
||||||
DashboardBanListResponse,
|
DashboardBanListResponse,
|
||||||
JailBanCount as JailBanCountModel,
|
|
||||||
TimeRange,
|
TimeRange,
|
||||||
_derive_origin,
|
_derive_origin,
|
||||||
bucket_count,
|
bucket_count,
|
||||||
)
|
)
|
||||||
|
from app.models.ban import (
|
||||||
|
JailBanCount as JailBanCountModel,
|
||||||
|
)
|
||||||
from app.repositories import fail2ban_db_repo
|
from app.repositories import fail2ban_db_repo
|
||||||
from app.utils.fail2ban_client import Fail2BanClient
|
from app.utils.fail2ban_client import Fail2BanClient, Fail2BanResponse
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
import aiosqlite
|
||||||
|
|
||||||
|
from app.services.geo_service import GeoInfo
|
||||||
|
|
||||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||||
|
|
||||||
GeoEnricher: TypeAlias = Callable[[str], Awaitable["GeoInfo"] | None]
|
type GeoEnricher = Callable[[str], Awaitable[GeoInfo | None]]
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Constants
|
# Constants
|
||||||
@@ -137,7 +136,7 @@ async def _get_fail2ban_db_path(socket_path: str) -> str:
|
|||||||
response = await client.send(["get", "dbfile"])
|
response = await client.send(["get", "dbfile"])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
code, data = response
|
code, data = cast("Fail2BanResponse", response)
|
||||||
except (TypeError, ValueError) as exc:
|
except (TypeError, ValueError) as exc:
|
||||||
raise RuntimeError(f"Unexpected response from fail2ban: {response!r}") from exc
|
raise RuntimeError(f"Unexpected response from fail2ban: {response!r}") from exc
|
||||||
|
|
||||||
@@ -276,7 +275,7 @@ async def list_bans(
|
|||||||
# Batch-resolve geo data for all IPs on this page in a single API call.
|
# Batch-resolve geo data for all IPs on this page in a single API call.
|
||||||
# This avoids hitting the 45 req/min single-IP rate limit when the
|
# This avoids hitting the 45 req/min single-IP rate limit when the
|
||||||
# page contains many bans (e.g. after a large blocklist import).
|
# page contains many bans (e.g. after a large blocklist import).
|
||||||
geo_map: dict[str, "GeoInfo"] = {}
|
geo_map: dict[str, GeoInfo] = {}
|
||||||
if http_session is not None and rows:
|
if http_session is not None and rows:
|
||||||
page_ips: list[str] = [r.ip for r in rows]
|
page_ips: list[str] = [r.ip for r in rows]
|
||||||
try:
|
try:
|
||||||
@@ -428,7 +427,7 @@ async def bans_by_country(
|
|||||||
)
|
)
|
||||||
|
|
||||||
unique_ips: list[str] = [r.ip for r in agg_rows]
|
unique_ips: list[str] = [r.ip for r in agg_rows]
|
||||||
geo_map: dict[str, "GeoInfo"] = {}
|
geo_map: dict[str, GeoInfo] = {}
|
||||||
|
|
||||||
if http_session is not None and unique_ips:
|
if http_session is not None and unique_ips:
|
||||||
# Serve only what is already in the in-memory cache — no API calls on
|
# Serve only what is already in the in-memory cache — no API calls on
|
||||||
@@ -449,7 +448,7 @@ async def bans_by_country(
|
|||||||
)
|
)
|
||||||
elif geo_enricher is not None and unique_ips:
|
elif geo_enricher is not None and unique_ips:
|
||||||
# Fallback: legacy per-IP enricher (used in tests / older callers).
|
# Fallback: legacy per-IP enricher (used in tests / older callers).
|
||||||
async def _safe_lookup(ip: str) -> tuple[str, "GeoInfo" | None]:
|
async def _safe_lookup(ip: str) -> tuple[str, GeoInfo | None]:
|
||||||
try:
|
try:
|
||||||
return ip, await geo_enricher(ip)
|
return ip, await geo_enricher(ip)
|
||||||
except Exception: # noqa: BLE001
|
except Exception: # noqa: BLE001
|
||||||
@@ -636,9 +635,7 @@ async def bans_by_jail(
|
|||||||
# has *any* rows and log a warning with min/max timeofban so operators can
|
# has *any* rows and log a warning with min/max timeofban so operators can
|
||||||
# diagnose timezone or filter mismatches from logs.
|
# diagnose timezone or filter mismatches from logs.
|
||||||
if total == 0:
|
if total == 0:
|
||||||
table_row_count, min_timeofban, max_timeofban = (
|
table_row_count, min_timeofban, max_timeofban = await fail2ban_db_repo.get_bans_table_summary(db_path)
|
||||||
await fail2ban_db_repo.get_bans_table_summary(db_path)
|
|
||||||
)
|
|
||||||
if table_row_count > 0:
|
if table_row_count > 0:
|
||||||
log.warning(
|
log.warning(
|
||||||
"ban_service_bans_by_jail_empty_despite_data",
|
"ban_service_bans_by_jail_empty_despite_data",
|
||||||
|
|||||||
@@ -542,7 +542,7 @@ async def list_import_logs(
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def _aiohttp_timeout(seconds: float) -> "aiohttp.ClientTimeout":
|
def _aiohttp_timeout(seconds: float) -> aiohttp.ClientTimeout:
|
||||||
"""Return an :class:`aiohttp.ClientTimeout` with the given total timeout.
|
"""Return an :class:`aiohttp.ClientTimeout` with the given total timeout.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ import os
|
|||||||
import re
|
import re
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, cast, TypeAlias
|
from typing import cast
|
||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
|
|
||||||
@@ -59,7 +59,6 @@ from app.services.jail_service import JailNotFoundError as JailNotFoundError
|
|||||||
from app.utils import conffile_parser
|
from app.utils import conffile_parser
|
||||||
from app.utils.fail2ban_client import (
|
from app.utils.fail2ban_client import (
|
||||||
Fail2BanClient,
|
Fail2BanClient,
|
||||||
Fail2BanCommand,
|
|
||||||
Fail2BanConnectionError,
|
Fail2BanConnectionError,
|
||||||
Fail2BanResponse,
|
Fail2BanResponse,
|
||||||
)
|
)
|
||||||
@@ -73,9 +72,7 @@ log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
|||||||
_SOCKET_TIMEOUT: float = 10.0
|
_SOCKET_TIMEOUT: float = 10.0
|
||||||
|
|
||||||
# Allowlist pattern for jail names used in path construction.
|
# Allowlist pattern for jail names used in path construction.
|
||||||
_SAFE_JAIL_NAME_RE: re.Pattern[str] = re.compile(
|
_SAFE_JAIL_NAME_RE: re.Pattern[str] = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$")
|
||||||
r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Sections that are not jail definitions.
|
# Sections that are not jail definitions.
|
||||||
_META_SECTIONS: frozenset[str] = frozenset({"INCLUDES", "DEFAULT"})
|
_META_SECTIONS: frozenset[str] = frozenset({"INCLUDES", "DEFAULT"})
|
||||||
@@ -167,8 +164,7 @@ class FilterReadonlyError(Exception):
|
|||||||
"""
|
"""
|
||||||
self.name: str = name
|
self.name: str = name
|
||||||
super().__init__(
|
super().__init__(
|
||||||
f"Filter {name!r} is a shipped default (.conf only); "
|
f"Filter {name!r} is a shipped default (.conf only); only user-created .local files can be deleted."
|
||||||
"only user-created .local files can be deleted."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -423,9 +419,7 @@ def _parse_jails_sync(
|
|||||||
# items() merges DEFAULT values automatically.
|
# items() merges DEFAULT values automatically.
|
||||||
jails[section] = dict(parser.items(section))
|
jails[section] = dict(parser.items(section))
|
||||||
except configparser.Error as exc:
|
except configparser.Error as exc:
|
||||||
log.warning(
|
log.warning("jail_section_parse_error", section=section, error=str(exc))
|
||||||
"jail_section_parse_error", section=section, error=str(exc)
|
|
||||||
)
|
|
||||||
|
|
||||||
log.debug("jails_parsed", count=len(jails), config_dir=str(config_dir))
|
log.debug("jails_parsed", count=len(jails), config_dir=str(config_dir))
|
||||||
return jails, source_files
|
return jails, source_files
|
||||||
@@ -522,11 +516,7 @@ def _build_inactive_jail(
|
|||||||
bantime_escalation=bantime_escalation,
|
bantime_escalation=bantime_escalation,
|
||||||
source_file=source_file,
|
source_file=source_file,
|
||||||
enabled=enabled,
|
enabled=enabled,
|
||||||
has_local_override=(
|
has_local_override=((config_dir / "jail.d" / f"{name}.local").is_file() if config_dir is not None else False),
|
||||||
(config_dir / "jail.d" / f"{name}.local").is_file()
|
|
||||||
if config_dir is not None
|
|
||||||
else False
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -557,7 +547,7 @@ async def _get_active_jail_names(socket_path: str) -> set[str]:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def _ok(response: object) -> object:
|
def _ok(response: object) -> object:
|
||||||
code, data = cast(Fail2BanResponse, response)
|
code, data = cast("Fail2BanResponse", response)
|
||||||
if code != 0:
|
if code != 0:
|
||||||
raise ValueError(f"fail2ban error {code}: {data!r}")
|
raise ValueError(f"fail2ban error {code}: {data!r}")
|
||||||
return data
|
return data
|
||||||
@@ -572,9 +562,7 @@ async def _get_active_jail_names(socket_path: str) -> set[str]:
|
|||||||
log.warning("fail2ban_unreachable_during_inactive_list")
|
log.warning("fail2ban_unreachable_during_inactive_list")
|
||||||
return set()
|
return set()
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
log.warning(
|
log.warning("fail2ban_status_error_during_inactive_list", error=str(exc))
|
||||||
"fail2ban_status_error_during_inactive_list", error=str(exc)
|
|
||||||
)
|
|
||||||
return set()
|
return set()
|
||||||
|
|
||||||
|
|
||||||
@@ -662,10 +650,7 @@ def _validate_jail_config_sync(
|
|||||||
issues.append(
|
issues.append(
|
||||||
JailValidationIssue(
|
JailValidationIssue(
|
||||||
field="filter",
|
field="filter",
|
||||||
message=(
|
message=(f"Filter file not found: filter.d/{base_filter}.conf (or .local)"),
|
||||||
f"Filter file not found: filter.d/{base_filter}.conf"
|
|
||||||
" (or .local)"
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -681,10 +666,7 @@ def _validate_jail_config_sync(
|
|||||||
issues.append(
|
issues.append(
|
||||||
JailValidationIssue(
|
JailValidationIssue(
|
||||||
field="action",
|
field="action",
|
||||||
message=(
|
message=(f"Action file not found: action.d/{action_name}.conf (or .local)"),
|
||||||
f"Action file not found: action.d/{action_name}.conf"
|
|
||||||
" (or .local)"
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -840,9 +822,7 @@ def _write_local_override_sync(
|
|||||||
try:
|
try:
|
||||||
jail_d.mkdir(parents=True, exist_ok=True)
|
jail_d.mkdir(parents=True, exist_ok=True)
|
||||||
except OSError as exc:
|
except OSError as exc:
|
||||||
raise ConfigWriteError(
|
raise ConfigWriteError(f"Cannot create jail.d directory: {exc}") from exc
|
||||||
f"Cannot create jail.d directory: {exc}"
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
local_path = jail_d / f"{jail_name}.local"
|
local_path = jail_d / f"{jail_name}.local"
|
||||||
|
|
||||||
@@ -867,7 +847,7 @@ def _write_local_override_sync(
|
|||||||
if overrides.get("port") is not None:
|
if overrides.get("port") is not None:
|
||||||
lines.append(f"port = {overrides['port']}")
|
lines.append(f"port = {overrides['port']}")
|
||||||
if overrides.get("logpath"):
|
if overrides.get("logpath"):
|
||||||
paths: list[str] = cast(list[str], overrides["logpath"])
|
paths: list[str] = cast("list[str]", overrides["logpath"])
|
||||||
if paths:
|
if paths:
|
||||||
lines.append(f"logpath = {paths[0]}")
|
lines.append(f"logpath = {paths[0]}")
|
||||||
for p in paths[1:]:
|
for p in paths[1:]:
|
||||||
@@ -890,9 +870,7 @@ def _write_local_override_sync(
|
|||||||
# Clean up temp file if rename failed.
|
# Clean up temp file if rename failed.
|
||||||
with contextlib.suppress(OSError):
|
with contextlib.suppress(OSError):
|
||||||
os.unlink(tmp_name) # noqa: F821 — only reachable when tmp_name is set
|
os.unlink(tmp_name) # noqa: F821 — only reachable when tmp_name is set
|
||||||
raise ConfigWriteError(
|
raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc
|
||||||
f"Failed to write {local_path}: {exc}"
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
log.info(
|
log.info(
|
||||||
"jail_local_written",
|
"jail_local_written",
|
||||||
@@ -921,9 +899,7 @@ def _restore_local_file_sync(local_path: Path, original_content: bytes | None) -
|
|||||||
try:
|
try:
|
||||||
local_path.unlink(missing_ok=True)
|
local_path.unlink(missing_ok=True)
|
||||||
except OSError as exc:
|
except OSError as exc:
|
||||||
raise ConfigWriteError(
|
raise ConfigWriteError(f"Failed to delete {local_path} during rollback: {exc}") from exc
|
||||||
f"Failed to delete {local_path} during rollback: {exc}"
|
|
||||||
) from exc
|
|
||||||
return
|
return
|
||||||
|
|
||||||
tmp_name: str | None = None
|
tmp_name: str | None = None
|
||||||
@@ -941,9 +917,7 @@ def _restore_local_file_sync(local_path: Path, original_content: bytes | None) -
|
|||||||
with contextlib.suppress(OSError):
|
with contextlib.suppress(OSError):
|
||||||
if tmp_name is not None:
|
if tmp_name is not None:
|
||||||
os.unlink(tmp_name)
|
os.unlink(tmp_name)
|
||||||
raise ConfigWriteError(
|
raise ConfigWriteError(f"Failed to restore {local_path} during rollback: {exc}") from exc
|
||||||
f"Failed to restore {local_path} during rollback: {exc}"
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_regex_patterns(patterns: list[str]) -> None:
|
def _validate_regex_patterns(patterns: list[str]) -> None:
|
||||||
@@ -979,9 +953,7 @@ def _write_filter_local_sync(filter_d: Path, name: str, content: str) -> None:
|
|||||||
try:
|
try:
|
||||||
filter_d.mkdir(parents=True, exist_ok=True)
|
filter_d.mkdir(parents=True, exist_ok=True)
|
||||||
except OSError as exc:
|
except OSError as exc:
|
||||||
raise ConfigWriteError(
|
raise ConfigWriteError(f"Cannot create filter.d directory: {exc}") from exc
|
||||||
f"Cannot create filter.d directory: {exc}"
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
local_path = filter_d / f"{name}.local"
|
local_path = filter_d / f"{name}.local"
|
||||||
try:
|
try:
|
||||||
@@ -998,9 +970,7 @@ def _write_filter_local_sync(filter_d: Path, name: str, content: str) -> None:
|
|||||||
except OSError as exc:
|
except OSError as exc:
|
||||||
with contextlib.suppress(OSError):
|
with contextlib.suppress(OSError):
|
||||||
os.unlink(tmp_name) # noqa: F821
|
os.unlink(tmp_name) # noqa: F821
|
||||||
raise ConfigWriteError(
|
raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc
|
||||||
f"Failed to write {local_path}: {exc}"
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
log.info("filter_local_written", filter=name, path=str(local_path))
|
log.info("filter_local_written", filter=name, path=str(local_path))
|
||||||
|
|
||||||
@@ -1031,9 +1001,7 @@ def _set_jail_local_key_sync(
|
|||||||
try:
|
try:
|
||||||
jail_d.mkdir(parents=True, exist_ok=True)
|
jail_d.mkdir(parents=True, exist_ok=True)
|
||||||
except OSError as exc:
|
except OSError as exc:
|
||||||
raise ConfigWriteError(
|
raise ConfigWriteError(f"Cannot create jail.d directory: {exc}") from exc
|
||||||
f"Cannot create jail.d directory: {exc}"
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
local_path = jail_d / f"{jail_name}.local"
|
local_path = jail_d / f"{jail_name}.local"
|
||||||
|
|
||||||
@@ -1072,9 +1040,7 @@ def _set_jail_local_key_sync(
|
|||||||
except OSError as exc:
|
except OSError as exc:
|
||||||
with contextlib.suppress(OSError):
|
with contextlib.suppress(OSError):
|
||||||
os.unlink(tmp_name) # noqa: F821
|
os.unlink(tmp_name) # noqa: F821
|
||||||
raise ConfigWriteError(
|
raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc
|
||||||
f"Failed to write {local_path}: {exc}"
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
log.info(
|
log.info(
|
||||||
"jail_local_key_set",
|
"jail_local_key_set",
|
||||||
@@ -1112,8 +1078,8 @@ async def list_inactive_jails(
|
|||||||
inactive jails.
|
inactive jails.
|
||||||
"""
|
"""
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
parsed_result: tuple[dict[str, dict[str, str]], dict[str, str]] = (
|
parsed_result: tuple[dict[str, dict[str, str]], dict[str, str]] = await loop.run_in_executor(
|
||||||
await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir))
|
None, _parse_jails_sync, Path(config_dir)
|
||||||
)
|
)
|
||||||
all_jails, source_files = parsed_result
|
all_jails, source_files = parsed_result
|
||||||
active_names: set[str] = await _get_active_jail_names(socket_path)
|
active_names: set[str] = await _get_active_jail_names(socket_path)
|
||||||
@@ -1170,9 +1136,7 @@ async def activate_jail(
|
|||||||
_safe_jail_name(name)
|
_safe_jail_name(name)
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
all_jails, _source_files = await loop.run_in_executor(
|
all_jails, _source_files = await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir))
|
||||||
None, _parse_jails_sync, Path(config_dir)
|
|
||||||
)
|
|
||||||
|
|
||||||
if name not in all_jails:
|
if name not in all_jails:
|
||||||
raise JailNotFoundInConfigError(name)
|
raise JailNotFoundInConfigError(name)
|
||||||
@@ -1208,10 +1172,7 @@ async def activate_jail(
|
|||||||
active=False,
|
active=False,
|
||||||
fail2ban_running=True,
|
fail2ban_running=True,
|
||||||
validation_warnings=warnings,
|
validation_warnings=warnings,
|
||||||
message=(
|
message=(f"Jail {name!r} cannot be activated: " + "; ".join(i.message for i in blocking)),
|
||||||
f"Jail {name!r} cannot be activated: "
|
|
||||||
+ "; ".join(i.message for i in blocking)
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
overrides: dict[str, object] = {
|
overrides: dict[str, object] = {
|
||||||
@@ -1254,9 +1215,7 @@ async def activate_jail(
|
|||||||
jail=name,
|
jail=name,
|
||||||
error=str(exc),
|
error=str(exc),
|
||||||
)
|
)
|
||||||
recovered = await _rollback_activation_async(
|
recovered = await _rollback_activation_async(config_dir, name, socket_path, original_content)
|
||||||
config_dir, name, socket_path, original_content
|
|
||||||
)
|
|
||||||
return JailActivationResponse(
|
return JailActivationResponse(
|
||||||
name=name,
|
name=name,
|
||||||
active=False,
|
active=False,
|
||||||
@@ -1272,9 +1231,7 @@ async def activate_jail(
|
|||||||
)
|
)
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
log.warning("reload_after_activate_failed", jail=name, error=str(exc))
|
log.warning("reload_after_activate_failed", jail=name, error=str(exc))
|
||||||
recovered = await _rollback_activation_async(
|
recovered = await _rollback_activation_async(config_dir, name, socket_path, original_content)
|
||||||
config_dir, name, socket_path, original_content
|
|
||||||
)
|
|
||||||
return JailActivationResponse(
|
return JailActivationResponse(
|
||||||
name=name,
|
name=name,
|
||||||
active=False,
|
active=False,
|
||||||
@@ -1305,9 +1262,7 @@ async def activate_jail(
|
|||||||
jail=name,
|
jail=name,
|
||||||
message="fail2ban socket unreachable after reload — initiating rollback.",
|
message="fail2ban socket unreachable after reload — initiating rollback.",
|
||||||
)
|
)
|
||||||
recovered = await _rollback_activation_async(
|
recovered = await _rollback_activation_async(config_dir, name, socket_path, original_content)
|
||||||
config_dir, name, socket_path, original_content
|
|
||||||
)
|
|
||||||
return JailActivationResponse(
|
return JailActivationResponse(
|
||||||
name=name,
|
name=name,
|
||||||
active=False,
|
active=False,
|
||||||
@@ -1330,9 +1285,7 @@ async def activate_jail(
|
|||||||
jail=name,
|
jail=name,
|
||||||
message="Jail did not appear in running jails — initiating rollback.",
|
message="Jail did not appear in running jails — initiating rollback.",
|
||||||
)
|
)
|
||||||
recovered = await _rollback_activation_async(
|
recovered = await _rollback_activation_async(config_dir, name, socket_path, original_content)
|
||||||
config_dir, name, socket_path, original_content
|
|
||||||
)
|
|
||||||
return JailActivationResponse(
|
return JailActivationResponse(
|
||||||
name=name,
|
name=name,
|
||||||
active=False,
|
active=False,
|
||||||
@@ -1388,14 +1341,10 @@ async def _rollback_activation_async(
|
|||||||
|
|
||||||
# Step 1 — restore original file (or delete it).
|
# Step 1 — restore original file (or delete it).
|
||||||
try:
|
try:
|
||||||
await loop.run_in_executor(
|
await loop.run_in_executor(None, _restore_local_file_sync, local_path, original_content)
|
||||||
None, _restore_local_file_sync, local_path, original_content
|
|
||||||
)
|
|
||||||
log.info("jail_activation_rollback_file_restored", jail=name)
|
log.info("jail_activation_rollback_file_restored", jail=name)
|
||||||
except ConfigWriteError as exc:
|
except ConfigWriteError as exc:
|
||||||
log.error(
|
log.error("jail_activation_rollback_restore_failed", jail=name, error=str(exc))
|
||||||
"jail_activation_rollback_restore_failed", jail=name, error=str(exc)
|
|
||||||
)
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Step 2 — reload fail2ban with the restored config.
|
# Step 2 — reload fail2ban with the restored config.
|
||||||
@@ -1403,9 +1352,7 @@ async def _rollback_activation_async(
|
|||||||
await jail_service.reload_all(socket_path)
|
await jail_service.reload_all(socket_path)
|
||||||
log.info("jail_activation_rollback_reload_ok", jail=name)
|
log.info("jail_activation_rollback_reload_ok", jail=name)
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
log.warning(
|
log.warning("jail_activation_rollback_reload_failed", jail=name, error=str(exc))
|
||||||
"jail_activation_rollback_reload_failed", jail=name, error=str(exc)
|
|
||||||
)
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Step 3 — wait for fail2ban to come back.
|
# Step 3 — wait for fail2ban to come back.
|
||||||
@@ -1450,9 +1397,7 @@ async def deactivate_jail(
|
|||||||
_safe_jail_name(name)
|
_safe_jail_name(name)
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
all_jails, _source_files = await loop.run_in_executor(
|
all_jails, _source_files = await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir))
|
||||||
None, _parse_jails_sync, Path(config_dir)
|
|
||||||
)
|
|
||||||
|
|
||||||
if name not in all_jails:
|
if name not in all_jails:
|
||||||
raise JailNotFoundInConfigError(name)
|
raise JailNotFoundInConfigError(name)
|
||||||
@@ -1510,9 +1455,7 @@ async def delete_jail_local_override(
|
|||||||
_safe_jail_name(name)
|
_safe_jail_name(name)
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
all_jails, _source_files = await loop.run_in_executor(
|
all_jails, _source_files = await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir))
|
||||||
None, _parse_jails_sync, Path(config_dir)
|
|
||||||
)
|
|
||||||
|
|
||||||
if name not in all_jails:
|
if name not in all_jails:
|
||||||
raise JailNotFoundInConfigError(name)
|
raise JailNotFoundInConfigError(name)
|
||||||
@@ -1523,13 +1466,9 @@ async def delete_jail_local_override(
|
|||||||
|
|
||||||
local_path = Path(config_dir) / "jail.d" / f"{name}.local"
|
local_path = Path(config_dir) / "jail.d" / f"{name}.local"
|
||||||
try:
|
try:
|
||||||
await loop.run_in_executor(
|
await loop.run_in_executor(None, lambda: local_path.unlink(missing_ok=True))
|
||||||
None, lambda: local_path.unlink(missing_ok=True)
|
|
||||||
)
|
|
||||||
except OSError as exc:
|
except OSError as exc:
|
||||||
raise ConfigWriteError(
|
raise ConfigWriteError(f"Failed to delete {local_path}: {exc}") from exc
|
||||||
f"Failed to delete {local_path}: {exc}"
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
log.info("jail_local_override_deleted", jail=name, path=str(local_path))
|
log.info("jail_local_override_deleted", jail=name, path=str(local_path))
|
||||||
|
|
||||||
@@ -1610,9 +1549,7 @@ async def rollback_jail(
|
|||||||
log.info("jail_rollback_start_attempted", jail=name, start_ok=started)
|
log.info("jail_rollback_start_attempted", jail=name, start_ok=started)
|
||||||
|
|
||||||
# Wait for the socket to come back.
|
# Wait for the socket to come back.
|
||||||
fail2ban_running = await wait_for_fail2ban(
|
fail2ban_running = await wait_for_fail2ban(socket_path, max_wait_seconds=10.0, poll_interval=2.0)
|
||||||
socket_path, max_wait_seconds=10.0, poll_interval=2.0
|
|
||||||
)
|
|
||||||
|
|
||||||
active_jails = 0
|
active_jails = 0
|
||||||
if fail2ban_running:
|
if fail2ban_running:
|
||||||
@@ -1626,10 +1563,7 @@ async def rollback_jail(
|
|||||||
disabled=True,
|
disabled=True,
|
||||||
fail2ban_running=True,
|
fail2ban_running=True,
|
||||||
active_jails=active_jails,
|
active_jails=active_jails,
|
||||||
message=(
|
message=(f"Jail {name!r} disabled and fail2ban restarted successfully with {active_jails} active jail(s)."),
|
||||||
f"Jail {name!r} disabled and fail2ban restarted successfully "
|
|
||||||
f"with {active_jails} active jail(s)."
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
log.warning("jail_rollback_fail2ban_still_down", jail=name)
|
log.warning("jail_rollback_fail2ban_still_down", jail=name)
|
||||||
@@ -1650,9 +1584,7 @@ async def rollback_jail(
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
# Allowlist pattern for filter names used in path construction.
|
# Allowlist pattern for filter names used in path construction.
|
||||||
_SAFE_FILTER_NAME_RE: re.Pattern[str] = re.compile(
|
_SAFE_FILTER_NAME_RE: re.Pattern[str] = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$")
|
||||||
r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FilterNotFoundError(Exception):
|
class FilterNotFoundError(Exception):
|
||||||
@@ -1764,9 +1696,7 @@ def _parse_filters_sync(
|
|||||||
try:
|
try:
|
||||||
content = conf_path.read_text(encoding="utf-8")
|
content = conf_path.read_text(encoding="utf-8")
|
||||||
except OSError as exc:
|
except OSError as exc:
|
||||||
log.warning(
|
log.warning("filter_read_error", name=name, path=str(conf_path), error=str(exc))
|
||||||
"filter_read_error", name=name, path=str(conf_path), error=str(exc)
|
|
||||||
)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if has_local:
|
if has_local:
|
||||||
@@ -1842,9 +1772,7 @@ async def list_filters(
|
|||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
# Run the synchronous scan in a thread-pool executor.
|
# Run the synchronous scan in a thread-pool executor.
|
||||||
raw_filters: list[tuple[str, str, str, bool, str]] = await loop.run_in_executor(
|
raw_filters: list[tuple[str, str, str, bool, str]] = await loop.run_in_executor(None, _parse_filters_sync, filter_d)
|
||||||
None, _parse_filters_sync, filter_d
|
|
||||||
)
|
|
||||||
|
|
||||||
# Fetch active jail names and their configs concurrently.
|
# Fetch active jail names and their configs concurrently.
|
||||||
all_jails_result, active_names = await asyncio.gather(
|
all_jails_result, active_names = await asyncio.gather(
|
||||||
@@ -1857,9 +1785,7 @@ async def list_filters(
|
|||||||
|
|
||||||
filters: list[FilterConfig] = []
|
filters: list[FilterConfig] = []
|
||||||
for name, filename, content, has_local, source_path in raw_filters:
|
for name, filename, content, has_local, source_path in raw_filters:
|
||||||
cfg = conffile_parser.parse_filter_file(
|
cfg = conffile_parser.parse_filter_file(content, name=name, filename=filename)
|
||||||
content, name=name, filename=filename
|
|
||||||
)
|
|
||||||
used_by = sorted(filter_to_jails.get(name, []))
|
used_by = sorted(filter_to_jails.get(name, []))
|
||||||
filters.append(
|
filters.append(
|
||||||
FilterConfig(
|
FilterConfig(
|
||||||
@@ -1947,9 +1873,7 @@ async def get_filter(
|
|||||||
|
|
||||||
content, has_local, source_path = await loop.run_in_executor(None, _read)
|
content, has_local, source_path = await loop.run_in_executor(None, _read)
|
||||||
|
|
||||||
cfg = conffile_parser.parse_filter_file(
|
cfg = conffile_parser.parse_filter_file(content, name=base_name, filename=f"{base_name}.conf")
|
||||||
content, name=base_name, filename=f"{base_name}.conf"
|
|
||||||
)
|
|
||||||
|
|
||||||
all_jails_result, active_names = await asyncio.gather(
|
all_jails_result, active_names = await asyncio.gather(
|
||||||
loop.run_in_executor(None, _parse_jails_sync, Path(config_dir)),
|
loop.run_in_executor(None, _parse_jails_sync, Path(config_dir)),
|
||||||
@@ -2182,9 +2106,7 @@ async def delete_filter(
|
|||||||
try:
|
try:
|
||||||
local_path.unlink()
|
local_path.unlink()
|
||||||
except OSError as exc:
|
except OSError as exc:
|
||||||
raise ConfigWriteError(
|
raise ConfigWriteError(f"Failed to delete {local_path}: {exc}") from exc
|
||||||
f"Failed to delete {local_path}: {exc}"
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
log.info("filter_local_deleted", filter=base_name, path=str(local_path))
|
log.info("filter_local_deleted", filter=base_name, path=str(local_path))
|
||||||
|
|
||||||
@@ -2226,9 +2148,7 @@ async def assign_filter_to_jail(
|
|||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
# Verify the jail exists in config.
|
# Verify the jail exists in config.
|
||||||
all_jails, _src = await loop.run_in_executor(
|
all_jails, _src = await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir))
|
||||||
None, _parse_jails_sync, Path(config_dir)
|
|
||||||
)
|
|
||||||
if jail_name not in all_jails:
|
if jail_name not in all_jails:
|
||||||
raise JailNotFoundInConfigError(jail_name)
|
raise JailNotFoundInConfigError(jail_name)
|
||||||
|
|
||||||
@@ -2276,9 +2196,7 @@ async def assign_filter_to_jail(
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
# Allowlist pattern for action names used in path construction.
|
# Allowlist pattern for action names used in path construction.
|
||||||
_SAFE_ACTION_NAME_RE: re.Pattern[str] = re.compile(
|
_SAFE_ACTION_NAME_RE: re.Pattern[str] = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$")
|
||||||
r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ActionNotFoundError(Exception):
|
class ActionNotFoundError(Exception):
|
||||||
@@ -2318,8 +2236,7 @@ class ActionReadonlyError(Exception):
|
|||||||
"""
|
"""
|
||||||
self.name: str = name
|
self.name: str = name
|
||||||
super().__init__(
|
super().__init__(
|
||||||
f"Action {name!r} is a shipped default (.conf only); "
|
f"Action {name!r} is a shipped default (.conf only); only user-created .local files can be deleted."
|
||||||
"only user-created .local files can be deleted."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -2428,9 +2345,7 @@ def _parse_actions_sync(
|
|||||||
try:
|
try:
|
||||||
content = conf_path.read_text(encoding="utf-8")
|
content = conf_path.read_text(encoding="utf-8")
|
||||||
except OSError as exc:
|
except OSError as exc:
|
||||||
log.warning(
|
log.warning("action_read_error", name=name, path=str(conf_path), error=str(exc))
|
||||||
"action_read_error", name=name, path=str(conf_path), error=str(exc)
|
|
||||||
)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if has_local:
|
if has_local:
|
||||||
@@ -2495,9 +2410,7 @@ def _append_jail_action_sync(
|
|||||||
try:
|
try:
|
||||||
jail_d.mkdir(parents=True, exist_ok=True)
|
jail_d.mkdir(parents=True, exist_ok=True)
|
||||||
except OSError as exc:
|
except OSError as exc:
|
||||||
raise ConfigWriteError(
|
raise ConfigWriteError(f"Cannot create jail.d directory: {exc}") from exc
|
||||||
f"Cannot create jail.d directory: {exc}"
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
local_path = jail_d / f"{jail_name}.local"
|
local_path = jail_d / f"{jail_name}.local"
|
||||||
|
|
||||||
@@ -2517,9 +2430,7 @@ def _append_jail_action_sync(
|
|||||||
|
|
||||||
existing_raw = parser.get(jail_name, "action") if parser.has_option(jail_name, "action") else ""
|
existing_raw = parser.get(jail_name, "action") if parser.has_option(jail_name, "action") else ""
|
||||||
existing_lines = [
|
existing_lines = [
|
||||||
line.strip()
|
line.strip() for line in existing_raw.splitlines() if line.strip() and not line.strip().startswith("#")
|
||||||
for line in existing_raw.splitlines()
|
|
||||||
if line.strip() and not line.strip().startswith("#")
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# Extract base names from existing entries for duplicate checking.
|
# Extract base names from existing entries for duplicate checking.
|
||||||
@@ -2533,9 +2444,7 @@ def _append_jail_action_sync(
|
|||||||
|
|
||||||
if existing_lines:
|
if existing_lines:
|
||||||
# configparser multi-line: continuation lines start with whitespace.
|
# configparser multi-line: continuation lines start with whitespace.
|
||||||
new_value = existing_lines[0] + "".join(
|
new_value = existing_lines[0] + "".join(f"\n {line}" for line in existing_lines[1:])
|
||||||
f"\n {line}" for line in existing_lines[1:]
|
|
||||||
)
|
|
||||||
parser.set(jail_name, "action", new_value)
|
parser.set(jail_name, "action", new_value)
|
||||||
else:
|
else:
|
||||||
parser.set(jail_name, "action", action_entry)
|
parser.set(jail_name, "action", action_entry)
|
||||||
@@ -2559,9 +2468,7 @@ def _append_jail_action_sync(
|
|||||||
except OSError as exc:
|
except OSError as exc:
|
||||||
with contextlib.suppress(OSError):
|
with contextlib.suppress(OSError):
|
||||||
os.unlink(tmp_name) # noqa: F821
|
os.unlink(tmp_name) # noqa: F821
|
||||||
raise ConfigWriteError(
|
raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc
|
||||||
f"Failed to write {local_path}: {exc}"
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
log.info(
|
log.info(
|
||||||
"jail_action_appended",
|
"jail_action_appended",
|
||||||
@@ -2612,9 +2519,7 @@ def _remove_jail_action_sync(
|
|||||||
|
|
||||||
existing_raw = parser.get(jail_name, "action")
|
existing_raw = parser.get(jail_name, "action")
|
||||||
existing_lines = [
|
existing_lines = [
|
||||||
line.strip()
|
line.strip() for line in existing_raw.splitlines() if line.strip() and not line.strip().startswith("#")
|
||||||
for line in existing_raw.splitlines()
|
|
||||||
if line.strip() and not line.strip().startswith("#")
|
|
||||||
]
|
]
|
||||||
|
|
||||||
def _base(entry: str) -> str:
|
def _base(entry: str) -> str:
|
||||||
@@ -2628,9 +2533,7 @@ def _remove_jail_action_sync(
|
|||||||
return
|
return
|
||||||
|
|
||||||
if filtered:
|
if filtered:
|
||||||
new_value = filtered[0] + "".join(
|
new_value = filtered[0] + "".join(f"\n {line}" for line in filtered[1:])
|
||||||
f"\n {line}" for line in filtered[1:]
|
|
||||||
)
|
|
||||||
parser.set(jail_name, "action", new_value)
|
parser.set(jail_name, "action", new_value)
|
||||||
else:
|
else:
|
||||||
parser.remove_option(jail_name, "action")
|
parser.remove_option(jail_name, "action")
|
||||||
@@ -2654,9 +2557,7 @@ def _remove_jail_action_sync(
|
|||||||
except OSError as exc:
|
except OSError as exc:
|
||||||
with contextlib.suppress(OSError):
|
with contextlib.suppress(OSError):
|
||||||
os.unlink(tmp_name) # noqa: F821
|
os.unlink(tmp_name) # noqa: F821
|
||||||
raise ConfigWriteError(
|
raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc
|
||||||
f"Failed to write {local_path}: {exc}"
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
log.info(
|
log.info(
|
||||||
"jail_action_removed",
|
"jail_action_removed",
|
||||||
@@ -2683,9 +2584,7 @@ def _write_action_local_sync(action_d: Path, name: str, content: str) -> None:
|
|||||||
try:
|
try:
|
||||||
action_d.mkdir(parents=True, exist_ok=True)
|
action_d.mkdir(parents=True, exist_ok=True)
|
||||||
except OSError as exc:
|
except OSError as exc:
|
||||||
raise ConfigWriteError(
|
raise ConfigWriteError(f"Cannot create action.d directory: {exc}") from exc
|
||||||
f"Cannot create action.d directory: {exc}"
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
local_path = action_d / f"{name}.local"
|
local_path = action_d / f"{name}.local"
|
||||||
try:
|
try:
|
||||||
@@ -2702,9 +2601,7 @@ def _write_action_local_sync(action_d: Path, name: str, content: str) -> None:
|
|||||||
except OSError as exc:
|
except OSError as exc:
|
||||||
with contextlib.suppress(OSError):
|
with contextlib.suppress(OSError):
|
||||||
os.unlink(tmp_name) # noqa: F821
|
os.unlink(tmp_name) # noqa: F821
|
||||||
raise ConfigWriteError(
|
raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc
|
||||||
f"Failed to write {local_path}: {exc}"
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
log.info("action_local_written", action=name, path=str(local_path))
|
log.info("action_local_written", action=name, path=str(local_path))
|
||||||
|
|
||||||
@@ -2740,9 +2637,7 @@ async def list_actions(
|
|||||||
action_d = Path(config_dir) / "action.d"
|
action_d = Path(config_dir) / "action.d"
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
raw_actions: list[tuple[str, str, str, bool, str]] = await loop.run_in_executor(
|
raw_actions: list[tuple[str, str, str, bool, str]] = await loop.run_in_executor(None, _parse_actions_sync, action_d)
|
||||||
None, _parse_actions_sync, action_d
|
|
||||||
)
|
|
||||||
|
|
||||||
all_jails_result, active_names = await asyncio.gather(
|
all_jails_result, active_names = await asyncio.gather(
|
||||||
loop.run_in_executor(None, _parse_jails_sync, Path(config_dir)),
|
loop.run_in_executor(None, _parse_jails_sync, Path(config_dir)),
|
||||||
@@ -2754,9 +2649,7 @@ async def list_actions(
|
|||||||
|
|
||||||
actions: list[ActionConfig] = []
|
actions: list[ActionConfig] = []
|
||||||
for name, filename, content, has_local, source_path in raw_actions:
|
for name, filename, content, has_local, source_path in raw_actions:
|
||||||
cfg = conffile_parser.parse_action_file(
|
cfg = conffile_parser.parse_action_file(content, name=name, filename=filename)
|
||||||
content, name=name, filename=filename
|
|
||||||
)
|
|
||||||
used_by = sorted(action_to_jails.get(name, []))
|
used_by = sorted(action_to_jails.get(name, []))
|
||||||
actions.append(
|
actions.append(
|
||||||
ActionConfig(
|
ActionConfig(
|
||||||
@@ -2843,9 +2736,7 @@ async def get_action(
|
|||||||
|
|
||||||
content, has_local, source_path = await loop.run_in_executor(None, _read)
|
content, has_local, source_path = await loop.run_in_executor(None, _read)
|
||||||
|
|
||||||
cfg = conffile_parser.parse_action_file(
|
cfg = conffile_parser.parse_action_file(content, name=base_name, filename=f"{base_name}.conf")
|
||||||
content, name=base_name, filename=f"{base_name}.conf"
|
|
||||||
)
|
|
||||||
|
|
||||||
all_jails_result, active_names = await asyncio.gather(
|
all_jails_result, active_names = await asyncio.gather(
|
||||||
loop.run_in_executor(None, _parse_jails_sync, Path(config_dir)),
|
loop.run_in_executor(None, _parse_jails_sync, Path(config_dir)),
|
||||||
@@ -3061,9 +2952,7 @@ async def delete_action(
|
|||||||
try:
|
try:
|
||||||
local_path.unlink()
|
local_path.unlink()
|
||||||
except OSError as exc:
|
except OSError as exc:
|
||||||
raise ConfigWriteError(
|
raise ConfigWriteError(f"Failed to delete {local_path}: {exc}") from exc
|
||||||
f"Failed to delete {local_path}: {exc}"
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
log.info("action_local_deleted", action=base_name, path=str(local_path))
|
log.info("action_local_deleted", action=base_name, path=str(local_path))
|
||||||
|
|
||||||
@@ -3105,9 +2994,7 @@ async def assign_action_to_jail(
|
|||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
all_jails, _src = await loop.run_in_executor(
|
all_jails, _src = await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir))
|
||||||
None, _parse_jails_sync, Path(config_dir)
|
|
||||||
)
|
|
||||||
if jail_name not in all_jails:
|
if jail_name not in all_jails:
|
||||||
raise JailNotFoundInConfigError(jail_name)
|
raise JailNotFoundInConfigError(jail_name)
|
||||||
|
|
||||||
@@ -3187,9 +3074,7 @@ async def remove_action_from_jail(
|
|||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
all_jails, _src = await loop.run_in_executor(
|
all_jails, _src = await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir))
|
||||||
None, _parse_jails_sync, Path(config_dir)
|
|
||||||
)
|
|
||||||
if jail_name not in all_jails:
|
if jail_name not in all_jails:
|
||||||
raise JailNotFoundInConfigError(jail_name)
|
raise JailNotFoundInConfigError(jail_name)
|
||||||
|
|
||||||
@@ -3218,4 +3103,3 @@ async def remove_action_from_jail(
|
|||||||
action=action_name,
|
action=action_name,
|
||||||
reload=do_reload,
|
reload=do_reload,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -95,7 +95,7 @@ def _ok(response: object) -> object:
|
|||||||
ValueError: If the return code indicates an error.
|
ValueError: If the return code indicates an error.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
code, data = cast(Fail2BanResponse, response)
|
code, data = cast("Fail2BanResponse", response)
|
||||||
except (TypeError, ValueError) as exc:
|
except (TypeError, ValueError) as exc:
|
||||||
raise ValueError(f"Unexpected fail2ban response shape: {response!r}") from exc
|
raise ValueError(f"Unexpected fail2ban response shape: {response!r}") from exc
|
||||||
if code != 0:
|
if code != 0:
|
||||||
@@ -128,7 +128,7 @@ def _ensure_list(value: object | None) -> list[str]:
|
|||||||
return [str(value)]
|
return [str(value)]
|
||||||
|
|
||||||
|
|
||||||
_T = TypeVar("_T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
async def _safe_get(
|
async def _safe_get(
|
||||||
@@ -143,13 +143,13 @@ async def _safe_get(
|
|||||||
return default
|
return default
|
||||||
|
|
||||||
|
|
||||||
async def _safe_get_typed(
|
async def _safe_get_typed[T](
|
||||||
client: Fail2BanClient,
|
client: Fail2BanClient,
|
||||||
command: Fail2BanCommand,
|
command: Fail2BanCommand,
|
||||||
default: _T,
|
default: T,
|
||||||
) -> _T:
|
) -> T:
|
||||||
"""Send a command and return the result typed as ``default``'s type."""
|
"""Send a command and return the result typed as ``default``'s type."""
|
||||||
return cast(_T, await _safe_get(client, command, default))
|
return cast("T", await _safe_get(client, command, default))
|
||||||
|
|
||||||
|
|
||||||
def _is_not_found_error(exc: Exception) -> bool:
|
def _is_not_found_error(exc: Exception) -> bool:
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ def _ok(response: object) -> object:
|
|||||||
ValueError: If the response indicates an error (return code ≠ 0).
|
ValueError: If the response indicates an error (return code ≠ 0).
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
code, data = cast(Fail2BanResponse, response)
|
code, data = cast("Fail2BanResponse", response)
|
||||||
except (TypeError, ValueError) as exc:
|
except (TypeError, ValueError) as exc:
|
||||||
raise ValueError(f"Unexpected fail2ban response shape: {response!r}") from exc
|
raise ValueError(f"Unexpected fail2ban response shape: {response!r}") from exc
|
||||||
|
|
||||||
|
|||||||
@@ -11,9 +11,11 @@ modifies or locks the fail2ban database.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
from app.services.geo_service import GeoEnricher
|
from app.services.geo_service import GeoEnricher
|
||||||
|
|
||||||
from app.models.ban import TIME_RANGE_SECONDS, TimeRange
|
from app.models.ban import TIME_RANGE_SECONDS, TimeRange
|
||||||
|
|||||||
@@ -14,7 +14,8 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import contextlib
|
import contextlib
|
||||||
import ipaddress
|
import ipaddress
|
||||||
from typing import TYPE_CHECKING, Awaitable, Callable, cast, TypeAlias
|
from collections.abc import Awaitable, Callable
|
||||||
|
from typing import TYPE_CHECKING, TypedDict, cast
|
||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
|
|
||||||
@@ -27,6 +28,7 @@ from app.models.jail import (
|
|||||||
JailStatus,
|
JailStatus,
|
||||||
JailSummary,
|
JailSummary,
|
||||||
)
|
)
|
||||||
|
from app.services.geo_service import GeoInfo
|
||||||
from app.utils.fail2ban_client import (
|
from app.utils.fail2ban_client import (
|
||||||
Fail2BanClient,
|
Fail2BanClient,
|
||||||
Fail2BanCommand,
|
Fail2BanCommand,
|
||||||
@@ -39,11 +41,21 @@ if TYPE_CHECKING:
|
|||||||
import aiohttp
|
import aiohttp
|
||||||
import aiosqlite
|
import aiosqlite
|
||||||
|
|
||||||
from app.services.geo_service import GeoInfo
|
|
||||||
|
|
||||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||||
|
|
||||||
GeoEnricher: TypeAlias = Callable[[str], Awaitable["GeoInfo | None"]]
|
class IpLookupResult(TypedDict):
|
||||||
|
"""Result returned by :func:`lookup_ip`.
|
||||||
|
|
||||||
|
This is intentionally a :class:`TypedDict` to provide precise typing for
|
||||||
|
callers (e.g. routers) while keeping the implementation flexible.
|
||||||
|
"""
|
||||||
|
|
||||||
|
ip: str
|
||||||
|
currently_banned_in: list[str]
|
||||||
|
geo: GeoInfo | None
|
||||||
|
|
||||||
|
|
||||||
|
GeoEnricher = Callable[[str], Awaitable[GeoInfo | None]]
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Constants
|
# Constants
|
||||||
@@ -104,7 +116,7 @@ def _ok(response: object) -> object:
|
|||||||
ValueError: If the response indicates an error (return code ≠ 0).
|
ValueError: If the response indicates an error (return code ≠ 0).
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
code, data = cast(Fail2BanResponse, response)
|
code, data = cast("Fail2BanResponse", response)
|
||||||
except (TypeError, ValueError) as exc:
|
except (TypeError, ValueError) as exc:
|
||||||
raise ValueError(f"Unexpected fail2ban response shape: {response!r}") from exc
|
raise ValueError(f"Unexpected fail2ban response shape: {response!r}") from exc
|
||||||
|
|
||||||
@@ -202,7 +214,7 @@ async def _safe_get(
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
response = await client.send(command)
|
response = await client.send(command)
|
||||||
return _ok(cast(Fail2BanResponse, response))
|
return _ok(cast("Fail2BanResponse", response))
|
||||||
except (ValueError, TypeError, Exception):
|
except (ValueError, TypeError, Exception):
|
||||||
return default
|
return default
|
||||||
|
|
||||||
@@ -337,7 +349,6 @@ async def _fetch_jail_summary(
|
|||||||
client.send(["get", name, "backend"]),
|
client.send(["get", name, "backend"]),
|
||||||
client.send(["get", name, "idle"]),
|
client.send(["get", name, "idle"]),
|
||||||
])
|
])
|
||||||
uses_backend_backend_commands = True
|
|
||||||
else:
|
else:
|
||||||
# Commands not supported; return default values without sending.
|
# Commands not supported; return default values without sending.
|
||||||
async def _return_default(value: object | None) -> Fail2BanResponse:
|
async def _return_default(value: object | None) -> Fail2BanResponse:
|
||||||
@@ -347,7 +358,6 @@ async def _fetch_jail_summary(
|
|||||||
_return_default("polling"), # backend default
|
_return_default("polling"), # backend default
|
||||||
_return_default(False), # idle default
|
_return_default(False), # idle default
|
||||||
])
|
])
|
||||||
uses_backend_backend_commands = False
|
|
||||||
|
|
||||||
_r = await asyncio.gather(*gather_list, return_exceptions=True)
|
_r = await asyncio.gather(*gather_list, return_exceptions=True)
|
||||||
status_raw: object | Exception = _r[0]
|
status_raw: object | Exception = _r[0]
|
||||||
@@ -377,7 +387,7 @@ async def _fetch_jail_summary(
|
|||||||
if isinstance(raw, Exception):
|
if isinstance(raw, Exception):
|
||||||
return fallback
|
return fallback
|
||||||
try:
|
try:
|
||||||
return int(str(_ok(cast(Fail2BanResponse, raw))))
|
return int(str(_ok(cast("Fail2BanResponse", raw))))
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
return fallback
|
return fallback
|
||||||
|
|
||||||
@@ -385,7 +395,7 @@ async def _fetch_jail_summary(
|
|||||||
if isinstance(raw, Exception):
|
if isinstance(raw, Exception):
|
||||||
return fallback
|
return fallback
|
||||||
try:
|
try:
|
||||||
return str(_ok(cast(Fail2BanResponse, raw)))
|
return str(_ok(cast("Fail2BanResponse", raw)))
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
return fallback
|
return fallback
|
||||||
|
|
||||||
@@ -393,7 +403,7 @@ async def _fetch_jail_summary(
|
|||||||
if isinstance(raw, Exception):
|
if isinstance(raw, Exception):
|
||||||
return fallback
|
return fallback
|
||||||
try:
|
try:
|
||||||
return bool(_ok(cast(Fail2BanResponse, raw)))
|
return bool(_ok(cast("Fail2BanResponse", raw)))
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
return fallback
|
return fallback
|
||||||
|
|
||||||
@@ -687,7 +697,7 @@ async def reload_all(
|
|||||||
names_set -= set(exclude_jails)
|
names_set -= set(exclude_jails)
|
||||||
|
|
||||||
stream: list[list[object]] = [["start", n] for n in sorted(names_set)]
|
stream: list[list[object]] = [["start", n] for n in sorted(names_set)]
|
||||||
_ok(await client.send(["reload", "--all", [], cast(Fail2BanToken, stream)]))
|
_ok(await client.send(["reload", "--all", [], cast("Fail2BanToken", stream)]))
|
||||||
log.info("all_jails_reloaded")
|
log.info("all_jails_reloaded")
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
# Detect UnknownJailException (missing or invalid jail configuration)
|
# Detect UnknownJailException (missing or invalid jail configuration)
|
||||||
@@ -811,8 +821,8 @@ async def unban_ip(
|
|||||||
async def get_active_bans(
|
async def get_active_bans(
|
||||||
socket_path: str,
|
socket_path: str,
|
||||||
geo_enricher: GeoEnricher | None = None,
|
geo_enricher: GeoEnricher | None = None,
|
||||||
http_session: "aiohttp.ClientSession" | None = None,
|
http_session: aiohttp.ClientSession | None = None,
|
||||||
app_db: "aiosqlite.Connection" | None = None,
|
app_db: aiosqlite.Connection | None = None,
|
||||||
) -> ActiveBanListResponse:
|
) -> ActiveBanListResponse:
|
||||||
"""Return all currently banned IPs across every jail.
|
"""Return all currently banned IPs across every jail.
|
||||||
|
|
||||||
@@ -880,7 +890,7 @@ async def get_active_bans(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ban_list: list[str] = cast(list[str], _ok(raw_result)) or []
|
ban_list: list[str] = cast("list[str]", _ok(raw_result)) or []
|
||||||
except (TypeError, ValueError) as exc:
|
except (TypeError, ValueError) as exc:
|
||||||
log.warning(
|
log.warning(
|
||||||
"active_bans_parse_error",
|
"active_bans_parse_error",
|
||||||
@@ -1007,8 +1017,8 @@ async def get_jail_banned_ips(
|
|||||||
page: int = 1,
|
page: int = 1,
|
||||||
page_size: int = 25,
|
page_size: int = 25,
|
||||||
search: str | None = None,
|
search: str | None = None,
|
||||||
http_session: "aiohttp.ClientSession" | None = None,
|
http_session: aiohttp.ClientSession | None = None,
|
||||||
app_db: "aiosqlite.Connection" | None = None,
|
app_db: aiosqlite.Connection | None = None,
|
||||||
) -> JailBannedIpsResponse:
|
) -> JailBannedIpsResponse:
|
||||||
"""Return a paginated list of currently banned IPs for a single jail.
|
"""Return a paginated list of currently banned IPs for a single jail.
|
||||||
|
|
||||||
@@ -1055,7 +1065,7 @@ async def get_jail_banned_ips(
|
|||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
raw_result = []
|
raw_result = []
|
||||||
|
|
||||||
ban_list: list[str] = cast(list[str], raw_result) or []
|
ban_list: list[str] = cast("list[str]", raw_result) or []
|
||||||
|
|
||||||
# Parse all entries.
|
# Parse all entries.
|
||||||
all_bans: list[ActiveBan] = []
|
all_bans: list[ActiveBan] = []
|
||||||
@@ -1121,7 +1131,7 @@ async def _enrich_bans(
|
|||||||
The same list with ``country`` fields populated where lookup succeeded.
|
The same list with ``country`` fields populated where lookup succeeded.
|
||||||
"""
|
"""
|
||||||
geo_results: list[object | Exception] = await asyncio.gather(
|
geo_results: list[object | Exception] = await asyncio.gather(
|
||||||
*[cast(Awaitable[object], geo_enricher(ban.ip)) for ban in bans],
|
*[cast("Awaitable[object]", geo_enricher(ban.ip)) for ban in bans],
|
||||||
return_exceptions=True,
|
return_exceptions=True,
|
||||||
)
|
)
|
||||||
enriched: list[ActiveBan] = []
|
enriched: list[ActiveBan] = []
|
||||||
@@ -1277,7 +1287,7 @@ async def lookup_ip(
|
|||||||
socket_path: str,
|
socket_path: str,
|
||||||
ip: str,
|
ip: str,
|
||||||
geo_enricher: GeoEnricher | None = None,
|
geo_enricher: GeoEnricher | None = None,
|
||||||
) -> dict[str, object | list[str] | None]:
|
) -> IpLookupResult:
|
||||||
"""Return ban status and history for a single IP address.
|
"""Return ban status and history for a single IP address.
|
||||||
|
|
||||||
Checks every running jail for whether the IP is currently banned.
|
Checks every running jail for whether the IP is currently banned.
|
||||||
@@ -1330,7 +1340,7 @@ async def lookup_ip(
|
|||||||
if isinstance(result, Exception):
|
if isinstance(result, Exception):
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
ban_list: list[str] = cast(list[str], _ok(result)) or []
|
ban_list: list[str] = cast("list[str]", _ok(result)) or []
|
||||||
if ip in ban_list:
|
if ip in ban_list:
|
||||||
currently_banned_in.append(jail_name)
|
currently_banned_in.append(jail_name)
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ HTTP/FastAPI concerns.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import cast, TypeAlias
|
from typing import cast
|
||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
|
|
||||||
@@ -21,7 +21,7 @@ from app.utils.fail2ban_client import Fail2BanClient, Fail2BanCommand, Fail2BanR
|
|||||||
# Types
|
# Types
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
Fail2BanSettingValue: TypeAlias = str | int | bool
|
type Fail2BanSettingValue = str | int | bool
|
||||||
"""Allowed values for server settings commands."""
|
"""Allowed values for server settings commands."""
|
||||||
|
|
||||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||||
@@ -106,7 +106,7 @@ async def _safe_get(
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
response = await client.send(command)
|
response = await client.send(command)
|
||||||
return _ok(cast(Fail2BanResponse, response))
|
return _ok(cast("Fail2BanResponse", response))
|
||||||
except Exception:
|
except Exception:
|
||||||
return default
|
return default
|
||||||
|
|
||||||
@@ -189,7 +189,7 @@ async def update_settings(socket_path: str, update: ServerSettingsUpdate) -> Non
|
|||||||
async def _set(key: str, value: Fail2BanSettingValue) -> None:
|
async def _set(key: str, value: Fail2BanSettingValue) -> None:
|
||||||
try:
|
try:
|
||||||
response = await client.send(["set", key, value])
|
response = await client.send(["set", key, value])
|
||||||
_ok(cast(Fail2BanResponse, response))
|
_ok(cast("Fail2BanResponse", response))
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
raise ServerOperationError(f"Failed to set {key!r} = {value!r}: {exc}") from exc
|
raise ServerOperationError(f"Failed to set {key!r} = {value!r}: {exc}") from exc
|
||||||
|
|
||||||
@@ -224,7 +224,7 @@ async def flush_logs(socket_path: str) -> str:
|
|||||||
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
||||||
try:
|
try:
|
||||||
response = await client.send(["flushlogs"])
|
response = await client.send(["flushlogs"])
|
||||||
result = _ok(cast(Fail2BanResponse, response))
|
result = _ok(cast("Fail2BanResponse", response))
|
||||||
log.info("logs_flushed", result=result)
|
log.info("logs_flushed", result=result)
|
||||||
return str(result)
|
return str(result)
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ GEO_RE_RESOLVE_INTERVAL: int = 600
|
|||||||
JOB_ID: str = "geo_re_resolve"
|
JOB_ID: str = "geo_re_resolve"
|
||||||
|
|
||||||
|
|
||||||
async def _run_re_resolve(app: "FastAPI") -> None:
|
async def _run_re_resolve(app: FastAPI) -> None:
|
||||||
"""Query NULL-country IPs from the database and re-resolve them.
|
"""Query NULL-country IPs from the database and re-resolve them.
|
||||||
|
|
||||||
Reads shared resources from ``app.state`` and delegates to
|
Reads shared resources from ``app.state`` and delegates to
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ HEALTH_CHECK_INTERVAL: int = 30
|
|||||||
_ACTIVATION_CRASH_WINDOW: int = 60
|
_ACTIVATION_CRASH_WINDOW: int = 60
|
||||||
|
|
||||||
|
|
||||||
async def _run_probe(app: "FastAPI") -> None:
|
async def _run_probe(app: FastAPI) -> None:
|
||||||
"""Probe fail2ban and cache the result on *app.state*.
|
"""Probe fail2ban and cache the result on *app.state*.
|
||||||
|
|
||||||
Detects online/offline state transitions. When fail2ban goes offline
|
Detects online/offline state transitions. When fail2ban goes offline
|
||||||
|
|||||||
@@ -21,34 +21,52 @@ import contextlib
|
|||||||
import errno
|
import errno
|
||||||
import socket
|
import socket
|
||||||
import time
|
import time
|
||||||
|
from collections.abc import Mapping, Sequence, Set
|
||||||
from pickle import HIGHEST_PROTOCOL, dumps, loads
|
from pickle import HIGHEST_PROTOCOL, dumps, loads
|
||||||
from typing import TYPE_CHECKING, TypeAlias
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import structlog
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Types
|
# Types
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
Fail2BanToken: TypeAlias = str | int | float | bool | None | dict[str, object] | list[object]
|
# Use covariant container types so callers can pass ``list[int]`` / ``dict[str, str]``
|
||||||
|
# without needing to cast. At runtime we only accept the basic built-in
|
||||||
|
# containers supported by fail2ban's protocol (list/dict/set) and stringify
|
||||||
|
# anything else.
|
||||||
|
#
|
||||||
|
# NOTE: ``Sequence`` will also accept tuples, but tuples are stringified at
|
||||||
|
# runtime because fail2ban only understands lists.
|
||||||
|
|
||||||
|
type Fail2BanToken = (
|
||||||
|
str
|
||||||
|
| int
|
||||||
|
| float
|
||||||
|
| bool
|
||||||
|
| None
|
||||||
|
| Mapping[str, object]
|
||||||
|
| Sequence[object]
|
||||||
|
| Set[object]
|
||||||
|
)
|
||||||
"""A single token in a fail2ban command.
|
"""A single token in a fail2ban command.
|
||||||
|
|
||||||
Fail2ban accepts simple types (str/int/float/bool) plus compound types
|
Fail2ban accepts simple types (str/int/float/bool) plus compound types
|
||||||
(list/dict). Complex objects are stringified before being sent.
|
(list/dict/set). Complex objects are stringified before being sent.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Fail2BanCommand: TypeAlias = list[Fail2BanToken]
|
type Fail2BanCommand = Sequence[Fail2BanToken]
|
||||||
"""A command sent to fail2ban over the socket.
|
"""A command sent to fail2ban over the socket.
|
||||||
|
|
||||||
Commands are pickle serialised lists of tokens.
|
Commands are pickle serialised sequences of tokens.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Fail2BanResponse: TypeAlias = tuple[int, object]
|
type Fail2BanResponse = tuple[int, object]
|
||||||
"""A typical fail2ban response containing a status code and payload."""
|
"""A typical fail2ban response containing a status code and payload."""
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from types import TracebackType
|
from types import TracebackType
|
||||||
|
|
||||||
import structlog
|
|
||||||
|
|
||||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||||
|
|
||||||
# fail2ban protocol constants — inline to avoid a hard import dependency
|
# fail2ban protocol constants — inline to avoid a hard import dependency
|
||||||
@@ -200,7 +218,7 @@ def _send_command_sync(
|
|||||||
) from last_oserror
|
) from last_oserror
|
||||||
|
|
||||||
|
|
||||||
def _coerce_command_token(token: Fail2BanToken) -> Fail2BanToken:
|
def _coerce_command_token(token: object) -> Fail2BanToken:
|
||||||
"""Coerce a command token to a type that fail2ban understands.
|
"""Coerce a command token to a type that fail2ban understands.
|
||||||
|
|
||||||
fail2ban's ``CSocket.convert`` accepts ``str``, ``bool``, ``int``,
|
fail2ban's ``CSocket.convert`` accepts ``str``, ``bool``, ``int``,
|
||||||
|
|||||||
@@ -60,4 +60,5 @@ plugins = ["pydantic.mypy"]
|
|||||||
asyncio_mode = "auto"
|
asyncio_mode = "auto"
|
||||||
pythonpath = [".", "../fail2ban-master"]
|
pythonpath = [".", "../fail2ban-master"]
|
||||||
testpaths = ["tests"]
|
testpaths = ["tests"]
|
||||||
addopts = "--cov=app --cov-report=term-missing"
|
addopts = "--asyncio-mode=auto --cov=app --cov-report=term-missing"
|
||||||
|
filterwarnings = ["ignore::pytest.PytestRemovedIn9Warning"]
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Generator
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -157,12 +158,12 @@ class TestRequireAuthSessionCache:
|
|||||||
"""In-memory session token cache inside ``require_auth``."""
|
"""In-memory session token cache inside ``require_auth``."""
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def reset_cache(self) -> None: # type: ignore[misc]
|
def reset_cache(self) -> Generator[None, None, None]:
|
||||||
"""Flush the session cache before and after every test in this class."""
|
"""Flush the session cache before and after every test in this class."""
|
||||||
from app import dependencies
|
from app import dependencies
|
||||||
|
|
||||||
dependencies.clear_session_cache()
|
dependencies.clear_session_cache()
|
||||||
yield # type: ignore[misc]
|
yield
|
||||||
dependencies.clear_session_cache()
|
dependencies.clear_session_cache()
|
||||||
|
|
||||||
async def test_second_request_skips_db(self, client: AsyncClient) -> None:
|
async def test_second_request_skips_db(self, client: AsyncClient) -> None:
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ class TestGeoLookup:
|
|||||||
async def test_200_with_geo_info(self, geo_client: AsyncClient) -> None:
|
async def test_200_with_geo_info(self, geo_client: AsyncClient) -> None:
|
||||||
"""GET /api/geo/lookup/{ip} returns 200 with enriched result."""
|
"""GET /api/geo/lookup/{ip} returns 200 with enriched result."""
|
||||||
geo = GeoInfo(country_code="DE", country_name="Germany", asn="12345", org="Acme")
|
geo = GeoInfo(country_code="DE", country_name="Germany", asn="12345", org="Acme")
|
||||||
result = {
|
result: dict[str, object] = {
|
||||||
"ip": "1.2.3.4",
|
"ip": "1.2.3.4",
|
||||||
"currently_banned_in": ["sshd"],
|
"currently_banned_in": ["sshd"],
|
||||||
"geo": geo,
|
"geo": geo,
|
||||||
@@ -92,7 +92,7 @@ class TestGeoLookup:
|
|||||||
|
|
||||||
async def test_200_when_not_banned(self, geo_client: AsyncClient) -> None:
|
async def test_200_when_not_banned(self, geo_client: AsyncClient) -> None:
|
||||||
"""GET /api/geo/lookup/{ip} returns empty list when IP is not banned anywhere."""
|
"""GET /api/geo/lookup/{ip} returns empty list when IP is not banned anywhere."""
|
||||||
result = {
|
result: dict[str, object] = {
|
||||||
"ip": "8.8.8.8",
|
"ip": "8.8.8.8",
|
||||||
"currently_banned_in": [],
|
"currently_banned_in": [],
|
||||||
"geo": GeoInfo(country_code="US", country_name="United States", asn=None, org=None),
|
"geo": GeoInfo(country_code="US", country_name="United States", asn=None, org=None),
|
||||||
@@ -108,7 +108,7 @@ class TestGeoLookup:
|
|||||||
|
|
||||||
async def test_200_with_no_geo(self, geo_client: AsyncClient) -> None:
|
async def test_200_with_no_geo(self, geo_client: AsyncClient) -> None:
|
||||||
"""GET /api/geo/lookup/{ip} returns null geo when enricher fails."""
|
"""GET /api/geo/lookup/{ip} returns null geo when enricher fails."""
|
||||||
result = {
|
result: dict[str, object] = {
|
||||||
"ip": "1.2.3.4",
|
"ip": "1.2.3.4",
|
||||||
"currently_banned_in": [],
|
"currently_banned_in": [],
|
||||||
"geo": None,
|
"geo": None,
|
||||||
@@ -144,7 +144,7 @@ class TestGeoLookup:
|
|||||||
|
|
||||||
async def test_ipv6_address(self, geo_client: AsyncClient) -> None:
|
async def test_ipv6_address(self, geo_client: AsyncClient) -> None:
|
||||||
"""GET /api/geo/lookup/{ip} handles IPv6 addresses."""
|
"""GET /api/geo/lookup/{ip} handles IPv6 addresses."""
|
||||||
result = {
|
result: dict[str, object] = {
|
||||||
"ip": "2001:db8::1",
|
"ip": "2001:db8::1",
|
||||||
"currently_banned_in": [],
|
"currently_banned_in": [],
|
||||||
"geo": None,
|
"geo": None,
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from httpx import ASGITransport, AsyncClient
|
|||||||
from app.config import Settings
|
from app.config import Settings
|
||||||
from app.db import init_db
|
from app.db import init_db
|
||||||
from app.main import create_app
|
from app.main import create_app
|
||||||
|
from app.models.ban import JailBannedIpsResponse
|
||||||
from app.models.jail import Jail, JailDetailResponse, JailListResponse, JailStatus, JailSummary
|
from app.models.jail import Jail, JailDetailResponse, JailListResponse, JailStatus, JailSummary
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -801,17 +802,17 @@ class TestGetJailBannedIps:
|
|||||||
def _mock_response(
|
def _mock_response(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
items: list[dict] | None = None,
|
items: list[dict[str, str | None]] | None = None,
|
||||||
total: int = 2,
|
total: int = 2,
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
page_size: int = 25,
|
page_size: int = 25,
|
||||||
) -> "JailBannedIpsResponse": # type: ignore[name-defined]
|
) -> JailBannedIpsResponse:
|
||||||
from app.models.ban import ActiveBan, JailBannedIpsResponse
|
from app.models.ban import ActiveBan, JailBannedIpsResponse
|
||||||
|
|
||||||
ban_items = (
|
ban_items = (
|
||||||
[
|
[
|
||||||
ActiveBan(
|
ActiveBan(
|
||||||
ip=item.get("ip", "1.2.3.4"),
|
ip=item.get("ip") or "1.2.3.4",
|
||||||
jail="sshd",
|
jail="sshd",
|
||||||
banned_at=item.get("banned_at", "2025-01-01T10:00:00+00:00"),
|
banned_at=item.get("banned_at", "2025-01-01T10:00:00+00:00"),
|
||||||
expires_at=item.get("expires_at", "2025-01-01T10:10:00+00:00"),
|
expires_at=item.get("expires_at", "2025-01-01T10:10:00+00:00"),
|
||||||
|
|||||||
@@ -247,9 +247,9 @@ class TestSetupCompleteCaching:
|
|||||||
assert not getattr(app.state, "_setup_complete_cached", False)
|
assert not getattr(app.state, "_setup_complete_cached", False)
|
||||||
|
|
||||||
# First non-exempt request — middleware queries DB and sets the flag.
|
# First non-exempt request — middleware queries DB and sets the flag.
|
||||||
await client.post("/api/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]}) # type: ignore[call-overload]
|
await client.post("/api/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]})
|
||||||
|
|
||||||
assert app.state._setup_complete_cached is True # type: ignore[attr-defined]
|
assert app.state._setup_complete_cached is True
|
||||||
|
|
||||||
async def test_cached_path_skips_is_setup_complete(
|
async def test_cached_path_skips_is_setup_complete(
|
||||||
self,
|
self,
|
||||||
@@ -267,12 +267,12 @@ class TestSetupCompleteCaching:
|
|||||||
|
|
||||||
# Do setup and warm the cache.
|
# Do setup and warm the cache.
|
||||||
await client.post("/api/setup", json=_SETUP_PAYLOAD)
|
await client.post("/api/setup", json=_SETUP_PAYLOAD)
|
||||||
await client.post("/api/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]}) # type: ignore[call-overload]
|
await client.post("/api/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]})
|
||||||
assert app.state._setup_complete_cached is True # type: ignore[attr-defined]
|
assert app.state._setup_complete_cached is True
|
||||||
|
|
||||||
call_count = 0
|
call_count = 0
|
||||||
|
|
||||||
async def _counting(db): # type: ignore[no-untyped-def]
|
async def _counting(db: aiosqlite.Connection) -> bool:
|
||||||
nonlocal call_count
|
nonlocal call_count
|
||||||
call_count += 1
|
call_count += 1
|
||||||
return True
|
return True
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ class TestCheckPasswordAsync:
|
|||||||
auth_service._check_password("secret", hashed), # noqa: SLF001
|
auth_service._check_password("secret", hashed), # noqa: SLF001
|
||||||
auth_service._check_password("wrong", hashed), # noqa: SLF001
|
auth_service._check_password("wrong", hashed), # noqa: SLF001
|
||||||
)
|
)
|
||||||
assert results == [True, False]
|
assert tuple(results) == (True, False)
|
||||||
|
|
||||||
|
|
||||||
class TestLogin:
|
class TestLogin:
|
||||||
|
|||||||
@@ -64,7 +64,7 @@ async def _create_f2b_db(path: str, rows: list[dict[str, Any]]) -> None:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def f2b_db_path(tmp_path: Path) -> str: # type: ignore[misc]
|
async def f2b_db_path(tmp_path: Path) -> str:
|
||||||
"""Return the path to a test fail2ban SQLite database with several bans."""
|
"""Return the path to a test fail2ban SQLite database with several bans."""
|
||||||
path = str(tmp_path / "fail2ban_test.sqlite3")
|
path = str(tmp_path / "fail2ban_test.sqlite3")
|
||||||
await _create_f2b_db(
|
await _create_f2b_db(
|
||||||
@@ -103,7 +103,7 @@ async def f2b_db_path(tmp_path: Path) -> str: # type: ignore[misc]
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def mixed_origin_db_path(tmp_path: Path) -> str: # type: ignore[misc]
|
async def mixed_origin_db_path(tmp_path: Path) -> str:
|
||||||
"""Return a database with bans from both blocklist-import and organic jails."""
|
"""Return a database with bans from both blocklist-import and organic jails."""
|
||||||
path = str(tmp_path / "fail2ban_mixed_origin.sqlite3")
|
path = str(tmp_path / "fail2ban_mixed_origin.sqlite3")
|
||||||
await _create_f2b_db(
|
await _create_f2b_db(
|
||||||
@@ -136,7 +136,7 @@ async def mixed_origin_db_path(tmp_path: Path) -> str: # type: ignore[misc]
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def empty_f2b_db_path(tmp_path: Path) -> str: # type: ignore[misc]
|
async def empty_f2b_db_path(tmp_path: Path) -> str:
|
||||||
"""Return the path to a fail2ban SQLite database with no ban records."""
|
"""Return the path to a fail2ban SQLite database with no ban records."""
|
||||||
path = str(tmp_path / "fail2ban_empty.sqlite3")
|
path = str(tmp_path / "fail2ban_empty.sqlite3")
|
||||||
await _create_f2b_db(path, [])
|
await _create_f2b_db(path, [])
|
||||||
@@ -632,13 +632,13 @@ class TestBansbyCountryBackground:
|
|||||||
from app.services import geo_service
|
from app.services import geo_service
|
||||||
|
|
||||||
# Pre-populate the cache for all three IPs in the fixture.
|
# Pre-populate the cache for all three IPs in the fixture.
|
||||||
geo_service._cache["10.0.0.1"] = geo_service.GeoInfo( # type: ignore[attr-defined]
|
geo_service._cache["10.0.0.1"] = geo_service.GeoInfo(
|
||||||
country_code="DE", country_name="Germany", asn=None, org=None
|
country_code="DE", country_name="Germany", asn=None, org=None
|
||||||
)
|
)
|
||||||
geo_service._cache["10.0.0.2"] = geo_service.GeoInfo( # type: ignore[attr-defined]
|
geo_service._cache["10.0.0.2"] = geo_service.GeoInfo(
|
||||||
country_code="US", country_name="United States", asn=None, org=None
|
country_code="US", country_name="United States", asn=None, org=None
|
||||||
)
|
)
|
||||||
geo_service._cache["10.0.0.3"] = geo_service.GeoInfo( # type: ignore[attr-defined]
|
geo_service._cache["10.0.0.3"] = geo_service.GeoInfo(
|
||||||
country_code="JP", country_name="Japan", asn=None, org=None
|
country_code="JP", country_name="Japan", asn=None, org=None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -114,13 +114,13 @@ async def _seed_f2b_db(path: str, n: int) -> list[str]:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def event_loop_policy() -> None: # type: ignore[misc]
|
def event_loop_policy() -> None:
|
||||||
"""Use the default event loop policy for module-scoped fixtures."""
|
"""Use the default event loop policy for module-scoped fixtures."""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
async def perf_db_path(tmp_path_factory: Any) -> str: # type: ignore[misc]
|
async def perf_db_path(tmp_path_factory: Any) -> str:
|
||||||
"""Return the path to a fail2ban DB seeded with 10 000 synthetic bans.
|
"""Return the path to a fail2ban DB seeded with 10 000 synthetic bans.
|
||||||
|
|
||||||
Module-scoped so the database is created only once for all perf tests.
|
Module-scoped so the database is created only once for all perf tests.
|
||||||
|
|||||||
@@ -13,15 +13,19 @@ from app.services.config_file_service import (
|
|||||||
JailNameError,
|
JailNameError,
|
||||||
JailNotFoundInConfigError,
|
JailNotFoundInConfigError,
|
||||||
_build_inactive_jail,
|
_build_inactive_jail,
|
||||||
|
_extract_action_base_name,
|
||||||
|
_extract_filter_base_name,
|
||||||
_ordered_config_files,
|
_ordered_config_files,
|
||||||
_parse_jails_sync,
|
_parse_jails_sync,
|
||||||
_resolve_filter,
|
_resolve_filter,
|
||||||
_safe_jail_name,
|
_safe_jail_name,
|
||||||
|
_validate_jail_config_sync,
|
||||||
_write_local_override_sync,
|
_write_local_override_sync,
|
||||||
activate_jail,
|
activate_jail,
|
||||||
deactivate_jail,
|
deactivate_jail,
|
||||||
list_inactive_jails,
|
list_inactive_jails,
|
||||||
rollback_jail,
|
rollback_jail,
|
||||||
|
validate_jail_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -292,9 +296,7 @@ class TestBuildInactiveJail:
|
|||||||
|
|
||||||
def test_has_local_override_absent(self, tmp_path: Path) -> None:
|
def test_has_local_override_absent(self, tmp_path: Path) -> None:
|
||||||
"""has_local_override is False when no .local file exists."""
|
"""has_local_override is False when no .local file exists."""
|
||||||
jail = _build_inactive_jail(
|
jail = _build_inactive_jail("sshd", {}, "/etc/fail2ban/jail.d/sshd.conf", config_dir=tmp_path)
|
||||||
"sshd", {}, "/etc/fail2ban/jail.d/sshd.conf", config_dir=tmp_path
|
|
||||||
)
|
|
||||||
assert jail.has_local_override is False
|
assert jail.has_local_override is False
|
||||||
|
|
||||||
def test_has_local_override_present(self, tmp_path: Path) -> None:
|
def test_has_local_override_present(self, tmp_path: Path) -> None:
|
||||||
@@ -302,9 +304,7 @@ class TestBuildInactiveJail:
|
|||||||
local = tmp_path / "jail.d" / "sshd.local"
|
local = tmp_path / "jail.d" / "sshd.local"
|
||||||
local.parent.mkdir(parents=True, exist_ok=True)
|
local.parent.mkdir(parents=True, exist_ok=True)
|
||||||
local.write_text("[sshd]\nenabled = false\n")
|
local.write_text("[sshd]\nenabled = false\n")
|
||||||
jail = _build_inactive_jail(
|
jail = _build_inactive_jail("sshd", {}, "/etc/fail2ban/jail.d/sshd.conf", config_dir=tmp_path)
|
||||||
"sshd", {}, "/etc/fail2ban/jail.d/sshd.conf", config_dir=tmp_path
|
|
||||||
)
|
|
||||||
assert jail.has_local_override is True
|
assert jail.has_local_override is True
|
||||||
|
|
||||||
def test_has_local_override_no_config_dir(self) -> None:
|
def test_has_local_override_no_config_dir(self) -> None:
|
||||||
@@ -363,9 +363,7 @@ class TestWriteLocalOverrideSync:
|
|||||||
assert "2222" in content
|
assert "2222" in content
|
||||||
|
|
||||||
def test_override_logpath_list(self, tmp_path: Path) -> None:
|
def test_override_logpath_list(self, tmp_path: Path) -> None:
|
||||||
_write_local_override_sync(
|
_write_local_override_sync(tmp_path, "sshd", True, {"logpath": ["/var/log/auth.log", "/var/log/secure"]})
|
||||||
tmp_path, "sshd", True, {"logpath": ["/var/log/auth.log", "/var/log/secure"]}
|
|
||||||
)
|
|
||||||
content = (tmp_path / "jail.d" / "sshd.local").read_text()
|
content = (tmp_path / "jail.d" / "sshd.local").read_text()
|
||||||
assert "/var/log/auth.log" in content
|
assert "/var/log/auth.log" in content
|
||||||
assert "/var/log/secure" in content
|
assert "/var/log/secure" in content
|
||||||
@@ -447,9 +445,7 @@ class TestListInactiveJails:
|
|||||||
assert "sshd" in names
|
assert "sshd" in names
|
||||||
assert "apache-auth" in names
|
assert "apache-auth" in names
|
||||||
|
|
||||||
async def test_has_local_override_true_when_local_file_exists(
|
async def test_has_local_override_true_when_local_file_exists(self, tmp_path: Path) -> None:
|
||||||
self, tmp_path: Path
|
|
||||||
) -> None:
|
|
||||||
"""has_local_override is True for a jail whose jail.d .local file exists."""
|
"""has_local_override is True for a jail whose jail.d .local file exists."""
|
||||||
_write(tmp_path / "jail.conf", JAIL_CONF)
|
_write(tmp_path / "jail.conf", JAIL_CONF)
|
||||||
local = tmp_path / "jail.d" / "apache-auth.local"
|
local = tmp_path / "jail.d" / "apache-auth.local"
|
||||||
@@ -463,9 +459,7 @@ class TestListInactiveJails:
|
|||||||
jail = next(j for j in result.jails if j.name == "apache-auth")
|
jail = next(j for j in result.jails if j.name == "apache-auth")
|
||||||
assert jail.has_local_override is True
|
assert jail.has_local_override is True
|
||||||
|
|
||||||
async def test_has_local_override_false_when_no_local_file(
|
async def test_has_local_override_false_when_no_local_file(self, tmp_path: Path) -> None:
|
||||||
self, tmp_path: Path
|
|
||||||
) -> None:
|
|
||||||
"""has_local_override is False when no jail.d .local file exists."""
|
"""has_local_override is False when no jail.d .local file exists."""
|
||||||
_write(tmp_path / "jail.conf", JAIL_CONF)
|
_write(tmp_path / "jail.conf", JAIL_CONF)
|
||||||
with patch(
|
with patch(
|
||||||
@@ -608,7 +602,8 @@ class TestActivateJail:
|
|||||||
patch(
|
patch(
|
||||||
"app.services.config_file_service._get_active_jail_names",
|
"app.services.config_file_service._get_active_jail_names",
|
||||||
new=AsyncMock(return_value=set()),
|
new=AsyncMock(return_value=set()),
|
||||||
),pytest.raises(JailNotFoundInConfigError)
|
),
|
||||||
|
pytest.raises(JailNotFoundInConfigError),
|
||||||
):
|
):
|
||||||
await activate_jail(str(tmp_path), "/fake.sock", "nonexistent", req)
|
await activate_jail(str(tmp_path), "/fake.sock", "nonexistent", req)
|
||||||
|
|
||||||
@@ -621,7 +616,8 @@ class TestActivateJail:
|
|||||||
patch(
|
patch(
|
||||||
"app.services.config_file_service._get_active_jail_names",
|
"app.services.config_file_service._get_active_jail_names",
|
||||||
new=AsyncMock(return_value={"sshd"}),
|
new=AsyncMock(return_value={"sshd"}),
|
||||||
),pytest.raises(JailAlreadyActiveError)
|
),
|
||||||
|
pytest.raises(JailAlreadyActiveError),
|
||||||
):
|
):
|
||||||
await activate_jail(str(tmp_path), "/fake.sock", "sshd", req)
|
await activate_jail(str(tmp_path), "/fake.sock", "sshd", req)
|
||||||
|
|
||||||
@@ -691,7 +687,8 @@ class TestDeactivateJail:
|
|||||||
patch(
|
patch(
|
||||||
"app.services.config_file_service._get_active_jail_names",
|
"app.services.config_file_service._get_active_jail_names",
|
||||||
new=AsyncMock(return_value={"sshd"}),
|
new=AsyncMock(return_value={"sshd"}),
|
||||||
),pytest.raises(JailNotFoundInConfigError)
|
),
|
||||||
|
pytest.raises(JailNotFoundInConfigError),
|
||||||
):
|
):
|
||||||
await deactivate_jail(str(tmp_path), "/fake.sock", "nonexistent")
|
await deactivate_jail(str(tmp_path), "/fake.sock", "nonexistent")
|
||||||
|
|
||||||
@@ -701,7 +698,8 @@ class TestDeactivateJail:
|
|||||||
patch(
|
patch(
|
||||||
"app.services.config_file_service._get_active_jail_names",
|
"app.services.config_file_service._get_active_jail_names",
|
||||||
new=AsyncMock(return_value=set()),
|
new=AsyncMock(return_value=set()),
|
||||||
),pytest.raises(JailAlreadyInactiveError)
|
),
|
||||||
|
pytest.raises(JailAlreadyInactiveError),
|
||||||
):
|
):
|
||||||
await deactivate_jail(str(tmp_path), "/fake.sock", "apache-auth")
|
await deactivate_jail(str(tmp_path), "/fake.sock", "apache-auth")
|
||||||
|
|
||||||
@@ -710,38 +708,6 @@ class TestDeactivateJail:
|
|||||||
await deactivate_jail(str(tmp_path), "/fake.sock", "a/b")
|
await deactivate_jail(str(tmp_path), "/fake.sock", "a/b")
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# _extract_filter_base_name
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestExtractFilterBaseName:
|
|
||||||
def test_simple_name(self) -> None:
|
|
||||||
from app.services.config_file_service import _extract_filter_base_name
|
|
||||||
|
|
||||||
assert _extract_filter_base_name("sshd") == "sshd"
|
|
||||||
|
|
||||||
def test_name_with_mode(self) -> None:
|
|
||||||
from app.services.config_file_service import _extract_filter_base_name
|
|
||||||
|
|
||||||
assert _extract_filter_base_name("sshd[mode=aggressive]") == "sshd"
|
|
||||||
|
|
||||||
def test_name_with_variable_mode(self) -> None:
|
|
||||||
from app.services.config_file_service import _extract_filter_base_name
|
|
||||||
|
|
||||||
assert _extract_filter_base_name("sshd[mode=%(mode)s]") == "sshd"
|
|
||||||
|
|
||||||
def test_whitespace_stripped(self) -> None:
|
|
||||||
from app.services.config_file_service import _extract_filter_base_name
|
|
||||||
|
|
||||||
assert _extract_filter_base_name(" nginx ") == "nginx"
|
|
||||||
|
|
||||||
def test_empty_string(self) -> None:
|
|
||||||
from app.services.config_file_service import _extract_filter_base_name
|
|
||||||
|
|
||||||
assert _extract_filter_base_name("") == ""
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# _build_filter_to_jails_map
|
# _build_filter_to_jails_map
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -757,9 +723,7 @@ class TestBuildFilterToJailsMap:
|
|||||||
def test_inactive_jail_not_included(self) -> None:
|
def test_inactive_jail_not_included(self) -> None:
|
||||||
from app.services.config_file_service import _build_filter_to_jails_map
|
from app.services.config_file_service import _build_filter_to_jails_map
|
||||||
|
|
||||||
result = _build_filter_to_jails_map(
|
result = _build_filter_to_jails_map({"apache-auth": {"filter": "apache-auth"}}, set())
|
||||||
{"apache-auth": {"filter": "apache-auth"}}, set()
|
|
||||||
)
|
|
||||||
assert result == {}
|
assert result == {}
|
||||||
|
|
||||||
def test_multiple_jails_sharing_filter(self) -> None:
|
def test_multiple_jails_sharing_filter(self) -> None:
|
||||||
@@ -775,9 +739,7 @@ class TestBuildFilterToJailsMap:
|
|||||||
def test_mode_suffix_stripped(self) -> None:
|
def test_mode_suffix_stripped(self) -> None:
|
||||||
from app.services.config_file_service import _build_filter_to_jails_map
|
from app.services.config_file_service import _build_filter_to_jails_map
|
||||||
|
|
||||||
result = _build_filter_to_jails_map(
|
result = _build_filter_to_jails_map({"sshd": {"filter": "sshd[mode=aggressive]"}}, {"sshd"})
|
||||||
{"sshd": {"filter": "sshd[mode=aggressive]"}}, {"sshd"}
|
|
||||||
)
|
|
||||||
assert "sshd" in result
|
assert "sshd" in result
|
||||||
|
|
||||||
def test_missing_filter_key_falls_back_to_jail_name(self) -> None:
|
def test_missing_filter_key_falls_back_to_jail_name(self) -> None:
|
||||||
@@ -988,10 +950,13 @@ class TestGetFilter:
|
|||||||
async def test_raises_filter_not_found(self, tmp_path: Path) -> None:
|
async def test_raises_filter_not_found(self, tmp_path: Path) -> None:
|
||||||
from app.services.config_file_service import FilterNotFoundError, get_filter
|
from app.services.config_file_service import FilterNotFoundError, get_filter
|
||||||
|
|
||||||
with patch(
|
with (
|
||||||
|
patch(
|
||||||
"app.services.config_file_service._get_active_jail_names",
|
"app.services.config_file_service._get_active_jail_names",
|
||||||
new=AsyncMock(return_value=set()),
|
new=AsyncMock(return_value=set()),
|
||||||
), pytest.raises(FilterNotFoundError):
|
),
|
||||||
|
pytest.raises(FilterNotFoundError),
|
||||||
|
):
|
||||||
await get_filter(str(tmp_path), "/fake.sock", "nonexistent")
|
await get_filter(str(tmp_path), "/fake.sock", "nonexistent")
|
||||||
|
|
||||||
async def test_has_local_override_detected(self, tmp_path: Path) -> None:
|
async def test_has_local_override_detected(self, tmp_path: Path) -> None:
|
||||||
@@ -1093,10 +1058,13 @@ class TestGetFilterLocalOnly:
|
|||||||
async def test_raises_when_neither_conf_nor_local(self, tmp_path: Path) -> None:
|
async def test_raises_when_neither_conf_nor_local(self, tmp_path: Path) -> None:
|
||||||
from app.services.config_file_service import FilterNotFoundError, get_filter
|
from app.services.config_file_service import FilterNotFoundError, get_filter
|
||||||
|
|
||||||
with patch(
|
with (
|
||||||
|
patch(
|
||||||
"app.services.config_file_service._get_active_jail_names",
|
"app.services.config_file_service._get_active_jail_names",
|
||||||
new=AsyncMock(return_value=set()),
|
new=AsyncMock(return_value=set()),
|
||||||
), pytest.raises(FilterNotFoundError):
|
),
|
||||||
|
pytest.raises(FilterNotFoundError),
|
||||||
|
):
|
||||||
await get_filter(str(tmp_path), "/fake.sock", "nonexistent")
|
await get_filter(str(tmp_path), "/fake.sock", "nonexistent")
|
||||||
|
|
||||||
async def test_accepts_local_extension(self, tmp_path: Path) -> None:
|
async def test_accepts_local_extension(self, tmp_path: Path) -> None:
|
||||||
@@ -1212,9 +1180,7 @@ class TestSetJailLocalKeySync:
|
|||||||
|
|
||||||
jail_d = tmp_path / "jail.d"
|
jail_d = tmp_path / "jail.d"
|
||||||
jail_d.mkdir()
|
jail_d.mkdir()
|
||||||
(jail_d / "sshd.local").write_text(
|
(jail_d / "sshd.local").write_text("[sshd]\nenabled = true\n")
|
||||||
"[sshd]\nenabled = true\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
_set_jail_local_key_sync(tmp_path, "sshd", "filter", "newfilter")
|
_set_jail_local_key_sync(tmp_path, "sshd", "filter", "newfilter")
|
||||||
|
|
||||||
@@ -1300,10 +1266,13 @@ class TestUpdateFilter:
|
|||||||
from app.models.config import FilterUpdateRequest
|
from app.models.config import FilterUpdateRequest
|
||||||
from app.services.config_file_service import FilterNotFoundError, update_filter
|
from app.services.config_file_service import FilterNotFoundError, update_filter
|
||||||
|
|
||||||
with patch(
|
with (
|
||||||
|
patch(
|
||||||
"app.services.config_file_service._get_active_jail_names",
|
"app.services.config_file_service._get_active_jail_names",
|
||||||
new=AsyncMock(return_value=set()),
|
new=AsyncMock(return_value=set()),
|
||||||
), pytest.raises(FilterNotFoundError):
|
),
|
||||||
|
pytest.raises(FilterNotFoundError),
|
||||||
|
):
|
||||||
await update_filter(
|
await update_filter(
|
||||||
str(tmp_path),
|
str(tmp_path),
|
||||||
"/fake.sock",
|
"/fake.sock",
|
||||||
@@ -1321,10 +1290,13 @@ class TestUpdateFilter:
|
|||||||
filter_d = tmp_path / "filter.d"
|
filter_d = tmp_path / "filter.d"
|
||||||
_write(filter_d / "sshd.conf", _FILTER_CONF_WITH_REGEX)
|
_write(filter_d / "sshd.conf", _FILTER_CONF_WITH_REGEX)
|
||||||
|
|
||||||
with patch(
|
with (
|
||||||
|
patch(
|
||||||
"app.services.config_file_service._get_active_jail_names",
|
"app.services.config_file_service._get_active_jail_names",
|
||||||
new=AsyncMock(return_value=set()),
|
new=AsyncMock(return_value=set()),
|
||||||
), pytest.raises(FilterInvalidRegexError):
|
),
|
||||||
|
pytest.raises(FilterInvalidRegexError),
|
||||||
|
):
|
||||||
await update_filter(
|
await update_filter(
|
||||||
str(tmp_path),
|
str(tmp_path),
|
||||||
"/fake.sock",
|
"/fake.sock",
|
||||||
@@ -1351,13 +1323,16 @@ class TestUpdateFilter:
|
|||||||
filter_d = tmp_path / "filter.d"
|
filter_d = tmp_path / "filter.d"
|
||||||
_write(filter_d / "sshd.conf", _FILTER_CONF)
|
_write(filter_d / "sshd.conf", _FILTER_CONF)
|
||||||
|
|
||||||
with patch(
|
with (
|
||||||
|
patch(
|
||||||
"app.services.config_file_service._get_active_jail_names",
|
"app.services.config_file_service._get_active_jail_names",
|
||||||
new=AsyncMock(return_value=set()),
|
new=AsyncMock(return_value=set()),
|
||||||
), patch(
|
),
|
||||||
|
patch(
|
||||||
"app.services.config_file_service.jail_service.reload_all",
|
"app.services.config_file_service.jail_service.reload_all",
|
||||||
new=AsyncMock(),
|
new=AsyncMock(),
|
||||||
) as mock_reload:
|
) as mock_reload,
|
||||||
|
):
|
||||||
await update_filter(
|
await update_filter(
|
||||||
str(tmp_path),
|
str(tmp_path),
|
||||||
"/fake.sock",
|
"/fake.sock",
|
||||||
@@ -1405,10 +1380,13 @@ class TestCreateFilter:
|
|||||||
filter_d = tmp_path / "filter.d"
|
filter_d = tmp_path / "filter.d"
|
||||||
_write(filter_d / "sshd.conf", _FILTER_CONF)
|
_write(filter_d / "sshd.conf", _FILTER_CONF)
|
||||||
|
|
||||||
with patch(
|
with (
|
||||||
|
patch(
|
||||||
"app.services.config_file_service._get_active_jail_names",
|
"app.services.config_file_service._get_active_jail_names",
|
||||||
new=AsyncMock(return_value=set()),
|
new=AsyncMock(return_value=set()),
|
||||||
), pytest.raises(FilterAlreadyExistsError):
|
),
|
||||||
|
pytest.raises(FilterAlreadyExistsError),
|
||||||
|
):
|
||||||
await create_filter(
|
await create_filter(
|
||||||
str(tmp_path),
|
str(tmp_path),
|
||||||
"/fake.sock",
|
"/fake.sock",
|
||||||
@@ -1422,10 +1400,13 @@ class TestCreateFilter:
|
|||||||
filter_d = tmp_path / "filter.d"
|
filter_d = tmp_path / "filter.d"
|
||||||
_write(filter_d / "custom.local", "[Definition]\n")
|
_write(filter_d / "custom.local", "[Definition]\n")
|
||||||
|
|
||||||
with patch(
|
with (
|
||||||
|
patch(
|
||||||
"app.services.config_file_service._get_active_jail_names",
|
"app.services.config_file_service._get_active_jail_names",
|
||||||
new=AsyncMock(return_value=set()),
|
new=AsyncMock(return_value=set()),
|
||||||
), pytest.raises(FilterAlreadyExistsError):
|
),
|
||||||
|
pytest.raises(FilterAlreadyExistsError),
|
||||||
|
):
|
||||||
await create_filter(
|
await create_filter(
|
||||||
str(tmp_path),
|
str(tmp_path),
|
||||||
"/fake.sock",
|
"/fake.sock",
|
||||||
@@ -1436,10 +1417,13 @@ class TestCreateFilter:
|
|||||||
from app.models.config import FilterCreateRequest
|
from app.models.config import FilterCreateRequest
|
||||||
from app.services.config_file_service import FilterInvalidRegexError, create_filter
|
from app.services.config_file_service import FilterInvalidRegexError, create_filter
|
||||||
|
|
||||||
with patch(
|
with (
|
||||||
|
patch(
|
||||||
"app.services.config_file_service._get_active_jail_names",
|
"app.services.config_file_service._get_active_jail_names",
|
||||||
new=AsyncMock(return_value=set()),
|
new=AsyncMock(return_value=set()),
|
||||||
), pytest.raises(FilterInvalidRegexError):
|
),
|
||||||
|
pytest.raises(FilterInvalidRegexError),
|
||||||
|
):
|
||||||
await create_filter(
|
await create_filter(
|
||||||
str(tmp_path),
|
str(tmp_path),
|
||||||
"/fake.sock",
|
"/fake.sock",
|
||||||
@@ -1461,13 +1445,16 @@ class TestCreateFilter:
|
|||||||
from app.models.config import FilterCreateRequest
|
from app.models.config import FilterCreateRequest
|
||||||
from app.services.config_file_service import create_filter
|
from app.services.config_file_service import create_filter
|
||||||
|
|
||||||
with patch(
|
with (
|
||||||
|
patch(
|
||||||
"app.services.config_file_service._get_active_jail_names",
|
"app.services.config_file_service._get_active_jail_names",
|
||||||
new=AsyncMock(return_value=set()),
|
new=AsyncMock(return_value=set()),
|
||||||
), patch(
|
),
|
||||||
|
patch(
|
||||||
"app.services.config_file_service.jail_service.reload_all",
|
"app.services.config_file_service.jail_service.reload_all",
|
||||||
new=AsyncMock(),
|
new=AsyncMock(),
|
||||||
) as mock_reload:
|
) as mock_reload,
|
||||||
|
):
|
||||||
await create_filter(
|
await create_filter(
|
||||||
str(tmp_path),
|
str(tmp_path),
|
||||||
"/fake.sock",
|
"/fake.sock",
|
||||||
@@ -1485,9 +1472,7 @@ class TestCreateFilter:
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
class TestDeleteFilter:
|
class TestDeleteFilter:
|
||||||
async def test_deletes_local_file_when_conf_and_local_exist(
|
async def test_deletes_local_file_when_conf_and_local_exist(self, tmp_path: Path) -> None:
|
||||||
self, tmp_path: Path
|
|
||||||
) -> None:
|
|
||||||
from app.services.config_file_service import delete_filter
|
from app.services.config_file_service import delete_filter
|
||||||
|
|
||||||
filter_d = tmp_path / "filter.d"
|
filter_d = tmp_path / "filter.d"
|
||||||
@@ -1524,9 +1509,7 @@ class TestDeleteFilter:
|
|||||||
with pytest.raises(FilterNotFoundError):
|
with pytest.raises(FilterNotFoundError):
|
||||||
await delete_filter(str(tmp_path), "nonexistent")
|
await delete_filter(str(tmp_path), "nonexistent")
|
||||||
|
|
||||||
async def test_accepts_filter_name_error_for_invalid_name(
|
async def test_accepts_filter_name_error_for_invalid_name(self, tmp_path: Path) -> None:
|
||||||
self, tmp_path: Path
|
|
||||||
) -> None:
|
|
||||||
from app.services.config_file_service import FilterNameError, delete_filter
|
from app.services.config_file_service import FilterNameError, delete_filter
|
||||||
|
|
||||||
with pytest.raises(FilterNameError):
|
with pytest.raises(FilterNameError):
|
||||||
@@ -1607,9 +1590,7 @@ class TestAssignFilterToJail:
|
|||||||
AssignFilterRequest(filter_name="sshd"),
|
AssignFilterRequest(filter_name="sshd"),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def test_raises_filter_name_error_for_invalid_filter(
|
async def test_raises_filter_name_error_for_invalid_filter(self, tmp_path: Path) -> None:
|
||||||
self, tmp_path: Path
|
|
||||||
) -> None:
|
|
||||||
from app.models.config import AssignFilterRequest
|
from app.models.config import AssignFilterRequest
|
||||||
from app.services.config_file_service import FilterNameError, assign_filter_to_jail
|
from app.services.config_file_service import FilterNameError, assign_filter_to_jail
|
||||||
|
|
||||||
@@ -1719,34 +1700,26 @@ class TestBuildActionToJailsMap:
|
|||||||
def test_active_jail_maps_to_action(self) -> None:
|
def test_active_jail_maps_to_action(self) -> None:
|
||||||
from app.services.config_file_service import _build_action_to_jails_map
|
from app.services.config_file_service import _build_action_to_jails_map
|
||||||
|
|
||||||
result = _build_action_to_jails_map(
|
result = _build_action_to_jails_map({"sshd": {"action": "iptables-multiport"}}, {"sshd"})
|
||||||
{"sshd": {"action": "iptables-multiport"}}, {"sshd"}
|
|
||||||
)
|
|
||||||
assert result == {"iptables-multiport": ["sshd"]}
|
assert result == {"iptables-multiport": ["sshd"]}
|
||||||
|
|
||||||
def test_inactive_jail_not_included(self) -> None:
|
def test_inactive_jail_not_included(self) -> None:
|
||||||
from app.services.config_file_service import _build_action_to_jails_map
|
from app.services.config_file_service import _build_action_to_jails_map
|
||||||
|
|
||||||
result = _build_action_to_jails_map(
|
result = _build_action_to_jails_map({"sshd": {"action": "iptables-multiport"}}, set())
|
||||||
{"sshd": {"action": "iptables-multiport"}}, set()
|
|
||||||
)
|
|
||||||
assert result == {}
|
assert result == {}
|
||||||
|
|
||||||
def test_multiple_actions_per_jail(self) -> None:
|
def test_multiple_actions_per_jail(self) -> None:
|
||||||
from app.services.config_file_service import _build_action_to_jails_map
|
from app.services.config_file_service import _build_action_to_jails_map
|
||||||
|
|
||||||
result = _build_action_to_jails_map(
|
result = _build_action_to_jails_map({"sshd": {"action": "iptables-multiport\niptables-ipset"}}, {"sshd"})
|
||||||
{"sshd": {"action": "iptables-multiport\niptables-ipset"}}, {"sshd"}
|
|
||||||
)
|
|
||||||
assert "iptables-multiport" in result
|
assert "iptables-multiport" in result
|
||||||
assert "iptables-ipset" in result
|
assert "iptables-ipset" in result
|
||||||
|
|
||||||
def test_parameter_block_stripped(self) -> None:
|
def test_parameter_block_stripped(self) -> None:
|
||||||
from app.services.config_file_service import _build_action_to_jails_map
|
from app.services.config_file_service import _build_action_to_jails_map
|
||||||
|
|
||||||
result = _build_action_to_jails_map(
|
result = _build_action_to_jails_map({"sshd": {"action": "iptables[port=ssh, protocol=tcp]"}}, {"sshd"})
|
||||||
{"sshd": {"action": "iptables[port=ssh, protocol=tcp]"}}, {"sshd"}
|
|
||||||
)
|
|
||||||
assert "iptables" in result
|
assert "iptables" in result
|
||||||
|
|
||||||
def test_multiple_jails_sharing_action(self) -> None:
|
def test_multiple_jails_sharing_action(self) -> None:
|
||||||
@@ -2001,10 +1974,13 @@ class TestGetAction:
|
|||||||
async def test_raises_for_unknown_action(self, tmp_path: Path) -> None:
|
async def test_raises_for_unknown_action(self, tmp_path: Path) -> None:
|
||||||
from app.services.config_file_service import ActionNotFoundError, get_action
|
from app.services.config_file_service import ActionNotFoundError, get_action
|
||||||
|
|
||||||
with patch(
|
with (
|
||||||
|
patch(
|
||||||
"app.services.config_file_service._get_active_jail_names",
|
"app.services.config_file_service._get_active_jail_names",
|
||||||
new=AsyncMock(return_value=set()),
|
new=AsyncMock(return_value=set()),
|
||||||
), pytest.raises(ActionNotFoundError):
|
),
|
||||||
|
pytest.raises(ActionNotFoundError),
|
||||||
|
):
|
||||||
await get_action(str(tmp_path), "/fake.sock", "nonexistent")
|
await get_action(str(tmp_path), "/fake.sock", "nonexistent")
|
||||||
|
|
||||||
async def test_local_only_action_returned(self, tmp_path: Path) -> None:
|
async def test_local_only_action_returned(self, tmp_path: Path) -> None:
|
||||||
@@ -2118,10 +2094,13 @@ class TestUpdateAction:
|
|||||||
from app.models.config import ActionUpdateRequest
|
from app.models.config import ActionUpdateRequest
|
||||||
from app.services.config_file_service import ActionNotFoundError, update_action
|
from app.services.config_file_service import ActionNotFoundError, update_action
|
||||||
|
|
||||||
with patch(
|
with (
|
||||||
|
patch(
|
||||||
"app.services.config_file_service._get_active_jail_names",
|
"app.services.config_file_service._get_active_jail_names",
|
||||||
new=AsyncMock(return_value=set()),
|
new=AsyncMock(return_value=set()),
|
||||||
), pytest.raises(ActionNotFoundError):
|
),
|
||||||
|
pytest.raises(ActionNotFoundError),
|
||||||
|
):
|
||||||
await update_action(
|
await update_action(
|
||||||
str(tmp_path),
|
str(tmp_path),
|
||||||
"/fake.sock",
|
"/fake.sock",
|
||||||
@@ -2587,9 +2566,7 @@ class TestRemoveActionFromJail:
|
|||||||
"app.services.config_file_service._get_active_jail_names",
|
"app.services.config_file_service._get_active_jail_names",
|
||||||
new=AsyncMock(return_value=set()),
|
new=AsyncMock(return_value=set()),
|
||||||
):
|
):
|
||||||
await remove_action_from_jail(
|
await remove_action_from_jail(str(tmp_path), "/fake.sock", "sshd", "iptables-multiport")
|
||||||
str(tmp_path), "/fake.sock", "sshd", "iptables-multiport"
|
|
||||||
)
|
|
||||||
|
|
||||||
content = (jail_d / "sshd.local").read_text()
|
content = (jail_d / "sshd.local").read_text()
|
||||||
assert "iptables-multiport" not in content
|
assert "iptables-multiport" not in content
|
||||||
@@ -2601,17 +2578,13 @@ class TestRemoveActionFromJail:
|
|||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(JailNotFoundInConfigError):
|
with pytest.raises(JailNotFoundInConfigError):
|
||||||
await remove_action_from_jail(
|
await remove_action_from_jail(str(tmp_path), "/fake.sock", "nonexistent", "iptables")
|
||||||
str(tmp_path), "/fake.sock", "nonexistent", "iptables"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def test_raises_jail_name_error(self, tmp_path: Path) -> None:
|
async def test_raises_jail_name_error(self, tmp_path: Path) -> None:
|
||||||
from app.services.config_file_service import JailNameError, remove_action_from_jail
|
from app.services.config_file_service import JailNameError, remove_action_from_jail
|
||||||
|
|
||||||
with pytest.raises(JailNameError):
|
with pytest.raises(JailNameError):
|
||||||
await remove_action_from_jail(
|
await remove_action_from_jail(str(tmp_path), "/fake.sock", "../evil", "iptables")
|
||||||
str(tmp_path), "/fake.sock", "../evil", "iptables"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def test_raises_action_name_error(self, tmp_path: Path) -> None:
|
async def test_raises_action_name_error(self, tmp_path: Path) -> None:
|
||||||
from app.services.config_file_service import ActionNameError, remove_action_from_jail
|
from app.services.config_file_service import ActionNameError, remove_action_from_jail
|
||||||
@@ -2619,9 +2592,7 @@ class TestRemoveActionFromJail:
|
|||||||
_write(tmp_path / "jail.conf", JAIL_CONF)
|
_write(tmp_path / "jail.conf", JAIL_CONF)
|
||||||
|
|
||||||
with pytest.raises(ActionNameError):
|
with pytest.raises(ActionNameError):
|
||||||
await remove_action_from_jail(
|
await remove_action_from_jail(str(tmp_path), "/fake.sock", "sshd", "../evil")
|
||||||
str(tmp_path), "/fake.sock", "sshd", "../evil"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def test_triggers_reload_when_requested(self, tmp_path: Path) -> None:
|
async def test_triggers_reload_when_requested(self, tmp_path: Path) -> None:
|
||||||
from app.services.config_file_service import remove_action_from_jail
|
from app.services.config_file_service import remove_action_from_jail
|
||||||
@@ -2640,9 +2611,7 @@ class TestRemoveActionFromJail:
|
|||||||
new=AsyncMock(),
|
new=AsyncMock(),
|
||||||
) as mock_reload,
|
) as mock_reload,
|
||||||
):
|
):
|
||||||
await remove_action_from_jail(
|
await remove_action_from_jail(str(tmp_path), "/fake.sock", "sshd", "iptables", do_reload=True)
|
||||||
str(tmp_path), "/fake.sock", "sshd", "iptables", do_reload=True
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_reload.assert_awaited_once()
|
mock_reload.assert_awaited_once()
|
||||||
|
|
||||||
@@ -2680,13 +2649,9 @@ class TestActivateJailReloadArgs:
|
|||||||
mock_js.reload_all = AsyncMock()
|
mock_js.reload_all = AsyncMock()
|
||||||
await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req)
|
await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req)
|
||||||
|
|
||||||
mock_js.reload_all.assert_awaited_once_with(
|
mock_js.reload_all.assert_awaited_once_with("/fake.sock", include_jails=["apache-auth"])
|
||||||
"/fake.sock", include_jails=["apache-auth"]
|
|
||||||
)
|
|
||||||
|
|
||||||
async def test_activate_returns_active_true_when_jail_starts(
|
async def test_activate_returns_active_true_when_jail_starts(self, tmp_path: Path) -> None:
|
||||||
self, tmp_path: Path
|
|
||||||
) -> None:
|
|
||||||
"""activate_jail returns active=True when the jail appears in post-reload names."""
|
"""activate_jail returns active=True when the jail appears in post-reload names."""
|
||||||
_write(tmp_path / "jail.conf", JAIL_CONF)
|
_write(tmp_path / "jail.conf", JAIL_CONF)
|
||||||
from app.models.config import ActivateJailRequest, JailValidationResult
|
from app.models.config import ActivateJailRequest, JailValidationResult
|
||||||
@@ -2708,16 +2673,12 @@ class TestActivateJailReloadArgs:
|
|||||||
),
|
),
|
||||||
):
|
):
|
||||||
mock_js.reload_all = AsyncMock()
|
mock_js.reload_all = AsyncMock()
|
||||||
result = await activate_jail(
|
result = await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req)
|
||||||
str(tmp_path), "/fake.sock", "apache-auth", req
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.active is True
|
assert result.active is True
|
||||||
assert "activated" in result.message.lower()
|
assert "activated" in result.message.lower()
|
||||||
|
|
||||||
async def test_activate_returns_active_false_when_jail_does_not_start(
|
async def test_activate_returns_active_false_when_jail_does_not_start(self, tmp_path: Path) -> None:
|
||||||
self, tmp_path: Path
|
|
||||||
) -> None:
|
|
||||||
"""activate_jail returns active=False when the jail is absent after reload.
|
"""activate_jail returns active=False when the jail is absent after reload.
|
||||||
|
|
||||||
This covers the Stage 3.1 requirement: if the jail config is invalid
|
This covers the Stage 3.1 requirement: if the jail config is invalid
|
||||||
@@ -2746,9 +2707,7 @@ class TestActivateJailReloadArgs:
|
|||||||
),
|
),
|
||||||
):
|
):
|
||||||
mock_js.reload_all = AsyncMock()
|
mock_js.reload_all = AsyncMock()
|
||||||
result = await activate_jail(
|
result = await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req)
|
||||||
str(tmp_path), "/fake.sock", "apache-auth", req
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.active is False
|
assert result.active is False
|
||||||
assert "apache-auth" in result.name
|
assert "apache-auth" in result.name
|
||||||
@@ -2776,23 +2735,13 @@ class TestDeactivateJailReloadArgs:
|
|||||||
mock_js.reload_all = AsyncMock()
|
mock_js.reload_all = AsyncMock()
|
||||||
await deactivate_jail(str(tmp_path), "/fake.sock", "sshd")
|
await deactivate_jail(str(tmp_path), "/fake.sock", "sshd")
|
||||||
|
|
||||||
mock_js.reload_all.assert_awaited_once_with(
|
mock_js.reload_all.assert_awaited_once_with("/fake.sock", exclude_jails=["sshd"])
|
||||||
"/fake.sock", exclude_jails=["sshd"]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# _validate_jail_config_sync (Task 3)
|
# _validate_jail_config_sync (Task 3)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
from app.services.config_file_service import ( # noqa: E402 (added after block)
|
|
||||||
_validate_jail_config_sync,
|
|
||||||
_extract_filter_base_name,
|
|
||||||
_extract_action_base_name,
|
|
||||||
validate_jail_config,
|
|
||||||
rollback_jail,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestExtractFilterBaseName:
|
class TestExtractFilterBaseName:
|
||||||
def test_plain_name(self) -> None:
|
def test_plain_name(self) -> None:
|
||||||
@@ -2938,11 +2887,11 @@ class TestRollbackJail:
|
|||||||
|
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
"app.services.config_file_service._start_daemon",
|
"app.services.config_file_service.start_daemon",
|
||||||
new=AsyncMock(return_value=True),
|
new=AsyncMock(return_value=True),
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"app.services.config_file_service._wait_for_fail2ban",
|
"app.services.config_file_service.wait_for_fail2ban",
|
||||||
new=AsyncMock(return_value=True),
|
new=AsyncMock(return_value=True),
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
@@ -2950,9 +2899,7 @@ class TestRollbackJail:
|
|||||||
new=AsyncMock(return_value=set()),
|
new=AsyncMock(return_value=set()),
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
result = await rollback_jail(
|
result = await rollback_jail(str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"])
|
||||||
str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"]
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.disabled is True
|
assert result.disabled is True
|
||||||
assert result.fail2ban_running is True
|
assert result.fail2ban_running is True
|
||||||
@@ -2968,26 +2915,22 @@ class TestRollbackJail:
|
|||||||
|
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
"app.services.config_file_service._start_daemon",
|
"app.services.config_file_service.start_daemon",
|
||||||
new=AsyncMock(return_value=False),
|
new=AsyncMock(return_value=False),
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"app.services.config_file_service._wait_for_fail2ban",
|
"app.services.config_file_service.wait_for_fail2ban",
|
||||||
new=AsyncMock(return_value=False),
|
new=AsyncMock(return_value=False),
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
result = await rollback_jail(
|
result = await rollback_jail(str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"])
|
||||||
str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"]
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.fail2ban_running is False
|
assert result.fail2ban_running is False
|
||||||
assert result.disabled is True
|
assert result.disabled is True
|
||||||
|
|
||||||
async def test_rollback_rejects_unsafe_name(self, tmp_path: Path) -> None:
|
async def test_rollback_rejects_unsafe_name(self, tmp_path: Path) -> None:
|
||||||
with pytest.raises(JailNameError):
|
with pytest.raises(JailNameError):
|
||||||
await rollback_jail(
|
await rollback_jail(str(tmp_path), "/fake.sock", "../evil", ["fail2ban-client", "start"])
|
||||||
str(tmp_path), "/fake.sock", "../evil", ["fail2ban-client", "start"]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -3096,9 +3039,7 @@ class TestActivateJailBlocking:
|
|||||||
class TestActivateJailRollback:
|
class TestActivateJailRollback:
|
||||||
"""Rollback logic in activate_jail restores the .local file and recovers."""
|
"""Rollback logic in activate_jail restores the .local file and recovers."""
|
||||||
|
|
||||||
async def test_activate_jail_rollback_on_reload_failure(
|
async def test_activate_jail_rollback_on_reload_failure(self, tmp_path: Path) -> None:
|
||||||
self, tmp_path: Path
|
|
||||||
) -> None:
|
|
||||||
"""Rollback when reload_all raises on the activation reload.
|
"""Rollback when reload_all raises on the activation reload.
|
||||||
|
|
||||||
Expects:
|
Expects:
|
||||||
@@ -3135,23 +3076,17 @@ class TestActivateJailRollback:
|
|||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"app.services.config_file_service._validate_jail_config_sync",
|
"app.services.config_file_service._validate_jail_config_sync",
|
||||||
return_value=JailValidationResult(
|
return_value=JailValidationResult(jail_name="apache-auth", valid=True),
|
||||||
jail_name="apache-auth", valid=True
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
mock_js.reload_all = AsyncMock(side_effect=reload_side_effect)
|
mock_js.reload_all = AsyncMock(side_effect=reload_side_effect)
|
||||||
result = await activate_jail(
|
result = await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req)
|
||||||
str(tmp_path), "/fake.sock", "apache-auth", req
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.active is False
|
assert result.active is False
|
||||||
assert result.recovered is True
|
assert result.recovered is True
|
||||||
assert local_path.read_text() == original_local
|
assert local_path.read_text() == original_local
|
||||||
|
|
||||||
async def test_activate_jail_rollback_on_health_check_failure(
|
async def test_activate_jail_rollback_on_health_check_failure(self, tmp_path: Path) -> None:
|
||||||
self, tmp_path: Path
|
|
||||||
) -> None:
|
|
||||||
"""Rollback when fail2ban is unreachable after the activation reload.
|
"""Rollback when fail2ban is unreachable after the activation reload.
|
||||||
|
|
||||||
Expects:
|
Expects:
|
||||||
@@ -3190,15 +3125,11 @@ class TestActivateJailRollback:
|
|||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"app.services.config_file_service._validate_jail_config_sync",
|
"app.services.config_file_service._validate_jail_config_sync",
|
||||||
return_value=JailValidationResult(
|
return_value=JailValidationResult(jail_name="apache-auth", valid=True),
|
||||||
jail_name="apache-auth", valid=True
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
mock_js.reload_all = AsyncMock()
|
mock_js.reload_all = AsyncMock()
|
||||||
result = await activate_jail(
|
result = await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req)
|
||||||
str(tmp_path), "/fake.sock", "apache-auth", req
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.active is False
|
assert result.active is False
|
||||||
assert result.recovered is True
|
assert result.recovered is True
|
||||||
@@ -3232,25 +3163,17 @@ class TestActivateJailRollback:
|
|||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"app.services.config_file_service._validate_jail_config_sync",
|
"app.services.config_file_service._validate_jail_config_sync",
|
||||||
return_value=JailValidationResult(
|
return_value=JailValidationResult(jail_name="apache-auth", valid=True),
|
||||||
jail_name="apache-auth", valid=True
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
# Both the activation reload and the recovery reload fail.
|
# Both the activation reload and the recovery reload fail.
|
||||||
mock_js.reload_all = AsyncMock(
|
mock_js.reload_all = AsyncMock(side_effect=RuntimeError("fail2ban unavailable"))
|
||||||
side_effect=RuntimeError("fail2ban unavailable")
|
result = await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req)
|
||||||
)
|
|
||||||
result = await activate_jail(
|
|
||||||
str(tmp_path), "/fake.sock", "apache-auth", req
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.active is False
|
assert result.active is False
|
||||||
assert result.recovered is False
|
assert result.recovered is False
|
||||||
|
|
||||||
async def test_activate_jail_rollback_on_jail_not_found_error(
|
async def test_activate_jail_rollback_on_jail_not_found_error(self, tmp_path: Path) -> None:
|
||||||
self, tmp_path: Path
|
|
||||||
) -> None:
|
|
||||||
"""Rollback when reload_all raises JailNotFoundError (invalid config).
|
"""Rollback when reload_all raises JailNotFoundError (invalid config).
|
||||||
|
|
||||||
When fail2ban cannot create a jail due to invalid configuration
|
When fail2ban cannot create a jail due to invalid configuration
|
||||||
@@ -3294,16 +3217,12 @@ class TestActivateJailRollback:
|
|||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"app.services.config_file_service._validate_jail_config_sync",
|
"app.services.config_file_service._validate_jail_config_sync",
|
||||||
return_value=JailValidationResult(
|
return_value=JailValidationResult(jail_name="apache-auth", valid=True),
|
||||||
jail_name="apache-auth", valid=True
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
mock_js.reload_all = AsyncMock(side_effect=reload_side_effect)
|
mock_js.reload_all = AsyncMock(side_effect=reload_side_effect)
|
||||||
mock_js.JailNotFoundError = JailNotFoundError
|
mock_js.JailNotFoundError = JailNotFoundError
|
||||||
result = await activate_jail(
|
result = await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req)
|
||||||
str(tmp_path), "/fake.sock", "apache-auth", req
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.active is False
|
assert result.active is False
|
||||||
assert result.recovered is True
|
assert result.recovered is True
|
||||||
@@ -3311,9 +3230,7 @@ class TestActivateJailRollback:
|
|||||||
# Verify the error message mentions logpath issues.
|
# Verify the error message mentions logpath issues.
|
||||||
assert "logpath" in result.message.lower() or "check that all logpath" in result.message.lower()
|
assert "logpath" in result.message.lower() or "check that all logpath" in result.message.lower()
|
||||||
|
|
||||||
async def test_activate_jail_rollback_deletes_file_when_no_prior_local(
|
async def test_activate_jail_rollback_deletes_file_when_no_prior_local(self, tmp_path: Path) -> None:
|
||||||
self, tmp_path: Path
|
|
||||||
) -> None:
|
|
||||||
"""Rollback deletes the .local file when none existed before activation.
|
"""Rollback deletes the .local file when none existed before activation.
|
||||||
|
|
||||||
When a jail had no .local override before activation, activate_jail
|
When a jail had no .local override before activation, activate_jail
|
||||||
@@ -3355,15 +3272,11 @@ class TestActivateJailRollback:
|
|||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"app.services.config_file_service._validate_jail_config_sync",
|
"app.services.config_file_service._validate_jail_config_sync",
|
||||||
return_value=JailValidationResult(
|
return_value=JailValidationResult(jail_name="apache-auth", valid=True),
|
||||||
jail_name="apache-auth", valid=True
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
mock_js.reload_all = AsyncMock(side_effect=reload_side_effect)
|
mock_js.reload_all = AsyncMock(side_effect=reload_side_effect)
|
||||||
result = await activate_jail(
|
result = await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req)
|
||||||
str(tmp_path), "/fake.sock", "apache-auth", req
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.active is False
|
assert result.active is False
|
||||||
assert result.recovered is True
|
assert result.recovered is True
|
||||||
@@ -3376,7 +3289,7 @@ class TestActivateJailRollback:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
class TestRollbackJail:
|
class TestRollbackJailIntegration:
|
||||||
"""Integration tests for :func:`~app.services.config_file_service.rollback_jail`."""
|
"""Integration tests for :func:`~app.services.config_file_service.rollback_jail`."""
|
||||||
|
|
||||||
async def test_local_file_written_enabled_false(self, tmp_path: Path) -> None:
|
async def test_local_file_written_enabled_false(self, tmp_path: Path) -> None:
|
||||||
@@ -3419,15 +3332,11 @@ class TestRollbackJail:
|
|||||||
AsyncMock(return_value={"other"}),
|
AsyncMock(return_value={"other"}),
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
await rollback_jail(
|
await rollback_jail(str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"])
|
||||||
str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"]
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_start.assert_awaited_once_with(["fail2ban-client", "start"])
|
mock_start.assert_awaited_once_with(["fail2ban-client", "start"])
|
||||||
|
|
||||||
async def test_fail2ban_running_reflects_socket_probe_not_subprocess_exit(
|
async def test_fail2ban_running_reflects_socket_probe_not_subprocess_exit(self, tmp_path: Path) -> None:
|
||||||
self, tmp_path: Path
|
|
||||||
) -> None:
|
|
||||||
"""fail2ban_running in the response reflects the socket probe result.
|
"""fail2ban_running in the response reflects the socket probe result.
|
||||||
|
|
||||||
Even when start_daemon returns True (subprocess exit 0), if the socket
|
Even when start_daemon returns True (subprocess exit 0), if the socket
|
||||||
@@ -3443,15 +3352,11 @@ class TestRollbackJail:
|
|||||||
AsyncMock(return_value=False), # socket still unresponsive
|
AsyncMock(return_value=False), # socket still unresponsive
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
result = await rollback_jail(
|
result = await rollback_jail(str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"])
|
||||||
str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"]
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.fail2ban_running is False
|
assert result.fail2ban_running is False
|
||||||
|
|
||||||
async def test_active_jails_zero_when_fail2ban_not_running(
|
async def test_active_jails_zero_when_fail2ban_not_running(self, tmp_path: Path) -> None:
|
||||||
self, tmp_path: Path
|
|
||||||
) -> None:
|
|
||||||
"""active_jails is 0 in the response when fail2ban_running is False."""
|
"""active_jails is 0 in the response when fail2ban_running is False."""
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
@@ -3463,15 +3368,11 @@ class TestRollbackJail:
|
|||||||
AsyncMock(return_value=False),
|
AsyncMock(return_value=False),
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
result = await rollback_jail(
|
result = await rollback_jail(str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"])
|
||||||
str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"]
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.active_jails == 0
|
assert result.active_jails == 0
|
||||||
|
|
||||||
async def test_active_jails_count_from_socket_when_running(
|
async def test_active_jails_count_from_socket_when_running(self, tmp_path: Path) -> None:
|
||||||
self, tmp_path: Path
|
|
||||||
) -> None:
|
|
||||||
"""active_jails reflects the actual jail count from the socket when fail2ban is up."""
|
"""active_jails reflects the actual jail count from the socket when fail2ban is up."""
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
@@ -3487,15 +3388,11 @@ class TestRollbackJail:
|
|||||||
AsyncMock(return_value={"sshd", "nginx", "apache-auth"}),
|
AsyncMock(return_value={"sshd", "nginx", "apache-auth"}),
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
result = await rollback_jail(
|
result = await rollback_jail(str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"])
|
||||||
str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"]
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.active_jails == 3
|
assert result.active_jails == 3
|
||||||
|
|
||||||
async def test_fail2ban_down_at_start_still_succeeds_file_write(
|
async def test_fail2ban_down_at_start_still_succeeds_file_write(self, tmp_path: Path) -> None:
|
||||||
self, tmp_path: Path
|
|
||||||
) -> None:
|
|
||||||
"""rollback_jail writes the local file even when fail2ban is down at call time."""
|
"""rollback_jail writes the local file even when fail2ban is down at call time."""
|
||||||
# fail2ban is down: start_daemon fails and wait_for_fail2ban returns False.
|
# fail2ban is down: start_daemon fails and wait_for_fail2ban returns False.
|
||||||
with (
|
with (
|
||||||
@@ -3508,12 +3405,9 @@ class TestRollbackJail:
|
|||||||
AsyncMock(return_value=False),
|
AsyncMock(return_value=False),
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
result = await rollback_jail(
|
result = await rollback_jail(str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"])
|
||||||
str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"]
|
|
||||||
)
|
|
||||||
|
|
||||||
local = tmp_path / "jail.d" / "sshd.local"
|
local = tmp_path / "jail.d" / "sshd.local"
|
||||||
assert local.is_file(), "local file must be written even when fail2ban is down"
|
assert local.is_file(), "local file must be written even when fail2ban is down"
|
||||||
assert result.disabled is True
|
assert result.disabled is True
|
||||||
assert result.fail2ban_running is False
|
assert result.fail2ban_running is False
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Mapping, Sequence
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -44,7 +45,7 @@ def _make_session(response_json: dict[str, object], status: int = 200) -> MagicM
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def clear_geo_cache() -> None: # type: ignore[misc]
|
def clear_geo_cache() -> None:
|
||||||
"""Flush the module-level geo cache before every test."""
|
"""Flush the module-level geo cache before every test."""
|
||||||
geo_service.clear_cache()
|
geo_service.clear_cache()
|
||||||
|
|
||||||
@@ -68,7 +69,7 @@ class TestLookupSuccess:
|
|||||||
"org": "AS3320 Deutsche Telekom AG",
|
"org": "AS3320 Deutsche Telekom AG",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
result = await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
|
result = await geo_service.lookup("1.2.3.4", session)
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result.country_code == "DE"
|
assert result.country_code == "DE"
|
||||||
@@ -84,7 +85,7 @@ class TestLookupSuccess:
|
|||||||
"org": "Google LLC",
|
"org": "Google LLC",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
result = await geo_service.lookup("8.8.8.8", session) # type: ignore[arg-type]
|
result = await geo_service.lookup("8.8.8.8", session)
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result.country_name == "United States"
|
assert result.country_name == "United States"
|
||||||
@@ -100,7 +101,7 @@ class TestLookupSuccess:
|
|||||||
"org": "Deutsche Telekom",
|
"org": "Deutsche Telekom",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
result = await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
|
result = await geo_service.lookup("1.2.3.4", session)
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result.asn == "AS3320"
|
assert result.asn == "AS3320"
|
||||||
@@ -116,7 +117,7 @@ class TestLookupSuccess:
|
|||||||
"org": "Google LLC",
|
"org": "Google LLC",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
result = await geo_service.lookup("8.8.8.8", session) # type: ignore[arg-type]
|
result = await geo_service.lookup("8.8.8.8", session)
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result.org == "Google LLC"
|
assert result.org == "Google LLC"
|
||||||
@@ -142,8 +143,8 @@ class TestLookupCaching:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
|
await geo_service.lookup("1.2.3.4", session)
|
||||||
await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
|
await geo_service.lookup("1.2.3.4", session)
|
||||||
|
|
||||||
# The session.get() should only have been called once.
|
# The session.get() should only have been called once.
|
||||||
assert session.get.call_count == 1
|
assert session.get.call_count == 1
|
||||||
@@ -160,9 +161,9 @@ class TestLookupCaching:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
await geo_service.lookup("2.3.4.5", session) # type: ignore[arg-type]
|
await geo_service.lookup("2.3.4.5", session)
|
||||||
geo_service.clear_cache()
|
geo_service.clear_cache()
|
||||||
await geo_service.lookup("2.3.4.5", session) # type: ignore[arg-type]
|
await geo_service.lookup("2.3.4.5", session)
|
||||||
|
|
||||||
assert session.get.call_count == 2
|
assert session.get.call_count == 2
|
||||||
|
|
||||||
@@ -172,8 +173,8 @@ class TestLookupCaching:
|
|||||||
{"status": "fail", "message": "reserved range"}
|
{"status": "fail", "message": "reserved range"}
|
||||||
)
|
)
|
||||||
|
|
||||||
await geo_service.lookup("192.168.1.1", session) # type: ignore[arg-type]
|
await geo_service.lookup("192.168.1.1", session)
|
||||||
await geo_service.lookup("192.168.1.1", session) # type: ignore[arg-type]
|
await geo_service.lookup("192.168.1.1", session)
|
||||||
|
|
||||||
# Second call is blocked by the negative cache — only one API hit.
|
# Second call is blocked by the negative cache — only one API hit.
|
||||||
assert session.get.call_count == 1
|
assert session.get.call_count == 1
|
||||||
@@ -190,7 +191,7 @@ class TestLookupFailures:
|
|||||||
async def test_non_200_response_returns_null_geo_info(self) -> None:
|
async def test_non_200_response_returns_null_geo_info(self) -> None:
|
||||||
"""A 429 or 500 status returns GeoInfo with null fields (not None)."""
|
"""A 429 or 500 status returns GeoInfo with null fields (not None)."""
|
||||||
session = _make_session({}, status=429)
|
session = _make_session({}, status=429)
|
||||||
result = await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
|
result = await geo_service.lookup("1.2.3.4", session)
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert isinstance(result, GeoInfo)
|
assert isinstance(result, GeoInfo)
|
||||||
assert result.country_code is None
|
assert result.country_code is None
|
||||||
@@ -203,7 +204,7 @@ class TestLookupFailures:
|
|||||||
mock_ctx.__aexit__ = AsyncMock(return_value=False)
|
mock_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||||
session.get = MagicMock(return_value=mock_ctx)
|
session.get = MagicMock(return_value=mock_ctx)
|
||||||
|
|
||||||
result = await geo_service.lookup("10.0.0.1", session) # type: ignore[arg-type]
|
result = await geo_service.lookup("10.0.0.1", session)
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert isinstance(result, GeoInfo)
|
assert isinstance(result, GeoInfo)
|
||||||
assert result.country_code is None
|
assert result.country_code is None
|
||||||
@@ -211,7 +212,7 @@ class TestLookupFailures:
|
|||||||
async def test_failed_status_returns_geo_info_with_nulls(self) -> None:
|
async def test_failed_status_returns_geo_info_with_nulls(self) -> None:
|
||||||
"""When ip-api returns ``status=fail`` a GeoInfo with null fields is returned (but not cached)."""
|
"""When ip-api returns ``status=fail`` a GeoInfo with null fields is returned (but not cached)."""
|
||||||
session = _make_session({"status": "fail", "message": "private range"})
|
session = _make_session({"status": "fail", "message": "private range"})
|
||||||
result = await geo_service.lookup("10.0.0.1", session) # type: ignore[arg-type]
|
result = await geo_service.lookup("10.0.0.1", session)
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert isinstance(result, GeoInfo)
|
assert isinstance(result, GeoInfo)
|
||||||
@@ -231,8 +232,8 @@ class TestNegativeCache:
|
|||||||
"""After a failed lookup the second call is served from the neg cache."""
|
"""After a failed lookup the second call is served from the neg cache."""
|
||||||
session = _make_session({"status": "fail", "message": "private range"})
|
session = _make_session({"status": "fail", "message": "private range"})
|
||||||
|
|
||||||
r1 = await geo_service.lookup("192.0.2.1", session) # type: ignore[arg-type]
|
r1 = await geo_service.lookup("192.0.2.1", session)
|
||||||
r2 = await geo_service.lookup("192.0.2.1", session) # type: ignore[arg-type]
|
r2 = await geo_service.lookup("192.0.2.1", session)
|
||||||
|
|
||||||
# Only one HTTP call should have been made; second served from neg cache.
|
# Only one HTTP call should have been made; second served from neg cache.
|
||||||
assert session.get.call_count == 1
|
assert session.get.call_count == 1
|
||||||
@@ -243,12 +244,12 @@ class TestNegativeCache:
|
|||||||
"""When the neg-cache entry is older than the TTL a new API call is made."""
|
"""When the neg-cache entry is older than the TTL a new API call is made."""
|
||||||
session = _make_session({"status": "fail", "message": "private range"})
|
session = _make_session({"status": "fail", "message": "private range"})
|
||||||
|
|
||||||
await geo_service.lookup("192.0.2.2", session) # type: ignore[arg-type]
|
await geo_service.lookup("192.0.2.2", session)
|
||||||
|
|
||||||
# Manually expire the neg-cache entry.
|
# Manually expire the neg-cache entry.
|
||||||
geo_service._neg_cache["192.0.2.2"] -= geo_service._NEG_CACHE_TTL + 1 # type: ignore[attr-defined]
|
geo_service._neg_cache["192.0.2.2"] -= geo_service._NEG_CACHE_TTL + 1
|
||||||
|
|
||||||
await geo_service.lookup("192.0.2.2", session) # type: ignore[arg-type]
|
await geo_service.lookup("192.0.2.2", session)
|
||||||
|
|
||||||
# Both calls should have hit the API.
|
# Both calls should have hit the API.
|
||||||
assert session.get.call_count == 2
|
assert session.get.call_count == 2
|
||||||
@@ -257,9 +258,9 @@ class TestNegativeCache:
|
|||||||
"""After clearing the neg cache the IP is eligible for a new API call."""
|
"""After clearing the neg cache the IP is eligible for a new API call."""
|
||||||
session = _make_session({"status": "fail", "message": "private range"})
|
session = _make_session({"status": "fail", "message": "private range"})
|
||||||
|
|
||||||
await geo_service.lookup("192.0.2.3", session) # type: ignore[arg-type]
|
await geo_service.lookup("192.0.2.3", session)
|
||||||
geo_service.clear_neg_cache()
|
geo_service.clear_neg_cache()
|
||||||
await geo_service.lookup("192.0.2.3", session) # type: ignore[arg-type]
|
await geo_service.lookup("192.0.2.3", session)
|
||||||
|
|
||||||
assert session.get.call_count == 2
|
assert session.get.call_count == 2
|
||||||
|
|
||||||
@@ -275,9 +276,9 @@ class TestNegativeCache:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
|
await geo_service.lookup("1.2.3.4", session)
|
||||||
|
|
||||||
assert "1.2.3.4" not in geo_service._neg_cache # type: ignore[attr-defined]
|
assert "1.2.3.4" not in geo_service._neg_cache
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -307,7 +308,7 @@ class TestGeoipFallback:
|
|||||||
mock_reader = self._make_geoip_reader("DE", "Germany")
|
mock_reader = self._make_geoip_reader("DE", "Germany")
|
||||||
|
|
||||||
with patch.object(geo_service, "_geoip_reader", mock_reader):
|
with patch.object(geo_service, "_geoip_reader", mock_reader):
|
||||||
result = await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
|
result = await geo_service.lookup("1.2.3.4", session)
|
||||||
|
|
||||||
mock_reader.country.assert_called_once_with("1.2.3.4")
|
mock_reader.country.assert_called_once_with("1.2.3.4")
|
||||||
assert result is not None
|
assert result is not None
|
||||||
@@ -320,12 +321,12 @@ class TestGeoipFallback:
|
|||||||
mock_reader = self._make_geoip_reader("US", "United States")
|
mock_reader = self._make_geoip_reader("US", "United States")
|
||||||
|
|
||||||
with patch.object(geo_service, "_geoip_reader", mock_reader):
|
with patch.object(geo_service, "_geoip_reader", mock_reader):
|
||||||
await geo_service.lookup("8.8.8.8", session) # type: ignore[arg-type]
|
await geo_service.lookup("8.8.8.8", session)
|
||||||
# Second call must be served from positive cache without hitting API.
|
# Second call must be served from positive cache without hitting API.
|
||||||
await geo_service.lookup("8.8.8.8", session) # type: ignore[arg-type]
|
await geo_service.lookup("8.8.8.8", session)
|
||||||
|
|
||||||
assert session.get.call_count == 1
|
assert session.get.call_count == 1
|
||||||
assert "8.8.8.8" in geo_service._cache # type: ignore[attr-defined]
|
assert "8.8.8.8" in geo_service._cache
|
||||||
|
|
||||||
async def test_geoip_fallback_not_called_on_api_success(self) -> None:
|
async def test_geoip_fallback_not_called_on_api_success(self) -> None:
|
||||||
"""When ip-api succeeds, the geoip2 reader must not be consulted."""
|
"""When ip-api succeeds, the geoip2 reader must not be consulted."""
|
||||||
@@ -341,7 +342,7 @@ class TestGeoipFallback:
|
|||||||
mock_reader = self._make_geoip_reader("XX", "Nowhere")
|
mock_reader = self._make_geoip_reader("XX", "Nowhere")
|
||||||
|
|
||||||
with patch.object(geo_service, "_geoip_reader", mock_reader):
|
with patch.object(geo_service, "_geoip_reader", mock_reader):
|
||||||
result = await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
|
result = await geo_service.lookup("1.2.3.4", session)
|
||||||
|
|
||||||
mock_reader.country.assert_not_called()
|
mock_reader.country.assert_not_called()
|
||||||
assert result is not None
|
assert result is not None
|
||||||
@@ -352,7 +353,7 @@ class TestGeoipFallback:
|
|||||||
session = _make_session({"status": "fail", "message": "private range"})
|
session = _make_session({"status": "fail", "message": "private range"})
|
||||||
|
|
||||||
with patch.object(geo_service, "_geoip_reader", None):
|
with patch.object(geo_service, "_geoip_reader", None):
|
||||||
result = await geo_service.lookup("10.0.0.1", session) # type: ignore[arg-type]
|
result = await geo_service.lookup("10.0.0.1", session)
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result.country_code is None
|
assert result.country_code is None
|
||||||
@@ -363,7 +364,7 @@ class TestGeoipFallback:
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def _make_batch_session(batch_response: list[dict[str, object]]) -> MagicMock:
|
def _make_batch_session(batch_response: Sequence[Mapping[str, object]]) -> MagicMock:
|
||||||
"""Build a mock aiohttp.ClientSession for batch POST calls.
|
"""Build a mock aiohttp.ClientSession for batch POST calls.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -412,7 +413,7 @@ class TestLookupBatchSingleCommit:
|
|||||||
session = _make_batch_session(batch_response)
|
session = _make_batch_session(batch_response)
|
||||||
db = _make_async_db()
|
db = _make_async_db()
|
||||||
|
|
||||||
await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type]
|
await geo_service.lookup_batch(ips, session, db=db)
|
||||||
|
|
||||||
db.commit.assert_awaited_once()
|
db.commit.assert_awaited_once()
|
||||||
|
|
||||||
@@ -426,7 +427,7 @@ class TestLookupBatchSingleCommit:
|
|||||||
session = _make_batch_session(batch_response)
|
session = _make_batch_session(batch_response)
|
||||||
db = _make_async_db()
|
db = _make_async_db()
|
||||||
|
|
||||||
await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type]
|
await geo_service.lookup_batch(ips, session, db=db)
|
||||||
|
|
||||||
db.commit.assert_awaited_once()
|
db.commit.assert_awaited_once()
|
||||||
|
|
||||||
@@ -452,13 +453,13 @@ class TestLookupBatchSingleCommit:
|
|||||||
|
|
||||||
async def test_no_commit_for_all_cached_ips(self) -> None:
|
async def test_no_commit_for_all_cached_ips(self) -> None:
|
||||||
"""When all IPs are already cached, no HTTP call and no commit occur."""
|
"""When all IPs are already cached, no HTTP call and no commit occur."""
|
||||||
geo_service._cache["5.5.5.5"] = GeoInfo( # type: ignore[attr-defined]
|
geo_service._cache["5.5.5.5"] = GeoInfo(
|
||||||
country_code="FR", country_name="France", asn="AS1", org="ISP"
|
country_code="FR", country_name="France", asn="AS1", org="ISP"
|
||||||
)
|
)
|
||||||
db = _make_async_db()
|
db = _make_async_db()
|
||||||
session = _make_batch_session([])
|
session = _make_batch_session([])
|
||||||
|
|
||||||
result = await geo_service.lookup_batch(["5.5.5.5"], session, db=db) # type: ignore[arg-type]
|
result = await geo_service.lookup_batch(["5.5.5.5"], session, db=db)
|
||||||
|
|
||||||
assert result["5.5.5.5"].country_code == "FR"
|
assert result["5.5.5.5"].country_code == "FR"
|
||||||
db.commit.assert_not_awaited()
|
db.commit.assert_not_awaited()
|
||||||
@@ -476,26 +477,26 @@ class TestDirtySetTracking:
|
|||||||
def test_successful_resolution_adds_to_dirty(self) -> None:
|
def test_successful_resolution_adds_to_dirty(self) -> None:
|
||||||
"""Storing a GeoInfo with a country_code adds the IP to _dirty."""
|
"""Storing a GeoInfo with a country_code adds the IP to _dirty."""
|
||||||
info = GeoInfo(country_code="DE", country_name="Germany", asn="AS1", org="ISP")
|
info = GeoInfo(country_code="DE", country_name="Germany", asn="AS1", org="ISP")
|
||||||
geo_service._store("1.2.3.4", info) # type: ignore[attr-defined]
|
geo_service._store("1.2.3.4", info)
|
||||||
|
|
||||||
assert "1.2.3.4" in geo_service._dirty # type: ignore[attr-defined]
|
assert "1.2.3.4" in geo_service._dirty
|
||||||
|
|
||||||
def test_null_country_does_not_add_to_dirty(self) -> None:
|
def test_null_country_does_not_add_to_dirty(self) -> None:
|
||||||
"""Storing a GeoInfo with country_code=None must not pollute _dirty."""
|
"""Storing a GeoInfo with country_code=None must not pollute _dirty."""
|
||||||
info = GeoInfo(country_code=None, country_name=None, asn=None, org=None)
|
info = GeoInfo(country_code=None, country_name=None, asn=None, org=None)
|
||||||
geo_service._store("10.0.0.1", info) # type: ignore[attr-defined]
|
geo_service._store("10.0.0.1", info)
|
||||||
|
|
||||||
assert "10.0.0.1" not in geo_service._dirty # type: ignore[attr-defined]
|
assert "10.0.0.1" not in geo_service._dirty
|
||||||
|
|
||||||
def test_clear_cache_also_clears_dirty(self) -> None:
|
def test_clear_cache_also_clears_dirty(self) -> None:
|
||||||
"""clear_cache() must discard any pending dirty entries."""
|
"""clear_cache() must discard any pending dirty entries."""
|
||||||
info = GeoInfo(country_code="US", country_name="United States", asn="AS1", org="ISP")
|
info = GeoInfo(country_code="US", country_name="United States", asn="AS1", org="ISP")
|
||||||
geo_service._store("8.8.8.8", info) # type: ignore[attr-defined]
|
geo_service._store("8.8.8.8", info)
|
||||||
assert geo_service._dirty # type: ignore[attr-defined]
|
assert geo_service._dirty
|
||||||
|
|
||||||
geo_service.clear_cache()
|
geo_service.clear_cache()
|
||||||
|
|
||||||
assert not geo_service._dirty # type: ignore[attr-defined]
|
assert not geo_service._dirty
|
||||||
|
|
||||||
async def test_lookup_batch_populates_dirty(self) -> None:
|
async def test_lookup_batch_populates_dirty(self) -> None:
|
||||||
"""After lookup_batch() with db=None, resolved IPs appear in _dirty."""
|
"""After lookup_batch() with db=None, resolved IPs appear in _dirty."""
|
||||||
@@ -509,7 +510,7 @@ class TestDirtySetTracking:
|
|||||||
await geo_service.lookup_batch(ips, session, db=None)
|
await geo_service.lookup_batch(ips, session, db=None)
|
||||||
|
|
||||||
for ip in ips:
|
for ip in ips:
|
||||||
assert ip in geo_service._dirty # type: ignore[attr-defined]
|
assert ip in geo_service._dirty
|
||||||
|
|
||||||
|
|
||||||
class TestFlushDirty:
|
class TestFlushDirty:
|
||||||
@@ -518,8 +519,8 @@ class TestFlushDirty:
|
|||||||
async def test_flush_writes_and_clears_dirty(self) -> None:
|
async def test_flush_writes_and_clears_dirty(self) -> None:
|
||||||
"""flush_dirty() inserts all dirty IPs and clears _dirty afterwards."""
|
"""flush_dirty() inserts all dirty IPs and clears _dirty afterwards."""
|
||||||
info = GeoInfo(country_code="GB", country_name="United Kingdom", asn="AS2856", org="BT")
|
info = GeoInfo(country_code="GB", country_name="United Kingdom", asn="AS2856", org="BT")
|
||||||
geo_service._store("100.0.0.1", info) # type: ignore[attr-defined]
|
geo_service._store("100.0.0.1", info)
|
||||||
assert "100.0.0.1" in geo_service._dirty # type: ignore[attr-defined]
|
assert "100.0.0.1" in geo_service._dirty
|
||||||
|
|
||||||
db = _make_async_db()
|
db = _make_async_db()
|
||||||
count = await geo_service.flush_dirty(db)
|
count = await geo_service.flush_dirty(db)
|
||||||
@@ -527,7 +528,7 @@ class TestFlushDirty:
|
|||||||
assert count == 1
|
assert count == 1
|
||||||
db.executemany.assert_awaited_once()
|
db.executemany.assert_awaited_once()
|
||||||
db.commit.assert_awaited_once()
|
db.commit.assert_awaited_once()
|
||||||
assert "100.0.0.1" not in geo_service._dirty # type: ignore[attr-defined]
|
assert "100.0.0.1" not in geo_service._dirty
|
||||||
|
|
||||||
async def test_flush_returns_zero_when_nothing_dirty(self) -> None:
|
async def test_flush_returns_zero_when_nothing_dirty(self) -> None:
|
||||||
"""flush_dirty() returns 0 and makes no DB calls when _dirty is empty."""
|
"""flush_dirty() returns 0 and makes no DB calls when _dirty is empty."""
|
||||||
@@ -541,7 +542,7 @@ class TestFlushDirty:
|
|||||||
async def test_flush_re_adds_to_dirty_on_db_error(self) -> None:
|
async def test_flush_re_adds_to_dirty_on_db_error(self) -> None:
|
||||||
"""When the DB write fails, entries are re-added to _dirty for retry."""
|
"""When the DB write fails, entries are re-added to _dirty for retry."""
|
||||||
info = GeoInfo(country_code="AU", country_name="Australia", asn="AS1", org="ISP")
|
info = GeoInfo(country_code="AU", country_name="Australia", asn="AS1", org="ISP")
|
||||||
geo_service._store("200.0.0.1", info) # type: ignore[attr-defined]
|
geo_service._store("200.0.0.1", info)
|
||||||
|
|
||||||
db = _make_async_db()
|
db = _make_async_db()
|
||||||
db.executemany = AsyncMock(side_effect=OSError("disk full"))
|
db.executemany = AsyncMock(side_effect=OSError("disk full"))
|
||||||
@@ -549,7 +550,7 @@ class TestFlushDirty:
|
|||||||
count = await geo_service.flush_dirty(db)
|
count = await geo_service.flush_dirty(db)
|
||||||
|
|
||||||
assert count == 0
|
assert count == 0
|
||||||
assert "200.0.0.1" in geo_service._dirty # type: ignore[attr-defined]
|
assert "200.0.0.1" in geo_service._dirty
|
||||||
|
|
||||||
async def test_flush_batch_and_lookup_batch_integration(self) -> None:
|
async def test_flush_batch_and_lookup_batch_integration(self) -> None:
|
||||||
"""lookup_batch() populates _dirty; flush_dirty() then persists them."""
|
"""lookup_batch() populates _dirty; flush_dirty() then persists them."""
|
||||||
@@ -562,14 +563,14 @@ class TestFlushDirty:
|
|||||||
|
|
||||||
# Resolve without DB to populate only in-memory cache and _dirty.
|
# Resolve without DB to populate only in-memory cache and _dirty.
|
||||||
await geo_service.lookup_batch(ips, session, db=None)
|
await geo_service.lookup_batch(ips, session, db=None)
|
||||||
assert geo_service._dirty == set(ips) # type: ignore[attr-defined]
|
assert geo_service._dirty == set(ips)
|
||||||
|
|
||||||
# Now flush to the DB.
|
# Now flush to the DB.
|
||||||
db = _make_async_db()
|
db = _make_async_db()
|
||||||
count = await geo_service.flush_dirty(db)
|
count = await geo_service.flush_dirty(db)
|
||||||
|
|
||||||
assert count == 2
|
assert count == 2
|
||||||
assert not geo_service._dirty # type: ignore[attr-defined]
|
assert not geo_service._dirty
|
||||||
db.commit.assert_awaited_once()
|
db.commit.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
@@ -585,7 +586,7 @@ class TestLookupBatchThrottling:
|
|||||||
"""When more than _BATCH_SIZE IPs are sent, asyncio.sleep is called
|
"""When more than _BATCH_SIZE IPs are sent, asyncio.sleep is called
|
||||||
between consecutive batch HTTP calls with at least _BATCH_DELAY."""
|
between consecutive batch HTTP calls with at least _BATCH_DELAY."""
|
||||||
# Generate _BATCH_SIZE + 1 IPs so we get exactly 2 batch calls.
|
# Generate _BATCH_SIZE + 1 IPs so we get exactly 2 batch calls.
|
||||||
batch_size: int = geo_service._BATCH_SIZE # type: ignore[attr-defined]
|
batch_size: int = geo_service._BATCH_SIZE
|
||||||
ips = [f"10.0.{i // 256}.{i % 256}" for i in range(batch_size + 1)]
|
ips = [f"10.0.{i // 256}.{i % 256}" for i in range(batch_size + 1)]
|
||||||
|
|
||||||
def _make_result(chunk: list[str], _session: object) -> dict[str, GeoInfo]:
|
def _make_result(chunk: list[str], _session: object) -> dict[str, GeoInfo]:
|
||||||
@@ -608,7 +609,7 @@ class TestLookupBatchThrottling:
|
|||||||
assert mock_batch.call_count == 2
|
assert mock_batch.call_count == 2
|
||||||
mock_sleep.assert_awaited_once()
|
mock_sleep.assert_awaited_once()
|
||||||
delay_arg: float = mock_sleep.call_args[0][0]
|
delay_arg: float = mock_sleep.call_args[0][0]
|
||||||
assert delay_arg >= geo_service._BATCH_DELAY # type: ignore[attr-defined]
|
assert delay_arg >= geo_service._BATCH_DELAY
|
||||||
|
|
||||||
async def test_lookup_batch_retries_on_full_chunk_failure(self) -> None:
|
async def test_lookup_batch_retries_on_full_chunk_failure(self) -> None:
|
||||||
"""When a chunk returns all-None on first try, it retries and succeeds."""
|
"""When a chunk returns all-None on first try, it retries and succeeds."""
|
||||||
@@ -650,7 +651,7 @@ class TestLookupBatchThrottling:
|
|||||||
_empty = GeoInfo(country_code=None, country_name=None, asn=None, org=None)
|
_empty = GeoInfo(country_code=None, country_name=None, asn=None, org=None)
|
||||||
_failure: dict[str, GeoInfo] = dict.fromkeys(ips, _empty)
|
_failure: dict[str, GeoInfo] = dict.fromkeys(ips, _empty)
|
||||||
|
|
||||||
max_retries: int = geo_service._BATCH_MAX_RETRIES # type: ignore[attr-defined]
|
max_retries: int = geo_service._BATCH_MAX_RETRIES
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
@@ -667,11 +668,11 @@ class TestLookupBatchThrottling:
|
|||||||
# IP should have no country.
|
# IP should have no country.
|
||||||
assert result["9.9.9.9"].country_code is None
|
assert result["9.9.9.9"].country_code is None
|
||||||
# Negative cache should contain the IP.
|
# Negative cache should contain the IP.
|
||||||
assert "9.9.9.9" in geo_service._neg_cache # type: ignore[attr-defined]
|
assert "9.9.9.9" in geo_service._neg_cache
|
||||||
# Sleep called for each retry with exponential backoff.
|
# Sleep called for each retry with exponential backoff.
|
||||||
assert mock_sleep.call_count == max_retries
|
assert mock_sleep.call_count == max_retries
|
||||||
backoff_values = [call.args[0] for call in mock_sleep.call_args_list]
|
backoff_values = [call.args[0] for call in mock_sleep.call_args_list]
|
||||||
batch_delay: float = geo_service._BATCH_DELAY # type: ignore[attr-defined]
|
batch_delay: float = geo_service._BATCH_DELAY
|
||||||
for i, val in enumerate(backoff_values):
|
for i, val in enumerate(backoff_values):
|
||||||
expected = batch_delay * (2 ** (i + 1))
|
expected = batch_delay * (2 ** (i + 1))
|
||||||
assert val == pytest.approx(expected)
|
assert val == pytest.approx(expected)
|
||||||
@@ -709,7 +710,7 @@ class TestErrorLogging:
|
|||||||
import structlog.testing
|
import structlog.testing
|
||||||
|
|
||||||
with structlog.testing.capture_logs() as captured:
|
with structlog.testing.capture_logs() as captured:
|
||||||
result = await geo_service.lookup("197.221.98.153", session) # type: ignore[arg-type]
|
result = await geo_service.lookup("197.221.98.153", session)
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result.country_code is None
|
assert result.country_code is None
|
||||||
@@ -733,7 +734,7 @@ class TestErrorLogging:
|
|||||||
import structlog.testing
|
import structlog.testing
|
||||||
|
|
||||||
with structlog.testing.capture_logs() as captured:
|
with structlog.testing.capture_logs() as captured:
|
||||||
await geo_service.lookup("10.0.0.1", session) # type: ignore[arg-type]
|
await geo_service.lookup("10.0.0.1", session)
|
||||||
|
|
||||||
request_failed = [e for e in captured if e.get("event") == "geo_lookup_request_failed"]
|
request_failed = [e for e in captured if e.get("event") == "geo_lookup_request_failed"]
|
||||||
assert len(request_failed) == 1
|
assert len(request_failed) == 1
|
||||||
@@ -757,7 +758,7 @@ class TestErrorLogging:
|
|||||||
import structlog.testing
|
import structlog.testing
|
||||||
|
|
||||||
with structlog.testing.capture_logs() as captured:
|
with structlog.testing.capture_logs() as captured:
|
||||||
result = await geo_service._batch_api_call(["1.2.3.4"], session) # type: ignore[attr-defined]
|
result = await geo_service._batch_api_call(["1.2.3.4"], session)
|
||||||
|
|
||||||
assert result["1.2.3.4"].country_code is None
|
assert result["1.2.3.4"].country_code is None
|
||||||
|
|
||||||
@@ -778,7 +779,7 @@ class TestLookupCachedOnly:
|
|||||||
|
|
||||||
def test_returns_cached_ips(self) -> None:
|
def test_returns_cached_ips(self) -> None:
|
||||||
"""IPs already in the cache are returned in the geo_map."""
|
"""IPs already in the cache are returned in the geo_map."""
|
||||||
geo_service._cache["1.1.1.1"] = GeoInfo( # type: ignore[attr-defined]
|
geo_service._cache["1.1.1.1"] = GeoInfo(
|
||||||
country_code="AU", country_name="Australia", asn="AS13335", org="Cloudflare"
|
country_code="AU", country_name="Australia", asn="AS13335", org="Cloudflare"
|
||||||
)
|
)
|
||||||
geo_map, uncached = geo_service.lookup_cached_only(["1.1.1.1"])
|
geo_map, uncached = geo_service.lookup_cached_only(["1.1.1.1"])
|
||||||
@@ -798,7 +799,7 @@ class TestLookupCachedOnly:
|
|||||||
"""IPs in the negative cache within TTL are not re-queued as uncached."""
|
"""IPs in the negative cache within TTL are not re-queued as uncached."""
|
||||||
import time
|
import time
|
||||||
|
|
||||||
geo_service._neg_cache["10.0.0.1"] = time.monotonic() # type: ignore[attr-defined]
|
geo_service._neg_cache["10.0.0.1"] = time.monotonic()
|
||||||
|
|
||||||
geo_map, uncached = geo_service.lookup_cached_only(["10.0.0.1"])
|
geo_map, uncached = geo_service.lookup_cached_only(["10.0.0.1"])
|
||||||
|
|
||||||
@@ -807,7 +808,7 @@ class TestLookupCachedOnly:
|
|||||||
|
|
||||||
def test_expired_neg_cache_requeued(self) -> None:
|
def test_expired_neg_cache_requeued(self) -> None:
|
||||||
"""IPs whose neg-cache entry has expired are listed as uncached."""
|
"""IPs whose neg-cache entry has expired are listed as uncached."""
|
||||||
geo_service._neg_cache["10.0.0.2"] = 0.0 # epoch 0 → expired # type: ignore[attr-defined]
|
geo_service._neg_cache["10.0.0.2"] = 0.0 # epoch 0 → expired
|
||||||
|
|
||||||
_geo_map, uncached = geo_service.lookup_cached_only(["10.0.0.2"])
|
_geo_map, uncached = geo_service.lookup_cached_only(["10.0.0.2"])
|
||||||
|
|
||||||
@@ -815,12 +816,12 @@ class TestLookupCachedOnly:
|
|||||||
|
|
||||||
def test_mixed_ips(self) -> None:
|
def test_mixed_ips(self) -> None:
|
||||||
"""A mix of cached, neg-cached, and unknown IPs is split correctly."""
|
"""A mix of cached, neg-cached, and unknown IPs is split correctly."""
|
||||||
geo_service._cache["1.2.3.4"] = GeoInfo( # type: ignore[attr-defined]
|
geo_service._cache["1.2.3.4"] = GeoInfo(
|
||||||
country_code="DE", country_name="Germany", asn=None, org=None
|
country_code="DE", country_name="Germany", asn=None, org=None
|
||||||
)
|
)
|
||||||
import time
|
import time
|
||||||
|
|
||||||
geo_service._neg_cache["5.5.5.5"] = time.monotonic() # type: ignore[attr-defined]
|
geo_service._neg_cache["5.5.5.5"] = time.monotonic()
|
||||||
|
|
||||||
geo_map, uncached = geo_service.lookup_cached_only(["1.2.3.4", "5.5.5.5", "9.9.9.9"])
|
geo_map, uncached = geo_service.lookup_cached_only(["1.2.3.4", "5.5.5.5", "9.9.9.9"])
|
||||||
|
|
||||||
@@ -829,7 +830,7 @@ class TestLookupCachedOnly:
|
|||||||
|
|
||||||
def test_deduplication(self) -> None:
|
def test_deduplication(self) -> None:
|
||||||
"""Duplicate IPs in the input appear at most once in the output."""
|
"""Duplicate IPs in the input appear at most once in the output."""
|
||||||
geo_service._cache["1.2.3.4"] = GeoInfo( # type: ignore[attr-defined]
|
geo_service._cache["1.2.3.4"] = GeoInfo(
|
||||||
country_code="US", country_name="United States", asn=None, org=None
|
country_code="US", country_name="United States", asn=None, org=None
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -866,7 +867,7 @@ class TestLookupBatchBulkWrites:
|
|||||||
session = _make_batch_session(batch_response)
|
session = _make_batch_session(batch_response)
|
||||||
db = _make_async_db()
|
db = _make_async_db()
|
||||||
|
|
||||||
await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type]
|
await geo_service.lookup_batch(ips, session, db=db)
|
||||||
|
|
||||||
# One executemany for the positive rows.
|
# One executemany for the positive rows.
|
||||||
assert db.executemany.await_count >= 1
|
assert db.executemany.await_count >= 1
|
||||||
@@ -883,7 +884,7 @@ class TestLookupBatchBulkWrites:
|
|||||||
session = _make_batch_session(batch_response)
|
session = _make_batch_session(batch_response)
|
||||||
db = _make_async_db()
|
db = _make_async_db()
|
||||||
|
|
||||||
await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type]
|
await geo_service.lookup_batch(ips, session, db=db)
|
||||||
|
|
||||||
assert db.executemany.await_count >= 1
|
assert db.executemany.await_count >= 1
|
||||||
db.execute.assert_not_awaited()
|
db.execute.assert_not_awaited()
|
||||||
@@ -905,7 +906,7 @@ class TestLookupBatchBulkWrites:
|
|||||||
session = _make_batch_session(batch_response)
|
session = _make_batch_session(batch_response)
|
||||||
db = _make_async_db()
|
db = _make_async_db()
|
||||||
|
|
||||||
await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type]
|
await geo_service.lookup_batch(ips, session, db=db)
|
||||||
|
|
||||||
# One executemany for positives, one for negatives.
|
# One executemany for positives, one for negatives.
|
||||||
assert db.executemany.await_count == 2
|
assert db.executemany.await_count == 2
|
||||||
|
|||||||
@@ -64,7 +64,7 @@ async def _create_f2b_db(path: str, rows: list[dict[str, Any]]) -> None:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def f2b_db_path(tmp_path: Path) -> str: # type: ignore[misc]
|
async def f2b_db_path(tmp_path: Path) -> str:
|
||||||
"""Return the path to a test fail2ban SQLite database."""
|
"""Return the path to a test fail2ban SQLite database."""
|
||||||
path = str(tmp_path / "fail2ban_test.sqlite3")
|
path = str(tmp_path / "fail2ban_test.sqlite3")
|
||||||
await _create_f2b_db(
|
await _create_f2b_db(
|
||||||
|
|||||||
@@ -996,9 +996,6 @@ class TestGetJailBannedIps:
|
|||||||
|
|
||||||
async def test_unknown_jail_raises_jail_not_found_error(self) -> None:
|
async def test_unknown_jail_raises_jail_not_found_error(self) -> None:
|
||||||
"""get_jail_banned_ips raises JailNotFoundError for unknown jail."""
|
"""get_jail_banned_ips raises JailNotFoundError for unknown jail."""
|
||||||
responses = {
|
|
||||||
"status|ghost|short": (0, pytest.raises), # will be overridden
|
|
||||||
}
|
|
||||||
# Simulate fail2ban returning an "unknown jail" error.
|
# Simulate fail2ban returning an "unknown jail" error.
|
||||||
class _FakeClient:
|
class _FakeClient:
|
||||||
def __init__(self, **_kw: Any) -> None:
|
def __init__(self, **_kw: Any) -> None:
|
||||||
|
|||||||
@@ -270,7 +270,7 @@ class TestCrashDetection:
|
|||||||
async def test_crash_within_window_creates_pending_recovery(self) -> None:
|
async def test_crash_within_window_creates_pending_recovery(self) -> None:
|
||||||
"""An online→offline transition within 60 s of activation must set pending_recovery."""
|
"""An online→offline transition within 60 s of activation must set pending_recovery."""
|
||||||
app = _make_app(prev_online=True)
|
app = _make_app(prev_online=True)
|
||||||
now = datetime.datetime.now(tz=datetime.timezone.utc)
|
now = datetime.datetime.now(tz=datetime.UTC)
|
||||||
app.state.last_activation = {
|
app.state.last_activation = {
|
||||||
"jail_name": "sshd",
|
"jail_name": "sshd",
|
||||||
"at": now - datetime.timedelta(seconds=10),
|
"at": now - datetime.timedelta(seconds=10),
|
||||||
@@ -297,7 +297,7 @@ class TestCrashDetection:
|
|||||||
app = _make_app(prev_online=True)
|
app = _make_app(prev_online=True)
|
||||||
app.state.last_activation = {
|
app.state.last_activation = {
|
||||||
"jail_name": "sshd",
|
"jail_name": "sshd",
|
||||||
"at": datetime.datetime.now(tz=datetime.timezone.utc)
|
"at": datetime.datetime.now(tz=datetime.UTC)
|
||||||
- datetime.timedelta(seconds=120),
|
- datetime.timedelta(seconds=120),
|
||||||
}
|
}
|
||||||
app.state.pending_recovery = None
|
app.state.pending_recovery = None
|
||||||
@@ -315,8 +315,8 @@ class TestCrashDetection:
|
|||||||
async def test_came_online_marks_pending_recovery_resolved(self) -> None:
|
async def test_came_online_marks_pending_recovery_resolved(self) -> None:
|
||||||
"""An offline→online transition must mark an existing pending_recovery as recovered."""
|
"""An offline→online transition must mark an existing pending_recovery as recovered."""
|
||||||
app = _make_app(prev_online=False)
|
app = _make_app(prev_online=False)
|
||||||
activated_at = datetime.datetime.now(tz=datetime.timezone.utc) - datetime.timedelta(seconds=30)
|
activated_at = datetime.datetime.now(tz=datetime.UTC) - datetime.timedelta(seconds=30)
|
||||||
detected_at = datetime.datetime.now(tz=datetime.timezone.utc)
|
detected_at = datetime.datetime.now(tz=datetime.UTC)
|
||||||
app.state.pending_recovery = PendingRecovery(
|
app.state.pending_recovery = PendingRecovery(
|
||||||
jail_name="sshd",
|
jail_name="sshd",
|
||||||
activated_at=activated_at,
|
activated_at=activated_at,
|
||||||
|
|||||||
Reference in New Issue
Block a user