Add better jail configuration: file CRUD, enable/disable, log paths

Task 4 (Better Jail Configuration) implementation:
- Add fail2ban_config_dir setting to app/config.py
- New file_config_service: list/view/edit/create jail.d, filter.d, action.d files
  with path-traversal prevention and 512 KB content size limit
- New file_config router: GET/PUT/POST endpoints for jail files, filter files,
  and action files; PUT .../enabled for toggle on/off
- Extend config_service with delete_log_path() and add_log_path()
- Add DELETE /api/config/jails/{name}/logpath and POST /api/config/jails/{name}/logpath
- Extend geo router with re-resolve endpoint; add geo_re_resolve background task
- Update blocklist_service with revised scheduling helpers
- Update Docker compose files with BANGUI_FAIL2BAN_CONFIG_DIR env var and
  rw volume mount for the fail2ban config directory
- Frontend: new Jail Files, Filters, Actions tabs in ConfigPage; file editor
  with accordion-per-file, editable textarea, save/create; add/delete log paths
- Frontend: types in types/config.ts; API calls in api/config.ts and api/endpoints.ts
- 63 new backend tests (test_file_config_service, test_file_config, test_geo_re_resolve)
- 6 new frontend tests in ConfigPageLogPath.test.tsx
- ruff, mypy --strict, tsc --noEmit, eslint: all clean; 617 backend tests pass
This commit is contained in:
2026-03-12 20:08:33 +01:00
parent 59464a1592
commit ea35695221
23 changed files with 2911 additions and 91 deletions

View File

