feat: Stage 1 — backend and frontend scaffolding
Backend (tasks 1.1, 1.5–1.8): - pyproject.toml with FastAPI, Pydantic v2, aiosqlite, APScheduler 3.x, structlog, bcrypt; ruff + mypy strict configured - Pydantic Settings (BANGUI_ prefix env vars, fail-fast validation) - SQLite schema: settings, sessions, blocklist_sources, import_log; WAL mode + foreign keys; idempotent init_db() - FastAPI app factory with lifespan (DB, aiohttp session, scheduler), CORS, unhandled-exception handler, GET /api/health - Fail2BanClient: async Unix-socket wrapper using run_in_executor, custom error types, async context manager - Utility modules: ip_utils, time_utils, constants - 47 tests; ruff 0 errors; mypy --strict 0 errors Frontend (tasks 1.2–1.4): - Vite + React 18 + TypeScript strict; Fluent UI v9; ESLint + Prettier - Custom brand theme (#0F6CBD, WCAG AA contrast) with light/dark variants - Typed fetch API client (ApiError, get/post/put/del) + endpoints constants - tsc --noEmit 0 errors
This commit is contained in:
22
backend/.env.example
Normal file
22
backend/.env.example
Normal file
@@ -0,0 +1,22 @@
|
||||
# BanGUI Backend — Environment Variables
|
||||
# Copy this file to .env and fill in the values.
|
||||
# Never commit .env to version control.
|
||||
|
||||
# Path to the BanGUI application SQLite database.
|
||||
BANGUI_DATABASE_PATH=bangui.db
|
||||
|
||||
# Path to the fail2ban Unix domain socket.
|
||||
BANGUI_FAIL2BAN_SOCKET=/var/run/fail2ban/fail2ban.sock
|
||||
|
||||
# Secret key used to sign session tokens. Use a long, random string.
|
||||
# Generate with: python -c "import secrets; print(secrets.token_hex(64))"
|
||||
BANGUI_SESSION_SECRET=replace-this-with-a-long-random-secret
|
||||
|
||||
# Session duration in minutes. Default: 60 minutes.
|
||||
BANGUI_SESSION_DURATION_MINUTES=60
|
||||
|
||||
# Timezone for displaying timestamps in the UI (IANA tz name).
|
||||
BANGUI_TIMEZONE=UTC
|
||||
|
||||
# Application log level: debug | info | warning | error | critical
|
||||
BANGUI_LOG_LEVEL=info
|
||||
1
backend/app/__init__.py
Normal file
1
backend/app/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""BanGUI backend application package."""
|
||||
64
backend/app/config.py
Normal file
64
backend/app/config.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""Application configuration loaded from environment variables and .env file.
|
||||
|
||||
Follows pydantic-settings patterns: all values are prefixed with BANGUI_
|
||||
and validated at startup via the Settings singleton.
|
||||
"""
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""BanGUI runtime configuration.
|
||||
|
||||
All fields are loaded from environment variables prefixed with ``BANGUI_``
|
||||
or from a ``.env`` file located next to the process working directory.
|
||||
The application will raise a :class:`pydantic.ValidationError` on startup
|
||||
if any required field is missing or has an invalid value.
|
||||
"""
|
||||
|
||||
database_path: str = Field(
|
||||
default="bangui.db",
|
||||
description="Filesystem path to the BanGUI SQLite application database.",
|
||||
)
|
||||
fail2ban_socket: str = Field(
|
||||
default="/var/run/fail2ban/fail2ban.sock",
|
||||
description="Path to the fail2ban Unix domain socket.",
|
||||
)
|
||||
session_secret: str = Field(
|
||||
...,
|
||||
description=(
|
||||
"Secret key used when generating session tokens. "
|
||||
"Must be unique and never committed to source control."
|
||||
),
|
||||
)
|
||||
session_duration_minutes: int = Field(
|
||||
default=60,
|
||||
ge=1,
|
||||
description="Number of minutes a session token remains valid after creation.",
|
||||
)
|
||||
timezone: str = Field(
|
||||
default="UTC",
|
||||
description="IANA timezone name used when displaying timestamps in the UI.",
|
||||
)
|
||||
log_level: str = Field(
|
||||
default="info",
|
||||
description="Application log level: debug | info | warning | error | critical.",
|
||||
)
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_prefix="BANGUI_",
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
case_sensitive=False,
|
||||
)
|
||||
|
||||
|
||||
def get_settings() -> Settings:
|
||||
"""Return a fresh :class:`Settings` instance loaded from the environment.
|
||||
|
||||
Returns:
|
||||
A validated :class:`Settings` object. Raises :class:`pydantic.ValidationError`
|
||||
if required keys are absent or values fail validation.
|
||||
"""
|
||||
return Settings()
|
||||
100
backend/app/db.py
Normal file
100
backend/app/db.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""Application database schema definition and initialisation.
|
||||
|
||||
BanGUI maintains its own SQLite database that stores configuration, session
|
||||
state, blocklist source definitions, and import run logs. This module is
|
||||
the single source of truth for the schema — all ``CREATE TABLE`` statements
|
||||
live here and are applied on first run via :func:`init_db`.
|
||||
|
||||
The fail2ban database is separate and is accessed read-only by the history
|
||||
and ban services.
|
||||
"""
|
||||
|
||||
import aiosqlite
|
||||
import structlog
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DDL statements
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_CREATE_SETTINGS: str = """
|
||||
CREATE TABLE IF NOT EXISTS settings (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
key TEXT NOT NULL UNIQUE,
|
||||
value TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now')),
|
||||
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now'))
|
||||
);
|
||||
"""
|
||||
|
||||
_CREATE_SESSIONS: str = """
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
token TEXT NOT NULL UNIQUE,
|
||||
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now')),
|
||||
expires_at TEXT NOT NULL
|
||||
);
|
||||
"""
|
||||
|
||||
_CREATE_SESSIONS_TOKEN_INDEX: str = """
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_sessions_token ON sessions (token);
|
||||
"""
|
||||
|
||||
_CREATE_BLOCKLIST_SOURCES: str = """
|
||||
CREATE TABLE IF NOT EXISTS blocklist_sources (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL,
|
||||
url TEXT NOT NULL UNIQUE,
|
||||
enabled INTEGER NOT NULL DEFAULT 1,
|
||||
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now')),
|
||||
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now'))
|
||||
);
|
||||
"""
|
||||
|
||||
_CREATE_IMPORT_LOG: str = """
|
||||
CREATE TABLE IF NOT EXISTS import_log (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
source_id INTEGER REFERENCES blocklist_sources(id) ON DELETE SET NULL,
|
||||
source_url TEXT NOT NULL,
|
||||
timestamp TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now')),
|
||||
ips_imported INTEGER NOT NULL DEFAULT 0,
|
||||
ips_skipped INTEGER NOT NULL DEFAULT 0,
|
||||
errors TEXT
|
||||
);
|
||||
"""
|
||||
|
||||
# Ordered list of DDL statements to execute on initialisation.
|
||||
_SCHEMA_STATEMENTS: list[str] = [
|
||||
_CREATE_SETTINGS,
|
||||
_CREATE_SESSIONS,
|
||||
_CREATE_SESSIONS_TOKEN_INDEX,
|
||||
_CREATE_BLOCKLIST_SOURCES,
|
||||
_CREATE_IMPORT_LOG,
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def init_db(db: aiosqlite.Connection) -> None:
|
||||
"""Create all BanGUI application tables if they do not already exist.
|
||||
|
||||
This function is idempotent — calling it on an already-initialised
|
||||
database has no effect. It should be called once during application
|
||||
startup inside the FastAPI lifespan handler.
|
||||
|
||||
Args:
|
||||
db: An open :class:`aiosqlite.Connection` to the application database.
|
||||
"""
|
||||
log.info("initialising_database_schema")
|
||||
async with db.execute("PRAGMA journal_mode=WAL;"):
|
||||
pass
|
||||
async with db.execute("PRAGMA foreign_keys=ON;"):
|
||||
pass
|
||||
for statement in _SCHEMA_STATEMENTS:
|
||||
await db.executescript(statement)
|
||||
await db.commit()
|
||||
log.info("database_schema_ready")
|
||||
56
backend/app/dependencies.py
Normal file
56
backend/app/dependencies.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""FastAPI dependency providers.
|
||||
|
||||
All ``Depends()`` callables that inject shared resources (database
|
||||
connection, settings, services, auth guard) are defined here.
|
||||
Routers import directly from this module — never from ``app.state``
|
||||
directly — to keep coupling explicit and testable.
|
||||
"""
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
import aiosqlite
|
||||
import structlog
|
||||
from fastapi import Depends, HTTPException, Request, status
|
||||
|
||||
from app.config import Settings
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
|
||||
async def get_db(request: Request) -> aiosqlite.Connection:
|
||||
"""Provide the shared :class:`aiosqlite.Connection` from ``app.state``.
|
||||
|
||||
Args:
|
||||
request: The current FastAPI request (injected automatically).
|
||||
|
||||
Returns:
|
||||
The application-wide aiosqlite connection opened during startup.
|
||||
|
||||
Raises:
|
||||
HTTPException: 503 if the database has not been initialised.
|
||||
"""
|
||||
db: aiosqlite.Connection | None = getattr(request.app.state, "db", None)
|
||||
if db is None:
|
||||
log.error("database_not_initialised")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Database is not available.",
|
||||
)
|
||||
return db
|
||||
|
||||
|
||||
async def get_settings(request: Request) -> Settings:
|
||||
"""Provide the :class:`~app.config.Settings` instance from ``app.state``.
|
||||
|
||||
Args:
|
||||
request: The current FastAPI request (injected automatically).
|
||||
|
||||
Returns:
|
||||
The application settings loaded at startup.
|
||||
"""
|
||||
return request.app.state.settings # type: ignore[no-any-return]
|
||||
|
||||
|
||||
# Convenience type aliases for route signatures.
|
||||
DbDep = Annotated[aiosqlite.Connection, Depends(get_db)]
|
||||
SettingsDep = Annotated[Settings, Depends(get_settings)]
|
||||
208
backend/app/main.py
Normal file
208
backend/app/main.py
Normal file
@@ -0,0 +1,208 @@
|
||||
"""BanGUI FastAPI application factory.
|
||||
|
||||
Call :func:`create_app` to obtain a configured :class:`fastapi.FastAPI`
|
||||
instance suitable for direct use with an ASGI server (e.g. ``uvicorn``) or
|
||||
in tests via ``httpx.AsyncClient``.
|
||||
|
||||
The lifespan handler manages all shared resources — database connection, HTTP
|
||||
session, and scheduler — so every component can rely on them being available
|
||||
on ``app.state`` throughout the request lifecycle.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
import aiohttp
|
||||
import aiosqlite
|
||||
import structlog
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler # type: ignore[import-untyped]
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from app.config import Settings, get_settings
|
||||
from app.db import init_db
|
||||
from app.routers import health
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Ensure the bundled fail2ban package is importable from fail2ban-master/
|
||||
# ---------------------------------------------------------------------------
|
||||
_FAIL2BAN_MASTER: Path = Path(__file__).resolve().parents[2] / "fail2ban-master"
|
||||
if str(_FAIL2BAN_MASTER) not in sys.path:
|
||||
sys.path.insert(0, str(_FAIL2BAN_MASTER))
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Logging configuration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _configure_logging(log_level: str) -> None:
|
||||
"""Configure structlog for production JSON output.
|
||||
|
||||
Args:
|
||||
log_level: One of ``debug``, ``info``, ``warning``, ``error``, ``critical``.
|
||||
"""
|
||||
level: int = logging.getLevelName(log_level.upper())
|
||||
logging.basicConfig(level=level, stream=sys.stdout, format="%(message)s")
|
||||
structlog.configure(
|
||||
processors=[
|
||||
structlog.contextvars.merge_contextvars,
|
||||
structlog.stdlib.filter_by_level,
|
||||
structlog.processors.TimeStamper(fmt="iso"),
|
||||
structlog.stdlib.add_logger_name,
|
||||
structlog.stdlib.add_log_level,
|
||||
structlog.stdlib.PositionalArgumentsFormatter(),
|
||||
structlog.processors.StackInfoRenderer(),
|
||||
structlog.processors.format_exc_info,
|
||||
structlog.processors.UnicodeDecoder(),
|
||||
structlog.processors.JSONRenderer(),
|
||||
],
|
||||
wrapper_class=structlog.stdlib.BoundLogger,
|
||||
context_class=dict,
|
||||
logger_factory=structlog.stdlib.LoggerFactory(),
|
||||
cache_logger_on_first_use=True,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Lifespan
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
"""Manage the lifetime of all shared application resources.
|
||||
|
||||
Resources are initialised in order on startup and released in reverse
|
||||
order on shutdown. They are stored on ``app.state`` so they are
|
||||
accessible to dependency providers and tests.
|
||||
|
||||
Args:
|
||||
app: The :class:`fastapi.FastAPI` instance being started.
|
||||
"""
|
||||
settings: Settings = app.state.settings
|
||||
_configure_logging(settings.log_level)
|
||||
|
||||
log.info("bangui_starting_up", database_path=settings.database_path)
|
||||
|
||||
# --- Application database ---
|
||||
db: aiosqlite.Connection = await aiosqlite.connect(settings.database_path)
|
||||
db.row_factory = aiosqlite.Row
|
||||
await init_db(db)
|
||||
app.state.db = db
|
||||
|
||||
# --- Shared HTTP client session ---
|
||||
http_session: aiohttp.ClientSession = aiohttp.ClientSession()
|
||||
app.state.http_session = http_session
|
||||
|
||||
# --- Background task scheduler ---
|
||||
scheduler: AsyncIOScheduler = AsyncIOScheduler(timezone="UTC")
|
||||
scheduler.start()
|
||||
app.state.scheduler = scheduler
|
||||
|
||||
log.info("bangui_started")
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
log.info("bangui_shutting_down")
|
||||
scheduler.shutdown(wait=False)
|
||||
await http_session.close()
|
||||
await db.close()
|
||||
log.info("bangui_shut_down")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Exception handlers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _unhandled_exception_handler(
|
||||
request: Request,
|
||||
exc: Exception,
|
||||
) -> JSONResponse:
|
||||
"""Return a sanitised 500 JSON response for any unhandled exception.
|
||||
|
||||
The exception is logged with full context before the response is sent.
|
||||
No stack trace is leaked to the client.
|
||||
|
||||
Args:
|
||||
request: The incoming FastAPI request.
|
||||
exc: The unhandled exception.
|
||||
|
||||
Returns:
|
||||
A :class:`fastapi.responses.JSONResponse` with status 500.
|
||||
"""
|
||||
log.error(
|
||||
"unhandled_exception",
|
||||
path=request.url.path,
|
||||
method=request.method,
|
||||
exc_info=exc,
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": "An unexpected error occurred. Please try again later."},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Application factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_app(settings: Settings | None = None) -> FastAPI:
|
||||
"""Create and configure the BanGUI FastAPI application.
|
||||
|
||||
This factory is the single entry point for creating the application.
|
||||
Tests can pass a custom ``settings`` object to override defaults
|
||||
without touching environment variables.
|
||||
|
||||
Args:
|
||||
settings: Optional pre-built :class:`~app.config.Settings` instance.
|
||||
If ``None``, settings are loaded from the environment via
|
||||
:func:`~app.config.get_settings`.
|
||||
|
||||
Returns:
|
||||
A fully configured :class:`fastapi.FastAPI` application ready for use.
|
||||
"""
|
||||
resolved_settings: Settings = settings if settings is not None else get_settings()
|
||||
|
||||
app: FastAPI = FastAPI(
|
||||
title="BanGUI",
|
||||
description="Web interface for monitoring, managing, and configuring fail2ban.",
|
||||
version="0.1.0",
|
||||
lifespan=_lifespan,
|
||||
)
|
||||
|
||||
# Store settings on app.state so the lifespan handler can access them.
|
||||
app.state.settings = resolved_settings
|
||||
|
||||
# --- CORS ---
|
||||
# In production the frontend is served by the same origin.
|
||||
# CORS is intentionally permissive only in development.
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["http://localhost:5173"], # Vite dev server
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# --- Exception handlers ---
|
||||
app.add_exception_handler(Exception, _unhandled_exception_handler)
|
||||
|
||||
# --- Routers ---
|
||||
app.include_router(health.router)
|
||||
|
||||
return app
|
||||
1
backend/app/models/__init__.py
Normal file
1
backend/app/models/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Pydantic request/response/domain models package."""
|
||||
46
backend/app/models/auth.py
Normal file
46
backend/app/models/auth.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Authentication Pydantic models.
|
||||
|
||||
Request, response, and domain models used by the auth router and service.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
"""Payload for ``POST /api/auth/login``."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
password: str = Field(..., description="Master password to authenticate with.")
|
||||
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
"""Successful login response.
|
||||
|
||||
The session token is also set as an ``HttpOnly`` cookie by the router.
|
||||
This model documents the JSON body for API-first consumers.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
token: str = Field(..., description="Session token for use in subsequent requests.")
|
||||
expires_at: str = Field(..., description="ISO 8601 UTC expiry timestamp.")
|
||||
|
||||
|
||||
class LogoutResponse(BaseModel):
|
||||
"""Response body for ``POST /api/auth/logout``."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
message: str = Field(default="Logged out successfully.")
|
||||
|
||||
|
||||
class Session(BaseModel):
|
||||
"""Internal domain model representing a persisted session record."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
id: int = Field(..., description="Auto-incremented row ID.")
|
||||
token: str = Field(..., description="Opaque session token.")
|
||||
created_at: str = Field(..., description="ISO 8601 UTC creation timestamp.")
|
||||
expires_at: str = Field(..., description="ISO 8601 UTC expiry timestamp.")
|
||||
91
backend/app/models/ban.py
Normal file
91
backend/app/models/ban.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""Ban management Pydantic models.
|
||||
|
||||
Request, response, and domain models used by the ban router and service.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class BanRequest(BaseModel):
|
||||
"""Payload for ``POST /api/bans`` (ban an IP)."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
ip: str = Field(..., description="IP address to ban.")
|
||||
jail: str = Field(..., description="Jail in which to apply the ban.")
|
||||
|
||||
|
||||
class UnbanRequest(BaseModel):
|
||||
"""Payload for ``DELETE /api/bans`` (unban an IP)."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
ip: str = Field(..., description="IP address to unban.")
|
||||
jail: str | None = Field(
|
||||
default=None,
|
||||
description="Jail to remove the ban from. ``null`` means all jails.",
|
||||
)
|
||||
unban_all: bool = Field(
|
||||
default=False,
|
||||
description="When ``true`` the IP is unbanned from every jail.",
|
||||
)
|
||||
|
||||
|
||||
class Ban(BaseModel):
|
||||
"""Domain model representing a single active or historical ban record."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
ip: str = Field(..., description="Banned IP address.")
|
||||
jail: str = Field(..., description="Jail that issued the ban.")
|
||||
banned_at: str = Field(..., description="ISO 8601 UTC timestamp of the ban.")
|
||||
expires_at: str | None = Field(
|
||||
default=None,
|
||||
description="ISO 8601 UTC expiry timestamp, or ``null`` if permanent.",
|
||||
)
|
||||
ban_count: int = Field(..., ge=1, description="Number of times this IP was banned.")
|
||||
country: str | None = Field(
|
||||
default=None,
|
||||
description="ISO 3166-1 alpha-2 country code resolved from the IP.",
|
||||
)
|
||||
|
||||
|
||||
class BanResponse(BaseModel):
|
||||
"""Response containing a single ban record."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
ban: Ban
|
||||
|
||||
|
||||
class BanListResponse(BaseModel):
|
||||
"""Paginated list of ban records."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
bans: list[Ban] = Field(default_factory=list)
|
||||
total: int = Field(..., ge=0, description="Total number of matching records.")
|
||||
|
||||
|
||||
class ActiveBan(BaseModel):
|
||||
"""A currently active ban entry returned by ``GET /api/bans/active``."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
ip: str = Field(..., description="Banned IP address.")
|
||||
jail: str = Field(..., description="Jail holding the ban.")
|
||||
banned_at: str = Field(..., description="ISO 8601 UTC start of the ban.")
|
||||
expires_at: str | None = Field(
|
||||
default=None,
|
||||
description="ISO 8601 UTC expiry, or ``null`` if permanent.",
|
||||
)
|
||||
ban_count: int = Field(..., ge=1, description="Running ban count for this IP.")
|
||||
|
||||
|
||||
class ActiveBanListResponse(BaseModel):
|
||||
"""List of all currently active bans across all jails."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
bans: list[ActiveBan] = Field(default_factory=list)
|
||||
total: int = Field(..., ge=0)
|
||||
84
backend/app/models/blocklist.py
Normal file
84
backend/app/models/blocklist.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""Blocklist source and import log Pydantic models."""
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class BlocklistSource(BaseModel):
|
||||
"""Domain model for a blocklist source definition."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
id: int
|
||||
name: str
|
||||
url: str
|
||||
enabled: bool
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
|
||||
class BlocklistSourceCreate(BaseModel):
|
||||
"""Payload for ``POST /api/blocklists``."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
name: str = Field(..., min_length=1, description="Human-readable source name.")
|
||||
url: str = Field(..., description="URL of the blocklist file.")
|
||||
enabled: bool = Field(default=True)
|
||||
|
||||
|
||||
class BlocklistSourceUpdate(BaseModel):
|
||||
"""Payload for ``PUT /api/blocklists/{id}``."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
name: str | None = Field(default=None, min_length=1)
|
||||
url: str | None = Field(default=None)
|
||||
enabled: bool | None = Field(default=None)
|
||||
|
||||
|
||||
class ImportLogEntry(BaseModel):
|
||||
"""A single blocklist import run record."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
id: int
|
||||
source_id: int | None
|
||||
source_url: str
|
||||
timestamp: str
|
||||
ips_imported: int
|
||||
ips_skipped: int
|
||||
errors: str | None
|
||||
|
||||
|
||||
class BlocklistListResponse(BaseModel):
|
||||
"""Response for ``GET /api/blocklists``."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
sources: list[BlocklistSource] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ImportLogListResponse(BaseModel):
|
||||
"""Response for ``GET /api/blocklists/log``."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
entries: list[ImportLogEntry] = Field(default_factory=list)
|
||||
total: int = Field(..., ge=0)
|
||||
|
||||
|
||||
class BlocklistSchedule(BaseModel):
|
||||
"""Current import schedule and next run information."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
hour: int = Field(..., ge=0, le=23, description="UTC hour for the daily import.")
|
||||
next_run_at: str | None = Field(default=None, description="ISO 8601 UTC timestamp of the next scheduled import.")
|
||||
|
||||
|
||||
class BlocklistScheduleUpdate(BaseModel):
|
||||
"""Payload for ``PUT /api/blocklists/schedule``."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
hour: int = Field(..., ge=0, le=23)
|
||||
57
backend/app/models/config.py
Normal file
57
backend/app/models/config.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""Configuration view/edit Pydantic models.
|
||||
|
||||
Request, response, and domain models for the config router and service.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class JailConfigUpdate(BaseModel):
|
||||
"""Payload for ``PUT /api/config/jails/{name}``."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
ban_time: int | None = Field(default=None, description="Ban duration in seconds. -1 for permanent.")
|
||||
max_retry: int | None = Field(default=None, ge=1)
|
||||
find_time: int | None = Field(default=None, ge=1)
|
||||
fail_regex: list[str] | None = Field(default=None, description="Failure detection regex patterns.")
|
||||
ignore_regex: list[str] | None = Field(default=None)
|
||||
date_pattern: str | None = Field(default=None)
|
||||
dns_mode: str | None = Field(default=None, description="DNS lookup mode: raw | warn | no.")
|
||||
enabled: bool | None = Field(default=None)
|
||||
|
||||
|
||||
class RegexTestRequest(BaseModel):
|
||||
"""Payload for ``POST /api/config/regex-test``."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
log_line: str = Field(..., description="Sample log line to test against.")
|
||||
fail_regex: str = Field(..., description="Regex pattern to match.")
|
||||
|
||||
|
||||
class RegexTestResponse(BaseModel):
|
||||
"""Result of a regex test."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
matched: bool = Field(..., description="Whether the pattern matched the log line.")
|
||||
groups: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Named groups captured by a successful match.",
|
||||
)
|
||||
error: str | None = Field(
|
||||
default=None,
|
||||
description="Compilation error message if the regex is invalid.",
|
||||
)
|
||||
|
||||
|
||||
class GlobalConfigResponse(BaseModel):
|
||||
"""Response for ``GET /api/config/global``."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
log_level: str
|
||||
log_target: str
|
||||
db_purge_age: int = Field(..., description="Seconds after which ban records are purged from the fail2ban DB.")
|
||||
db_max_matches: int = Field(..., description="Maximum stored log-line matches per ban record.")
|
||||
45
backend/app/models/history.py
Normal file
45
backend/app/models/history.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""Ban history Pydantic models."""
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class HistoryEntry(BaseModel):
|
||||
"""A single historical ban record from the fail2ban database."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
ip: str
|
||||
jail: str
|
||||
banned_at: str = Field(..., description="ISO 8601 UTC timestamp of the ban.")
|
||||
released_at: str | None = Field(default=None, description="ISO 8601 UTC timestamp when the ban expired.")
|
||||
ban_count: int = Field(..., ge=1, description="Total number of times this IP was banned.")
|
||||
country: str | None = None
|
||||
matched_lines: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class IpTimeline(BaseModel):
|
||||
"""Per-IP ban history timeline."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
ip: str
|
||||
total_bans: int = Field(..., ge=0)
|
||||
total_failures: int = Field(..., ge=0)
|
||||
events: list[HistoryEntry] = Field(default_factory=list)
|
||||
|
||||
|
||||
class HistoryListResponse(BaseModel):
|
||||
"""Paginated response for ``GET /api/history``."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
entries: list[HistoryEntry] = Field(default_factory=list)
|
||||
total: int = Field(..., ge=0)
|
||||
|
||||
|
||||
class IpHistoryResponse(BaseModel):
|
||||
"""Response for ``GET /api/history/{ip}``."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
timeline: IpTimeline
|
||||
89
backend/app/models/jail.py
Normal file
89
backend/app/models/jail.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""Jail management Pydantic models.
|
||||
|
||||
Request, response, and domain models used by the jails router and service.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class JailStatus(BaseModel):
|
||||
"""Runtime metrics for a single jail."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
currently_banned: int = Field(..., ge=0)
|
||||
total_banned: int = Field(..., ge=0)
|
||||
currently_failed: int = Field(..., ge=0)
|
||||
total_failed: int = Field(..., ge=0)
|
||||
|
||||
|
||||
class Jail(BaseModel):
|
||||
"""Domain model for a single fail2ban jail with its full configuration."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
name: str = Field(..., description="Jail name as configured in fail2ban.")
|
||||
enabled: bool = Field(..., description="Whether the jail is currently active.")
|
||||
running: bool = Field(..., description="Whether the jail backend is running.")
|
||||
idle: bool = Field(default=False, description="Whether the jail is in idle mode.")
|
||||
backend: str = Field(..., description="Log monitoring backend (e.g. polling, systemd).")
|
||||
log_paths: list[str] = Field(default_factory=list, description="Monitored log files.")
|
||||
fail_regex: list[str] = Field(default_factory=list, description="Failure detection regex patterns.")
|
||||
ignore_regex: list[str] = Field(default_factory=list, description="Regex patterns that bypass the ban logic.")
|
||||
ignore_ips: list[str] = Field(default_factory=list, description="IP addresses or CIDRs on the ignore list.")
|
||||
date_pattern: str | None = Field(default=None, description="Custom date pattern for log parsing.")
|
||||
log_encoding: str = Field(default="UTF-8", description="Log file encoding.")
|
||||
find_time: int = Field(..., description="Time window (seconds) for counting failures.")
|
||||
ban_time: int = Field(..., description="Duration (seconds) of a ban. -1 means permanent.")
|
||||
max_retry: int = Field(..., description="Number of failures before a ban is issued.")
|
||||
status: JailStatus | None = Field(default=None, description="Runtime counters.")
|
||||
|
||||
|
||||
class JailSummary(BaseModel):
|
||||
"""Lightweight jail entry for the overview list."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
name: str
|
||||
enabled: bool
|
||||
running: bool
|
||||
idle: bool
|
||||
backend: str
|
||||
find_time: int
|
||||
ban_time: int
|
||||
max_retry: int
|
||||
status: JailStatus | None = None
|
||||
|
||||
|
||||
class JailListResponse(BaseModel):
|
||||
"""Response for ``GET /api/jails``."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
jails: list[JailSummary] = Field(default_factory=list)
|
||||
total: int = Field(..., ge=0)
|
||||
|
||||
|
||||
class JailDetailResponse(BaseModel):
|
||||
"""Response for ``GET /api/jails/{name}``."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
jail: Jail
|
||||
|
||||
|
||||
class JailCommandResponse(BaseModel):
|
||||
"""Generic response for jail control commands (start, stop, reload, idle)."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
message: str
|
||||
jail: str
|
||||
|
||||
|
||||
class IgnoreIpRequest(BaseModel):
|
||||
"""Payload for adding an IP or network to a jail's ignore list."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
ip: str = Field(..., description="IP address or CIDR network to ignore.")
|
||||
58
backend/app/models/server.py
Normal file
58
backend/app/models/server.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""Server status and health-check Pydantic models.
|
||||
|
||||
Used by the dashboard router, health service, and server settings router.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class ServerStatus(BaseModel):
|
||||
"""Cached fail2ban server health snapshot."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
online: bool = Field(..., description="Whether fail2ban is reachable via its socket.")
|
||||
version: str | None = Field(default=None, description="fail2ban version string.")
|
||||
active_jails: int = Field(default=0, ge=0, description="Number of currently active jails.")
|
||||
total_bans: int = Field(default=0, ge=0, description="Aggregated current ban count across all jails.")
|
||||
total_failures: int = Field(default=0, ge=0, description="Aggregated current failure count across all jails.")
|
||||
|
||||
|
||||
class ServerStatusResponse(BaseModel):
|
||||
"""Response for ``GET /api/dashboard/status``."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
status: ServerStatus
|
||||
|
||||
|
||||
class ServerSettings(BaseModel):
|
||||
"""Domain model for fail2ban server-level settings."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
log_level: str = Field(..., description="fail2ban daemon log level.")
|
||||
log_target: str = Field(..., description="Log destination: STDOUT, STDERR, SYSLOG, or a file path.")
|
||||
syslog_socket: str | None = Field(default=None)
|
||||
db_path: str = Field(..., description="Path to the fail2ban ban history database.")
|
||||
db_purge_age: int = Field(..., description="Seconds before old records are purged.")
|
||||
db_max_matches: int = Field(..., description="Maximum stored matches per ban record.")
|
||||
|
||||
|
||||
class ServerSettingsUpdate(BaseModel):
|
||||
"""Payload for ``PUT /api/server/settings``."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
log_level: str | None = Field(default=None)
|
||||
log_target: str | None = Field(default=None)
|
||||
db_purge_age: int | None = Field(default=None, ge=0)
|
||||
db_max_matches: int | None = Field(default=None, ge=0)
|
||||
|
||||
|
||||
class ServerSettingsResponse(BaseModel):
|
||||
"""Response for ``GET /api/server/settings``."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
settings: ServerSettings
|
||||
56
backend/app/models/setup.py
Normal file
56
backend/app/models/setup.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""Setup wizard Pydantic models.
|
||||
|
||||
Request, response, and domain models for the first-run configuration wizard.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class SetupRequest(BaseModel):
|
||||
"""Payload for ``POST /api/setup``."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
master_password: str = Field(
|
||||
...,
|
||||
min_length=8,
|
||||
description="Master password that protects the BanGUI interface.",
|
||||
)
|
||||
database_path: str = Field(
|
||||
default="bangui.db",
|
||||
description="Filesystem path to the BanGUI SQLite application database.",
|
||||
)
|
||||
fail2ban_socket: str = Field(
|
||||
default="/var/run/fail2ban/fail2ban.sock",
|
||||
description="Path to the fail2ban Unix domain socket.",
|
||||
)
|
||||
timezone: str = Field(
|
||||
default="UTC",
|
||||
description="IANA timezone name used when displaying timestamps.",
|
||||
)
|
||||
session_duration_minutes: int = Field(
|
||||
default=60,
|
||||
ge=1,
|
||||
description="Number of minutes a user session remains valid.",
|
||||
)
|
||||
|
||||
|
||||
class SetupResponse(BaseModel):
|
||||
"""Response returned after a successful initial setup."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
message: str = Field(
|
||||
default="Setup completed successfully. Please log in.",
|
||||
)
|
||||
|
||||
|
||||
class SetupStatusResponse(BaseModel):
|
||||
"""Response indicating whether setup has been completed."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
completed: bool = Field(
|
||||
...,
|
||||
description="``True`` if the initial setup has already been performed.",
|
||||
)
|
||||
1
backend/app/repositories/__init__.py
Normal file
1
backend/app/repositories/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Database access layer (repositories) package."""
|
||||
1
backend/app/routers/__init__.py
Normal file
1
backend/app/routers/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""FastAPI routers package."""
|
||||
21
backend/app/routers/health.py
Normal file
21
backend/app/routers/health.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Health check router.
|
||||
|
||||
A lightweight ``GET /api/health`` endpoint that verifies the application
|
||||
is running and can serve requests. It does not probe fail2ban — that
|
||||
responsibility belongs to the health service (Stage 4).
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
router: APIRouter = APIRouter(prefix="/api", tags=["Health"])
|
||||
|
||||
|
||||
@router.get("/health", summary="Application health check")
|
||||
async def health_check() -> JSONResponse:
|
||||
"""Return a 200 response confirming the API is operational.
|
||||
|
||||
Returns:
|
||||
A JSON object with ``{"status": "ok"}``.
|
||||
"""
|
||||
return JSONResponse(content={"status": "ok"})
|
||||
1
backend/app/services/__init__.py
Normal file
1
backend/app/services/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Business logic services package."""
|
||||
1
backend/app/tasks/__init__.py
Normal file
1
backend/app/tasks/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""APScheduler background tasks package."""
|
||||
1
backend/app/utils/__init__.py
Normal file
1
backend/app/utils/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Shared utilities, helpers, and constants package."""
|
||||
78
backend/app/utils/constants.py
Normal file
78
backend/app/utils/constants.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""Application-wide constants.
|
||||
|
||||
All magic numbers, default paths, and limit values live here.
|
||||
Import from this module rather than hard-coding values in business logic.
|
||||
"""
|
||||
|
||||
from typing import Final
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# fail2ban integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
DEFAULT_FAIL2BAN_SOCKET: Final[str] = "/var/run/fail2ban/fail2ban.sock"
|
||||
"""Default path to the fail2ban Unix domain socket."""
|
||||
|
||||
FAIL2BAN_SOCKET_TIMEOUT_SECONDS: Final[float] = 5.0
|
||||
"""Maximum seconds to wait for a response from the fail2ban socket."""
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Database
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
DEFAULT_DATABASE_PATH: Final[str] = "bangui.db"
|
||||
"""Default filename for the BanGUI application SQLite database."""
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Authentication
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
DEFAULT_SESSION_DURATION_MINUTES: Final[int] = 60
|
||||
"""Default session lifetime in minutes."""
|
||||
|
||||
SESSION_TOKEN_BYTES: Final[int] = 64
|
||||
"""Number of random bytes used when generating a session token."""
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Time-range presets (used by dashboard and history endpoints)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
TIME_RANGE_24H: Final[str] = "24h"
|
||||
TIME_RANGE_7D: Final[str] = "7d"
|
||||
TIME_RANGE_30D: Final[str] = "30d"
|
||||
TIME_RANGE_365D: Final[str] = "365d"
|
||||
|
||||
VALID_TIME_RANGES: Final[frozenset[str]] = frozenset(
|
||||
{TIME_RANGE_24H, TIME_RANGE_7D, TIME_RANGE_30D, TIME_RANGE_365D}
|
||||
)
|
||||
|
||||
TIME_RANGE_HOURS: Final[dict[str, int]] = {
|
||||
TIME_RANGE_24H: 24,
|
||||
TIME_RANGE_7D: 7 * 24,
|
||||
TIME_RANGE_30D: 30 * 24,
|
||||
TIME_RANGE_365D: 365 * 24,
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pagination
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
DEFAULT_PAGE_SIZE: Final[int] = 50
|
||||
MAX_PAGE_SIZE: Final[int] = 500
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Blocklist import
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
BLOCKLIST_IMPORT_DEFAULT_HOUR: Final[int] = 3
|
||||
"""Default hour (UTC) for the nightly blocklist import job."""
|
||||
|
||||
BLOCKLIST_PREVIEW_MAX_LINES: Final[int] = 100
|
||||
"""Maximum number of IP lines returned by the blocklist preview endpoint."""
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Health check
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
HEALTH_CHECK_INTERVAL_SECONDS: Final[int] = 30
|
||||
"""How often the background health-check task polls fail2ban."""
|
||||
247
backend/app/utils/fail2ban_client.py
Normal file
247
backend/app/utils/fail2ban_client.py
Normal file
@@ -0,0 +1,247 @@
|
||||
"""Async wrapper around the fail2ban Unix domain socket protocol.
|
||||
|
||||
fail2ban uses a proprietary binary protocol over a Unix domain socket:
|
||||
commands are transmitted as pickle-serialised Python lists and responses
|
||||
are returned the same way. The protocol constants (``END``, ``CLOSE``)
|
||||
come from ``fail2ban.protocol.CSPROTO``.
|
||||
|
||||
Because the underlying socket is blocking, all I/O is dispatched to a
|
||||
thread-pool executor so the FastAPI event loop is never blocked.
|
||||
|
||||
Usage::
|
||||
|
||||
async with Fail2BanClient(socket_path="/var/run/fail2ban/fail2ban.sock") as client:
|
||||
status = await client.send(["status"])
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import socket
|
||||
from pickle import HIGHEST_PROTOCOL, dumps, loads
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import TracebackType
|
||||
|
||||
import structlog
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
# fail2ban protocol constants — inline to avoid a hard import dependency
|
||||
# at module load time (the fail2ban-master path may not be on sys.path yet
|
||||
# in some test environments).
|
||||
_PROTO_END: bytes = b"<F2B_END_COMMAND>"
|
||||
_PROTO_CLOSE: bytes = b"<F2B_CLOSE_COMMAND>"
|
||||
_PROTO_EMPTY: bytes = b""
|
||||
|
||||
# Default receive buffer size (doubles on each iteration up to max).
|
||||
_RECV_BUFSIZE_START: int = 1024
|
||||
_RECV_BUFSIZE_MAX: int = 32768
|
||||
|
||||
|
||||
class Fail2BanConnectionError(Exception):
|
||||
"""Raised when the fail2ban socket is unreachable or returns an error."""
|
||||
|
||||
def __init__(self, message: str, socket_path: str) -> None:
|
||||
"""Initialise with a human-readable message and the socket path.
|
||||
|
||||
Args:
|
||||
message: Description of the connection problem.
|
||||
socket_path: The fail2ban socket path that was targeted.
|
||||
"""
|
||||
self.socket_path: str = socket_path
|
||||
super().__init__(f"{message} (socket: {socket_path})")
|
||||
|
||||
|
||||
class Fail2BanProtocolError(Exception):
|
||||
"""Raised when the response from fail2ban cannot be parsed."""
|
||||
|
||||
|
||||
def _send_command_sync(
|
||||
socket_path: str,
|
||||
command: list[Any],
|
||||
timeout: float,
|
||||
) -> Any:
|
||||
"""Send a command to fail2ban and return the parsed response.
|
||||
|
||||
This is a **synchronous** function intended to be called from within
|
||||
:func:`asyncio.get_event_loop().run_in_executor` so that the event loop
|
||||
is not blocked.
|
||||
|
||||
Args:
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
command: List of command tokens, e.g. ``["status", "sshd"]``.
|
||||
timeout: Socket timeout in seconds.
|
||||
|
||||
Returns:
|
||||
The deserialized Python object returned by fail2ban.
|
||||
|
||||
Raises:
|
||||
Fail2BanConnectionError: If the socket cannot be reached.
|
||||
Fail2BanProtocolError: If the response cannot be unpickled.
|
||||
"""
|
||||
sock: socket.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
try:
|
||||
sock.settimeout(timeout)
|
||||
sock.connect(socket_path)
|
||||
|
||||
# Serialise and send the command.
|
||||
payload: bytes = dumps(
|
||||
list(map(_coerce_command_token, command)),
|
||||
HIGHEST_PROTOCOL,
|
||||
)
|
||||
sock.sendall(payload)
|
||||
sock.sendall(_PROTO_END)
|
||||
|
||||
# Receive until we see the end marker.
|
||||
raw: bytes = _PROTO_EMPTY
|
||||
bufsize: int = _RECV_BUFSIZE_START
|
||||
while raw.rfind(_PROTO_END, -32) == -1:
|
||||
chunk: bytes = sock.recv(bufsize)
|
||||
if not chunk:
|
||||
raise Fail2BanConnectionError(
|
||||
"Connection closed unexpectedly by fail2ban",
|
||||
socket_path,
|
||||
)
|
||||
if chunk == _PROTO_END:
|
||||
break
|
||||
raw += chunk
|
||||
if bufsize < _RECV_BUFSIZE_MAX:
|
||||
bufsize <<= 1
|
||||
|
||||
try:
|
||||
return loads(raw)
|
||||
except Exception as exc:
|
||||
raise Fail2BanProtocolError(
|
||||
f"Failed to unpickle fail2ban response: {exc}"
|
||||
) from exc
|
||||
except OSError as exc:
|
||||
raise Fail2BanConnectionError(str(exc), socket_path) from exc
|
||||
finally:
|
||||
with contextlib.suppress(OSError):
|
||||
sock.sendall(_PROTO_CLOSE + _PROTO_END)
|
||||
with contextlib.suppress(OSError):
|
||||
sock.shutdown(socket.SHUT_RDWR)
|
||||
sock.close()
|
||||
|
||||
|
||||
def _coerce_command_token(token: Any) -> Any:
|
||||
"""Coerce a command token to a type that fail2ban understands.
|
||||
|
||||
fail2ban's ``CSocket.convert`` accepts ``str``, ``bool``, ``int``,
|
||||
``float``, ``list``, ``dict``, and ``set``. Any other type is
|
||||
stringified.
|
||||
|
||||
Args:
|
||||
token: A single token from the command list.
|
||||
|
||||
Returns:
|
||||
The token in a type safe for pickle transmission to fail2ban.
|
||||
"""
|
||||
if isinstance(token, (str, bool, int, float, list, dict, set)):
|
||||
return token
|
||||
return str(token)
|
||||
|
||||
|
||||
class Fail2BanClient:
|
||||
"""Async client for communicating with the fail2ban daemon via its socket.
|
||||
|
||||
All blocking socket I/O is offloaded to the default thread-pool executor
|
||||
so the asyncio event loop remains unblocked.
|
||||
|
||||
The client can be used as an async context manager::
|
||||
|
||||
async with Fail2BanClient(socket_path) as client:
|
||||
result = await client.send(["status"])
|
||||
|
||||
Or instantiated directly and closed manually::
|
||||
|
||||
client = Fail2BanClient(socket_path)
|
||||
result = await client.send(["status"])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
socket_path: str,
|
||||
timeout: float = 5.0,
|
||||
) -> None:
|
||||
"""Initialise the client.
|
||||
|
||||
Args:
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
timeout: Socket I/O timeout in seconds.
|
||||
"""
|
||||
self.socket_path: str = socket_path
|
||||
self.timeout: float = timeout
|
||||
|
||||
async def send(self, command: list[Any]) -> Any:
|
||||
"""Send a command to fail2ban and return the response.
|
||||
|
||||
The command is serialised as a pickle list, sent to the socket, and
|
||||
the response is deserialised before being returned.
|
||||
|
||||
Args:
|
||||
command: A list of command tokens, e.g. ``["status", "sshd"]``.
|
||||
|
||||
Returns:
|
||||
The Python object returned by fail2ban (typically a list or dict).
|
||||
|
||||
Raises:
|
||||
Fail2BanConnectionError: If the socket cannot be reached or the
|
||||
connection is unexpectedly closed.
|
||||
Fail2BanProtocolError: If the response cannot be decoded.
|
||||
"""
|
||||
log.debug("fail2ban_sending_command", command=command)
|
||||
loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
|
||||
try:
|
||||
response: Any = await loop.run_in_executor(
|
||||
None,
|
||||
_send_command_sync,
|
||||
self.socket_path,
|
||||
command,
|
||||
self.timeout,
|
||||
)
|
||||
except Fail2BanConnectionError:
|
||||
log.warning(
|
||||
"fail2ban_connection_error",
|
||||
socket_path=self.socket_path,
|
||||
command=command,
|
||||
)
|
||||
raise
|
||||
except Fail2BanProtocolError:
|
||||
log.error(
|
||||
"fail2ban_protocol_error",
|
||||
socket_path=self.socket_path,
|
||||
command=command,
|
||||
)
|
||||
raise
|
||||
log.debug("fail2ban_received_response", command=command)
|
||||
return response
|
||||
|
||||
async def ping(self) -> bool:
|
||||
"""Return ``True`` if the fail2ban daemon is reachable.
|
||||
|
||||
Sends a ``ping`` command and checks for a ``pong`` response.
|
||||
|
||||
Returns:
|
||||
``True`` when the daemon responds correctly, ``False`` otherwise.
|
||||
"""
|
||||
try:
|
||||
response: Any = await self.send(["ping"])
|
||||
return bool(response == 1) # fail2ban returns 1 on successful ping
|
||||
except (Fail2BanConnectionError, Fail2BanProtocolError):
|
||||
return False
|
||||
|
||||
async def __aenter__(self) -> Fail2BanClient:
|
||||
"""Return self when used as an async context manager."""
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
"""No-op exit — each command opens and closes its own socket."""
|
||||
101
backend/app/utils/ip_utils.py
Normal file
101
backend/app/utils/ip_utils.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""IP address and CIDR range validation and normalisation utilities.
|
||||
|
||||
All IP handling in BanGUI goes through these helpers to enforce consistency
|
||||
and prevent malformed addresses from reaching fail2ban.
|
||||
"""
|
||||
|
||||
import ipaddress
|
||||
|
||||
|
||||
def is_valid_ip(address: str) -> bool:
|
||||
"""Return ``True`` if *address* is a valid IPv4 or IPv6 address.
|
||||
|
||||
Args:
|
||||
address: The string to validate.
|
||||
|
||||
Returns:
|
||||
``True`` if the string represents a valid IP address, ``False`` otherwise.
|
||||
"""
|
||||
try:
|
||||
ipaddress.ip_address(address)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def is_valid_network(cidr: str) -> bool:
|
||||
"""Return ``True`` if *cidr* is a valid IPv4 or IPv6 network in CIDR notation.
|
||||
|
||||
Args:
|
||||
cidr: The string to validate, e.g. ``"192.168.0.0/24"``.
|
||||
|
||||
Returns:
|
||||
``True`` if the string is a valid CIDR network, ``False`` otherwise.
|
||||
"""
|
||||
try:
|
||||
ipaddress.ip_network(cidr, strict=False)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def is_valid_ip_or_network(value: str) -> bool:
|
||||
"""Return ``True`` if *value* is a valid IP address or CIDR network.
|
||||
|
||||
Args:
|
||||
value: The string to validate.
|
||||
|
||||
Returns:
|
||||
``True`` if the string is a valid IP address or CIDR range.
|
||||
"""
|
||||
return is_valid_ip(value) or is_valid_network(value)
|
||||
|
||||
|
||||
def normalise_ip(address: str) -> str:
|
||||
"""Return a normalised string representation of an IP address.
|
||||
|
||||
IPv6 addresses are compressed to their canonical short form.
|
||||
IPv4 addresses are returned unchanged.
|
||||
|
||||
Args:
|
||||
address: A valid IP address string.
|
||||
|
||||
Returns:
|
||||
Normalised IP address string.
|
||||
|
||||
Raises:
|
||||
ValueError: If *address* is not a valid IP address.
|
||||
"""
|
||||
return str(ipaddress.ip_address(address))
|
||||
|
||||
|
||||
def normalise_network(cidr: str) -> str:
|
||||
"""Return a normalised string representation of a CIDR network.
|
||||
|
||||
Host bits are masked to produce the network address.
|
||||
|
||||
Args:
|
||||
cidr: A valid CIDR network string, e.g. ``"192.168.1.5/24"``.
|
||||
|
||||
Returns:
|
||||
Normalised network string, e.g. ``"192.168.1.0/24"``.
|
||||
|
||||
Raises:
|
||||
ValueError: If *cidr* is not a valid network.
|
||||
"""
|
||||
return str(ipaddress.ip_network(cidr, strict=False))
|
||||
|
||||
|
||||
def ip_version(address: str) -> int:
|
||||
"""Return 4 or 6 depending on the IP version of *address*.
|
||||
|
||||
Args:
|
||||
address: A valid IP address string.
|
||||
|
||||
Returns:
|
||||
``4`` for IPv4, ``6`` for IPv6.
|
||||
|
||||
Raises:
|
||||
ValueError: If *address* is not a valid IP address.
|
||||
"""
|
||||
return ipaddress.ip_address(address).version
|
||||
67
backend/app/utils/time_utils.py
Normal file
67
backend/app/utils/time_utils.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""Timezone-aware datetime helpers.
|
||||
|
||||
All datetimes in BanGUI are stored and transmitted in UTC.
|
||||
Conversion to the user's display timezone happens only at the presentation
|
||||
layer (frontend). These utilities provide a consistent, safe foundation
|
||||
for working with time throughout the backend.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
|
||||
|
||||
def utc_now() -> datetime.datetime:
|
||||
"""Return the current UTC time as a timezone-aware :class:`datetime.datetime`.
|
||||
|
||||
Returns:
|
||||
Current UTC datetime with ``tzinfo=datetime.UTC``.
|
||||
"""
|
||||
return datetime.datetime.now(datetime.UTC)
|
||||
|
||||
|
||||
def utc_from_timestamp(ts: float) -> datetime.datetime:
|
||||
"""Convert a POSIX timestamp to a timezone-aware UTC datetime.
|
||||
|
||||
Args:
|
||||
ts: POSIX timestamp (seconds since Unix epoch).
|
||||
|
||||
Returns:
|
||||
Timezone-aware UTC :class:`datetime.datetime`.
|
||||
"""
|
||||
return datetime.datetime.fromtimestamp(ts, tz=datetime.UTC)
|
||||
|
||||
|
||||
def add_minutes(dt: datetime.datetime, minutes: int) -> datetime.datetime:
|
||||
"""Return a new datetime that is *minutes* ahead of *dt*.
|
||||
|
||||
Args:
|
||||
dt: The source datetime (must be timezone-aware).
|
||||
minutes: Number of minutes to add. May be negative.
|
||||
|
||||
Returns:
|
||||
A new timezone-aware :class:`datetime.datetime`.
|
||||
"""
|
||||
return dt + datetime.timedelta(minutes=minutes)
|
||||
|
||||
|
||||
def is_expired(expires_at: datetime.datetime) -> bool:
|
||||
"""Return ``True`` if *expires_at* is in the past relative to UTC now.
|
||||
|
||||
Args:
|
||||
expires_at: The expiry timestamp to check (must be timezone-aware).
|
||||
|
||||
Returns:
|
||||
``True`` when the timestamp is past, ``False`` otherwise.
|
||||
"""
|
||||
return utc_now() >= expires_at
|
||||
|
||||
|
||||
def hours_ago(hours: int) -> datetime.datetime:
|
||||
"""Return a timezone-aware UTC datetime *hours* before now.
|
||||
|
||||
Args:
|
||||
hours: Number of hours to subtract from the current time.
|
||||
|
||||
Returns:
|
||||
Timezone-aware UTC :class:`datetime.datetime`.
|
||||
"""
|
||||
return utc_now() - datetime.timedelta(hours=hours)
|
||||
59
backend/pyproject.toml
Normal file
59
backend/pyproject.toml
Normal file
@@ -0,0 +1,59 @@
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "bangui-backend"
|
||||
version = "0.1.0"
|
||||
description = "BanGUI backend — fail2ban web management interface"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"fastapi>=0.115.0",
|
||||
"uvicorn[standard]>=0.32.0",
|
||||
"pydantic>=2.9.0",
|
||||
"pydantic-settings>=2.6.0",
|
||||
"aiosqlite>=0.20.0",
|
||||
"aiohttp>=3.11.0",
|
||||
"apscheduler>=3.10,<4.0",
|
||||
"structlog>=24.4.0",
|
||||
"bcrypt>=4.2.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=8.3.0",
|
||||
"pytest-asyncio>=0.24.0",
|
||||
"httpx>=0.27.0",
|
||||
"ruff>=0.8.0",
|
||||
"mypy>=1.13.0",
|
||||
"pytest-cov>=6.0.0",
|
||||
"pytest-mock>=3.14.0",
|
||||
]
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["app"]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 120
|
||||
target-version = "py312"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "W", "I", "N", "UP", "B", "C4", "SIM", "TCH"]
|
||||
ignore = ["B008"] # FastAPI uses function calls in default arguments (Depends)
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"tests/**" = ["E402"] # sys.path manipulation before imports is intentional in test helpers
|
||||
|
||||
[tool.ruff.format]
|
||||
quote-style = "double"
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.12"
|
||||
strict = true
|
||||
plugins = ["pydantic.mypy"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
pythonpath = [".", "../fail2ban-master"]
|
||||
testpaths = ["tests"]
|
||||
addopts = "--cov=app --cov-report=term-missing"
|
||||
1
backend/tests/__init__.py
Normal file
1
backend/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests package."""
|
||||
64
backend/tests/conftest.py
Normal file
64
backend/tests/conftest.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""Shared pytest fixtures for the BanGUI backend test suite.
|
||||
|
||||
All fixtures are async-compatible via pytest-asyncio. External dependencies
|
||||
(fail2ban socket, HTTP APIs) are always mocked so tests never touch real
|
||||
infrastructure.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Ensure the bundled fail2ban package is importable.
|
||||
_FAIL2BAN_MASTER: Path = Path(__file__).resolve().parents[2] / "fail2ban-master"
|
||||
if str(_FAIL2BAN_MASTER) not in sys.path:
|
||||
sys.path.insert(0, str(_FAIL2BAN_MASTER))
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from app.config import Settings
|
||||
from app.main import create_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_settings(tmp_path: Path) -> Settings:
|
||||
"""Return a ``Settings`` instance configured for testing.
|
||||
|
||||
Uses a temporary directory for the database so tests are isolated from
|
||||
each other and from the development database.
|
||||
|
||||
Args:
|
||||
tmp_path: Pytest-provided temporary directory (unique per test).
|
||||
|
||||
Returns:
|
||||
A :class:`~app.config.Settings` instance with overridden paths.
|
||||
"""
|
||||
return Settings(
|
||||
database_path=str(tmp_path / "test_bangui.db"),
|
||||
fail2ban_socket="/tmp/fake_fail2ban.sock",
|
||||
session_secret="test-secret-key-do-not-use-in-production",
|
||||
session_duration_minutes=60,
|
||||
timezone="UTC",
|
||||
log_level="debug",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client(test_settings: Settings) -> AsyncClient:
|
||||
"""Provide an ``AsyncClient`` wired to a test instance of the BanGUI app.
|
||||
|
||||
The client sends requests directly to the ASGI application (no network).
|
||||
A fresh database is created for each test.
|
||||
|
||||
Args:
|
||||
test_settings: Injected test settings fixture.
|
||||
|
||||
Yields:
|
||||
An :class:`httpx.AsyncClient` with ``base_url="http://test"``.
|
||||
"""
|
||||
app = create_app(settings=test_settings)
|
||||
transport: ASGITransport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
yield ac
|
||||
1
backend/tests/test_repositories/__init__.py
Normal file
1
backend/tests/test_repositories/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Repository test package."""
|
||||
69
backend/tests/test_repositories/test_db_init.py
Normal file
69
backend/tests/test_repositories/test_db_init.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""Tests for app.db — database schema initialisation."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import aiosqlite
|
||||
import pytest
|
||||
|
||||
from app.db import init_db
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_db_creates_settings_table(tmp_path: Path) -> None:
|
||||
"""``init_db`` must create the ``settings`` table."""
|
||||
db_path = str(tmp_path / "test.db")
|
||||
async with aiosqlite.connect(db_path) as db:
|
||||
await init_db(db)
|
||||
async with db.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='settings';"
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
assert row is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_db_creates_sessions_table(tmp_path: Path) -> None:
|
||||
"""``init_db`` must create the ``sessions`` table."""
|
||||
db_path = str(tmp_path / "test.db")
|
||||
async with aiosqlite.connect(db_path) as db:
|
||||
await init_db(db)
|
||||
async with db.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='sessions';"
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
assert row is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_db_creates_blocklist_sources_table(tmp_path: Path) -> None:
|
||||
"""``init_db`` must create the ``blocklist_sources`` table."""
|
||||
db_path = str(tmp_path / "test.db")
|
||||
async with aiosqlite.connect(db_path) as db:
|
||||
await init_db(db)
|
||||
async with db.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='blocklist_sources';"
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
assert row is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_db_creates_import_log_table(tmp_path: Path) -> None:
|
||||
"""``init_db`` must create the ``import_log`` table."""
|
||||
db_path = str(tmp_path / "test.db")
|
||||
async with aiosqlite.connect(db_path) as db:
|
||||
await init_db(db)
|
||||
async with db.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='import_log';"
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
assert row is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_db_is_idempotent(tmp_path: Path) -> None:
|
||||
"""Calling ``init_db`` twice on the same database must not raise."""
|
||||
db_path = str(tmp_path / "test.db")
|
||||
async with aiosqlite.connect(db_path) as db:
|
||||
await init_db(db)
|
||||
await init_db(db) # Second call must be a no-op.
|
||||
1
backend/tests/test_routers/__init__.py
Normal file
1
backend/tests/test_routers/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Router test package."""
|
||||
26
backend/tests/test_routers/test_health.py
Normal file
26
backend/tests/test_routers/test_health.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""Tests for the health check router."""
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_returns_200(client: AsyncClient) -> None:
|
||||
"""``GET /api/health`` must return HTTP 200."""
|
||||
response = await client.get("/api/health")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_returns_ok_status(client: AsyncClient) -> None:
|
||||
"""``GET /api/health`` must return ``{"status": "ok"}``."""
|
||||
response = await client.get("/api/health")
|
||||
data: dict[str, str] = response.json()
|
||||
assert data == {"status": "ok"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_content_type_is_json(client: AsyncClient) -> None:
|
||||
"""``GET /api/health`` must set the ``Content-Type`` header to JSON."""
|
||||
response = await client.get("/api/health")
|
||||
assert "application/json" in response.headers.get("content-type", "")
|
||||
1
backend/tests/test_services/__init__.py
Normal file
1
backend/tests/test_services/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Service test package."""
|
||||
87
backend/tests/test_services/test_fail2ban_client.py
Normal file
87
backend/tests/test_services/test_fail2ban_client.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""Tests for app.utils.fail2ban_client."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.utils.fail2ban_client import (
|
||||
Fail2BanClient,
|
||||
Fail2BanConnectionError,
|
||||
Fail2BanProtocolError,
|
||||
_send_command_sync,
|
||||
)
|
||||
|
||||
|
||||
class TestFail2BanClientPing:
|
||||
"""Tests for :meth:`Fail2BanClient.ping`."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ping_returns_true_when_daemon_responds(self) -> None:
|
||||
"""``ping()`` must return ``True`` when fail2ban responds with 1."""
|
||||
client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
|
||||
with patch.object(client, "send", new_callable=AsyncMock, return_value=1):
|
||||
result = await client.ping()
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ping_returns_false_on_connection_error(self) -> None:
|
||||
"""``ping()`` must return ``False`` when the daemon is unreachable."""
|
||||
client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
|
||||
with patch.object(
|
||||
client,
|
||||
"send",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Fail2BanConnectionError("refused", "/fake/fail2ban.sock"),
|
||||
):
|
||||
result = await client.ping()
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ping_returns_false_on_protocol_error(self) -> None:
|
||||
"""``ping()`` must return ``False`` if the response cannot be parsed."""
|
||||
client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
|
||||
with patch.object(
|
||||
client,
|
||||
"send",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Fail2BanProtocolError("bad pickle"),
|
||||
):
|
||||
result = await client.ping()
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestFail2BanClientContextManager:
|
||||
"""Tests for the async context manager protocol."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_manager_returns_self(self) -> None:
|
||||
"""``async with Fail2BanClient(...)`` must yield the client itself."""
|
||||
client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
|
||||
async with client as ctx:
|
||||
assert ctx is client
|
||||
|
||||
|
||||
class TestSendCommandSync:
|
||||
"""Tests for the synchronous :func:`_send_command_sync` helper."""
|
||||
|
||||
def test_send_command_sync_raises_connection_error_when_socket_absent(self) -> None:
|
||||
"""Must raise :class:`Fail2BanConnectionError` if the socket does not exist."""
|
||||
with pytest.raises(Fail2BanConnectionError):
|
||||
_send_command_sync(
|
||||
socket_path="/nonexistent/fail2ban.sock",
|
||||
command=["ping"],
|
||||
timeout=1.0,
|
||||
)
|
||||
|
||||
def test_send_command_sync_raises_connection_error_on_oserror(self) -> None:
|
||||
"""Must translate :class:`OSError` into :class:`Fail2BanConnectionError`."""
|
||||
with patch("socket.socket") as mock_socket_cls:
|
||||
mock_sock = MagicMock()
|
||||
mock_sock.connect.side_effect = OSError("connection refused")
|
||||
mock_socket_cls.return_value = mock_sock
|
||||
with pytest.raises(Fail2BanConnectionError):
|
||||
_send_command_sync(
|
||||
socket_path="/fake/fail2ban.sock",
|
||||
command=["status"],
|
||||
timeout=1.0,
|
||||
)
|
||||
106
backend/tests/test_services/test_ip_utils.py
Normal file
106
backend/tests/test_services/test_ip_utils.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""Tests for app.utils.ip_utils."""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.utils.ip_utils import (
|
||||
ip_version,
|
||||
is_valid_ip,
|
||||
is_valid_ip_or_network,
|
||||
is_valid_network,
|
||||
normalise_ip,
|
||||
normalise_network,
|
||||
)
|
||||
|
||||
|
||||
class TestIsValidIp:
|
||||
"""Tests for :func:`is_valid_ip`."""
|
||||
|
||||
def test_is_valid_ip_with_valid_ipv4_returns_true(self) -> None:
|
||||
assert is_valid_ip("192.168.1.1") is True
|
||||
|
||||
def test_is_valid_ip_with_valid_ipv6_returns_true(self) -> None:
|
||||
assert is_valid_ip("2001:db8::1") is True
|
||||
|
||||
def test_is_valid_ip_with_cidr_returns_false(self) -> None:
|
||||
assert is_valid_ip("10.0.0.0/8") is False
|
||||
|
||||
def test_is_valid_ip_with_empty_string_returns_false(self) -> None:
|
||||
assert is_valid_ip("") is False
|
||||
|
||||
def test_is_valid_ip_with_hostname_returns_false(self) -> None:
|
||||
assert is_valid_ip("example.com") is False
|
||||
|
||||
def test_is_valid_ip_with_loopback_returns_true(self) -> None:
|
||||
assert is_valid_ip("127.0.0.1") is True
|
||||
|
||||
|
||||
class TestIsValidNetwork:
|
||||
"""Tests for :func:`is_valid_network`."""
|
||||
|
||||
def test_is_valid_network_with_valid_cidr_returns_true(self) -> None:
|
||||
assert is_valid_network("192.168.0.0/24") is True
|
||||
|
||||
def test_is_valid_network_with_host_bits_set_returns_true(self) -> None:
|
||||
# strict=False means host bits being set is allowed.
|
||||
assert is_valid_network("192.168.0.1/24") is True
|
||||
|
||||
def test_is_valid_network_with_plain_ip_returns_true(self) -> None:
|
||||
# A bare IP is treated as a host-only /32 network — this is valid.
|
||||
assert is_valid_network("192.168.0.1") is True
|
||||
|
||||
def test_is_valid_network_with_hostname_returns_false(self) -> None:
|
||||
assert is_valid_network("example.com") is False
|
||||
|
||||
def test_is_valid_network_with_invalid_prefix_returns_false(self) -> None:
|
||||
assert is_valid_network("10.0.0.0/99") is False
|
||||
|
||||
|
||||
class TestIsValidIpOrNetwork:
|
||||
"""Tests for :func:`is_valid_ip_or_network`."""
|
||||
|
||||
def test_accepts_plain_ip(self) -> None:
|
||||
assert is_valid_ip_or_network("1.2.3.4") is True
|
||||
|
||||
def test_accepts_cidr(self) -> None:
|
||||
assert is_valid_ip_or_network("10.0.0.0/8") is True
|
||||
|
||||
def test_rejects_garbage(self) -> None:
|
||||
assert is_valid_ip_or_network("not-an-ip") is False
|
||||
|
||||
|
||||
class TestNormaliseIp:
|
||||
"""Tests for :func:`normalise_ip`."""
|
||||
|
||||
def test_normalise_ip_ipv4_unchanged(self) -> None:
|
||||
assert normalise_ip("10.20.30.40") == "10.20.30.40"
|
||||
|
||||
def test_normalise_ip_ipv6_compressed(self) -> None:
|
||||
assert normalise_ip("2001:0db8:0000:0000:0000:0000:0000:0001") == "2001:db8::1"
|
||||
|
||||
def test_normalise_ip_invalid_raises_value_error(self) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
normalise_ip("not-an-ip")
|
||||
|
||||
|
||||
class TestNormaliseNetwork:
|
||||
"""Tests for :func:`normalise_network`."""
|
||||
|
||||
def test_normalise_network_masks_host_bits(self) -> None:
|
||||
assert normalise_network("192.168.1.5/24") == "192.168.1.0/24"
|
||||
|
||||
def test_normalise_network_already_canonical(self) -> None:
|
||||
assert normalise_network("10.0.0.0/8") == "10.0.0.0/8"
|
||||
|
||||
|
||||
class TestIpVersion:
|
||||
"""Tests for :func:`ip_version`."""
|
||||
|
||||
def test_ip_version_ipv4_returns_4(self) -> None:
|
||||
assert ip_version("8.8.8.8") == 4
|
||||
|
||||
def test_ip_version_ipv6_returns_6(self) -> None:
|
||||
assert ip_version("::1") == 6
|
||||
|
||||
def test_ip_version_invalid_raises_value_error(self) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
ip_version("garbage")
|
||||
79
backend/tests/test_services/test_time_utils.py
Normal file
79
backend/tests/test_services/test_time_utils.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""Tests for app.utils.time_utils."""
|
||||
|
||||
import datetime
|
||||
|
||||
from app.utils.time_utils import add_minutes, hours_ago, is_expired, utc_from_timestamp, utc_now
|
||||
|
||||
|
||||
class TestUtcNow:
|
||||
"""Tests for :func:`utc_now`."""
|
||||
|
||||
def test_utc_now_returns_timezone_aware_datetime(self) -> None:
|
||||
result = utc_now()
|
||||
assert result.tzinfo is not None
|
||||
|
||||
def test_utc_now_timezone_is_utc(self) -> None:
|
||||
result = utc_now()
|
||||
assert result.tzinfo == datetime.UTC
|
||||
|
||||
def test_utc_now_is_recent(self) -> None:
|
||||
before = datetime.datetime.now(datetime.UTC)
|
||||
result = utc_now()
|
||||
after = datetime.datetime.now(datetime.UTC)
|
||||
assert before <= result <= after
|
||||
|
||||
|
||||
class TestUtcFromTimestamp:
|
||||
"""Tests for :func:`utc_from_timestamp`."""
|
||||
|
||||
def test_utc_from_timestamp_epoch_returns_utc_epoch(self) -> None:
|
||||
result = utc_from_timestamp(0.0)
|
||||
assert result == datetime.datetime(1970, 1, 1, tzinfo=datetime.UTC)
|
||||
|
||||
def test_utc_from_timestamp_returns_aware_datetime(self) -> None:
|
||||
result = utc_from_timestamp(1_000_000_000.0)
|
||||
assert result.tzinfo is not None
|
||||
|
||||
|
||||
class TestAddMinutes:
|
||||
"""Tests for :func:`add_minutes`."""
|
||||
|
||||
def test_add_minutes_positive(self) -> None:
|
||||
dt = datetime.datetime(2024, 1, 1, 12, 0, 0, tzinfo=datetime.UTC)
|
||||
result = add_minutes(dt, 30)
|
||||
expected = datetime.datetime(2024, 1, 1, 12, 30, 0, tzinfo=datetime.UTC)
|
||||
assert result == expected
|
||||
|
||||
def test_add_minutes_negative(self) -> None:
|
||||
dt = datetime.datetime(2024, 1, 1, 12, 0, 0, tzinfo=datetime.UTC)
|
||||
result = add_minutes(dt, -60)
|
||||
expected = datetime.datetime(2024, 1, 1, 11, 0, 0, tzinfo=datetime.UTC)
|
||||
assert result == expected
|
||||
|
||||
|
||||
class TestIsExpired:
|
||||
"""Tests for :func:`is_expired`."""
|
||||
|
||||
def test_is_expired_past_timestamp_returns_true(self) -> None:
|
||||
past = datetime.datetime(2000, 1, 1, tzinfo=datetime.UTC)
|
||||
assert is_expired(past) is True
|
||||
|
||||
def test_is_expired_future_timestamp_returns_false(self) -> None:
|
||||
future = datetime.datetime(2099, 1, 1, tzinfo=datetime.UTC)
|
||||
assert is_expired(future) is False
|
||||
|
||||
|
||||
class TestHoursAgo:
|
||||
"""Tests for :func:`hours_ago`."""
|
||||
|
||||
def test_hours_ago_returns_past_datetime(self) -> None:
|
||||
result = hours_ago(24)
|
||||
assert result < utc_now()
|
||||
|
||||
def test_hours_ago_correct_delta(self) -> None:
|
||||
before = utc_now()
|
||||
result = hours_ago(1)
|
||||
after = utc_now()
|
||||
expected_min = before - datetime.timedelta(hours=1, seconds=1)
|
||||
expected_max = after - datetime.timedelta(hours=1) + datetime.timedelta(seconds=1)
|
||||
assert expected_min <= result <= expected_max
|
||||
Reference in New Issue
Block a user