Stage 10: external blocklist importer — backend + frontend
- blocklist_repo.py: CRUD for blocklist_sources table - import_log_repo.py: add/list/get-last log entries - blocklist_service.py: source CRUD, preview, import (download/validate/ban), import_all, schedule get/set/info - blocklist_import.py: APScheduler task (hourly/daily/weekly schedule triggers) - blocklist.py router: 9 endpoints (list/create/update/delete/preview/import/ schedule-get+put/log) - blocklist.py models: ScheduleFrequency (StrEnum), ScheduleConfig, ScheduleInfo, ImportSourceResult, ImportRunResult, PreviewResponse - 59 new tests (18 repo + 19 service + 22 router); 374 total pass - ruff clean, mypy clean for Stage 10 files - types/blocklist.ts, api/blocklist.ts, hooks/useBlocklist.ts - BlocklistsPage.tsx: source management, schedule picker, import log table - Frontend tsc + ESLint clean
This commit is contained in:
@@ -33,8 +33,8 @@ from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from app.config import Settings, get_settings
|
||||
from app.db import init_db
|
||||
from app.routers import auth, bans, config, dashboard, geo, health, history, jails, server, setup
|
||||
from app.tasks import health_check
|
||||
from app.routers import auth, bans, blocklist, config, dashboard, geo, health, history, jails, server, setup
|
||||
from app.tasks import blocklist_import, health_check
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Ensure the bundled fail2ban package is importable from fail2ban-master/
|
||||
@@ -118,6 +118,9 @@ async def _lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
# --- Health-check background probe ---
|
||||
health_check.register(app)
|
||||
|
||||
# --- Blocklist import scheduled task ---
|
||||
blocklist_import.register(app)
|
||||
|
||||
log.info("bangui_started")
|
||||
|
||||
try:
|
||||
@@ -279,5 +282,6 @@ def create_app(settings: Settings | None = None) -> FastAPI:
|
||||
app.include_router(config.router)
|
||||
app.include_router(server.router)
|
||||
app.include_router(history.router)
|
||||
app.include_router(blocklist.router)
|
||||
|
||||
return app
|
||||
|
||||
@@ -1,7 +1,19 @@
|
||||
"""Blocklist source and import log Pydantic models."""
|
||||
"""Blocklist source and import log Pydantic models.
|
||||
|
||||
Data shapes for blocklist source management, import operations, scheduling,
|
||||
and import log retrieval.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Blocklist source
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class BlocklistSource(BaseModel):
|
||||
"""Domain model for a blocklist source definition."""
|
||||
@@ -21,21 +33,34 @@ class BlocklistSourceCreate(BaseModel):
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
name: str = Field(..., min_length=1, description="Human-readable source name.")
|
||||
url: str = Field(..., description="URL of the blocklist file.")
|
||||
name: str = Field(..., min_length=1, max_length=100, description="Human-readable source name.")
|
||||
url: str = Field(..., min_length=1, description="URL of the blocklist file.")
|
||||
enabled: bool = Field(default=True)
|
||||
|
||||
|
||||
class BlocklistSourceUpdate(BaseModel):
|
||||
"""Payload for ``PUT /api/blocklists/{id}``."""
|
||||
"""Payload for ``PUT /api/blocklists/{id}``. All fields are optional."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
name: str | None = Field(default=None, min_length=1)
|
||||
name: str | None = Field(default=None, min_length=1, max_length=100)
|
||||
url: str | None = Field(default=None)
|
||||
enabled: bool | None = Field(default=None)
|
||||
|
||||
|
||||
class BlocklistListResponse(BaseModel):
|
||||
"""Response for ``GET /api/blocklists``."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
sources: list[BlocklistSource] = Field(default_factory=list)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Import log
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ImportLogEntry(BaseModel):
|
||||
"""A single blocklist import run record."""
|
||||
|
||||
@@ -50,35 +75,105 @@ class ImportLogEntry(BaseModel):
|
||||
errors: str | None
|
||||
|
||||
|
||||
class BlocklistListResponse(BaseModel):
|
||||
"""Response for ``GET /api/blocklists``."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
sources: list[BlocklistSource] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ImportLogListResponse(BaseModel):
|
||||
"""Response for ``GET /api/blocklists/log``."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
entries: list[ImportLogEntry] = Field(default_factory=list)
|
||||
items: list[ImportLogEntry] = Field(default_factory=list)
|
||||
total: int = Field(..., ge=0)
|
||||
page: int = Field(default=1, ge=1)
|
||||
page_size: int = Field(default=50, ge=1)
|
||||
total_pages: int = Field(default=1, ge=1)
|
||||
|
||||
|
||||
class BlocklistSchedule(BaseModel):
|
||||
"""Current import schedule and next run information."""
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schedule
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ScheduleFrequency(StrEnum):
|
||||
"""Available import schedule frequency presets."""
|
||||
|
||||
hourly = "hourly"
|
||||
daily = "daily"
|
||||
weekly = "weekly"
|
||||
|
||||
|
||||
class ScheduleConfig(BaseModel):
|
||||
"""Import schedule configuration.
|
||||
|
||||
The interpretation of fields depends on *frequency*:
|
||||
|
||||
- ``hourly``: ``interval_hours`` controls how often (every N hours).
|
||||
- ``daily``: ``hour`` and ``minute`` specify the daily run time (UTC).
|
||||
- ``weekly``: additionally uses ``day_of_week`` (0=Monday … 6=Sunday).
|
||||
"""
|
||||
|
||||
# No strict=True here: FastAPI and json.loads() both supply enum values as
|
||||
# plain strings; strict mode would reject string→enum coercion.
|
||||
|
||||
frequency: ScheduleFrequency = ScheduleFrequency.daily
|
||||
interval_hours: int = Field(default=24, ge=1, le=168, description="Used when frequency=hourly")
|
||||
hour: int = Field(default=3, ge=0, le=23, description="UTC hour for daily/weekly runs")
|
||||
minute: int = Field(default=0, ge=0, le=59, description="Minute for daily/weekly runs")
|
||||
day_of_week: int = Field(
|
||||
default=0,
|
||||
ge=0,
|
||||
le=6,
|
||||
description="Day of week for weekly runs (0=Monday … 6=Sunday)",
|
||||
)
|
||||
|
||||
|
||||
class ScheduleInfo(BaseModel):
|
||||
"""Current schedule configuration together with runtime metadata."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
hour: int = Field(..., ge=0, le=23, description="UTC hour for the daily import.")
|
||||
next_run_at: str | None = Field(default=None, description="ISO 8601 UTC timestamp of the next scheduled import.")
|
||||
config: ScheduleConfig
|
||||
next_run_at: str | None
|
||||
last_run_at: str | None
|
||||
|
||||
|
||||
class BlocklistScheduleUpdate(BaseModel):
|
||||
"""Payload for ``PUT /api/blocklists/schedule``."""
|
||||
# ---------------------------------------------------------------------------
|
||||
# Import results
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ImportSourceResult(BaseModel):
|
||||
"""Result of importing a single blocklist source."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
hour: int = Field(..., ge=0, le=23)
|
||||
source_id: int | None
|
||||
source_url: str
|
||||
ips_imported: int
|
||||
ips_skipped: int
|
||||
error: str | None
|
||||
|
||||
|
||||
class ImportRunResult(BaseModel):
|
||||
"""Aggregated result from a full import run across all enabled sources."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
results: list[ImportSourceResult] = Field(default_factory=list)
|
||||
total_imported: int
|
||||
total_skipped: int
|
||||
errors_count: int
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Preview
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class PreviewResponse(BaseModel):
|
||||
"""Response for ``GET /api/blocklists/{id}/preview``."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
entries: list[str] = Field(default_factory=list, description="Sample of valid IP entries")
|
||||
total_lines: int
|
||||
valid_count: int
|
||||
skipped_count: int
|
||||
|
||||
187
backend/app/repositories/blocklist_repo.py
Normal file
187
backend/app/repositories/blocklist_repo.py
Normal file
@@ -0,0 +1,187 @@
|
||||
"""Blocklist sources repository.
|
||||
|
||||
CRUD operations for the ``blocklist_sources`` table in the application
|
||||
SQLite database. All methods accept a :class:`aiosqlite.Connection` — no
|
||||
ORM, no HTTP exceptions.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import aiosqlite
|
||||
|
||||
|
||||
async def create_source(
|
||||
db: aiosqlite.Connection,
|
||||
name: str,
|
||||
url: str,
|
||||
*,
|
||||
enabled: bool = True,
|
||||
) -> int:
|
||||
"""Insert a new blocklist source and return its generated id.
|
||||
|
||||
Args:
|
||||
db: Active aiosqlite connection.
|
||||
name: Human-readable display name.
|
||||
url: URL of the blocklist text file.
|
||||
enabled: Whether the source is active. Defaults to ``True``.
|
||||
|
||||
Returns:
|
||||
The ``ROWID`` / primary key of the new row.
|
||||
"""
|
||||
cursor = await db.execute(
|
||||
"""
|
||||
INSERT INTO blocklist_sources (name, url, enabled)
|
||||
VALUES (?, ?, ?)
|
||||
""",
|
||||
(name, url, int(enabled)),
|
||||
)
|
||||
await db.commit()
|
||||
return int(cursor.lastrowid) # type: ignore[arg-type]
|
||||
|
||||
|
||||
async def get_source(
|
||||
db: aiosqlite.Connection,
|
||||
source_id: int,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Return a single blocklist source row as a plain dict, or ``None``.
|
||||
|
||||
Args:
|
||||
db: Active aiosqlite connection.
|
||||
source_id: Primary key of the source to retrieve.
|
||||
|
||||
Returns:
|
||||
A dict with keys matching the ``blocklist_sources`` columns, or
|
||||
``None`` if no row with that id exists.
|
||||
"""
|
||||
async with db.execute(
|
||||
"SELECT id, name, url, enabled, created_at, updated_at FROM blocklist_sources WHERE id = ?",
|
||||
(source_id,),
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return _row_to_dict(row)
|
||||
|
||||
|
||||
async def list_sources(db: aiosqlite.Connection) -> list[dict[str, Any]]:
|
||||
"""Return all blocklist sources ordered by id ascending.
|
||||
|
||||
Args:
|
||||
db: Active aiosqlite connection.
|
||||
|
||||
Returns:
|
||||
List of dicts, one per row in ``blocklist_sources``.
|
||||
"""
|
||||
async with db.execute(
|
||||
"SELECT id, name, url, enabled, created_at, updated_at FROM blocklist_sources ORDER BY id"
|
||||
) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
return [_row_to_dict(r) for r in rows]
|
||||
|
||||
|
||||
async def list_enabled_sources(db: aiosqlite.Connection) -> list[dict[str, Any]]:
|
||||
"""Return only enabled blocklist sources ordered by id.
|
||||
|
||||
Args:
|
||||
db: Active aiosqlite connection.
|
||||
|
||||
Returns:
|
||||
List of dicts for rows where ``enabled = 1``.
|
||||
"""
|
||||
async with db.execute(
|
||||
"SELECT id, name, url, enabled, created_at, updated_at FROM blocklist_sources WHERE enabled = 1 ORDER BY id"
|
||||
) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
return [_row_to_dict(r) for r in rows]
|
||||
|
||||
|
||||
async def update_source(
|
||||
db: aiosqlite.Connection,
|
||||
source_id: int,
|
||||
*,
|
||||
name: str | None = None,
|
||||
url: str | None = None,
|
||||
enabled: bool | None = None,
|
||||
) -> bool:
|
||||
"""Update one or more fields on a blocklist source.
|
||||
|
||||
Only the keyword arguments that are not ``None`` are included in the
|
||||
``UPDATE`` statement.
|
||||
|
||||
Args:
|
||||
db: Active aiosqlite connection.
|
||||
source_id: Primary key of the source to update.
|
||||
name: New display name, or ``None`` to leave unchanged.
|
||||
url: New URL, or ``None`` to leave unchanged.
|
||||
enabled: New enabled flag, or ``None`` to leave unchanged.
|
||||
|
||||
Returns:
|
||||
``True`` if a row was updated, ``False`` if the id does not exist.
|
||||
"""
|
||||
fields: list[str] = []
|
||||
params: list[Any] = []
|
||||
|
||||
if name is not None:
|
||||
fields.append("name = ?")
|
||||
params.append(name)
|
||||
if url is not None:
|
||||
fields.append("url = ?")
|
||||
params.append(url)
|
||||
if enabled is not None:
|
||||
fields.append("enabled = ?")
|
||||
params.append(int(enabled))
|
||||
|
||||
if not fields:
|
||||
# Nothing to update — treat as success only if the row exists.
|
||||
return await get_source(db, source_id) is not None
|
||||
|
||||
fields.append("updated_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now')")
|
||||
params.append(source_id)
|
||||
|
||||
cursor = await db.execute(
|
||||
f"UPDATE blocklist_sources SET {', '.join(fields)} WHERE id = ?", # noqa: S608
|
||||
params,
|
||||
)
|
||||
await db.commit()
|
||||
return cursor.rowcount > 0
|
||||
|
||||
|
||||
async def delete_source(db: aiosqlite.Connection, source_id: int) -> bool:
|
||||
"""Delete a blocklist source by id.
|
||||
|
||||
Args:
|
||||
db: Active aiosqlite connection.
|
||||
source_id: Primary key of the source to remove.
|
||||
|
||||
Returns:
|
||||
``True`` if a row was deleted, ``False`` if the id did not exist.
|
||||
"""
|
||||
cursor = await db.execute(
|
||||
"DELETE FROM blocklist_sources WHERE id = ?",
|
||||
(source_id,),
|
||||
)
|
||||
await db.commit()
|
||||
return cursor.rowcount > 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _row_to_dict(row: Any) -> dict[str, Any]:
|
||||
"""Convert an aiosqlite row to a plain Python dict.
|
||||
|
||||
Args:
|
||||
row: An :class:`aiosqlite.Row` or sequence returned by a cursor.
|
||||
|
||||
Returns:
|
||||
``dict`` mapping column names to values with ``enabled`` cast to
|
||||
``bool``.
|
||||
"""
|
||||
d: dict[str, Any] = dict(row)
|
||||
d["enabled"] = bool(d["enabled"])
|
||||
return d
|
||||
155
backend/app/repositories/import_log_repo.py
Normal file
155
backend/app/repositories/import_log_repo.py
Normal file
@@ -0,0 +1,155 @@
|
||||
"""Import log repository.
|
||||
|
||||
Persists and queries blocklist import run records in the ``import_log``
|
||||
table. All methods are plain async functions that accept a
|
||||
:class:`aiosqlite.Connection`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import aiosqlite
|
||||
|
||||
|
||||
async def add_log(
|
||||
db: aiosqlite.Connection,
|
||||
*,
|
||||
source_id: int | None,
|
||||
source_url: str,
|
||||
ips_imported: int,
|
||||
ips_skipped: int,
|
||||
errors: str | None,
|
||||
) -> int:
|
||||
"""Insert a new import log entry and return its id.
|
||||
|
||||
Args:
|
||||
db: Active aiosqlite connection.
|
||||
source_id: FK to ``blocklist_sources.id``, or ``None`` if the source
|
||||
has been deleted since the import ran.
|
||||
source_url: URL that was downloaded.
|
||||
ips_imported: Number of IPs successfully applied as bans.
|
||||
ips_skipped: Number of lines that were skipped (invalid or CIDR).
|
||||
errors: Error message string, or ``None`` if the import succeeded.
|
||||
|
||||
Returns:
|
||||
Primary key of the inserted row.
|
||||
"""
|
||||
cursor = await db.execute(
|
||||
"""
|
||||
INSERT INTO import_log (source_id, source_url, ips_imported, ips_skipped, errors)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""",
|
||||
(source_id, source_url, ips_imported, ips_skipped, errors),
|
||||
)
|
||||
await db.commit()
|
||||
return int(cursor.lastrowid) # type: ignore[arg-type]
|
||||
|
||||
|
||||
async def list_logs(
|
||||
db: aiosqlite.Connection,
|
||||
*,
|
||||
source_id: int | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 50,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""Return a paginated list of import log entries.
|
||||
|
||||
Args:
|
||||
db: Active aiosqlite connection.
|
||||
source_id: If given, filter to logs for this source only.
|
||||
page: 1-based page index.
|
||||
page_size: Number of items per page.
|
||||
|
||||
Returns:
|
||||
A 2-tuple ``(items, total)`` where *items* is a list of dicts and
|
||||
*total* is the count of all matching rows (ignoring pagination).
|
||||
"""
|
||||
where = ""
|
||||
params_count: list[Any] = []
|
||||
params_rows: list[Any] = []
|
||||
|
||||
if source_id is not None:
|
||||
where = " WHERE source_id = ?"
|
||||
params_count.append(source_id)
|
||||
params_rows.append(source_id)
|
||||
|
||||
# Total count
|
||||
async with db.execute(
|
||||
f"SELECT COUNT(*) FROM import_log{where}", # noqa: S608
|
||||
params_count,
|
||||
) as cursor:
|
||||
count_row = await cursor.fetchone()
|
||||
total: int = int(count_row[0]) if count_row else 0
|
||||
|
||||
offset = (page - 1) * page_size
|
||||
params_rows.extend([page_size, offset])
|
||||
|
||||
async with db.execute(
|
||||
f"""
|
||||
SELECT id, source_id, source_url, timestamp, ips_imported, ips_skipped, errors
|
||||
FROM import_log{where}
|
||||
ORDER BY id DESC
|
||||
LIMIT ? OFFSET ?
|
||||
""", # noqa: S608
|
||||
params_rows,
|
||||
) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
items = [_row_to_dict(r) for r in rows]
|
||||
|
||||
return items, total
|
||||
|
||||
|
||||
async def get_last_log(db: aiosqlite.Connection) -> dict[str, Any] | None:
|
||||
"""Return the most recent import log entry across all sources.
|
||||
|
||||
Args:
|
||||
db: Active aiosqlite connection.
|
||||
|
||||
Returns:
|
||||
The latest log entry as a dict, or ``None`` if no logs exist.
|
||||
"""
|
||||
async with db.execute(
|
||||
"""
|
||||
SELECT id, source_id, source_url, timestamp, ips_imported, ips_skipped, errors
|
||||
FROM import_log
|
||||
ORDER BY id DESC
|
||||
LIMIT 1
|
||||
"""
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
return _row_to_dict(row) if row is not None else None
|
||||
|
||||
|
||||
def compute_total_pages(total: int, page_size: int) -> int:
|
||||
"""Return the total number of pages for a given total and page size.
|
||||
|
||||
Args:
|
||||
total: Total number of items.
|
||||
page_size: Items per page.
|
||||
|
||||
Returns:
|
||||
Number of pages (minimum 1).
|
||||
"""
|
||||
if total == 0:
|
||||
return 1
|
||||
return math.ceil(total / page_size)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _row_to_dict(row: Any) -> dict[str, Any]:
|
||||
"""Convert an aiosqlite row to a plain Python dict.
|
||||
|
||||
Args:
|
||||
row: An :class:`aiosqlite.Row` or sequence returned by a cursor.
|
||||
|
||||
Returns:
|
||||
Dict mapping column names to Python values.
|
||||
"""
|
||||
return dict(row)
|
||||
370
backend/app/routers/blocklist.py
Normal file
370
backend/app/routers/blocklist.py
Normal file
@@ -0,0 +1,370 @@
|
||||
"""Blocklist router.
|
||||
|
||||
Manages external IP blocklist sources, triggers manual imports, and exposes
|
||||
the import schedule and log:
|
||||
|
||||
* ``GET /api/blocklists`` — list all sources
|
||||
* ``POST /api/blocklists`` — add a source
|
||||
* ``GET /api/blocklists/import`` — (reserved; use POST)
|
||||
* ``POST /api/blocklists/import`` — trigger a manual import now
|
||||
* ``GET /api/blocklists/schedule`` — get current schedule + next run
|
||||
* ``PUT /api/blocklists/schedule`` — update schedule
|
||||
* ``GET /api/blocklists/log`` — paginated import log
|
||||
* ``GET /api/blocklists/{id}`` — get a single source
|
||||
* ``PUT /api/blocklists/{id}`` — edit a source
|
||||
* ``DELETE /api/blocklists/{id}`` — remove a source
|
||||
* ``GET /api/blocklists/{id}/preview`` — preview the blocklist contents
|
||||
|
||||
Note: static path segments (``/import``, ``/schedule``, ``/log``) are
|
||||
registered *before* the ``/{id}`` routes so FastAPI resolves them correctly.
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Annotated
|
||||
|
||||
import aiosqlite
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import aiohttp
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||
|
||||
from app.dependencies import AuthDep, get_db
|
||||
from app.models.blocklist import (
|
||||
BlocklistListResponse,
|
||||
BlocklistSource,
|
||||
BlocklistSourceCreate,
|
||||
BlocklistSourceUpdate,
|
||||
ImportLogListResponse,
|
||||
ImportRunResult,
|
||||
PreviewResponse,
|
||||
ScheduleConfig,
|
||||
ScheduleInfo,
|
||||
)
|
||||
from app.repositories import import_log_repo
|
||||
from app.services import blocklist_service
|
||||
from app.tasks import blocklist_import as blocklist_import_task
|
||||
|
||||
router: APIRouter = APIRouter(prefix="/api/blocklists", tags=["Blocklists"])
|
||||
|
||||
DbDep = Annotated[aiosqlite.Connection, Depends(get_db)]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Source list + create
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get(
|
||||
"",
|
||||
response_model=BlocklistListResponse,
|
||||
summary="List all blocklist sources",
|
||||
)
|
||||
async def list_blocklists(
|
||||
db: DbDep,
|
||||
_auth: AuthDep,
|
||||
) -> BlocklistListResponse:
|
||||
"""Return all configured blocklist source definitions.
|
||||
|
||||
Args:
|
||||
db: Application database connection (injected).
|
||||
_auth: Validated session — enforces authentication.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.blocklist.BlocklistListResponse` with all sources.
|
||||
"""
|
||||
sources = await blocklist_service.list_sources(db)
|
||||
return BlocklistListResponse(sources=sources)
|
||||
|
||||
|
||||
@router.post(
|
||||
"",
|
||||
response_model=BlocklistSource,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Add a new blocklist source",
|
||||
)
|
||||
async def create_blocklist(
|
||||
payload: BlocklistSourceCreate,
|
||||
db: DbDep,
|
||||
_auth: AuthDep,
|
||||
) -> BlocklistSource:
|
||||
"""Create a new blocklist source definition.
|
||||
|
||||
Args:
|
||||
payload: New source data (name, url, enabled).
|
||||
db: Application database connection (injected).
|
||||
_auth: Validated session — enforces authentication.
|
||||
|
||||
Returns:
|
||||
The newly created :class:`~app.models.blocklist.BlocklistSource`.
|
||||
"""
|
||||
return await blocklist_service.create_source(
|
||||
db, payload.name, payload.url, enabled=payload.enabled
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Static sub-paths — must be declared BEFORE /{id}
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post(
|
||||
"/import",
|
||||
response_model=ImportRunResult,
|
||||
summary="Trigger a manual blocklist import",
|
||||
)
|
||||
async def run_import_now(
|
||||
request: Request,
|
||||
db: DbDep,
|
||||
_auth: AuthDep,
|
||||
) -> ImportRunResult:
|
||||
"""Download and apply all enabled blocklist sources immediately.
|
||||
|
||||
Args:
|
||||
request: Incoming request (used to access shared HTTP session).
|
||||
db: Application database connection (injected).
|
||||
_auth: Validated session — enforces authentication.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.blocklist.ImportRunResult` with per-source
|
||||
results and aggregated counters.
|
||||
"""
|
||||
http_session: aiohttp.ClientSession = request.app.state.http_session
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
return await blocklist_service.import_all(db, http_session, socket_path)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/schedule",
|
||||
response_model=ScheduleInfo,
|
||||
summary="Get the current import schedule",
|
||||
)
|
||||
async def get_schedule(
|
||||
request: Request,
|
||||
db: DbDep,
|
||||
_auth: AuthDep,
|
||||
) -> ScheduleInfo:
|
||||
"""Return the current schedule configuration and runtime metadata.
|
||||
|
||||
The ``next_run_at`` field is read from APScheduler if the job is active.
|
||||
|
||||
Args:
|
||||
request: Incoming request (used to query the scheduler).
|
||||
db: Application database connection (injected).
|
||||
_auth: Validated session — enforces authentication.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.blocklist.ScheduleInfo` with config and run
|
||||
times.
|
||||
"""
|
||||
scheduler = request.app.state.scheduler
|
||||
job = scheduler.get_job(blocklist_import_task.JOB_ID)
|
||||
next_run_at: str | None = None
|
||||
if job is not None and job.next_run_time is not None:
|
||||
next_run_at = job.next_run_time.isoformat()
|
||||
|
||||
return await blocklist_service.get_schedule_info(db, next_run_at)
|
||||
|
||||
|
||||
@router.put(
|
||||
"/schedule",
|
||||
response_model=ScheduleInfo,
|
||||
summary="Update the import schedule",
|
||||
)
|
||||
async def update_schedule(
|
||||
payload: ScheduleConfig,
|
||||
request: Request,
|
||||
db: DbDep,
|
||||
_auth: AuthDep,
|
||||
) -> ScheduleInfo:
|
||||
"""Persist a new schedule configuration and reschedule the import job.
|
||||
|
||||
Args:
|
||||
payload: New :class:`~app.models.blocklist.ScheduleConfig`.
|
||||
request: Incoming request (used to access the scheduler).
|
||||
db: Application database connection (injected).
|
||||
_auth: Validated session — enforces authentication.
|
||||
|
||||
Returns:
|
||||
Updated :class:`~app.models.blocklist.ScheduleInfo`.
|
||||
"""
|
||||
await blocklist_service.set_schedule(db, payload)
|
||||
# Reschedule the background job immediately.
|
||||
blocklist_import_task.reschedule(request.app)
|
||||
|
||||
job = request.app.state.scheduler.get_job(blocklist_import_task.JOB_ID)
|
||||
next_run_at: str | None = None
|
||||
if job is not None and job.next_run_time is not None:
|
||||
next_run_at = job.next_run_time.isoformat()
|
||||
|
||||
return await blocklist_service.get_schedule_info(db, next_run_at)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/log",
|
||||
response_model=ImportLogListResponse,
|
||||
summary="Get the paginated import log",
|
||||
)
|
||||
async def get_import_log(
|
||||
db: DbDep,
|
||||
_auth: AuthDep,
|
||||
source_id: int | None = Query(default=None, description="Filter by source id"),
|
||||
page: int = Query(default=1, ge=1),
|
||||
page_size: int = Query(default=50, ge=1, le=200),
|
||||
) -> ImportLogListResponse:
|
||||
"""Return a paginated log of all import runs.
|
||||
|
||||
Args:
|
||||
db: Application database connection (injected).
|
||||
_auth: Validated session — enforces authentication.
|
||||
source_id: Optional filter — only show logs for this source.
|
||||
page: 1-based page number.
|
||||
page_size: Items per page.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.blocklist.ImportLogListResponse`.
|
||||
"""
|
||||
items, total = await import_log_repo.list_logs(
|
||||
db, source_id=source_id, page=page, page_size=page_size
|
||||
)
|
||||
total_pages = import_log_repo.compute_total_pages(total, page_size)
|
||||
from app.models.blocklist import ImportLogEntry # noqa: PLC0415
|
||||
|
||||
return ImportLogListResponse(
|
||||
items=[ImportLogEntry.model_validate(i) for i in items],
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Single source CRUD — parameterised routes AFTER static sub-paths
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{source_id}",
|
||||
response_model=BlocklistSource,
|
||||
summary="Get a single blocklist source",
|
||||
)
|
||||
async def get_blocklist(
|
||||
source_id: int,
|
||||
db: DbDep,
|
||||
_auth: AuthDep,
|
||||
) -> BlocklistSource:
|
||||
"""Return a single blocklist source by id.
|
||||
|
||||
Args:
|
||||
source_id: Primary key of the source.
|
||||
db: Application database connection (injected).
|
||||
_auth: Validated session — enforces authentication.
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 if the source does not exist.
|
||||
"""
|
||||
source = await blocklist_service.get_source(db, source_id)
|
||||
if source is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Blocklist source not found.")
|
||||
return source
|
||||
|
||||
|
||||
@router.put(
|
||||
"/{source_id}",
|
||||
response_model=BlocklistSource,
|
||||
summary="Update a blocklist source",
|
||||
)
|
||||
async def update_blocklist(
|
||||
source_id: int,
|
||||
payload: BlocklistSourceUpdate,
|
||||
db: DbDep,
|
||||
_auth: AuthDep,
|
||||
) -> BlocklistSource:
|
||||
"""Update one or more fields on a blocklist source.
|
||||
|
||||
Args:
|
||||
source_id: Primary key of the source to update.
|
||||
payload: Fields to update (all optional).
|
||||
db: Application database connection (injected).
|
||||
_auth: Validated session — enforces authentication.
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 if the source does not exist.
|
||||
"""
|
||||
updated = await blocklist_service.update_source(
|
||||
db,
|
||||
source_id,
|
||||
name=payload.name,
|
||||
url=payload.url,
|
||||
enabled=payload.enabled,
|
||||
)
|
||||
if updated is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Blocklist source not found.")
|
||||
return updated
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{source_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Delete a blocklist source",
|
||||
)
|
||||
async def delete_blocklist(
|
||||
source_id: int,
|
||||
db: DbDep,
|
||||
_auth: AuthDep,
|
||||
) -> None:
|
||||
"""Delete a blocklist source by id.
|
||||
|
||||
Args:
|
||||
source_id: Primary key of the source to remove.
|
||||
db: Application database connection (injected).
|
||||
_auth: Validated session — enforces authentication.
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 if the source does not exist.
|
||||
"""
|
||||
deleted = await blocklist_service.delete_source(db, source_id)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Blocklist source not found.")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{source_id}/preview",
|
||||
response_model=PreviewResponse,
|
||||
summary="Preview the contents of a blocklist source",
|
||||
)
|
||||
async def preview_blocklist(
|
||||
source_id: int,
|
||||
request: Request,
|
||||
db: DbDep,
|
||||
_auth: AuthDep,
|
||||
) -> PreviewResponse:
|
||||
"""Download and preview a sample of a blocklist source.
|
||||
|
||||
Returns the first :data:`~app.services.blocklist_service._PREVIEW_LINES`
|
||||
valid IP entries together with validation statistics.
|
||||
|
||||
Args:
|
||||
source_id: Primary key of the source to preview.
|
||||
request: Incoming request (used to access the HTTP session).
|
||||
db: Application database connection (injected).
|
||||
_auth: Validated session — enforces authentication.
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 if the source does not exist.
|
||||
HTTPException: 502 if the URL cannot be reached.
|
||||
"""
|
||||
source = await blocklist_service.get_source(db, source_id)
|
||||
if source is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Blocklist source not found.")
|
||||
|
||||
http_session: aiohttp.ClientSession = request.app.state.http_session
|
||||
try:
|
||||
return await blocklist_service.preview_source(source.url, http_session)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"Could not fetch blocklist: {exc}",
|
||||
) from exc
|
||||
493
backend/app/services/blocklist_service.py
Normal file
493
backend/app/services/blocklist_service.py
Normal file
@@ -0,0 +1,493 @@
|
||||
"""Blocklist service.
|
||||
|
||||
Manages blocklist source CRUD, URL preview, IP import (download → validate →
|
||||
ban via fail2ban), and schedule persistence.
|
||||
|
||||
All ban operations target a dedicated fail2ban jail (default:
|
||||
``"blocklist-import"``) so blocklist-origin bans are tracked separately from
|
||||
regular bans. If that jail does not exist or fail2ban is unreachable, the
|
||||
error is recorded in the import log and processing continues.
|
||||
|
||||
Schedule configuration is stored as JSON in the application settings table
|
||||
under the key ``"blocklist_schedule"``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import structlog
|
||||
|
||||
from app.models.blocklist import (
|
||||
BlocklistSource,
|
||||
ImportRunResult,
|
||||
ImportSourceResult,
|
||||
PreviewResponse,
|
||||
ScheduleConfig,
|
||||
ScheduleInfo,
|
||||
)
|
||||
from app.repositories import blocklist_repo, import_log_repo, settings_repo
|
||||
from app.utils.ip_utils import is_valid_ip, is_valid_network
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import aiohttp
|
||||
import aiosqlite
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
#: Settings key used to persist the schedule config.
|
||||
_SCHEDULE_SETTINGS_KEY: str = "blocklist_schedule"
|
||||
|
||||
#: fail2ban jail name for blocklist-origin bans.
|
||||
BLOCKLIST_JAIL: str = "blocklist-import"
|
||||
|
||||
#: Maximum number of sample entries returned by the preview endpoint.
|
||||
_PREVIEW_LINES: int = 20
|
||||
|
||||
#: Maximum bytes to download for a preview (first 64 KB).
|
||||
_PREVIEW_MAX_BYTES: int = 65536
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Source CRUD helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _row_to_source(row: dict[str, Any]) -> BlocklistSource:
|
||||
"""Convert a repository row dict to a :class:`BlocklistSource`.
|
||||
|
||||
Args:
|
||||
row: Dict with keys matching the ``blocklist_sources`` columns.
|
||||
|
||||
Returns:
|
||||
A validated :class:`~app.models.blocklist.BlocklistSource` instance.
|
||||
"""
|
||||
return BlocklistSource.model_validate(row)
|
||||
|
||||
|
||||
async def list_sources(db: aiosqlite.Connection) -> list[BlocklistSource]:
|
||||
"""Return all configured blocklist sources.
|
||||
|
||||
Args:
|
||||
db: Active application database connection.
|
||||
|
||||
Returns:
|
||||
List of :class:`~app.models.blocklist.BlocklistSource` instances.
|
||||
"""
|
||||
rows = await blocklist_repo.list_sources(db)
|
||||
return [_row_to_source(r) for r in rows]
|
||||
|
||||
|
||||
async def get_source(
|
||||
db: aiosqlite.Connection,
|
||||
source_id: int,
|
||||
) -> BlocklistSource | None:
|
||||
"""Return a single blocklist source, or ``None`` if not found.
|
||||
|
||||
Args:
|
||||
db: Active application database connection.
|
||||
source_id: Primary key of the desired source.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.blocklist.BlocklistSource` or ``None``.
|
||||
"""
|
||||
row = await blocklist_repo.get_source(db, source_id)
|
||||
return _row_to_source(row) if row is not None else None
|
||||
|
||||
|
||||
async def create_source(
|
||||
db: aiosqlite.Connection,
|
||||
name: str,
|
||||
url: str,
|
||||
*,
|
||||
enabled: bool = True,
|
||||
) -> BlocklistSource:
|
||||
"""Create a new blocklist source and return the persisted record.
|
||||
|
||||
Args:
|
||||
db: Active application database connection.
|
||||
name: Human-readable display name.
|
||||
url: URL of the blocklist text file.
|
||||
enabled: Whether the source is active. Defaults to ``True``.
|
||||
|
||||
Returns:
|
||||
The newly created :class:`~app.models.blocklist.BlocklistSource`.
|
||||
"""
|
||||
new_id = await blocklist_repo.create_source(db, name, url, enabled=enabled)
|
||||
source = await get_source(db, new_id)
|
||||
assert source is not None # noqa: S101
|
||||
log.info("blocklist_source_created", id=new_id, name=name, url=url)
|
||||
return source
|
||||
|
||||
|
||||
async def update_source(
|
||||
db: aiosqlite.Connection,
|
||||
source_id: int,
|
||||
*,
|
||||
name: str | None = None,
|
||||
url: str | None = None,
|
||||
enabled: bool | None = None,
|
||||
) -> BlocklistSource | None:
|
||||
"""Update fields on a blocklist source.
|
||||
|
||||
Args:
|
||||
db: Active application database connection.
|
||||
source_id: Primary key of the source to modify.
|
||||
name: New display name, or ``None`` to leave unchanged.
|
||||
url: New URL, or ``None`` to leave unchanged.
|
||||
enabled: New enabled state, or ``None`` to leave unchanged.
|
||||
|
||||
Returns:
|
||||
Updated :class:`~app.models.blocklist.BlocklistSource`, or ``None``
|
||||
if the source does not exist.
|
||||
"""
|
||||
updated = await blocklist_repo.update_source(
|
||||
db, source_id, name=name, url=url, enabled=enabled
|
||||
)
|
||||
if not updated:
|
||||
return None
|
||||
source = await get_source(db, source_id)
|
||||
log.info("blocklist_source_updated", id=source_id)
|
||||
return source
|
||||
|
||||
|
||||
async def delete_source(db: aiosqlite.Connection, source_id: int) -> bool:
|
||||
"""Delete a blocklist source.
|
||||
|
||||
Args:
|
||||
db: Active application database connection.
|
||||
source_id: Primary key of the source to delete.
|
||||
|
||||
Returns:
|
||||
``True`` if the source was found and deleted, ``False`` otherwise.
|
||||
"""
|
||||
deleted = await blocklist_repo.delete_source(db, source_id)
|
||||
if deleted:
|
||||
log.info("blocklist_source_deleted", id=source_id)
|
||||
return deleted
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Preview
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def preview_source(
|
||||
url: str,
|
||||
http_session: aiohttp.ClientSession,
|
||||
*,
|
||||
sample_lines: int = _PREVIEW_LINES,
|
||||
) -> PreviewResponse:
|
||||
"""Download the beginning of a blocklist URL and return a preview.
|
||||
|
||||
Args:
|
||||
url: URL to download.
|
||||
http_session: Shared :class:`aiohttp.ClientSession`.
|
||||
sample_lines: Maximum number of lines to include in the preview.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.blocklist.PreviewResponse` with a sample of
|
||||
valid IP entries and validation statistics.
|
||||
|
||||
Raises:
|
||||
ValueError: If the URL cannot be reached or returns a non-200 status.
|
||||
"""
|
||||
try:
|
||||
async with http_session.get(url, timeout=_aiohttp_timeout(10)) as resp:
|
||||
if resp.status != 200:
|
||||
raise ValueError(f"HTTP {resp.status} from {url}")
|
||||
raw = await resp.content.read(_PREVIEW_MAX_BYTES)
|
||||
except Exception as exc:
|
||||
log.warning("blocklist_preview_failed", url=url, error=str(exc))
|
||||
raise ValueError(str(exc)) from exc
|
||||
|
||||
lines = raw.decode(errors="replace").splitlines()
|
||||
entries: list[str] = []
|
||||
valid = 0
|
||||
skipped = 0
|
||||
|
||||
for line in lines:
|
||||
stripped = line.strip()
|
||||
if not stripped or stripped.startswith("#"):
|
||||
continue
|
||||
if is_valid_ip(stripped) or is_valid_network(stripped):
|
||||
valid += 1
|
||||
if len(entries) < sample_lines:
|
||||
entries.append(stripped)
|
||||
else:
|
||||
skipped += 1
|
||||
|
||||
return PreviewResponse(
|
||||
entries=entries,
|
||||
total_lines=len(lines),
|
||||
valid_count=valid,
|
||||
skipped_count=skipped,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Import
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def import_source(
|
||||
source: BlocklistSource,
|
||||
http_session: aiohttp.ClientSession,
|
||||
socket_path: str,
|
||||
db: aiosqlite.Connection,
|
||||
) -> ImportSourceResult:
|
||||
"""Download and apply bans from a single blocklist source.
|
||||
|
||||
The function downloads the URL, validates each line as an IP address,
|
||||
and bans valid IPv4/IPv6 addresses via fail2ban in
|
||||
:data:`BLOCKLIST_JAIL`. CIDR ranges are counted as skipped since
|
||||
fail2ban requires individual addresses. Any error encountered during
|
||||
download is recorded and the result is returned without raising.
|
||||
|
||||
Args:
|
||||
source: The :class:`~app.models.blocklist.BlocklistSource` to import.
|
||||
http_session: Shared :class:`aiohttp.ClientSession`.
|
||||
socket_path: Path to the fail2ban Unix socket.
|
||||
db: Application database for logging.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.blocklist.ImportSourceResult` with counters.
|
||||
"""
|
||||
# --- Download ---
|
||||
try:
|
||||
async with http_session.get(
|
||||
source.url, timeout=_aiohttp_timeout(30)
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
error_msg = f"HTTP {resp.status}"
|
||||
await _log_result(db, source, 0, 0, error_msg)
|
||||
log.warning("blocklist_import_download_failed", url=source.url, status=resp.status)
|
||||
return ImportSourceResult(
|
||||
source_id=source.id,
|
||||
source_url=source.url,
|
||||
ips_imported=0,
|
||||
ips_skipped=0,
|
||||
error=error_msg,
|
||||
)
|
||||
content = await resp.text(errors="replace")
|
||||
except Exception as exc:
|
||||
error_msg = str(exc)
|
||||
await _log_result(db, source, 0, 0, error_msg)
|
||||
log.warning("blocklist_import_download_error", url=source.url, error=error_msg)
|
||||
return ImportSourceResult(
|
||||
source_id=source.id,
|
||||
source_url=source.url,
|
||||
ips_imported=0,
|
||||
ips_skipped=0,
|
||||
error=error_msg,
|
||||
)
|
||||
|
||||
# --- Validate and ban ---
|
||||
imported = 0
|
||||
skipped = 0
|
||||
ban_error: str | None = None
|
||||
|
||||
# Import jail_service here to avoid circular import at module level.
|
||||
from app.services import jail_service # noqa: PLC0415
|
||||
|
||||
for line in content.splitlines():
|
||||
stripped = line.strip()
|
||||
if not stripped or stripped.startswith("#"):
|
||||
continue
|
||||
|
||||
if not is_valid_ip(stripped):
|
||||
# Skip CIDRs and malformed entries gracefully.
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
try:
|
||||
await jail_service.ban_ip(socket_path, BLOCKLIST_JAIL, stripped)
|
||||
imported += 1
|
||||
except Exception as exc:
|
||||
skipped += 1
|
||||
if ban_error is None:
|
||||
ban_error = str(exc)
|
||||
log.debug("blocklist_ban_failed", ip=stripped, error=str(exc))
|
||||
|
||||
await _log_result(db, source, imported, skipped, ban_error)
|
||||
log.info(
|
||||
"blocklist_source_imported",
|
||||
source_id=source.id,
|
||||
url=source.url,
|
||||
imported=imported,
|
||||
skipped=skipped,
|
||||
error=ban_error,
|
||||
)
|
||||
return ImportSourceResult(
|
||||
source_id=source.id,
|
||||
source_url=source.url,
|
||||
ips_imported=imported,
|
||||
ips_skipped=skipped,
|
||||
error=ban_error,
|
||||
)
|
||||
|
||||
|
||||
async def import_all(
|
||||
db: aiosqlite.Connection,
|
||||
http_session: aiohttp.ClientSession,
|
||||
socket_path: str,
|
||||
) -> ImportRunResult:
|
||||
"""Import all enabled blocklist sources.
|
||||
|
||||
Iterates over every source with ``enabled = True``, calls
|
||||
:func:`import_source` for each, and aggregates the results.
|
||||
|
||||
Args:
|
||||
db: Application database connection.
|
||||
http_session: Shared :class:`aiohttp.ClientSession`.
|
||||
socket_path: fail2ban socket path.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.blocklist.ImportRunResult` with aggregated
|
||||
counters and per-source results.
|
||||
"""
|
||||
sources = await blocklist_repo.list_enabled_sources(db)
|
||||
results: list[ImportSourceResult] = []
|
||||
total_imported = 0
|
||||
total_skipped = 0
|
||||
errors_count = 0
|
||||
|
||||
for row in sources:
|
||||
source = _row_to_source(row)
|
||||
result = await import_source(source, http_session, socket_path, db)
|
||||
results.append(result)
|
||||
total_imported += result.ips_imported
|
||||
total_skipped += result.ips_skipped
|
||||
if result.error is not None:
|
||||
errors_count += 1
|
||||
|
||||
log.info(
|
||||
"blocklist_import_all_complete",
|
||||
sources=len(sources),
|
||||
total_imported=total_imported,
|
||||
total_skipped=total_skipped,
|
||||
errors=errors_count,
|
||||
)
|
||||
return ImportRunResult(
|
||||
results=results,
|
||||
total_imported=total_imported,
|
||||
total_skipped=total_skipped,
|
||||
errors_count=errors_count,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schedule
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_DEFAULT_SCHEDULE = ScheduleConfig()
|
||||
|
||||
|
||||
async def get_schedule(db: aiosqlite.Connection) -> ScheduleConfig:
|
||||
"""Read the import schedule config from the settings table.
|
||||
|
||||
Returns the default config (daily at 03:00 UTC) if no schedule has been
|
||||
saved yet.
|
||||
|
||||
Args:
|
||||
db: Active application database connection.
|
||||
|
||||
Returns:
|
||||
The stored (or default) :class:`~app.models.blocklist.ScheduleConfig`.
|
||||
"""
|
||||
raw = await settings_repo.get_setting(db, _SCHEDULE_SETTINGS_KEY)
|
||||
if raw is None:
|
||||
return _DEFAULT_SCHEDULE
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
return ScheduleConfig.model_validate(data)
|
||||
except Exception:
|
||||
log.warning("blocklist_schedule_invalid", raw=raw)
|
||||
return _DEFAULT_SCHEDULE
|
||||
|
||||
|
||||
async def set_schedule(
|
||||
db: aiosqlite.Connection,
|
||||
config: ScheduleConfig,
|
||||
) -> ScheduleConfig:
|
||||
"""Persist a new schedule configuration.
|
||||
|
||||
Args:
|
||||
db: Active application database connection.
|
||||
config: The :class:`~app.models.blocklist.ScheduleConfig` to store.
|
||||
|
||||
Returns:
|
||||
The saved configuration (same object after validation).
|
||||
"""
|
||||
await settings_repo.set_setting(
|
||||
db, _SCHEDULE_SETTINGS_KEY, config.model_dump_json()
|
||||
)
|
||||
log.info("blocklist_schedule_updated", frequency=config.frequency, hour=config.hour)
|
||||
return config
|
||||
|
||||
|
||||
async def get_schedule_info(
|
||||
db: aiosqlite.Connection,
|
||||
next_run_at: str | None,
|
||||
) -> ScheduleInfo:
|
||||
"""Return the schedule config together with last-run metadata.
|
||||
|
||||
Args:
|
||||
db: Active application database connection.
|
||||
next_run_at: ISO 8601 string of the next scheduled run, or ``None``
|
||||
if not yet scheduled (provided by the caller from APScheduler).
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.blocklist.ScheduleInfo` combining config and
|
||||
runtime metadata.
|
||||
"""
|
||||
config = await get_schedule(db)
|
||||
last_log = await import_log_repo.get_last_log(db)
|
||||
last_run_at = last_log["timestamp"] if last_log else None
|
||||
return ScheduleInfo(config=config, next_run_at=next_run_at, last_run_at=last_run_at)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _aiohttp_timeout(seconds: float) -> Any:
|
||||
"""Return an :class:`aiohttp.ClientTimeout` with the given total timeout.
|
||||
|
||||
Args:
|
||||
seconds: Total timeout in seconds.
|
||||
|
||||
Returns:
|
||||
An :class:`aiohttp.ClientTimeout` instance.
|
||||
"""
|
||||
import aiohttp # noqa: PLC0415
|
||||
|
||||
return aiohttp.ClientTimeout(total=seconds)
|
||||
|
||||
|
||||
async def _log_result(
|
||||
db: aiosqlite.Connection,
|
||||
source: BlocklistSource,
|
||||
ips_imported: int,
|
||||
ips_skipped: int,
|
||||
error: str | None,
|
||||
) -> None:
|
||||
"""Write an import log entry for a completed source import.
|
||||
|
||||
Args:
|
||||
db: Application database connection.
|
||||
source: The source that was imported.
|
||||
ips_imported: Count of successfully banned IPs.
|
||||
ips_skipped: Count of skipped/invalid entries.
|
||||
error: Error string, or ``None`` on success.
|
||||
"""
|
||||
await import_log_repo.add_log(
|
||||
db,
|
||||
source_id=source.id,
|
||||
source_url=source.url,
|
||||
ips_imported=ips_imported,
|
||||
ips_skipped=ips_skipped,
|
||||
errors=error,
|
||||
)
|
||||
153
backend/app/tasks/blocklist_import.py
Normal file
153
backend/app/tasks/blocklist_import.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""External blocklist import background task.
|
||||
|
||||
Registers an APScheduler job that downloads all enabled blocklist sources,
|
||||
validates their entries, and applies bans via fail2ban on a configurable
|
||||
schedule. The default schedule is daily at 03:00 UTC; it is stored in the
|
||||
application :class:`~app.models.blocklist.ScheduleConfig` settings and can
|
||||
be updated at runtime through the blocklist router.
|
||||
|
||||
The scheduler job ID is ``"blocklist_import"`` — using a stable id means
|
||||
re-registering the job (e.g. after a schedule update) safely replaces the
|
||||
existing entry without creating duplicates.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import structlog
|
||||
|
||||
from app.models.blocklist import ScheduleFrequency
|
||||
from app.services import blocklist_service
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import FastAPI
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
#: Stable APScheduler job id so the job can be replaced without duplicates.
|
||||
JOB_ID: str = "blocklist_import"
|
||||
|
||||
|
||||
async def _run_import(app: Any) -> None:
|
||||
"""APScheduler callback that imports all enabled blocklist sources.
|
||||
|
||||
Reads shared resources from ``app.state`` and delegates to
|
||||
:func:`~app.services.blocklist_service.import_all`.
|
||||
|
||||
Args:
|
||||
app: The :class:`fastapi.FastAPI` application instance passed via
|
||||
APScheduler ``kwargs``.
|
||||
"""
|
||||
db = app.state.db
|
||||
http_session = app.state.http_session
|
||||
socket_path: str = app.state.settings.fail2ban_socket
|
||||
|
||||
log.info("blocklist_import_starting")
|
||||
try:
|
||||
result = await blocklist_service.import_all(db, http_session, socket_path)
|
||||
log.info(
|
||||
"blocklist_import_finished",
|
||||
total_imported=result.total_imported,
|
||||
total_skipped=result.total_skipped,
|
||||
errors=result.errors_count,
|
||||
)
|
||||
except Exception:
|
||||
log.exception("blocklist_import_unexpected_error")
|
||||
|
||||
|
||||
def register(app: FastAPI) -> None:
|
||||
"""Add (or replace) the blocklist import job in the application scheduler.
|
||||
|
||||
Reads the persisted :class:`~app.models.blocklist.ScheduleConfig` from
|
||||
the database and translates it into the appropriate APScheduler trigger.
|
||||
|
||||
Should be called inside the lifespan handler after the scheduler and
|
||||
database have been initialised.
|
||||
|
||||
Args:
|
||||
app: The :class:`fastapi.FastAPI` application instance whose
|
||||
``app.state.scheduler`` will receive the job.
|
||||
"""
|
||||
import asyncio # noqa: PLC0415
|
||||
|
||||
async def _do_register() -> None:
|
||||
config = await blocklist_service.get_schedule(app.state.db)
|
||||
_apply_schedule(app, config)
|
||||
|
||||
# APScheduler is synchronous at registration time; use asyncio to read
|
||||
# the stored schedule from the DB before registering.
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(_do_register())
|
||||
except RuntimeError:
|
||||
# If the current thread already has a running loop (uvicorn), schedule
|
||||
# the registration as a coroutine.
|
||||
asyncio.ensure_future(_do_register())
|
||||
|
||||
|
||||
def reschedule(app: FastAPI) -> None:
|
||||
"""Re-register the blocklist import job with the latest schedule config.
|
||||
|
||||
Called by the blocklist router after a schedule update so changes take
|
||||
effect immediately without a server restart.
|
||||
|
||||
Args:
|
||||
app: The :class:`fastapi.FastAPI` application instance.
|
||||
"""
|
||||
import asyncio # noqa: PLC0415
|
||||
|
||||
async def _do_reschedule() -> None:
|
||||
config = await blocklist_service.get_schedule(app.state.db)
|
||||
_apply_schedule(app, config)
|
||||
|
||||
asyncio.ensure_future(_do_reschedule())
|
||||
|
||||
|
||||
def _apply_schedule(app: FastAPI, config: Any) -> None: # type: ignore[override]
|
||||
"""Add or replace the APScheduler cron/interval job for the given config.
|
||||
|
||||
Args:
|
||||
app: FastAPI application instance.
|
||||
config: :class:`~app.models.blocklist.ScheduleConfig` to apply.
|
||||
"""
|
||||
scheduler = app.state.scheduler
|
||||
|
||||
kwargs: dict[str, Any] = {"app": app}
|
||||
trigger_type: str
|
||||
trigger_kwargs: dict[str, Any]
|
||||
|
||||
if config.frequency == ScheduleFrequency.hourly:
|
||||
trigger_type = "interval"
|
||||
trigger_kwargs = {"hours": config.interval_hours}
|
||||
elif config.frequency == ScheduleFrequency.weekly:
|
||||
trigger_type = "cron"
|
||||
trigger_kwargs = {
|
||||
"day_of_week": config.day_of_week,
|
||||
"hour": config.hour,
|
||||
"minute": config.minute,
|
||||
}
|
||||
else: # daily (default)
|
||||
trigger_type = "cron"
|
||||
trigger_kwargs = {
|
||||
"hour": config.hour,
|
||||
"minute": config.minute,
|
||||
}
|
||||
|
||||
# Remove existing job if it exists, then add new one.
|
||||
if scheduler.get_job(JOB_ID):
|
||||
scheduler.remove_job(JOB_ID)
|
||||
|
||||
scheduler.add_job(
|
||||
_run_import,
|
||||
trigger=trigger_type,
|
||||
id=JOB_ID,
|
||||
kwargs=kwargs,
|
||||
**trigger_kwargs,
|
||||
)
|
||||
log.info(
|
||||
"blocklist_import_scheduled",
|
||||
frequency=config.frequency,
|
||||
trigger=trigger_type,
|
||||
trigger_kwargs=trigger_kwargs,
|
||||
)
|
||||
210
backend/tests/test_repositories/test_blocklist.py
Normal file
210
backend/tests/test_repositories/test_blocklist.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""Tests for blocklist_repo and import_log_repo."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import aiosqlite
|
||||
import pytest
|
||||
|
||||
from app.db import init_db
|
||||
from app.repositories import blocklist_repo, import_log_repo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db(tmp_path: Path) -> aiosqlite.Connection: # type: ignore[misc]
|
||||
"""Provide an initialised aiosqlite connection for repository tests."""
|
||||
conn: aiosqlite.Connection = await aiosqlite.connect(str(tmp_path / "bl_test.db"))
|
||||
conn.row_factory = aiosqlite.Row
|
||||
await init_db(conn)
|
||||
yield conn
|
||||
await conn.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# blocklist_repo tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBlocklistRepo:
|
||||
async def test_create_source_returns_int_id(self, db: aiosqlite.Connection) -> None:
|
||||
"""create_source returns a positive integer id."""
|
||||
source_id = await blocklist_repo.create_source(db, "Test", "https://example.com/list.txt")
|
||||
assert isinstance(source_id, int)
|
||||
assert source_id > 0
|
||||
|
||||
async def test_get_source_returns_row(self, db: aiosqlite.Connection) -> None:
|
||||
"""get_source returns the correct row after creation."""
|
||||
source_id = await blocklist_repo.create_source(db, "Alpha", "https://alpha.test/ips.txt")
|
||||
row = await blocklist_repo.get_source(db, source_id)
|
||||
assert row is not None
|
||||
assert row["name"] == "Alpha"
|
||||
assert row["url"] == "https://alpha.test/ips.txt"
|
||||
assert row["enabled"] is True
|
||||
|
||||
async def test_get_source_missing_returns_none(self, db: aiosqlite.Connection) -> None:
|
||||
"""get_source returns None for a non-existent id."""
|
||||
result = await blocklist_repo.get_source(db, 9999)
|
||||
assert result is None
|
||||
|
||||
async def test_list_sources_empty(self, db: aiosqlite.Connection) -> None:
|
||||
"""list_sources returns empty list when no sources exist."""
|
||||
rows = await blocklist_repo.list_sources(db)
|
||||
assert rows == []
|
||||
|
||||
async def test_list_sources_returns_all(self, db: aiosqlite.Connection) -> None:
|
||||
"""list_sources returns all created sources."""
|
||||
await blocklist_repo.create_source(db, "A", "https://a.test/")
|
||||
await blocklist_repo.create_source(db, "B", "https://b.test/")
|
||||
rows = await blocklist_repo.list_sources(db)
|
||||
assert len(rows) == 2
|
||||
|
||||
async def test_list_enabled_sources_filters(self, db: aiosqlite.Connection) -> None:
|
||||
"""list_enabled_sources only returns rows with enabled=True."""
|
||||
await blocklist_repo.create_source(db, "Enabled", "https://on.test/", enabled=True)
|
||||
id2 = await blocklist_repo.create_source(db, "Disabled", "https://off.test/", enabled=False)
|
||||
await blocklist_repo.update_source(db, id2, enabled=False)
|
||||
rows = await blocklist_repo.list_enabled_sources(db)
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["name"] == "Enabled"
|
||||
|
||||
async def test_update_source_name(self, db: aiosqlite.Connection) -> None:
|
||||
"""update_source changes the name field."""
|
||||
source_id = await blocklist_repo.create_source(db, "Old", "https://old.test/")
|
||||
updated = await blocklist_repo.update_source(db, source_id, name="New")
|
||||
assert updated is True
|
||||
row = await blocklist_repo.get_source(db, source_id)
|
||||
assert row is not None
|
||||
assert row["name"] == "New"
|
||||
|
||||
async def test_update_source_enabled_false(self, db: aiosqlite.Connection) -> None:
|
||||
"""update_source can disable a source."""
|
||||
source_id = await blocklist_repo.create_source(db, "On", "https://on.test/")
|
||||
await blocklist_repo.update_source(db, source_id, enabled=False)
|
||||
row = await blocklist_repo.get_source(db, source_id)
|
||||
assert row is not None
|
||||
assert row["enabled"] is False
|
||||
|
||||
async def test_update_source_missing_returns_false(self, db: aiosqlite.Connection) -> None:
|
||||
"""update_source returns False for a non-existent id."""
|
||||
result = await blocklist_repo.update_source(db, 9999, name="Ghost")
|
||||
assert result is False
|
||||
|
||||
async def test_delete_source_removes_row(self, db: aiosqlite.Connection) -> None:
|
||||
"""delete_source removes the row and returns True."""
|
||||
source_id = await blocklist_repo.create_source(db, "Del", "https://del.test/")
|
||||
deleted = await blocklist_repo.delete_source(db, source_id)
|
||||
assert deleted is True
|
||||
assert await blocklist_repo.get_source(db, source_id) is None
|
||||
|
||||
async def test_delete_source_missing_returns_false(self, db: aiosqlite.Connection) -> None:
|
||||
"""delete_source returns False for a non-existent id."""
|
||||
result = await blocklist_repo.delete_source(db, 9999)
|
||||
assert result is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# import_log_repo tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestImportLogRepo:
|
||||
async def test_add_log_returns_id(self, db: aiosqlite.Connection) -> None:
|
||||
"""add_log returns a positive integer id."""
|
||||
log_id = await import_log_repo.add_log(
|
||||
db,
|
||||
source_id=None,
|
||||
source_url="https://example.com/list.txt",
|
||||
ips_imported=10,
|
||||
ips_skipped=2,
|
||||
errors=None,
|
||||
)
|
||||
assert isinstance(log_id, int)
|
||||
assert log_id > 0
|
||||
|
||||
async def test_list_logs_returns_all(self, db: aiosqlite.Connection) -> None:
|
||||
"""list_logs returns all logs when no source_id filter is applied."""
|
||||
for i in range(3):
|
||||
await import_log_repo.add_log(
|
||||
db,
|
||||
source_id=None,
|
||||
source_url=f"https://s{i}.test/",
|
||||
ips_imported=i * 5,
|
||||
ips_skipped=0,
|
||||
errors=None,
|
||||
)
|
||||
items, total = await import_log_repo.list_logs(db)
|
||||
assert total == 3
|
||||
assert len(items) == 3
|
||||
|
||||
async def test_list_logs_pagination(self, db: aiosqlite.Connection) -> None:
|
||||
"""list_logs respects page and page_size."""
|
||||
for i in range(5):
|
||||
await import_log_repo.add_log(
|
||||
db,
|
||||
source_id=None,
|
||||
source_url=f"https://p{i}.test/",
|
||||
ips_imported=1,
|
||||
ips_skipped=0,
|
||||
errors=None,
|
||||
)
|
||||
items, total = await import_log_repo.list_logs(db, page=2, page_size=2)
|
||||
assert total == 5
|
||||
assert len(items) == 2
|
||||
|
||||
async def test_list_logs_source_filter(self, db: aiosqlite.Connection) -> None:
|
||||
"""list_logs filters by source_id."""
|
||||
source_id = await blocklist_repo.create_source(db, "Src", "https://s.test/")
|
||||
await import_log_repo.add_log(
|
||||
db,
|
||||
source_id=source_id,
|
||||
source_url="https://s.test/",
|
||||
ips_imported=5,
|
||||
ips_skipped=0,
|
||||
errors=None,
|
||||
)
|
||||
await import_log_repo.add_log(
|
||||
db,
|
||||
source_id=None,
|
||||
source_url="https://other.test/",
|
||||
ips_imported=3,
|
||||
ips_skipped=0,
|
||||
errors=None,
|
||||
)
|
||||
items, total = await import_log_repo.list_logs(db, source_id=source_id)
|
||||
assert total == 1
|
||||
assert items[0]["source_url"] == "https://s.test/"
|
||||
|
||||
async def test_get_last_log_empty(self, db: aiosqlite.Connection) -> None:
|
||||
"""get_last_log returns None when no logs exist."""
|
||||
result = await import_log_repo.get_last_log(db)
|
||||
assert result is None
|
||||
|
||||
async def test_get_last_log_returns_most_recent(self, db: aiosqlite.Connection) -> None:
|
||||
"""get_last_log returns the most recently inserted entry."""
|
||||
await import_log_repo.add_log(
|
||||
db,
|
||||
source_id=None,
|
||||
source_url="https://first.test/",
|
||||
ips_imported=1,
|
||||
ips_skipped=0,
|
||||
errors=None,
|
||||
)
|
||||
await import_log_repo.add_log(
|
||||
db,
|
||||
source_id=None,
|
||||
source_url="https://last.test/",
|
||||
ips_imported=2,
|
||||
ips_skipped=0,
|
||||
errors=None,
|
||||
)
|
||||
last = await import_log_repo.get_last_log(db)
|
||||
assert last is not None
|
||||
assert last["source_url"] == "https://last.test/"
|
||||
|
||||
async def test_compute_total_pages(self) -> None:
|
||||
"""compute_total_pages returns correct page count."""
|
||||
assert import_log_repo.compute_total_pages(0, 10) == 1
|
||||
assert import_log_repo.compute_total_pages(10, 10) == 1
|
||||
assert import_log_repo.compute_total_pages(11, 10) == 2
|
||||
assert import_log_repo.compute_total_pages(20, 5) == 4
|
||||
447
backend/tests/test_routers/test_blocklist.py
Normal file
447
backend/tests/test_routers/test_blocklist.py
Normal file
@@ -0,0 +1,447 @@
|
||||
"""Tests for the blocklist router (9 endpoints)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import aiosqlite
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from app.config import Settings
|
||||
from app.db import init_db
|
||||
from app.main import create_app
|
||||
from app.models.blocklist import (
|
||||
BlocklistListResponse,
|
||||
BlocklistSource,
|
||||
ImportLogListResponse,
|
||||
ImportRunResult,
|
||||
ImportSourceResult,
|
||||
PreviewResponse,
|
||||
ScheduleConfig,
|
||||
ScheduleFrequency,
|
||||
ScheduleInfo,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SETUP_PAYLOAD = {
|
||||
"master_password": "testpassword1",
|
||||
"database_path": "bangui.db",
|
||||
"fail2ban_socket": "/var/run/fail2ban/fail2ban.sock",
|
||||
"timezone": "UTC",
|
||||
"session_duration_minutes": 60,
|
||||
}
|
||||
|
||||
|
||||
def _make_source(source_id: int = 1) -> BlocklistSource:
|
||||
return BlocklistSource(
|
||||
id=source_id,
|
||||
name="Test Source",
|
||||
url="https://test.test/ips.txt",
|
||||
enabled=True,
|
||||
created_at="2026-01-01T00:00:00Z",
|
||||
updated_at="2026-01-01T00:00:00Z",
|
||||
)
|
||||
|
||||
|
||||
def _make_source_list() -> BlocklistListResponse:
|
||||
return BlocklistListResponse(sources=[_make_source(1), _make_source(2)])
|
||||
|
||||
|
||||
def _make_schedule_info() -> ScheduleInfo:
|
||||
return ScheduleInfo(
|
||||
config=ScheduleConfig(
|
||||
frequency=ScheduleFrequency.daily,
|
||||
interval_hours=24,
|
||||
hour=3,
|
||||
minute=0,
|
||||
day_of_week=0,
|
||||
),
|
||||
next_run_at="2026-02-01T03:00:00+00:00",
|
||||
last_run_at=None,
|
||||
)
|
||||
|
||||
|
||||
def _make_import_result() -> ImportRunResult:
|
||||
return ImportRunResult(
|
||||
results=[
|
||||
ImportSourceResult(
|
||||
source_id=1,
|
||||
source_url="https://test.test/ips.txt",
|
||||
ips_imported=5,
|
||||
ips_skipped=1,
|
||||
error=None,
|
||||
)
|
||||
],
|
||||
total_imported=5,
|
||||
total_skipped=1,
|
||||
errors_count=0,
|
||||
)
|
||||
|
||||
|
||||
def _make_log_response() -> ImportLogListResponse:
|
||||
return ImportLogListResponse(
|
||||
items=[], total=0, page=1, page_size=50, total_pages=1
|
||||
)
|
||||
|
||||
|
||||
def _make_preview() -> PreviewResponse:
|
||||
return PreviewResponse(
|
||||
entries=["1.2.3.4", "5.6.7.8"],
|
||||
total_lines=10,
|
||||
valid_count=8,
|
||||
skipped_count=2,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixture
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def bl_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
|
||||
"""Provide an authenticated AsyncClient for blocklist endpoint tests."""
|
||||
settings = Settings(
|
||||
database_path=str(tmp_path / "bl_router_test.db"),
|
||||
fail2ban_socket="/tmp/fake_fail2ban.sock",
|
||||
session_secret="test-bl-secret",
|
||||
session_duration_minutes=60,
|
||||
timezone="UTC",
|
||||
log_level="debug",
|
||||
)
|
||||
app = create_app(settings=settings)
|
||||
|
||||
db: aiosqlite.Connection = await aiosqlite.connect(settings.database_path)
|
||||
db.row_factory = aiosqlite.Row
|
||||
await init_db(db)
|
||||
app.state.db = db
|
||||
app.state.http_session = MagicMock()
|
||||
|
||||
# Provide a minimal scheduler stub so the router can call .get_job().
|
||||
scheduler_stub = MagicMock()
|
||||
scheduler_stub.get_job = MagicMock(return_value=None)
|
||||
app.state.scheduler = scheduler_stub
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
resp = await ac.post("/api/setup", json=_SETUP_PAYLOAD)
|
||||
assert resp.status_code == 201
|
||||
|
||||
login_resp = await ac.post(
|
||||
"/api/auth/login",
|
||||
json={"password": _SETUP_PAYLOAD["master_password"]},
|
||||
)
|
||||
assert login_resp.status_code == 200
|
||||
|
||||
yield ac
|
||||
|
||||
await db.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/blocklists
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestListBlocklists:
|
||||
async def test_authenticated_returns_200(self, bl_client: AsyncClient) -> None:
|
||||
"""Authenticated request to list sources returns HTTP 200."""
|
||||
with patch(
|
||||
"app.routers.blocklist.blocklist_service.list_sources",
|
||||
new=AsyncMock(return_value=_make_source_list().sources),
|
||||
):
|
||||
resp = await bl_client.get("/api/blocklists")
|
||||
assert resp.status_code == 200
|
||||
|
||||
async def test_returns_401_unauthenticated(self, client: AsyncClient) -> None:
|
||||
"""Unauthenticated request returns 401."""
|
||||
await client.post("/api/setup", json=_SETUP_PAYLOAD)
|
||||
resp = await client.get("/api/blocklists")
|
||||
assert resp.status_code == 401
|
||||
|
||||
async def test_response_contains_sources_key(self, bl_client: AsyncClient) -> None:
|
||||
"""Response body has a 'sources' array."""
|
||||
with patch(
|
||||
"app.routers.blocklist.blocklist_service.list_sources",
|
||||
new=AsyncMock(return_value=[_make_source()]),
|
||||
):
|
||||
resp = await bl_client.get("/api/blocklists")
|
||||
body = resp.json()
|
||||
assert "sources" in body
|
||||
assert isinstance(body["sources"], list)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/blocklists
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCreateBlocklist:
|
||||
async def test_create_returns_201(self, bl_client: AsyncClient) -> None:
|
||||
"""POST /api/blocklists creates a source and returns HTTP 201."""
|
||||
with patch(
|
||||
"app.routers.blocklist.blocklist_service.create_source",
|
||||
new=AsyncMock(return_value=_make_source()),
|
||||
):
|
||||
resp = await bl_client.post(
|
||||
"/api/blocklists",
|
||||
json={"name": "Test", "url": "https://test.test/", "enabled": True},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
|
||||
async def test_create_source_id_in_response(self, bl_client: AsyncClient) -> None:
|
||||
"""Created source response includes the id field."""
|
||||
with patch(
|
||||
"app.routers.blocklist.blocklist_service.create_source",
|
||||
new=AsyncMock(return_value=_make_source(42)),
|
||||
):
|
||||
resp = await bl_client.post(
|
||||
"/api/blocklists",
|
||||
json={"name": "Test", "url": "https://test.test/", "enabled": True},
|
||||
)
|
||||
assert resp.json()["id"] == 42
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PUT /api/blocklists/{id}
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestUpdateBlocklist:
|
||||
async def test_update_returns_200(self, bl_client: AsyncClient) -> None:
|
||||
"""PUT /api/blocklists/1 returns 200 for a found source."""
|
||||
updated = _make_source()
|
||||
updated.enabled = False
|
||||
with patch(
|
||||
"app.routers.blocklist.blocklist_service.update_source",
|
||||
new=AsyncMock(return_value=updated),
|
||||
):
|
||||
resp = await bl_client.put(
|
||||
"/api/blocklists/1",
|
||||
json={"enabled": False},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
async def test_update_returns_404_for_missing(self, bl_client: AsyncClient) -> None:
|
||||
"""PUT /api/blocklists/999 returns 404 when source does not exist."""
|
||||
with patch(
|
||||
"app.routers.blocklist.blocklist_service.update_source",
|
||||
new=AsyncMock(return_value=None),
|
||||
):
|
||||
resp = await bl_client.put(
|
||||
"/api/blocklists/999",
|
||||
json={"enabled": False},
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DELETE /api/blocklists/{id}
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDeleteBlocklist:
|
||||
async def test_delete_returns_204(self, bl_client: AsyncClient) -> None:
|
||||
"""DELETE /api/blocklists/1 returns 204 for a found source."""
|
||||
with patch(
|
||||
"app.routers.blocklist.blocklist_service.delete_source",
|
||||
new=AsyncMock(return_value=True),
|
||||
):
|
||||
resp = await bl_client.delete("/api/blocklists/1")
|
||||
assert resp.status_code == 204
|
||||
|
||||
async def test_delete_returns_404_for_missing(self, bl_client: AsyncClient) -> None:
|
||||
"""DELETE /api/blocklists/999 returns 404 when source does not exist."""
|
||||
with patch(
|
||||
"app.routers.blocklist.blocklist_service.delete_source",
|
||||
new=AsyncMock(return_value=False),
|
||||
):
|
||||
resp = await bl_client.delete("/api/blocklists/999")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/blocklists/{id}/preview
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPreviewBlocklist:
|
||||
async def test_preview_returns_200(self, bl_client: AsyncClient) -> None:
|
||||
"""GET /api/blocklists/1/preview returns 200 for existing source."""
|
||||
with patch(
|
||||
"app.routers.blocklist.blocklist_service.get_source",
|
||||
new=AsyncMock(return_value=_make_source()),
|
||||
), patch(
|
||||
"app.routers.blocklist.blocklist_service.preview_source",
|
||||
new=AsyncMock(return_value=_make_preview()),
|
||||
):
|
||||
resp = await bl_client.get("/api/blocklists/1/preview")
|
||||
assert resp.status_code == 200
|
||||
|
||||
async def test_preview_returns_404_for_missing(self, bl_client: AsyncClient) -> None:
|
||||
"""GET /api/blocklists/999/preview returns 404 when source not found."""
|
||||
with patch(
|
||||
"app.routers.blocklist.blocklist_service.get_source",
|
||||
new=AsyncMock(return_value=None),
|
||||
):
|
||||
resp = await bl_client.get("/api/blocklists/999/preview")
|
||||
assert resp.status_code == 404
|
||||
|
||||
async def test_preview_returns_502_on_download_error(
|
||||
self, bl_client: AsyncClient
|
||||
) -> None:
|
||||
"""GET /api/blocklists/1/preview returns 502 when URL is unreachable."""
|
||||
with patch(
|
||||
"app.routers.blocklist.blocklist_service.get_source",
|
||||
new=AsyncMock(return_value=_make_source()),
|
||||
), patch(
|
||||
"app.routers.blocklist.blocklist_service.preview_source",
|
||||
new=AsyncMock(side_effect=ValueError("Connection refused")),
|
||||
):
|
||||
resp = await bl_client.get("/api/blocklists/1/preview")
|
||||
assert resp.status_code == 502
|
||||
|
||||
async def test_preview_response_shape(self, bl_client: AsyncClient) -> None:
|
||||
"""Preview response has entries, valid_count, skipped_count, total_lines."""
|
||||
with patch(
|
||||
"app.routers.blocklist.blocklist_service.get_source",
|
||||
new=AsyncMock(return_value=_make_source()),
|
||||
), patch(
|
||||
"app.routers.blocklist.blocklist_service.preview_source",
|
||||
new=AsyncMock(return_value=_make_preview()),
|
||||
):
|
||||
resp = await bl_client.get("/api/blocklists/1/preview")
|
||||
body = resp.json()
|
||||
assert "entries" in body
|
||||
assert "valid_count" in body
|
||||
assert "skipped_count" in body
|
||||
assert "total_lines" in body
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/blocklists/import
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunImport:
|
||||
async def test_import_returns_200(self, bl_client: AsyncClient) -> None:
|
||||
"""POST /api/blocklists/import returns 200 with aggregated results."""
|
||||
with patch(
|
||||
"app.routers.blocklist.blocklist_service.import_all",
|
||||
new=AsyncMock(return_value=_make_import_result()),
|
||||
):
|
||||
resp = await bl_client.post("/api/blocklists/import")
|
||||
assert resp.status_code == 200
|
||||
|
||||
async def test_import_response_shape(self, bl_client: AsyncClient) -> None:
|
||||
"""Import response has results, total_imported, total_skipped, errors_count."""
|
||||
with patch(
|
||||
"app.routers.blocklist.blocklist_service.import_all",
|
||||
new=AsyncMock(return_value=_make_import_result()),
|
||||
):
|
||||
resp = await bl_client.post("/api/blocklists/import")
|
||||
body = resp.json()
|
||||
assert "total_imported" in body
|
||||
assert "total_skipped" in body
|
||||
assert "errors_count" in body
|
||||
assert "results" in body
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/blocklists/schedule
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetSchedule:
|
||||
async def test_schedule_returns_200(self, bl_client: AsyncClient) -> None:
|
||||
"""GET /api/blocklists/schedule returns 200."""
|
||||
with patch(
|
||||
"app.routers.blocklist.blocklist_service.get_schedule_info",
|
||||
new=AsyncMock(return_value=_make_schedule_info()),
|
||||
):
|
||||
resp = await bl_client.get("/api/blocklists/schedule")
|
||||
assert resp.status_code == 200
|
||||
|
||||
async def test_schedule_response_has_config(self, bl_client: AsyncClient) -> None:
|
||||
"""Schedule response includes the config sub-object."""
|
||||
with patch(
|
||||
"app.routers.blocklist.blocklist_service.get_schedule_info",
|
||||
new=AsyncMock(return_value=_make_schedule_info()),
|
||||
):
|
||||
resp = await bl_client.get("/api/blocklists/schedule")
|
||||
body = resp.json()
|
||||
assert "config" in body
|
||||
assert "next_run_at" in body
|
||||
assert "last_run_at" in body
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PUT /api/blocklists/schedule
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestUpdateSchedule:
|
||||
async def test_update_schedule_returns_200(self, bl_client: AsyncClient) -> None:
|
||||
"""PUT /api/blocklists/schedule persists new config and returns 200."""
|
||||
new_info = ScheduleInfo(
|
||||
config=ScheduleConfig(
|
||||
frequency=ScheduleFrequency.hourly,
|
||||
interval_hours=12,
|
||||
hour=0,
|
||||
minute=0,
|
||||
day_of_week=0,
|
||||
),
|
||||
next_run_at=None,
|
||||
last_run_at=None,
|
||||
)
|
||||
with patch(
|
||||
"app.routers.blocklist.blocklist_service.set_schedule",
|
||||
new=AsyncMock(),
|
||||
), patch(
|
||||
"app.routers.blocklist.blocklist_service.get_schedule_info",
|
||||
new=AsyncMock(return_value=new_info),
|
||||
), patch(
|
||||
"app.routers.blocklist.blocklist_import_task.reschedule",
|
||||
):
|
||||
resp = await bl_client.put(
|
||||
"/api/blocklists/schedule",
|
||||
json={
|
||||
"frequency": "hourly",
|
||||
"interval_hours": 12,
|
||||
"hour": 0,
|
||||
"minute": 0,
|
||||
"day_of_week": 0,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/blocklists/log
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestImportLog:
|
||||
async def test_log_returns_200(self, bl_client: AsyncClient) -> None:
|
||||
"""GET /api/blocklists/log returns 200."""
|
||||
resp = await bl_client.get("/api/blocklists/log")
|
||||
assert resp.status_code == 200
|
||||
|
||||
async def test_log_response_shape(self, bl_client: AsyncClient) -> None:
|
||||
"""Log response has items, total, page, page_size, total_pages."""
|
||||
resp = await bl_client.get("/api/blocklists/log")
|
||||
body = resp.json()
|
||||
for key in ("items", "total", "page", "page_size", "total_pages"):
|
||||
assert key in body
|
||||
|
||||
async def test_log_empty_when_no_runs(self, bl_client: AsyncClient) -> None:
|
||||
"""Log returns empty items list when no import runs have occurred."""
|
||||
resp = await bl_client.get("/api/blocklists/log")
|
||||
body = resp.json()
|
||||
assert body["total"] == 0
|
||||
assert body["items"] == []
|
||||
233
backend/tests/test_services/test_blocklist_service.py
Normal file
233
backend/tests/test_services/test_blocklist_service.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""Tests for blocklist_service — source CRUD, preview, import, schedule."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import aiosqlite
|
||||
import pytest
|
||||
|
||||
from app.db import init_db
|
||||
from app.models.blocklist import BlocklistSource, ScheduleConfig, ScheduleFrequency
|
||||
from app.services import blocklist_service
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixture
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db(tmp_path: Path) -> aiosqlite.Connection: # type: ignore[misc]
|
||||
"""Provide an initialised aiosqlite connection."""
|
||||
conn: aiosqlite.Connection = await aiosqlite.connect(str(tmp_path / "bl_svc.db"))
|
||||
conn.row_factory = aiosqlite.Row
|
||||
await init_db(conn)
|
||||
yield conn
|
||||
await conn.close()
|
||||
|
||||
|
||||
def _make_session(text: str, status: int = 200) -> MagicMock:
|
||||
"""Build a mock aiohttp session that returns *text* for GET requests."""
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = status
|
||||
mock_resp.text = AsyncMock(return_value=text)
|
||||
mock_resp.content = AsyncMock()
|
||||
mock_resp.content.read = AsyncMock(return_value=text.encode())
|
||||
|
||||
mock_ctx = AsyncMock()
|
||||
mock_ctx.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
session = MagicMock()
|
||||
session.get = MagicMock(return_value=mock_ctx)
|
||||
return session
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Source CRUD
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSourceCRUD:
|
||||
async def test_create_and_get(self, db: aiosqlite.Connection) -> None:
|
||||
"""create_source persists and get_source retrieves a source."""
|
||||
source = await blocklist_service.create_source(db, "Test", "https://t.test/")
|
||||
assert isinstance(source, BlocklistSource)
|
||||
assert source.name == "Test"
|
||||
assert source.enabled is True
|
||||
|
||||
fetched = await blocklist_service.get_source(db, source.id)
|
||||
assert fetched is not None
|
||||
assert fetched.id == source.id
|
||||
|
||||
async def test_get_missing_returns_none(self, db: aiosqlite.Connection) -> None:
|
||||
"""get_source returns None for a non-existent id."""
|
||||
result = await blocklist_service.get_source(db, 9999)
|
||||
assert result is None
|
||||
|
||||
async def test_list_sources_empty(self, db: aiosqlite.Connection) -> None:
|
||||
"""list_sources returns empty list when no sources exist."""
|
||||
sources = await blocklist_service.list_sources(db)
|
||||
assert sources == []
|
||||
|
||||
async def test_list_sources_returns_all(self, db: aiosqlite.Connection) -> None:
|
||||
"""list_sources returns all created sources."""
|
||||
await blocklist_service.create_source(db, "A", "https://a.test/")
|
||||
await blocklist_service.create_source(db, "B", "https://b.test/")
|
||||
sources = await blocklist_service.list_sources(db)
|
||||
assert len(sources) == 2
|
||||
|
||||
async def test_update_source_fields(self, db: aiosqlite.Connection) -> None:
|
||||
"""update_source modifies specified fields."""
|
||||
source = await blocklist_service.create_source(db, "Original", "https://orig.test/")
|
||||
updated = await blocklist_service.update_source(db, source.id, name="Updated", enabled=False)
|
||||
assert updated is not None
|
||||
assert updated.name == "Updated"
|
||||
assert updated.enabled is False
|
||||
|
||||
async def test_update_source_missing_returns_none(self, db: aiosqlite.Connection) -> None:
|
||||
"""update_source returns None for a non-existent id."""
|
||||
result = await blocklist_service.update_source(db, 9999, name="Ghost")
|
||||
assert result is None
|
||||
|
||||
async def test_delete_source(self, db: aiosqlite.Connection) -> None:
|
||||
"""delete_source removes a source and returns True."""
|
||||
source = await blocklist_service.create_source(db, "Del", "https://del.test/")
|
||||
deleted = await blocklist_service.delete_source(db, source.id)
|
||||
assert deleted is True
|
||||
assert await blocklist_service.get_source(db, source.id) is None
|
||||
|
||||
async def test_delete_source_missing_returns_false(self, db: aiosqlite.Connection) -> None:
|
||||
"""delete_source returns False for a non-existent id."""
|
||||
result = await blocklist_service.delete_source(db, 9999)
|
||||
assert result is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Preview
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPreview:
|
||||
async def test_preview_valid_ips(self) -> None:
|
||||
"""preview_source returns valid IPs from the downloaded content."""
|
||||
content = "1.2.3.4\n5.6.7.8\n# comment\ninvalid\n9.0.0.1\n"
|
||||
session = _make_session(content)
|
||||
result = await blocklist_service.preview_source("https://test.test/ips.txt", session)
|
||||
assert result.valid_count == 3
|
||||
assert result.skipped_count == 1 # "invalid"
|
||||
assert "1.2.3.4" in result.entries
|
||||
|
||||
async def test_preview_http_error_raises(self) -> None:
|
||||
"""preview_source raises ValueError when the server returns non-200."""
|
||||
session = _make_session("", status=404)
|
||||
with pytest.raises(ValueError, match="HTTP 404"):
|
||||
await blocklist_service.preview_source("https://bad.test/", session)
|
||||
|
||||
async def test_preview_limits_entries(self) -> None:
|
||||
"""preview_source caps entries to sample_lines."""
|
||||
ips = "\n".join(f"1.2.3.{i}" for i in range(50))
|
||||
session = _make_session(ips)
|
||||
result = await blocklist_service.preview_source(
|
||||
"https://test.test/", session, sample_lines=10
|
||||
)
|
||||
assert len(result.entries) <= 10
|
||||
assert result.valid_count == 50
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Import
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestImport:
|
||||
async def test_import_source_bans_valid_ips(self, db: aiosqlite.Connection) -> None:
|
||||
"""import_source calls ban_ip for every valid IP in the blocklist."""
|
||||
content = "1.2.3.4\n5.6.7.8\n# skip me\n"
|
||||
session = _make_session(content)
|
||||
|
||||
source = await blocklist_service.create_source(db, "Import Test", "https://t.test/")
|
||||
|
||||
with patch(
|
||||
"app.services.jail_service.ban_ip", new_callable=AsyncMock
|
||||
) as mock_ban:
|
||||
result = await blocklist_service.import_source(
|
||||
source, session, "/tmp/fake.sock", db
|
||||
)
|
||||
|
||||
assert result.ips_imported == 2
|
||||
assert result.ips_skipped == 0
|
||||
assert result.error is None
|
||||
assert mock_ban.call_count == 2
|
||||
|
||||
async def test_import_source_skips_cidrs(self, db: aiosqlite.Connection) -> None:
|
||||
"""import_source skips CIDR ranges (fail2ban expects individual IPs)."""
|
||||
content = "1.2.3.4\n10.0.0.0/24\n"
|
||||
session = _make_session(content)
|
||||
source = await blocklist_service.create_source(db, "CIDR Test", "https://c.test/")
|
||||
|
||||
with patch("app.services.jail_service.ban_ip", new_callable=AsyncMock):
|
||||
result = await blocklist_service.import_source(
|
||||
source, session, "/tmp/fake.sock", db
|
||||
)
|
||||
|
||||
assert result.ips_imported == 1
|
||||
assert result.ips_skipped == 1
|
||||
|
||||
async def test_import_source_records_download_error(self, db: aiosqlite.Connection) -> None:
|
||||
"""import_source records an error and returns 0 imported on HTTP failure."""
|
||||
session = _make_session("", status=503)
|
||||
source = await blocklist_service.create_source(db, "Err Source", "https://err.test/")
|
||||
|
||||
result = await blocklist_service.import_source(
|
||||
source, session, "/tmp/fake.sock", db
|
||||
)
|
||||
|
||||
assert result.ips_imported == 0
|
||||
assert result.error is not None
|
||||
|
||||
async def test_import_all_runs_all_enabled(self, db: aiosqlite.Connection) -> None:
|
||||
"""import_all aggregates results across all enabled sources."""
|
||||
await blocklist_service.create_source(db, "S1", "https://s1.test/")
|
||||
s2 = await blocklist_service.create_source(db, "S2", "https://s2.test/", enabled=False)
|
||||
_ = s2 # noqa: F841
|
||||
|
||||
content = "1.2.3.4\n5.6.7.8\n"
|
||||
session = _make_session(content)
|
||||
|
||||
with patch(
|
||||
"app.services.jail_service.ban_ip", new_callable=AsyncMock
|
||||
):
|
||||
result = await blocklist_service.import_all(db, session, "/tmp/fake.sock")
|
||||
|
||||
# Only S1 is enabled, S2 is disabled.
|
||||
assert len(result.results) == 1
|
||||
assert result.results[0].source_url == "https://s1.test/"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schedule
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSchedule:
|
||||
async def test_get_schedule_default(self, db: aiosqlite.Connection) -> None:
|
||||
"""get_schedule returns the default daily-03:00 config when nothing is saved."""
|
||||
config = await blocklist_service.get_schedule(db)
|
||||
assert config.frequency == ScheduleFrequency.daily
|
||||
assert config.hour == 3
|
||||
|
||||
async def test_set_and_get_round_trip(self, db: aiosqlite.Connection) -> None:
|
||||
"""set_schedule persists config retrievable by get_schedule."""
|
||||
cfg = ScheduleConfig(frequency=ScheduleFrequency.hourly, interval_hours=6, hour=0, minute=0, day_of_week=0)
|
||||
await blocklist_service.set_schedule(db, cfg)
|
||||
loaded = await blocklist_service.get_schedule(db)
|
||||
assert loaded.frequency == ScheduleFrequency.hourly
|
||||
assert loaded.interval_hours == 6
|
||||
|
||||
async def test_get_schedule_info_no_log(self, db: aiosqlite.Connection) -> None:
|
||||
"""get_schedule_info returns None for last_run_at when no log exists."""
|
||||
info = await blocklist_service.get_schedule_info(db, None)
|
||||
assert info.last_run_at is None
|
||||
assert info.next_run_at is None
|
||||
Reference in New Issue
Block a user