@@ -52,6 +52,14 @@ class Settings(BaseSettings):
"When set, failed ip-api.com lookups fall back to local resolution."
),
)
fail2ban_config_dir: str = Field(
default="/config/fail2ban",
description=(
"Path to the fail2ban configuration directory. "
"Must contain subdirectories jail.d/, filter.d/, and action.d/. "
"Used for listing, viewing, and editing configuration files through the web UI."
),
)
model_config = SettingsConfigDict(
env_prefix="BANGUI_",

View File

@@ -33,8 +33,21 @@ from starlette.middleware.base import BaseHTTPMiddleware
from app.config import Settings, get_settings
from app.db import init_db
from app.routers import auth, bans, blocklist, config, dashboard, geo, health, history, jails, server, setup
from app.tasks import blocklist_import, geo_cache_flush, health_check
from app.routers import (
auth,
bans,
blocklist,
config,
dashboard,
file_config,
geo,
health,
history,
jails,
server,
setup,
)
from app.tasks import blocklist_import, geo_cache_flush, geo_re_resolve, health_check
from app.utils.fail2ban_client import Fail2BanConnectionError, Fail2BanProtocolError
# ---------------------------------------------------------------------------
@@ -140,6 +153,15 @@ async def _lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
geo_service.init_geoip(settings.geoip_db_path)
await geo_service.load_cache_from_db(db)
# Log unresolved geo entries so the operator can see the scope of the issue.
async with db.execute(
"SELECT COUNT(*) FROM geo_cache WHERE country_code IS NULL"
) as cur:
row = await cur.fetchone()
unresolved_count: int = int(row[0]) if row else 0
if unresolved_count > 0:
log.warning("geo_cache_unresolved_ips", unresolved=unresolved_count)
# --- Background task scheduler ---
scheduler: AsyncIOScheduler = AsyncIOScheduler(timezone="UTC")
scheduler.start()
@@ -154,6 +176,9 @@ async def _lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
# --- Periodic geo cache flush to SQLite ---
geo_cache_flush.register(app)
# --- Periodic re-resolve of NULL-country geo entries ---
geo_re_resolve.register(app)
log.info("bangui_started")
try:
@@ -375,6 +400,7 @@ def create_app(settings: Settings | None = None) -> FastAPI:
app.include_router(bans.router)
app.include_router(geo.router)
app.include_router(config.router)
app.include_router(file_config.router)
app.include_router(server.router)
app.include_router(history.router)
app.include_router(blocklist.router)

View File

@@ -0,0 +1,109 @@
"""Pydantic models for file-based fail2ban configuration management.
Covers jail config files (``jail.d/``), filter definitions (``filter.d/``),
and action definitions (``action.d/``).
"""
from pydantic import BaseModel, ConfigDict, Field
# ---------------------------------------------------------------------------
# Jail config file models (Task 4a)
# ---------------------------------------------------------------------------
class JailConfigFile(BaseModel):
"""Metadata for a single jail configuration file in ``jail.d/``."""
model_config = ConfigDict(strict=True)
name: str = Field(..., description="Jail name (file stem, e.g. ``sshd``).")
filename: str = Field(..., description="Actual filename (e.g. ``sshd.conf``).")
enabled: bool = Field(
...,
description=(
"Whether the jail is enabled. Derived from the ``enabled`` key "
"inside the file; defaults to ``true`` when the key is absent."
),
)
class JailConfigFilesResponse(BaseModel):
"""Response for ``GET /api/config/jail-files``."""
model_config = ConfigDict(strict=True)
files: list[JailConfigFile] = Field(default_factory=list)
total: int = Field(..., ge=0)
class JailConfigFileContent(BaseModel):
"""Single jail config file with its raw content."""
model_config = ConfigDict(strict=True)
name: str = Field(..., description="Jail name (file stem).")
filename: str = Field(..., description="Actual filename.")
enabled: bool = Field(..., description="Whether the jail is enabled.")
content: str = Field(..., description="Raw file content.")
class JailConfigFileEnabledUpdate(BaseModel):
"""Payload for ``PUT /api/config/jail-files/{filename}/enabled``."""
model_config = ConfigDict(strict=True)
enabled: bool = Field(..., description="New enabled state for this jail.")
# ---------------------------------------------------------------------------
# Generic conf-file entry (shared by filter.d and action.d)
# ---------------------------------------------------------------------------
class ConfFileEntry(BaseModel):
"""Metadata for a single ``.conf`` or ``.local`` file."""
model_config = ConfigDict(strict=True)
name: str = Field(..., description="Base name without extension (e.g. ``sshd``).")
filename: str = Field(..., description="Actual filename (e.g. ``sshd.conf``).")
class ConfFilesResponse(BaseModel):
"""Response for list endpoints (``GET /api/config/filters`` and ``GET /api/config/actions``)."""
model_config = ConfigDict(strict=True)
files: list[ConfFileEntry] = Field(default_factory=list)
total: int = Field(..., ge=0)
class ConfFileContent(BaseModel):
"""A conf file with its raw text content."""
model_config = ConfigDict(strict=True)
name: str = Field(..., description="Base name without extension.")
filename: str = Field(..., description="Actual filename.")
content: str = Field(..., description="Raw file content.")
class ConfFileUpdateRequest(BaseModel):
"""Payload for ``PUT /api/config/filters/{name}`` and ``PUT /api/config/actions/{name}``."""
model_config = ConfigDict(strict=True)
content: str = Field(..., description="New raw file content (must not exceed 512 KB).")
class ConfFileCreateRequest(BaseModel):
"""Payload for ``POST /api/config/filters`` and ``POST /api/config/actions``."""
model_config = ConfigDict(strict=True)
name: str = Field(
...,
description="New file base name (without extension). Must contain only "
"alphanumeric characters, hyphens, underscores, and dots.",
)
content: str = Field(..., description="Initial raw file content (must not exceed 512 KB).")

View File

@@ -32,6 +32,21 @@ class GeoDetail(BaseModel):
)
class GeoCacheStatsResponse(BaseModel):
"""Response for ``GET /api/geo/stats``.
Exposes diagnostic counters of the geo cache subsystem so operators
can assess resolution health from the UI or CLI.
"""
model_config = ConfigDict(strict=True)
cache_size: int = Field(..., description="Number of positive entries in the in-memory cache.")
unresolved: int = Field(..., description="Number of geo_cache rows with country_code IS NULL.")
neg_cache_size: int = Field(..., description="Number of entries in the in-memory negative cache.")
dirty_size: int = Field(..., description="Number of newly resolved entries not yet flushed to disk.")
class IpLookupResponse(BaseModel):
"""Response for ``GET /api/geo/lookup/{ip}``.

View File

@@ -18,7 +18,7 @@ from __future__ import annotations
from typing import Annotated
from fastapi import APIRouter, HTTPException, Path, Request, status
from fastapi import APIRouter, HTTPException, Path, Query, Request, status
from app.dependencies import AuthDep
from app.models.config import (
@@ -354,9 +354,42 @@ async def add_log_path(
raise _bad_gateway(exc) from exc
# ---------------------------------------------------------------------------
# Log preview
# ---------------------------------------------------------------------------
@router.delete(
"/jails/{name}/logpath",
status_code=status.HTTP_204_NO_CONTENT,
summary="Remove a monitored log path from a jail",
)
async def delete_log_path(
request: Request,
_auth: AuthDep,
name: _NamePath,
log_path: str = Query(..., description="Absolute path of the log file to stop monitoring."),
) -> None:
"""Stop a jail from monitoring the specified log file.
Uses ``set <jail> dellogpath <path>`` to remove the log path at runtime
without requiring a daemon restart.
Args:
request: Incoming request.
_auth: Validated session.
name: Jail name.
log_path: Absolute path to the log file to remove (query parameter).
Raises:
HTTPException: 404 when the jail does not exist.
HTTPException: 400 when the command is rejected.
HTTPException: 502 when fail2ban is unreachable.
"""
socket_path: str = request.app.state.settings.fail2ban_socket
try:
await config_service.delete_log_path(socket_path, name, log_path)
except JailNotFoundError:
raise _not_found(name) from None
except ConfigOperationError as exc:
raise _bad_request(str(exc)) from exc
except Fail2BanConnectionError as exc:
raise _bad_gateway(exc) from exc
@router.post(

View File

@@ -0,0 +1,495 @@
"""File-based fail2ban configuration router.
Provides endpoints to list, view, edit, and create fail2ban configuration
files directly on the filesystem (``jail.d/``, ``filter.d/``, ``action.d/``).
Endpoints:
* ``GET /api/config/jail-files`` — list all jail config files
* ``GET /api/config/jail-files/{filename}`` — get one jail config file (with content)
* ``PUT /api/config/jail-files/{filename}/enabled`` — enable/disable a jail config
* ``GET /api/config/filters`` — list all filter files
* ``GET /api/config/filters/{name}`` — get one filter file (with content)
* ``PUT /api/config/filters/{name}`` — update a filter file
* ``POST /api/config/filters`` — create a new filter file
* ``GET /api/config/actions`` — list all action files
* ``GET /api/config/actions/{name}`` — get one action file (with content)
* ``PUT /api/config/actions/{name}`` — update an action file
* ``POST /api/config/actions`` — create a new action file
"""
from __future__ import annotations
from typing import Annotated
from fastapi import APIRouter, HTTPException, Path, Request, status
from app.dependencies import AuthDep
from app.models.file_config import (
ConfFileContent,
ConfFileCreateRequest,
ConfFilesResponse,
ConfFileUpdateRequest,
JailConfigFileContent,
JailConfigFileEnabledUpdate,
JailConfigFilesResponse,
)
from app.services import file_config_service
from app.services.file_config_service import (
ConfigDirError,
ConfigFileExistsError,
ConfigFileNameError,
ConfigFileNotFoundError,
ConfigFileWriteError,
)
router: APIRouter = APIRouter(prefix="/api/config", tags=["Config"])
# ---------------------------------------------------------------------------
# Path type aliases
# ---------------------------------------------------------------------------
_FilenamePath = Annotated[
str, Path(description="Config filename including extension (e.g. ``sshd.conf``).")
]
_NamePath = Annotated[
str, Path(description="Base name with or without extension (e.g. ``sshd`` or ``sshd.conf``).")
]
# ---------------------------------------------------------------------------
# Error helpers
# ---------------------------------------------------------------------------
def _not_found(filename: str) -> HTTPException:
return HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Config file not found: {filename!r}",
)
def _bad_request(message: str) -> HTTPException:
return HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=message,
)
def _conflict(filename: str) -> HTTPException:
return HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Config file already exists: {filename!r}",
)
def _service_unavailable(message: str) -> HTTPException:
return HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=message,
)
# ---------------------------------------------------------------------------
# Jail config file endpoints (Task 4a)
# ---------------------------------------------------------------------------
@router.get(
"/jail-files",
response_model=JailConfigFilesResponse,
summary="List all jail config files",
)
async def list_jail_config_files(
request: Request,
_auth: AuthDep,
) -> JailConfigFilesResponse:
"""Return metadata for every ``.conf`` and ``.local`` file in ``jail.d/``.
The ``enabled`` field reflects the value of the ``enabled`` key inside the
file (defaulting to ``true`` when the key is absent).
Args:
request: Incoming request (used for ``app.state.settings``).
_auth: Validated session — enforces authentication.
Returns:
:class:`~app.models.file_config.JailConfigFilesResponse`.
"""
config_dir: str = request.app.state.settings.fail2ban_config_dir
try:
return await file_config_service.list_jail_config_files(config_dir)
except ConfigDirError as exc:
raise _service_unavailable(str(exc)) from exc
@router.get(
"/jail-files/{filename}",
response_model=JailConfigFileContent,
summary="Return a single jail config file with its content",
)
async def get_jail_config_file(
request: Request,
_auth: AuthDep,
filename: _FilenamePath,
) -> JailConfigFileContent:
"""Return the metadata and raw content of one jail config file.
Args:
request: Incoming request.
_auth: Validated session.
filename: Filename including extension (e.g. ``sshd.conf``).
Returns:
:class:`~app.models.file_config.JailConfigFileContent`.
Raises:
HTTPException: 400 if *filename* is unsafe.
HTTPException: 404 if the file does not exist.
HTTPException: 503 if the config directory is unavailable.
"""
config_dir: str = request.app.state.settings.fail2ban_config_dir
try:
return await file_config_service.get_jail_config_file(config_dir, filename)
except ConfigFileNameError as exc:
raise _bad_request(str(exc)) from exc
except ConfigFileNotFoundError:
raise _not_found(filename) from None
except ConfigDirError as exc:
raise _service_unavailable(str(exc)) from exc
@router.put(
"/jail-files/{filename}/enabled",
status_code=status.HTTP_204_NO_CONTENT,
summary="Enable or disable a jail configuration file",
)
async def set_jail_config_file_enabled(
request: Request,
_auth: AuthDep,
filename: _FilenamePath,
body: JailConfigFileEnabledUpdate,
) -> None:
"""Set the ``enabled = true/false`` key inside a jail config file.
The change modifies the file on disk. You must reload fail2ban
(``POST /api/config/reload``) separately for the change to take effect.
Args:
request: Incoming request.
_auth: Validated session.
filename: Filename of the jail config file (e.g. ``sshd.conf``).
body: New enabled state.
Raises:
HTTPException: 400 if *filename* is unsafe or the operation fails.
HTTPException: 404 if the file does not exist.
HTTPException: 503 if the config directory is unavailable.
"""
config_dir: str = request.app.state.settings.fail2ban_config_dir
try:
await file_config_service.set_jail_config_enabled(
config_dir, filename, body.enabled
)
except ConfigFileNameError as exc:
raise _bad_request(str(exc)) from exc
except ConfigFileNotFoundError:
raise _not_found(filename) from None
except ConfigFileWriteError as exc:
raise _bad_request(str(exc)) from exc
except ConfigDirError as exc:
raise _service_unavailable(str(exc)) from exc
# ---------------------------------------------------------------------------
# Filter file endpoints (Task 4d)
# ---------------------------------------------------------------------------
@router.get(
"/filters",
response_model=ConfFilesResponse,
summary="List all filter definition files",
)
async def list_filter_files(
request: Request,
_auth: AuthDep,
) -> ConfFilesResponse:
"""Return a list of every ``.conf`` and ``.local`` file in ``filter.d/``.
Args:
request: Incoming request.
_auth: Validated session.
Returns:
:class:`~app.models.file_config.ConfFilesResponse`.
"""
config_dir: str = request.app.state.settings.fail2ban_config_dir
try:
return await file_config_service.list_filter_files(config_dir)
except ConfigDirError as exc:
raise _service_unavailable(str(exc)) from exc
@router.get(
"/filters/{name}",
response_model=ConfFileContent,
summary="Return a filter definition file with its content",
)
async def get_filter_file(
request: Request,
_auth: AuthDep,
name: _NamePath,
) -> ConfFileContent:
"""Return the content of a filter definition file.
Args:
request: Incoming request.
_auth: Validated session.
name: Base name with or without extension (e.g. ``sshd`` or ``sshd.conf``).
Returns:
:class:`~app.models.file_config.ConfFileContent`.
Raises:
HTTPException: 400 if *name* is unsafe.
HTTPException: 404 if the file does not exist.
HTTPException: 503 if the config directory is unavailable.
"""
config_dir: str = request.app.state.settings.fail2ban_config_dir
try:
return await file_config_service.get_filter_file(config_dir, name)
except ConfigFileNameError as exc:
raise _bad_request(str(exc)) from exc
except ConfigFileNotFoundError:
raise _not_found(name) from None
except ConfigDirError as exc:
raise _service_unavailable(str(exc)) from exc
@router.put(
"/filters/{name}",
status_code=status.HTTP_204_NO_CONTENT,
summary="Update a filter definition file",
)
async def write_filter_file(
request: Request,
_auth: AuthDep,
name: _NamePath,
body: ConfFileUpdateRequest,
) -> None:
"""Overwrite the content of an existing filter definition file.
Args:
request: Incoming request.
_auth: Validated session.
name: Base name with or without extension.
body: New file content.
Raises:
HTTPException: 400 if *name* is unsafe or content exceeds the size limit.
HTTPException: 404 if the file does not exist.
HTTPException: 503 if the config directory is unavailable.
"""
config_dir: str = request.app.state.settings.fail2ban_config_dir
try:
await file_config_service.write_filter_file(config_dir, name, body)
except ConfigFileNameError as exc:
raise _bad_request(str(exc)) from exc
except ConfigFileNotFoundError:
raise _not_found(name) from None
except ConfigFileWriteError as exc:
raise _bad_request(str(exc)) from exc
except ConfigDirError as exc:
raise _service_unavailable(str(exc)) from exc
@router.post(
"/filters",
status_code=status.HTTP_201_CREATED,
response_model=ConfFileContent,
summary="Create a new filter definition file",
)
async def create_filter_file(
request: Request,
_auth: AuthDep,
body: ConfFileCreateRequest,
) -> ConfFileContent:
"""Create a new ``.conf`` file in ``filter.d/``.
Args:
request: Incoming request.
_auth: Validated session.
body: Name and initial content for the new file.
Returns:
The created :class:`~app.models.file_config.ConfFileContent`.
Raises:
HTTPException: 400 if *name* is invalid or content exceeds limit.
HTTPException: 409 if a file with that name already exists.
HTTPException: 503 if the config directory is unavailable.
"""
config_dir: str = request.app.state.settings.fail2ban_config_dir
try:
filename = await file_config_service.create_filter_file(config_dir, body)
except ConfigFileNameError as exc:
raise _bad_request(str(exc)) from exc
except ConfigFileExistsError:
raise _conflict(body.name) from None
except ConfigFileWriteError as exc:
raise _bad_request(str(exc)) from exc
except ConfigDirError as exc:
raise _service_unavailable(str(exc)) from exc
return ConfFileContent(
name=body.name,
filename=filename,
content=body.content,
)
# ---------------------------------------------------------------------------
# Action file endpoints (Task 4e)
# ---------------------------------------------------------------------------
@router.get(
"/actions",
response_model=ConfFilesResponse,
summary="List all action definition files",
)
async def list_action_files(
request: Request,
_auth: AuthDep,
) -> ConfFilesResponse:
"""Return a list of every ``.conf`` and ``.local`` file in ``action.d/``.
Args:
request: Incoming request.
_auth: Validated session.
Returns:
:class:`~app.models.file_config.ConfFilesResponse`.
"""
config_dir: str = request.app.state.settings.fail2ban_config_dir
try:
return await file_config_service.list_action_files(config_dir)
except ConfigDirError as exc:
raise _service_unavailable(str(exc)) from exc
@router.get(
"/actions/{name}",
response_model=ConfFileContent,
summary="Return an action definition file with its content",
)
async def get_action_file(
request: Request,
_auth: AuthDep,
name: _NamePath,
) -> ConfFileContent:
"""Return the content of an action definition file.
Args:
request: Incoming request.
_auth: Validated session.
name: Base name with or without extension.
Returns:
:class:`~app.models.file_config.ConfFileContent`.
Raises:
HTTPException: 400 if *name* is unsafe.
HTTPException: 404 if the file does not exist.
HTTPException: 503 if the config directory is unavailable.
"""
config_dir: str = request.app.state.settings.fail2ban_config_dir
try:
return await file_config_service.get_action_file(config_dir, name)
except ConfigFileNameError as exc:
raise _bad_request(str(exc)) from exc
except ConfigFileNotFoundError:
raise _not_found(name) from None
except ConfigDirError as exc:
raise _service_unavailable(str(exc)) from exc
@router.put(
"/actions/{name}",
status_code=status.HTTP_204_NO_CONTENT,
summary="Update an action definition file",
)
async def write_action_file(
request: Request,
_auth: AuthDep,
name: _NamePath,
body: ConfFileUpdateRequest,
) -> None:
"""Overwrite the content of an existing action definition file.
Args:
request: Incoming request.
_auth: Validated session.
name: Base name with or without extension.
body: New file content.
Raises:
HTTPException: 400 if *name* is unsafe or content exceeds the size limit.
HTTPException: 404 if the file does not exist.
HTTPException: 503 if the config directory is unavailable.
"""
config_dir: str = request.app.state.settings.fail2ban_config_dir
try:
await file_config_service.write_action_file(config_dir, name, body)
except ConfigFileNameError as exc:
raise _bad_request(str(exc)) from exc
except ConfigFileNotFoundError:
raise _not_found(name) from None
except ConfigFileWriteError as exc:
raise _bad_request(str(exc)) from exc
except ConfigDirError as exc:
raise _service_unavailable(str(exc)) from exc
@router.post(
"/actions",
status_code=status.HTTP_201_CREATED,
response_model=ConfFileContent,
summary="Create a new action definition file",
)
async def create_action_file(
request: Request,
_auth: AuthDep,
body: ConfFileCreateRequest,
) -> ConfFileContent:
"""Create a new ``.conf`` file in ``action.d/``.
Args:
request: Incoming request.
_auth: Validated session.
body: Name and initial content for the new file.
Returns:
The created :class:`~app.models.file_config.ConfFileContent`.
Raises:
HTTPException: 400 if *name* is invalid or content exceeds limit.
HTTPException: 409 if a file with that name already exists.
HTTPException: 503 if the config directory is unavailable.
"""
config_dir: str = request.app.state.settings.fail2ban_config_dir
try:
filename = await file_config_service.create_action_file(config_dir, body)
except ConfigFileNameError as exc:
raise _bad_request(str(exc)) from exc
except ConfigFileExistsError:
raise _conflict(body.name) from None
except ConfigFileWriteError as exc:
raise _bad_request(str(exc)) from exc
except ConfigDirError as exc:
raise _service_unavailable(str(exc)) from exc
return ConfFileContent(
name=body.name,
filename=filename,
content=body.content,
)

View File

@@ -17,7 +17,7 @@ import aiosqlite
from fastapi import APIRouter, Depends, HTTPException, Path, Request, status
from app.dependencies import AuthDep, get_db
from app.models.geo import GeoDetail, IpLookupResponse
from app.models.geo import GeoCacheStatsResponse, GeoDetail, IpLookupResponse
from app.services import geo_service, jail_service
from app.utils.fail2ban_client import Fail2BanConnectionError
@@ -99,6 +99,35 @@ async def lookup_ip(
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# GET /api/geo/stats
# ---------------------------------------------------------------------------
@router.get(
"/stats",
response_model=GeoCacheStatsResponse,
summary="Geo cache diagnostic counters",
)
async def geo_stats(
_auth: AuthDep,
db: Annotated[aiosqlite.Connection, Depends(get_db)],
) -> GeoCacheStatsResponse:
"""Return diagnostic counters for the geo cache subsystem.
Useful for operators and the UI to gauge geo-resolution health.
Args:
_auth: Validated session — enforces authentication.
db: BanGUI application database connection.
Returns:
:class:`~app.models.geo.GeoCacheStatsResponse` with current counters.
"""
stats: dict[str, int] = await geo_service.cache_stats(db)
return GeoCacheStatsResponse(**stats)
@router.post(
"/re-resolve",
summary="Re-resolve all IPs whose country could not be determined",

View File

@@ -340,20 +340,34 @@ async def import_source(
if imported_ips:
from app.services import geo_service # noqa: PLC0415
try:
await geo_service.lookup_batch(imported_ips, http_session, db=db)
uncached_ips: list[str] = [
ip for ip in imported_ips if not geo_service.is_cached(ip)
]
skipped_geo: int = len(imported_ips) - len(uncached_ips)
if skipped_geo > 0:
log.info(
"blocklist_geo_prewarm_complete",
"blocklist_geo_prewarm_cache_hit",
source_id=source.id,
count=len(imported_ips),
)
except Exception as exc: # noqa: BLE001
log.warning(
"blocklist_geo_prewarm_failed",
source_id=source.id,
error=str(exc),
skipped=skipped_geo,
to_lookup=len(uncached_ips),
)
if uncached_ips:
try:
await geo_service.lookup_batch(uncached_ips, http_session, db=db)
log.info(
"blocklist_geo_prewarm_complete",
source_id=source.id,
count=len(uncached_ips),
)
except Exception as exc: # noqa: BLE001
log.warning(
"blocklist_geo_prewarm_failed",
source_id=source.id,
error=str(exc),
)
return ImportSourceResult(
source_id=source.id,
source_url=source.url,

View File

@@ -520,6 +520,42 @@ async def add_log_path(
raise ConfigOperationError(f"Failed to add log path {req.log_path!r}: {exc}") from exc
async def delete_log_path(
socket_path: str,
jail: str,
log_path: str,
) -> None:
"""Remove a monitored log path from an existing jail.
Uses ``set <jail> dellogpath <path>`` to remove the path at runtime
without requiring a daemon restart.
Args:
socket_path: Path to the fail2ban Unix domain socket.
jail: Jail name from which the log path should be removed.
log_path: Absolute path of the log file to stop monitoring.
Raises:
JailNotFoundError: If *jail* is not a known jail.
ConfigOperationError: If the command is rejected by fail2ban.
~app.utils.fail2ban_client.Fail2BanConnectionError: Socket unreachable.
"""
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
try:
_ok(await client.send(["status", jail, "short"]))
except ValueError as exc:
if _is_not_found_error(exc):
raise JailNotFoundError(jail) from exc
raise
try:
_ok(await client.send(["set", jail, "dellogpath", log_path]))
log.info("log_path_deleted", jail=jail, path=log_path)
except ValueError as exc:
raise ConfigOperationError(f"Failed to delete log path {log_path!r}: {exc}") from exc
async def preview_log(req: LogPreviewRequest) -> LogPreviewResponse:
"""Read the last *num_lines* of a log file and test *fail_regex* against each.

View File

@@ -0,0 +1,725 @@
"""File-based fail2ban configuration service.
Provides functions to list, read, and write files in the fail2ban
configuration directory (``jail.d/``, ``filter.d/``, ``action.d/``).
All file operations are synchronous (wrapped in
:func:`asyncio.get_event_loop().run_in_executor` by callers that need async
behaviour) because the config files are small and infrequently touched — the
overhead of async I/O is not warranted here.
Security note: every path-related helper validates that the resolved path
stays strictly inside the configured config directory to prevent directory
traversal attacks.
"""
from __future__ import annotations
import asyncio
import configparser
import re
from pathlib import Path
import structlog
from app.models.file_config import (
ConfFileContent,
ConfFileCreateRequest,
ConfFileEntry,
ConfFilesResponse,
ConfFileUpdateRequest,
JailConfigFile,
JailConfigFileContent,
JailConfigFilesResponse,
)
log: structlog.stdlib.BoundLogger = structlog.get_logger()
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
_MAX_CONTENT_BYTES: int = 512 * 1024 # 512 KB hard cap on file write size
_CONF_EXTENSIONS: tuple[str, str] = (".conf", ".local")
# Allowed characters in a new file's base name. Tighter than the OS allows
# on purpose: alphanumeric, hyphen, underscore, dot (but not leading dot).
_SAFE_NAME_RE: re.Pattern[str] = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$")
# ---------------------------------------------------------------------------
# Custom exceptions
# ---------------------------------------------------------------------------
class ConfigDirError(Exception):
"""Raised when the fail2ban config directory is missing or inaccessible."""
class ConfigFileNotFoundError(Exception):
"""Raised when a requested config file does not exist."""
def __init__(self, filename: str) -> None:
"""Initialise with the filename that was not found.
Args:
filename: The filename that could not be located.
"""
self.filename = filename
super().__init__(f"Config file not found: {filename!r}")
class ConfigFileExistsError(Exception):
"""Raised when trying to create a file that already exists."""
def __init__(self, filename: str) -> None:
"""Initialise with the filename that already exists.
Args:
filename: The filename that conflicts.
"""
self.filename = filename
super().__init__(f"Config file already exists: {filename!r}")
class ConfigFileWriteError(Exception):
"""Raised when a file cannot be written (permissions, disk full, etc.)."""
class ConfigFileNameError(Exception):
"""Raised when a supplied filename is invalid or unsafe."""
# ---------------------------------------------------------------------------
# Internal path helpers
# ---------------------------------------------------------------------------
def _resolve_subdir(config_dir: str, subdir: str) -> Path:
"""Resolve and return the path of *subdir* inside *config_dir*.
Args:
config_dir: The top-level fail2ban config directory.
subdir: Subdirectory name (e.g. ``"jail.d"``).
Returns:
Resolved :class:`~pathlib.Path` to the subdirectory.
Raises:
ConfigDirError: If *config_dir* does not exist or is not a directory.
"""
base = Path(config_dir).resolve()
if not base.is_dir():
raise ConfigDirError(f"fail2ban config directory not found: {config_dir!r}")
return base / subdir
def _assert_within(base: Path, target: Path) -> None:
"""Raise :class:`ConfigFileNameError` if *target* is outside *base*.
Args:
base: The allowed root directory (resolved).
target: The path to validate (resolved).
Raises:
ConfigFileNameError: If *target* would escape *base*.
"""
try:
target.relative_to(base)
except ValueError as err:
raise ConfigFileNameError(
f"Path {str(target)!r} escapes config directory {str(base)!r}"
) from err
def _validate_new_name(name: str) -> None:
"""Validate a base name for a new config file.
Args:
name: The proposed base name (without extension).
Raises:
ConfigFileNameError: If *name* contains invalid characters or patterns.
"""
if not _SAFE_NAME_RE.match(name):
raise ConfigFileNameError(
f"Invalid config file name {name!r}. "
"Use only alphanumeric characters, hyphens, underscores, and dots; "
"must start with an alphanumeric character."
)
def _validate_content(content: str) -> None:
"""Reject content that exceeds the size limit.
Args:
content: The proposed file content.
Raises:
ConfigFileWriteError: If *content* exceeds :data:`_MAX_CONTENT_BYTES`.
"""
if len(content.encode("utf-8")) > _MAX_CONTENT_BYTES:
raise ConfigFileWriteError(
f"Content exceeds maximum allowed size of {_MAX_CONTENT_BYTES // 1024} KB."
)
# ---------------------------------------------------------------------------
# Internal helpers — INI parsing / patching
# ---------------------------------------------------------------------------
def _parse_enabled(path: Path) -> bool:
"""Return the ``enabled`` value for the primary section in *path*.
Reads the INI file with :mod:`configparser` and looks for an ``enabled``
key in the section whose name matches the file stem (or in ``DEFAULT``).
Returns ``True`` if the key is absent (fail2ban's own default).
Args:
path: Path to a ``.conf`` or ``.local`` jail config file.
Returns:
``True`` if the jail is (or defaults to) enabled, ``False`` otherwise.
"""
cp = configparser.ConfigParser(
# Treat all keys case-insensitively; interpolation disabled because
# fail2ban uses %(variables)s which would confuse configparser.
interpolation=None,
)
try:
cp.read(str(path), encoding="utf-8")
except configparser.Error:
return True # Unreadable files are treated as enabled (safe default).
jail_name = path.stem
# Prefer the jail-specific section; fall back to DEFAULT.
for section in (jail_name, "DEFAULT"):
if cp.has_option(section, "enabled"):
raw = cp.get(section, "enabled").strip().lower()
return raw in ("true", "1", "yes")
return True
def _set_enabled_in_content(content: str, enabled: bool) -> str:
"""Return *content* with the first ``enabled = …`` line replaced.
If no ``enabled`` line exists, appends one to the last ``[section]`` block
found in the file.
Args:
content: Current raw file content.
enabled: New value for the ``enabled`` key.
Returns:
Modified file content as a string.
"""
value = "true" if enabled else "false"
# Try to replace an existing "enabled = ..." line (inside any section).
pattern = re.compile(
r"^(\s*enabled\s*=\s*).*$",
re.MULTILINE | re.IGNORECASE,
)
if pattern.search(content):
return pattern.sub(rf"\g<1>{value}", content, count=1)
# No existing enabled line. Find the last [section] header and append
# the enabled setting right after it.
section_pattern = re.compile(r"^\[([^\[\]]+)\]\s*$", re.MULTILINE)
matches = list(section_pattern.finditer(content))
if matches:
# Insert after the last section header line.
last_match = matches[-1]
insert_pos = last_match.end()
return content[:insert_pos] + f"\nenabled = {value}" + content[insert_pos:]
# No section found at all — prepend a minimal block.
return f"[DEFAULT]\nenabled = {value}\n\n" + content
# ---------------------------------------------------------------------------
# Public API — jail config files (Task 4a)
# ---------------------------------------------------------------------------
async def list_jail_config_files(config_dir: str) -> JailConfigFilesResponse:
"""List all jail config files in ``<config_dir>/jail.d/``.
Only ``.conf`` and ``.local`` files are returned. The ``enabled`` state
is parsed from each file's content.
Args:
config_dir: Path to the fail2ban configuration directory.
Returns:
:class:`~app.models.file_config.JailConfigFilesResponse`.
Raises:
ConfigDirError: If *config_dir* does not exist.
"""
def _do() -> JailConfigFilesResponse:
jail_d = _resolve_subdir(config_dir, "jail.d")
if not jail_d.is_dir():
log.warning("jail_d_not_found", config_dir=config_dir)
return JailConfigFilesResponse(files=[], total=0)
files: list[JailConfigFile] = []
for path in sorted(jail_d.iterdir()):
if not path.is_file():
continue
if path.suffix not in _CONF_EXTENSIONS:
continue
_assert_within(jail_d.resolve(), path.resolve())
files.append(
JailConfigFile(
name=path.stem,
filename=path.name,
enabled=_parse_enabled(path),
)
)
log.info("jail_config_files_listed", count=len(files))
return JailConfigFilesResponse(files=files, total=len(files))
return await asyncio.get_event_loop().run_in_executor(None, _do)
async def get_jail_config_file(config_dir: str, filename: str) -> JailConfigFileContent:
"""Return the content and metadata of a single jail config file.
Args:
config_dir: Path to the fail2ban configuration directory.
filename: The filename (e.g. ``sshd.conf``) — must end in ``.conf`` or ``.local``.
Returns:
:class:`~app.models.file_config.JailConfigFileContent`.
Raises:
ConfigFileNameError: If *filename* is unsafe.
ConfigFileNotFoundError: If the file does not exist.
ConfigDirError: If the config directory does not exist.
"""
def _do() -> JailConfigFileContent:
jail_d = _resolve_subdir(config_dir, "jail.d").resolve()
if not jail_d.is_dir():
raise ConfigFileNotFoundError(filename)
path = (jail_d / filename).resolve()
_assert_within(jail_d, path)
if path.suffix not in _CONF_EXTENSIONS:
raise ConfigFileNameError(
f"Invalid file extension for {filename!r}. "
"Only .conf and .local files are supported."
)
if not path.is_file():
raise ConfigFileNotFoundError(filename)
content = path.read_text(encoding="utf-8", errors="replace")
return JailConfigFileContent(
name=path.stem,
filename=path.name,
enabled=_parse_enabled(path),
content=content,
)
return await asyncio.get_event_loop().run_in_executor(None, _do)
async def set_jail_config_enabled(
config_dir: str,
filename: str,
enabled: bool,
) -> None:
"""Set the ``enabled`` flag in a jail config file.
Reads the file, modifies (or inserts) the ``enabled`` key, and writes it
back. The update preserves all other content including comments.
Args:
config_dir: Path to the fail2ban configuration directory.
filename: The filename (e.g. ``sshd.conf``).
enabled: New value for the ``enabled`` key.
Raises:
ConfigFileNameError: If *filename* is unsafe.
ConfigFileNotFoundError: If the file does not exist.
ConfigFileWriteError: If the file cannot be written.
ConfigDirError: If the config directory does not exist.
"""
def _do() -> None:
jail_d = _resolve_subdir(config_dir, "jail.d").resolve()
if not jail_d.is_dir():
raise ConfigFileNotFoundError(filename)
path = (jail_d / filename).resolve()
_assert_within(jail_d, path)
if path.suffix not in _CONF_EXTENSIONS:
raise ConfigFileNameError(
f"Only .conf and .local files are supported, got {filename!r}."
)
if not path.is_file():
raise ConfigFileNotFoundError(filename)
original = path.read_text(encoding="utf-8", errors="replace")
updated = _set_enabled_in_content(original, enabled)
try:
path.write_text(updated, encoding="utf-8")
except OSError as exc:
raise ConfigFileWriteError(
f"Cannot write {filename!r}: {exc}"
) from exc
log.info(
"jail_config_file_enabled_set",
filename=filename,
enabled=enabled,
)
await asyncio.get_event_loop().run_in_executor(None, _do)
# ---------------------------------------------------------------------------
# Internal helpers — generic conf file listing / reading / writing
# ---------------------------------------------------------------------------
def _list_conf_files(subdir: Path) -> ConfFilesResponse:
"""List ``.conf`` and ``.local`` files in *subdir*.
Args:
subdir: Resolved path to the directory to scan.
Returns:
:class:`~app.models.file_config.ConfFilesResponse`.
"""
if not subdir.is_dir():
return ConfFilesResponse(files=[], total=0)
files: list[ConfFileEntry] = []
for path in sorted(subdir.iterdir()):
if not path.is_file():
continue
if path.suffix not in _CONF_EXTENSIONS:
continue
_assert_within(subdir.resolve(), path.resolve())
files.append(ConfFileEntry(name=path.stem, filename=path.name))
return ConfFilesResponse(files=files, total=len(files))
def _read_conf_file(subdir: Path, name: str) -> ConfFileContent:
"""Read a single conf file by base name.
Args:
subdir: Resolved path to the containing directory.
name: Base name with optional extension. If no extension is given,
``.conf`` is tried first, then ``.local``.
Returns:
:class:`~app.models.file_config.ConfFileContent`.
Raises:
ConfigFileNameError: If *name* is unsafe.
ConfigFileNotFoundError: If no matching file is found.
"""
resolved_subdir = subdir.resolve()
# Accept names with or without extension.
if "." in name and not name.startswith("."):
candidates = [resolved_subdir / name]
else:
candidates = [resolved_subdir / (name + ext) for ext in _CONF_EXTENSIONS]
for path in candidates:
resolved = path.resolve()
_assert_within(resolved_subdir, resolved)
if resolved.is_file():
content = resolved.read_text(encoding="utf-8", errors="replace")
return ConfFileContent(
name=resolved.stem,
filename=resolved.name,
content=content,
)
raise ConfigFileNotFoundError(name)
def _write_conf_file(subdir: Path, name: str, content: str) -> None:
"""Overwrite or create a conf file.
Args:
subdir: Resolved path to the containing directory.
name: Base name with optional extension.
content: New file content.
Raises:
ConfigFileNameError: If *name* is unsafe.
ConfigFileNotFoundError: If *name* does not match an existing file
(use :func:`_create_conf_file` for new files).
ConfigFileWriteError: If the file cannot be written.
"""
resolved_subdir = subdir.resolve()
_validate_content(content)
# Accept names with or without extension.
if "." in name and not name.startswith("."):
candidates = [resolved_subdir / name]
else:
candidates = [resolved_subdir / (name + ext) for ext in _CONF_EXTENSIONS]
target: Path | None = None
for path in candidates:
resolved = path.resolve()
_assert_within(resolved_subdir, resolved)
if resolved.is_file():
target = resolved
break
if target is None:
raise ConfigFileNotFoundError(name)
try:
target.write_text(content, encoding="utf-8")
except OSError as exc:
raise ConfigFileWriteError(f"Cannot write {name!r}: {exc}") from exc
def _create_conf_file(subdir: Path, name: str, content: str) -> str:
"""Create a new ``.conf`` file in *subdir*.
Args:
subdir: Resolved path to the containing directory.
name: Base name for the new file (without extension).
content: Initial file content.
Returns:
The filename that was created (e.g. ``myfilter.conf``).
Raises:
ConfigFileNameError: If *name* is invalid.
ConfigFileExistsError: If a ``.conf`` or ``.local`` file with *name* already exists.
ConfigFileWriteError: If the file cannot be written.
"""
resolved_subdir = subdir.resolve()
_validate_new_name(name)
_validate_content(content)
for ext in _CONF_EXTENSIONS:
existing = (resolved_subdir / (name + ext)).resolve()
_assert_within(resolved_subdir, existing)
if existing.exists():
raise ConfigFileExistsError(name + ext)
target = (resolved_subdir / (name + ".conf")).resolve()
_assert_within(resolved_subdir, target)
try:
target.write_text(content, encoding="utf-8")
except OSError as exc:
raise ConfigFileWriteError(f"Cannot create {name!r}: {exc}") from exc
return target.name
# ---------------------------------------------------------------------------
# Public API — filter files (Task 4d)
# ---------------------------------------------------------------------------
async def list_filter_files(config_dir: str) -> ConfFilesResponse:
"""List all filter definition files in ``<config_dir>/filter.d/``.
Args:
config_dir: Path to the fail2ban configuration directory.
Returns:
:class:`~app.models.file_config.ConfFilesResponse`.
Raises:
ConfigDirError: If *config_dir* does not exist.
"""
def _do() -> ConfFilesResponse:
filter_d = _resolve_subdir(config_dir, "filter.d")
result = _list_conf_files(filter_d)
log.info("filter_files_listed", count=result.total)
return result
return await asyncio.get_event_loop().run_in_executor(None, _do)
async def get_filter_file(config_dir: str, name: str) -> ConfFileContent:
"""Return the content of a filter definition file.
Args:
config_dir: Path to the fail2ban configuration directory.
name: Base name (with or without ``.conf``/``.local`` extension).
Returns:
:class:`~app.models.file_config.ConfFileContent`.
Raises:
ConfigFileNotFoundError: If no matching file is found.
ConfigDirError: If *config_dir* does not exist.
"""
def _do() -> ConfFileContent:
filter_d = _resolve_subdir(config_dir, "filter.d")
return _read_conf_file(filter_d, name)
return await asyncio.get_event_loop().run_in_executor(None, _do)
async def write_filter_file(
config_dir: str,
name: str,
req: ConfFileUpdateRequest,
) -> None:
"""Overwrite an existing filter definition file.
Args:
config_dir: Path to the fail2ban configuration directory.
name: Base name of the file to update (with or without extension).
req: :class:`~app.models.file_config.ConfFileUpdateRequest` with new content.
Raises:
ConfigFileNotFoundError: If no matching file is found.
ConfigFileWriteError: If the file cannot be written.
ConfigDirError: If *config_dir* does not exist.
"""
def _do() -> None:
filter_d = _resolve_subdir(config_dir, "filter.d")
_write_conf_file(filter_d, name, req.content)
log.info("filter_file_written", name=name)
await asyncio.get_event_loop().run_in_executor(None, _do)
async def create_filter_file(
config_dir: str,
req: ConfFileCreateRequest,
) -> str:
"""Create a new filter definition file.
Args:
config_dir: Path to the fail2ban configuration directory.
req: :class:`~app.models.file_config.ConfFileCreateRequest`.
Returns:
The filename that was created.
Raises:
ConfigFileExistsError: If a file with that name already exists.
ConfigFileNameError: If the name is invalid.
ConfigFileWriteError: If the file cannot be created.
ConfigDirError: If *config_dir* does not exist.
"""
def _do() -> str:
filter_d = _resolve_subdir(config_dir, "filter.d")
filename = _create_conf_file(filter_d, req.name, req.content)
log.info("filter_file_created", filename=filename)
return filename
return await asyncio.get_event_loop().run_in_executor(None, _do)
# ---------------------------------------------------------------------------
# Public API — action files (Task 4e)
# ---------------------------------------------------------------------------
async def list_action_files(config_dir: str) -> ConfFilesResponse:
"""List all action definition files in ``<config_dir>/action.d/``.
Args:
config_dir: Path to the fail2ban configuration directory.
Returns:
:class:`~app.models.file_config.ConfFilesResponse`.
Raises:
ConfigDirError: If *config_dir* does not exist.
"""
def _do() -> ConfFilesResponse:
action_d = _resolve_subdir(config_dir, "action.d")
result = _list_conf_files(action_d)
log.info("action_files_listed", count=result.total)
return result
return await asyncio.get_event_loop().run_in_executor(None, _do)
async def get_action_file(config_dir: str, name: str) -> ConfFileContent:
"""Return the content of an action definition file.
Args:
config_dir: Path to the fail2ban configuration directory.
name: Base name (with or without ``.conf``/``.local`` extension).
Returns:
:class:`~app.models.file_config.ConfFileContent`.
Raises:
ConfigFileNotFoundError: If no matching file is found.
ConfigDirError: If *config_dir* does not exist.
"""
def _do() -> ConfFileContent:
action_d = _resolve_subdir(config_dir, "action.d")
return _read_conf_file(action_d, name)
return await asyncio.get_event_loop().run_in_executor(None, _do)
async def write_action_file(
config_dir: str,
name: str,
req: ConfFileUpdateRequest,
) -> None:
"""Overwrite an existing action definition file.
Args:
config_dir: Path to the fail2ban configuration directory.
name: Base name of the file to update.
req: :class:`~app.models.file_config.ConfFileUpdateRequest` with new content.
Raises:
ConfigFileNotFoundError: If no matching file is found.
ConfigFileWriteError: If the file cannot be written.
ConfigDirError: If *config_dir* does not exist.
"""
def _do() -> None:
action_d = _resolve_subdir(config_dir, "action.d")
_write_conf_file(action_d, name, req.content)
log.info("action_file_written", name=name)
await asyncio.get_event_loop().run_in_executor(None, _do)
async def create_action_file(
config_dir: str,
req: ConfFileCreateRequest,
) -> str:
"""Create a new action definition file.
Args:
config_dir: Path to the fail2ban configuration directory.
req: :class:`~app.models.file_config.ConfFileCreateRequest`.
Returns:
The filename that was created.
Raises:
ConfigFileExistsError: If a file with that name already exists.
ConfigFileNameError: If the name is invalid.
ConfigFileWriteError: If the file cannot be created.
ConfigDirError: If *config_dir* does not exist.
"""
def _do() -> str:
action_d = _resolve_subdir(config_dir, "action.d")
filename = _create_conf_file(action_d, req.name, req.content)
log.info("action_file_created", filename=filename)
return filename
return await asyncio.get_event_loop().run_in_executor(None, _do)

View File

@@ -0,0 +1,103 @@
"""Geo re-resolve background task.
Registers an APScheduler job that periodically retries IP addresses in the
``geo_cache`` table whose ``country_code`` is ``NULL``. These are IPs that
previously failed to resolve (e.g. due to ip-api.com rate limiting) and were
recorded as negative entries.
The task runs every 10 minutes. On each invocation it:
1. Queries all ``NULL``-country rows from ``geo_cache``.
2. Clears the in-memory negative cache so those IPs are eligible for a fresh
API attempt.
3. Delegates to :func:`~app.services.geo_service.lookup_batch` which already
handles rate-limit throttling and retries.
4. Logs how many IPs were retried and how many resolved successfully.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
import structlog
from app.services import geo_service
if TYPE_CHECKING:
from fastapi import FastAPI
log: structlog.stdlib.BoundLogger = structlog.get_logger()
#: How often the re-resolve job fires (seconds). 10 minutes.
GEO_RE_RESOLVE_INTERVAL: int = 600
#: Stable APScheduler job ID — ensures re-registration replaces, not duplicates.
JOB_ID: str = "geo_re_resolve"
async def _run_re_resolve(app: Any) -> None:
"""Query NULL-country IPs from the database and re-resolve them.
Reads shared resources from ``app.state`` and delegates to
:func:`~app.services.geo_service.lookup_batch`.
Args:
app: The :class:`fastapi.FastAPI` application instance passed via
APScheduler ``kwargs``.
"""
db = app.state.db
http_session = app.state.http_session
# Fetch all IPs with NULL country_code from the persistent cache.
unresolved_ips: list[str] = []
async with db.execute(
"SELECT ip FROM geo_cache WHERE country_code IS NULL"
) as cursor:
async for row in cursor:
unresolved_ips.append(str(row[0]))
if not unresolved_ips:
log.debug("geo_re_resolve_skip", reason="no_unresolved_ips")
return
log.info("geo_re_resolve_start", unresolved=len(unresolved_ips))
# Clear the negative cache so these IPs are eligible for fresh API calls.
geo_service.clear_neg_cache()
# lookup_batch handles throttling, retries, and persistence when db is
# passed. This is a background task so DB writes are allowed.
results = await geo_service.lookup_batch(unresolved_ips, http_session, db=db)
resolved_count: int = sum(
1 for info in results.values() if info.country_code is not None
)
log.info(
"geo_re_resolve_complete",
retried=len(unresolved_ips),
resolved=resolved_count,
)
def register(app: FastAPI) -> None:
"""Add (or replace) the geo re-resolve job in the application scheduler.
Must be called after the scheduler has been started (i.e., inside the
lifespan handler, after ``scheduler.start()``).
The first invocation is deferred by one full interval so the initial
blocklist prewarm has time to finish before re-resolve kicks in.
Args:
app: The :class:`fastapi.FastAPI` application instance whose
``app.state.scheduler`` will receive the job.
"""
app.state.scheduler.add_job(
_run_re_resolve,
trigger="interval",
seconds=GEO_RE_RESOLVE_INTERVAL,
kwargs={"app": app},
id=JOB_ID,
replace_existing=True,
)
log.info("geo_re_resolve_scheduled", interval_seconds=GEO_RE_RESOLVE_INTERVAL)

View File

@@ -0,0 +1,379 @@
"""Tests for the file_config router endpoints."""
from __future__ import annotations
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import aiosqlite
import pytest
from httpx import ASGITransport, AsyncClient
from app.config import Settings
from app.db import init_db
from app.main import create_app
from app.models.file_config import (
ConfFileContent,
ConfFileEntry,
ConfFilesResponse,
JailConfigFile,
JailConfigFileContent,
JailConfigFilesResponse,
)
from app.services.file_config_service import (
ConfigDirError,
ConfigFileExistsError,
ConfigFileNameError,
ConfigFileNotFoundError,
ConfigFileWriteError,
)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
_SETUP_PAYLOAD = {
"master_password": "testpassword1",
"database_path": "bangui.db",
"fail2ban_socket": "/var/run/fail2ban/fail2ban.sock",
"timezone": "UTC",
"session_duration_minutes": 60,
}
@pytest.fixture
async def file_config_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
"""Provide an authenticated ``AsyncClient`` for file_config endpoint tests."""
settings = Settings(
database_path=str(tmp_path / "file_config_test.db"),
fail2ban_socket="/tmp/fake.sock",
session_secret="test-file-config-secret",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
)
app = create_app(settings=settings)
db: aiosqlite.Connection = await aiosqlite.connect(settings.database_path)
db.row_factory = aiosqlite.Row
await init_db(db)
app.state.db = db
app.state.http_session = MagicMock()
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
await ac.post("/api/setup", json=_SETUP_PAYLOAD)
login = await ac.post(
"/api/auth/login",
json={"password": _SETUP_PAYLOAD["master_password"]},
)
assert login.status_code == 200
yield ac
await db.close()
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _jail_files_resp(files: list[JailConfigFile] | None = None) -> JailConfigFilesResponse:
files = files or [JailConfigFile(name="sshd", filename="sshd.conf", enabled=True)]
return JailConfigFilesResponse(files=files, total=len(files))
def _conf_files_resp(files: list[ConfFileEntry] | None = None) -> ConfFilesResponse:
files = files or [ConfFileEntry(name="nginx", filename="nginx.conf")]
return ConfFilesResponse(files=files, total=len(files))
def _conf_file_content(name: str = "nginx") -> ConfFileContent:
return ConfFileContent(
name=name,
filename=f"{name}.conf",
content=f"[Definition]\n# {name} filter\n",
)
# ---------------------------------------------------------------------------
# GET /api/config/jail-files
# ---------------------------------------------------------------------------
class TestListJailConfigFiles:
async def test_200_returns_file_list(
self, file_config_client: AsyncClient
) -> None:
with patch(
"app.routers.file_config.file_config_service.list_jail_config_files",
AsyncMock(return_value=_jail_files_resp()),
):
resp = await file_config_client.get("/api/config/jail-files")
assert resp.status_code == 200
data = resp.json()
assert data["total"] == 1
assert data["files"][0]["filename"] == "sshd.conf"
async def test_503_on_config_dir_error(
self, file_config_client: AsyncClient
) -> None:
with patch(
"app.routers.file_config.file_config_service.list_jail_config_files",
AsyncMock(side_effect=ConfigDirError("not found")),
):
resp = await file_config_client.get("/api/config/jail-files")
assert resp.status_code == 503
async def test_401_unauthenticated(self, file_config_client: AsyncClient) -> None:
resp = await AsyncClient(
transport=ASGITransport(app=file_config_client._transport.app), # type: ignore[attr-defined]
base_url="http://test",
).get("/api/config/jail-files")
assert resp.status_code == 401
# ---------------------------------------------------------------------------
# GET /api/config/jail-files/{filename}
# ---------------------------------------------------------------------------
class TestGetJailConfigFile:
async def test_200_returns_content(
self, file_config_client: AsyncClient
) -> None:
content = JailConfigFileContent(
name="sshd",
filename="sshd.conf",
enabled=True,
content="[sshd]\nenabled = true\n",
)
with patch(
"app.routers.file_config.file_config_service.get_jail_config_file",
AsyncMock(return_value=content),
):
resp = await file_config_client.get("/api/config/jail-files/sshd.conf")
assert resp.status_code == 200
assert resp.json()["content"] == "[sshd]\nenabled = true\n"
async def test_404_not_found(self, file_config_client: AsyncClient) -> None:
with patch(
"app.routers.file_config.file_config_service.get_jail_config_file",
AsyncMock(side_effect=ConfigFileNotFoundError("missing.conf")),
):
resp = await file_config_client.get("/api/config/jail-files/missing.conf")
assert resp.status_code == 404
async def test_400_invalid_filename(
self, file_config_client: AsyncClient
) -> None:
with patch(
"app.routers.file_config.file_config_service.get_jail_config_file",
AsyncMock(side_effect=ConfigFileNameError("bad name")),
):
resp = await file_config_client.get("/api/config/jail-files/bad.txt")
assert resp.status_code == 400
# ---------------------------------------------------------------------------
# PUT /api/config/jail-files/{filename}/enabled
# ---------------------------------------------------------------------------
class TestSetJailConfigEnabled:
async def test_204_on_success(self, file_config_client: AsyncClient) -> None:
with patch(
"app.routers.file_config.file_config_service.set_jail_config_enabled",
AsyncMock(return_value=None),
):
resp = await file_config_client.put(
"/api/config/jail-files/sshd.conf/enabled",
json={"enabled": False},
)
assert resp.status_code == 204
async def test_404_file_not_found(self, file_config_client: AsyncClient) -> None:
with patch(
"app.routers.file_config.file_config_service.set_jail_config_enabled",
AsyncMock(side_effect=ConfigFileNotFoundError("missing.conf")),
):
resp = await file_config_client.put(
"/api/config/jail-files/missing.conf/enabled",
json={"enabled": True},
)
assert resp.status_code == 404
# ---------------------------------------------------------------------------
# GET /api/config/filters
# ---------------------------------------------------------------------------
class TestListFilterFiles:
async def test_200_returns_files(self, file_config_client: AsyncClient) -> None:
with patch(
"app.routers.file_config.file_config_service.list_filter_files",
AsyncMock(return_value=_conf_files_resp()),
):
resp = await file_config_client.get("/api/config/filters")
assert resp.status_code == 200
assert resp.json()["total"] == 1
async def test_503_on_config_dir_error(
self, file_config_client: AsyncClient
) -> None:
with patch(
"app.routers.file_config.file_config_service.list_filter_files",
AsyncMock(side_effect=ConfigDirError("x")),
):
resp = await file_config_client.get("/api/config/filters")
assert resp.status_code == 503
# ---------------------------------------------------------------------------
# GET /api/config/filters/{name}
# ---------------------------------------------------------------------------
class TestGetFilterFile:
async def test_200_returns_content(self, file_config_client: AsyncClient) -> None:
with patch(
"app.routers.file_config.file_config_service.get_filter_file",
AsyncMock(return_value=_conf_file_content("nginx")),
):
resp = await file_config_client.get("/api/config/filters/nginx")
assert resp.status_code == 200
assert resp.json()["name"] == "nginx"
async def test_404_not_found(self, file_config_client: AsyncClient) -> None:
with patch(
"app.routers.file_config.file_config_service.get_filter_file",
AsyncMock(side_effect=ConfigFileNotFoundError("missing")),
):
resp = await file_config_client.get("/api/config/filters/missing")
assert resp.status_code == 404
# ---------------------------------------------------------------------------
# PUT /api/config/filters/{name}
# ---------------------------------------------------------------------------
class TestUpdateFilterFile:
async def test_204_on_success(self, file_config_client: AsyncClient) -> None:
with patch(
"app.routers.file_config.file_config_service.write_filter_file",
AsyncMock(return_value=None),
):
resp = await file_config_client.put(
"/api/config/filters/nginx",
json={"content": "[Definition]\nfailregex = test\n"},
)
assert resp.status_code == 204
async def test_400_write_error(self, file_config_client: AsyncClient) -> None:
with patch(
"app.routers.file_config.file_config_service.write_filter_file",
AsyncMock(side_effect=ConfigFileWriteError("disk full")),
):
resp = await file_config_client.put(
"/api/config/filters/nginx",
json={"content": "x"},
)
assert resp.status_code == 400
# ---------------------------------------------------------------------------
# POST /api/config/filters
# ---------------------------------------------------------------------------
class TestCreateFilterFile:
async def test_201_creates_file(self, file_config_client: AsyncClient) -> None:
with patch(
"app.routers.file_config.file_config_service.create_filter_file",
AsyncMock(return_value="myfilter.conf"),
):
resp = await file_config_client.post(
"/api/config/filters",
json={"name": "myfilter", "content": "[Definition]\n"},
)
assert resp.status_code == 201
assert resp.json()["filename"] == "myfilter.conf"
async def test_409_conflict(self, file_config_client: AsyncClient) -> None:
with patch(
"app.routers.file_config.file_config_service.create_filter_file",
AsyncMock(side_effect=ConfigFileExistsError("myfilter.conf")),
):
resp = await file_config_client.post(
"/api/config/filters",
json={"name": "myfilter", "content": "[Definition]\n"},
)
assert resp.status_code == 409
async def test_400_invalid_name(self, file_config_client: AsyncClient) -> None:
with patch(
"app.routers.file_config.file_config_service.create_filter_file",
AsyncMock(side_effect=ConfigFileNameError("bad/../name")),
):
resp = await file_config_client.post(
"/api/config/filters",
json={"name": "../escape", "content": "[Definition]\n"},
)
assert resp.status_code == 400
# ---------------------------------------------------------------------------
# GET /api/config/actions (smoke test — same logic as filters)
# ---------------------------------------------------------------------------
class TestListActionFiles:
async def test_200_returns_files(self, file_config_client: AsyncClient) -> None:
action_entry = ConfFileEntry(name="iptables", filename="iptables.conf")
resp_data = ConfFilesResponse(files=[action_entry], total=1)
with patch(
"app.routers.file_config.file_config_service.list_action_files",
AsyncMock(return_value=resp_data),
):
resp = await file_config_client.get("/api/config/actions")
assert resp.status_code == 200
assert resp.json()["files"][0]["filename"] == "iptables.conf"
# ---------------------------------------------------------------------------
# POST /api/config/actions
# ---------------------------------------------------------------------------
class TestCreateActionFile:
async def test_201_creates_file(self, file_config_client: AsyncClient) -> None:
with patch(
"app.routers.file_config.file_config_service.create_action_file",
AsyncMock(return_value="myaction.conf"),
):
resp = await file_config_client.post(
"/api/config/actions",
json={"name": "myaction", "content": "[Definition]\n"},
)
assert resp.status_code == 201
assert resp.json()["filename"] == "myaction.conf"

View File

@@ -215,3 +215,66 @@ class TestReResolve:
base_url="http://test",
).post("/api/geo/re-resolve")
assert resp.status_code == 401
# ---------------------------------------------------------------------------
# GET /api/geo/stats
# ---------------------------------------------------------------------------
class TestGeoStats:
"""Tests for ``GET /api/geo/stats``."""
async def test_returns_200_with_stats(self, geo_client: AsyncClient) -> None:
"""GET /api/geo/stats returns 200 with the expected keys."""
stats = {
"cache_size": 100,
"unresolved": 5,
"neg_cache_size": 2,
"dirty_size": 0,
}
with patch(
"app.routers.geo.geo_service.cache_stats",
AsyncMock(return_value=stats),
):
resp = await geo_client.get("/api/geo/stats")
assert resp.status_code == 200
data = resp.json()
assert data["cache_size"] == 100
assert data["unresolved"] == 5
assert data["neg_cache_size"] == 2
assert data["dirty_size"] == 0
async def test_stats_empty_cache(self, geo_client: AsyncClient) -> None:
"""GET /api/geo/stats returns all zeros on a fresh database."""
resp = await geo_client.get("/api/geo/stats")
assert resp.status_code == 200
data = resp.json()
assert data["cache_size"] >= 0
assert data["unresolved"] == 0
assert data["neg_cache_size"] >= 0
assert data["dirty_size"] >= 0
async def test_stats_counts_unresolved(self, geo_client: AsyncClient) -> None:
"""GET /api/geo/stats counts NULL-country rows correctly."""
app = geo_client._transport.app # type: ignore[attr-defined]
db: aiosqlite.Connection = app.state.db
await db.execute("INSERT OR IGNORE INTO geo_cache (ip) VALUES (?)", ("7.7.7.7",))
await db.execute("INSERT OR IGNORE INTO geo_cache (ip) VALUES (?)", ("8.8.8.8",))
await db.commit()
resp = await geo_client.get("/api/geo/stats")
assert resp.status_code == 200
assert resp.json()["unresolved"] >= 2
async def test_401_when_unauthenticated(self, geo_client: AsyncClient) -> None:
"""GET /api/geo/stats requires authentication."""
app = geo_client._transport.app # type: ignore[attr-defined]
resp = await AsyncClient(
transport=ASGITransport(app=app),
base_url="http://test",
).get("/api/geo/stats")
assert resp.status_code == 401

View File

@@ -293,3 +293,47 @@ class TestSchedule:
)
info = await blocklist_service.get_schedule_info(db, None)
assert info.last_run_errors is True
# ---------------------------------------------------------------------------
# Geo prewarm cache filtering
# ---------------------------------------------------------------------------
class TestGeoPrewarmCacheFilter:
async def test_import_source_skips_cached_ips_for_geo_prewarm(
self, db: aiosqlite.Connection
) -> None:
"""import_source only sends uncached IPs to geo_service.lookup_batch."""
content = "1.2.3.4\n5.6.7.8\n9.10.11.12\n"
session = _make_session(content)
source = await blocklist_service.create_source(
db, "Geo Filter", "https://gf.test/"
)
# Pretend 1.2.3.4 is already cached.
def _mock_is_cached(ip: str) -> bool:
return ip == "1.2.3.4"
with (
patch("app.services.jail_service.ban_ip", new_callable=AsyncMock),
patch(
"app.services.geo_service.is_cached",
side_effect=_mock_is_cached,
),
patch(
"app.services.geo_service.lookup_batch",
new_callable=AsyncMock,
return_value={},
) as mock_batch,
):
result = await blocklist_service.import_source(
source, session, "/tmp/fake.sock", db
)
assert result.ips_imported == 3
# lookup_batch should receive only the 2 uncached IPs.
mock_batch.assert_called_once()
call_ips = mock_batch.call_args[0][0]
assert "1.2.3.4" not in call_ips
assert set(call_ips) == {"5.6.7.8", "9.10.11.12"}

View File

@@ -0,0 +1,401 @@
"""Tests for file_config_service functions."""
from __future__ import annotations
from pathlib import Path
import pytest
from app.models.file_config import ConfFileCreateRequest, ConfFileUpdateRequest
from app.services.file_config_service import (
ConfigDirError,
ConfigFileExistsError,
ConfigFileNameError,
ConfigFileNotFoundError,
ConfigFileWriteError,
_parse_enabled,
_set_enabled_in_content,
_validate_new_name,
create_action_file,
create_filter_file,
get_action_file,
get_filter_file,
get_jail_config_file,
list_action_files,
list_filter_files,
list_jail_config_files,
set_jail_config_enabled,
write_action_file,
write_filter_file,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_config_dir(tmp_path: Path) -> Path:
"""Create a minimal fail2ban config directory structure."""
config_dir = tmp_path / "fail2ban"
(config_dir / "jail.d").mkdir(parents=True)
(config_dir / "filter.d").mkdir(parents=True)
(config_dir / "action.d").mkdir(parents=True)
return config_dir
# ---------------------------------------------------------------------------
# _parse_enabled
# ---------------------------------------------------------------------------
def test_parse_enabled_explicit_true(tmp_path: Path) -> None:
f = tmp_path / "sshd.conf"
f.write_text("[sshd]\nenabled = true\n")
assert _parse_enabled(f) is True
def test_parse_enabled_explicit_false(tmp_path: Path) -> None:
f = tmp_path / "sshd.conf"
f.write_text("[sshd]\nenabled = false\n")
assert _parse_enabled(f) is False
def test_parse_enabled_default_true_when_absent(tmp_path: Path) -> None:
f = tmp_path / "sshd.conf"
f.write_text("[sshd]\nbantime = 600\n")
assert _parse_enabled(f) is True
def test_parse_enabled_in_default_section(tmp_path: Path) -> None:
f = tmp_path / "custom.conf"
f.write_text("[DEFAULT]\nenabled = false\n")
assert _parse_enabled(f) is False
# ---------------------------------------------------------------------------
# _set_enabled_in_content
# ---------------------------------------------------------------------------
def test_set_enabled_replaces_existing_line() -> None:
src = "[sshd]\nenabled = false\nbantime = 600\n"
result = _set_enabled_in_content(src, True)
assert "enabled = true" in result
assert "enabled = false" not in result
def test_set_enabled_inserts_after_section() -> None:
src = "[sshd]\nbantime = 600\n"
result = _set_enabled_in_content(src, False)
assert "enabled = false" in result
def test_set_enabled_prepends_default_when_no_section() -> None:
result = _set_enabled_in_content("bantime = 600\n", True)
assert "enabled = true" in result
# ---------------------------------------------------------------------------
# _validate_new_name
# ---------------------------------------------------------------------------
@pytest.mark.parametrize("name", ["sshd", "my-filter", "test.local", "A1_filter"])
def test_validate_new_name_valid(name: str) -> None:
_validate_new_name(name) # should not raise
@pytest.mark.parametrize(
"name",
[
"",
".",
".hidden",
"../escape",
"bad/slash",
"a" * 129, # too long
"hello world", # space
],
)
def test_validate_new_name_invalid(name: str) -> None:
with pytest.raises(ConfigFileNameError):
_validate_new_name(name)
# ---------------------------------------------------------------------------
# list_jail_config_files
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_list_jail_config_files_empty(tmp_path: Path) -> None:
config_dir = _make_config_dir(tmp_path)
resp = await list_jail_config_files(str(config_dir))
assert resp.files == []
assert resp.total == 0
@pytest.mark.asyncio
async def test_list_jail_config_files_returns_conf_files(tmp_path: Path) -> None:
config_dir = _make_config_dir(tmp_path)
(config_dir / "jail.d" / "sshd.conf").write_text("[sshd]\nenabled = true\n")
(config_dir / "jail.d" / "nginx.conf").write_text("[nginx]\n")
(config_dir / "jail.d" / "other.txt").write_text("ignored")
resp = await list_jail_config_files(str(config_dir))
names = {f.filename for f in resp.files}
assert names == {"sshd.conf", "nginx.conf"}
assert resp.total == 2
@pytest.mark.asyncio
async def test_list_jail_config_files_enabled_state(tmp_path: Path) -> None:
config_dir = _make_config_dir(tmp_path)
(config_dir / "jail.d" / "a.conf").write_text("[a]\nenabled = false\n")
(config_dir / "jail.d" / "b.conf").write_text("[b]\n")
resp = await list_jail_config_files(str(config_dir))
by_name = {f.filename: f for f in resp.files}
assert by_name["a.conf"].enabled is False
assert by_name["b.conf"].enabled is True
@pytest.mark.asyncio
async def test_list_jail_config_files_missing_config_dir(tmp_path: Path) -> None:
with pytest.raises(ConfigDirError):
await list_jail_config_files(str(tmp_path / "nonexistent"))
# ---------------------------------------------------------------------------
# get_jail_config_file
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_get_jail_config_file_returns_content(tmp_path: Path) -> None:
config_dir = _make_config_dir(tmp_path)
(config_dir / "jail.d" / "sshd.conf").write_text("[sshd]\nenabled = true\n")
result = await get_jail_config_file(str(config_dir), "sshd.conf")
assert result.filename == "sshd.conf"
assert result.name == "sshd"
assert result.enabled is True
assert "[sshd]" in result.content
@pytest.mark.asyncio
async def test_get_jail_config_file_not_found(tmp_path: Path) -> None:
config_dir = _make_config_dir(tmp_path)
with pytest.raises(ConfigFileNotFoundError):
await get_jail_config_file(str(config_dir), "missing.conf")
@pytest.mark.asyncio
async def test_get_jail_config_file_invalid_extension(tmp_path: Path) -> None:
config_dir = _make_config_dir(tmp_path)
(config_dir / "jail.d" / "bad.txt").write_text("content")
with pytest.raises(ConfigFileNameError):
await get_jail_config_file(str(config_dir), "bad.txt")
@pytest.mark.asyncio
async def test_get_jail_config_file_path_traversal(tmp_path: Path) -> None:
config_dir = _make_config_dir(tmp_path)
with pytest.raises((ConfigFileNameError, ConfigFileNotFoundError)):
await get_jail_config_file(str(config_dir), "../jail.conf")
# ---------------------------------------------------------------------------
# set_jail_config_enabled
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_set_jail_config_enabled_writes_false(tmp_path: Path) -> None:
config_dir = _make_config_dir(tmp_path)
path = config_dir / "jail.d" / "sshd.conf"
path.write_text("[sshd]\nenabled = true\n")
await set_jail_config_enabled(str(config_dir), "sshd.conf", False)
assert "enabled = false" in path.read_text()
@pytest.mark.asyncio
async def test_set_jail_config_enabled_inserts_when_missing(tmp_path: Path) -> None:
config_dir = _make_config_dir(tmp_path)
path = config_dir / "jail.d" / "sshd.conf"
path.write_text("[sshd]\nbantime = 600\n")
await set_jail_config_enabled(str(config_dir), "sshd.conf", False)
assert "enabled = false" in path.read_text()
@pytest.mark.asyncio
async def test_set_jail_config_enabled_file_not_found(tmp_path: Path) -> None:
config_dir = _make_config_dir(tmp_path)
with pytest.raises(ConfigFileNotFoundError):
await set_jail_config_enabled(str(config_dir), "missing.conf", True)
# ---------------------------------------------------------------------------
# list_filter_files / list_action_files
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_list_filter_files_empty(tmp_path: Path) -> None:
config_dir = _make_config_dir(tmp_path)
resp = await list_filter_files(str(config_dir))
assert resp.files == []
@pytest.mark.asyncio
async def test_list_filter_files_returns_files(tmp_path: Path) -> None:
config_dir = _make_config_dir(tmp_path)
(config_dir / "filter.d" / "nginx.conf").write_text("[Definition]\n")
(config_dir / "filter.d" / "sshd.local").write_text("[Definition]\n")
(config_dir / "filter.d" / "ignore.py").write_text("# ignored")
resp = await list_filter_files(str(config_dir))
names = {f.filename for f in resp.files}
assert names == {"nginx.conf", "sshd.local"}
@pytest.mark.asyncio
async def test_list_action_files_returns_files(tmp_path: Path) -> None:
config_dir = _make_config_dir(tmp_path)
(config_dir / "action.d" / "iptables.conf").write_text("[Definition]\n")
resp = await list_action_files(str(config_dir))
assert resp.files[0].filename == "iptables.conf"
# ---------------------------------------------------------------------------
# get_filter_file / get_action_file
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_get_filter_file_by_stem(tmp_path: Path) -> None:
config_dir = _make_config_dir(tmp_path)
(config_dir / "filter.d" / "nginx.conf").write_text("[Definition]\nfailregex = test\n")
result = await get_filter_file(str(config_dir), "nginx")
assert result.name == "nginx"
assert "failregex" in result.content
@pytest.mark.asyncio
async def test_get_filter_file_by_full_name(tmp_path: Path) -> None:
config_dir = _make_config_dir(tmp_path)
(config_dir / "filter.d" / "nginx.conf").write_text("[Definition]\n")
result = await get_filter_file(str(config_dir), "nginx.conf")
assert result.filename == "nginx.conf"
@pytest.mark.asyncio
async def test_get_filter_file_not_found(tmp_path: Path) -> None:
config_dir = _make_config_dir(tmp_path)
with pytest.raises(ConfigFileNotFoundError):
await get_filter_file(str(config_dir), "nonexistent")
@pytest.mark.asyncio
async def test_get_action_file_returns_content(tmp_path: Path) -> None:
config_dir = _make_config_dir(tmp_path)
(config_dir / "action.d" / "iptables.conf").write_text("[Definition]\nactionban = <ip>\n")
result = await get_action_file(str(config_dir), "iptables")
assert "actionban" in result.content
# ---------------------------------------------------------------------------
# write_filter_file / write_action_file
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_write_filter_file_updates_content(tmp_path: Path) -> None:
config_dir = _make_config_dir(tmp_path)
(config_dir / "filter.d" / "nginx.conf").write_text("[Definition]\n")
req = ConfFileUpdateRequest(content="[Definition]\nfailregex = new\n")
await write_filter_file(str(config_dir), "nginx", req)
assert "failregex = new" in (config_dir / "filter.d" / "nginx.conf").read_text()
@pytest.mark.asyncio
async def test_write_filter_file_not_found(tmp_path: Path) -> None:
config_dir = _make_config_dir(tmp_path)
req = ConfFileUpdateRequest(content="[Definition]\n")
with pytest.raises(ConfigFileNotFoundError):
await write_filter_file(str(config_dir), "missing", req)
@pytest.mark.asyncio
async def test_write_filter_file_too_large(tmp_path: Path) -> None:
config_dir = _make_config_dir(tmp_path)
(config_dir / "filter.d" / "nginx.conf").write_text("[Definition]\n")
big_content = "x" * (512 * 1024 + 1)
req = ConfFileUpdateRequest(content=big_content)
with pytest.raises(ConfigFileWriteError):
await write_filter_file(str(config_dir), "nginx", req)
@pytest.mark.asyncio
async def test_write_action_file_updates_content(tmp_path: Path) -> None:
config_dir = _make_config_dir(tmp_path)
(config_dir / "action.d" / "iptables.conf").write_text("[Definition]\n")
req = ConfFileUpdateRequest(content="[Definition]\nactionban = new\n")
await write_action_file(str(config_dir), "iptables", req)
assert "actionban = new" in (config_dir / "action.d" / "iptables.conf").read_text()
# ---------------------------------------------------------------------------
# create_filter_file / create_action_file
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_create_filter_file_creates_file(tmp_path: Path) -> None:
config_dir = _make_config_dir(tmp_path)
req = ConfFileCreateRequest(name="myfilter", content="[Definition]\n")
result = await create_filter_file(str(config_dir), req)
assert result == "myfilter.conf"
assert (config_dir / "filter.d" / "myfilter.conf").is_file()
@pytest.mark.asyncio
async def test_create_filter_file_conflict(tmp_path: Path) -> None:
config_dir = _make_config_dir(tmp_path)
(config_dir / "filter.d" / "ngx.conf").write_text("[Definition]\n")
req = ConfFileCreateRequest(name="ngx", content="[Definition]\n")
with pytest.raises(ConfigFileExistsError):
await create_filter_file(str(config_dir), req)
@pytest.mark.asyncio
async def test_create_filter_file_invalid_name(tmp_path: Path) -> None:
config_dir = _make_config_dir(tmp_path)
req = ConfFileCreateRequest(name="../escape", content="[Definition]\n")
with pytest.raises(ConfigFileNameError):
await create_filter_file(str(config_dir), req)
@pytest.mark.asyncio
async def test_create_action_file_creates_file(tmp_path: Path) -> None:
config_dir = _make_config_dir(tmp_path)
req = ConfFileCreateRequest(name="my-action", content="[Definition]\n")
result = await create_action_file(str(config_dir), req)
assert result == "my-action.conf"
assert (config_dir / "action.d" / "my-action.conf").is_file()

View File

@@ -0,0 +1 @@
"""APScheduler task tests package."""

View File

@@ -0,0 +1,167 @@
"""Tests for the geo re-resolve background task.
Validates that :func:`~app.tasks.geo_re_resolve._run_re_resolve` correctly
queries NULL-country IPs from the database, clears the negative cache, and
delegates to :func:`~app.services.geo_service.lookup_batch` for a fresh
resolution attempt.
"""
from __future__ import annotations
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.services.geo_service import GeoInfo
from app.tasks.geo_re_resolve import _run_re_resolve
class _AsyncRowIterator:
"""Minimal async iterator over a list of row tuples."""
def __init__(self, rows: list[tuple[str]]) -> None:
self._iter = iter(rows)
def __aiter__(self) -> _AsyncRowIterator:
return self
async def __anext__(self) -> tuple[str]:
try:
return next(self._iter)
except StopIteration:
raise StopAsyncIteration # noqa: B904
def _make_app(
unresolved_ips: list[str],
lookup_result: dict[str, GeoInfo] | None = None,
) -> MagicMock:
"""Build a minimal mock ``app`` with ``state.db`` and ``state.http_session``.
The mock database returns *unresolved_ips* when the re-resolve task
queries ``SELECT ip FROM geo_cache WHERE country_code IS NULL``.
Args:
unresolved_ips: IPs to return from the mocked DB query.
lookup_result: Value returned by the mocked ``lookup_batch``.
Defaults to an empty dict.
Returns:
A :class:`unittest.mock.MagicMock` that mimics ``fastapi.FastAPI``.
"""
if lookup_result is None:
lookup_result = {}
rows = [(ip,) for ip in unresolved_ips]
cursor = _AsyncRowIterator(rows)
# db.execute() returns an async context manager yielding the cursor.
ctx = AsyncMock()
ctx.__aenter__ = AsyncMock(return_value=cursor)
ctx.__aexit__ = AsyncMock(return_value=False)
db = AsyncMock()
db.execute = MagicMock(return_value=ctx)
http_session = MagicMock()
app = MagicMock()
app.state.db = db
app.state.http_session = http_session
return app
@pytest.mark.asyncio
async def test_run_re_resolve_no_unresolved_ips_skips() -> None:
"""The task should return immediately when no NULL-country IPs exist."""
app = _make_app(unresolved_ips=[])
with patch("app.tasks.geo_re_resolve.geo_service") as mock_geo:
await _run_re_resolve(app)
mock_geo.clear_neg_cache.assert_not_called()
mock_geo.lookup_batch.assert_not_called()
@pytest.mark.asyncio
async def test_run_re_resolve_clears_neg_cache() -> None:
"""The task must clear the negative cache before calling lookup_batch."""
ips = ["1.2.3.4", "5.6.7.8"]
result: dict[str, GeoInfo] = {
"1.2.3.4": GeoInfo(country_code="DE", country_name="Germany", asn="AS3320", org="DTAG"),
"5.6.7.8": GeoInfo(country_code="US", country_name="United States", asn="AS15169", org="Google"),
}
app = _make_app(unresolved_ips=ips, lookup_result=result)
with patch("app.tasks.geo_re_resolve.geo_service") as mock_geo:
mock_geo.lookup_batch = AsyncMock(return_value=result)
await _run_re_resolve(app)
mock_geo.clear_neg_cache.assert_called_once()
@pytest.mark.asyncio
async def test_run_re_resolve_calls_lookup_batch_with_db() -> None:
"""The task must pass the real db to lookup_batch for persistence."""
ips = ["10.0.0.1", "10.0.0.2"]
result: dict[str, GeoInfo] = {
"10.0.0.1": GeoInfo(country_code="FR", country_name="France", asn=None, org=None),
"10.0.0.2": GeoInfo(country_code=None, country_name=None, asn=None, org=None),
}
app = _make_app(unresolved_ips=ips, lookup_result=result)
with patch("app.tasks.geo_re_resolve.geo_service") as mock_geo:
mock_geo.lookup_batch = AsyncMock(return_value=result)
await _run_re_resolve(app)
mock_geo.lookup_batch.assert_called_once_with(
ips,
app.state.http_session,
db=app.state.db,
)
@pytest.mark.asyncio
async def test_run_re_resolve_logs_correct_counts(caplog: Any) -> None:
"""The task should log the number retried and number resolved."""
ips = ["1.1.1.1", "2.2.2.2", "3.3.3.3"]
result: dict[str, GeoInfo] = {
"1.1.1.1": GeoInfo(country_code="AU", country_name="Australia", asn=None, org=None),
"2.2.2.2": GeoInfo(country_code="JP", country_name="Japan", asn=None, org=None),
"3.3.3.3": GeoInfo(country_code=None, country_name=None, asn=None, org=None),
}
app = _make_app(unresolved_ips=ips, lookup_result=result)
with patch("app.tasks.geo_re_resolve.geo_service") as mock_geo:
mock_geo.lookup_batch = AsyncMock(return_value=result)
await _run_re_resolve(app)
# Verify lookup_batch was called (the logging assertions rely on
# structlog which is hard to capture in caplog; instead we verify
# the function ran to completion and the counts are correct by
# checking that lookup_batch received the right number of IPs).
call_args = mock_geo.lookup_batch.call_args
assert len(call_args[0][0]) == 3
@pytest.mark.asyncio
async def test_run_re_resolve_handles_all_resolved() -> None:
"""When every IP resolves successfully the task should complete normally."""
ips = ["4.4.4.4"]
result: dict[str, GeoInfo] = {
"4.4.4.4": GeoInfo(country_code="GB", country_name="United Kingdom", asn=None, org=None),
}
app = _make_app(unresolved_ips=ips, lookup_result=result)
with patch("app.tasks.geo_re_resolve.geo_service") as mock_geo:
mock_geo.lookup_batch = AsyncMock(return_value=result)
await _run_re_resolve(app)
mock_geo.clear_neg_cache.assert_called_once()
mock_geo.lookup_batch.assert_called_once()