fixed tests

This commit is contained in:
2026-05-15 20:41:05 +02:00
parent 96ce516ecf
commit 77df5d5d65
50 changed files with 1482 additions and 5089 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -102,7 +102,7 @@ for (int i = 0; i < items.Count; i++)
// Step 1 — run the task prompt // Step 1 — run the task prompt
await RunCopilot(Enumerable.Empty<string>(), $"/caveman full"); await RunCopilot(Enumerable.Empty<string>(), $"/caveman full");
await RunCopilot(new[] { "--continue" }, $"read ./Docs/Instructions.md. fix the following test and only that one {item}"); await RunCopilot(new[] { "--continue" }, $"read ./Docs/Instructions.md. fix the following test and only that one. Keep in mind that i did many refactorings and test may is obsolet or need to be changed. {item}");
if (cts.IsCancellationRequested) break; if (cts.IsCancellationRequested) break;
// Step 2 — confirm completion in the same chat session // Step 2 — confirm completion in the same chat session

View File

@@ -14,6 +14,7 @@ from __future__ import annotations
from pathlib import Path from pathlib import Path
import aiosqlite import aiosqlite
from app.utils.logging_compat import get_logger from app.utils.logging_compat import get_logger
log = get_logger(__name__) log = get_logger(__name__)
@@ -246,7 +247,6 @@ CREATE INDEX IF NOT EXISTS idx_import_log_source_id_desc
} }
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Public API # Public API
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -254,6 +254,7 @@ CREATE INDEX IF NOT EXISTS idx_import_log_source_id_desc
async def _configure_connection(db: aiosqlite.Connection) -> None: async def _configure_connection(db: aiosqlite.Connection) -> None:
"""Apply hardening pragmas to a newly-opened SQLite connection.""" """Apply hardening pragmas to a newly-opened SQLite connection."""
await db.execute("PRAGMA journal_mode=WAL;")
await db.execute("PRAGMA foreign_keys=ON;") await db.execute("PRAGMA foreign_keys=ON;")
await db.execute("PRAGMA busy_timeout=5000;") await db.execute("PRAGMA busy_timeout=5000;")
@@ -271,11 +272,18 @@ async def _cleanup_wal_files(db_path: str) -> None:
Args: Args:
db_path: Path to the database file. db_path: Path to the database file.
""" """
import time
wal_path = Path(db_path + "-wal") wal_path = Path(db_path + "-wal")
shm_path = Path(db_path + "-shm") shm_path = Path(db_path + "-shm")
for path in (wal_path, shm_path): for path in (wal_path, shm_path):
if path.exists(): if path.exists():
# Skip files that were modified recently — they likely belong to an
# active connection. Only remove stale files left by crashes.
mtime = path.stat().st_mtime
if time.time() - mtime < 10:
continue
try: try:
path.unlink() path.unlink()
log.warning("orphaned_sqlite_file_removed", path=str(path)) log.warning("orphaned_sqlite_file_removed", path=str(path))
@@ -313,17 +321,17 @@ async def _parse_migration_statements(script: str) -> list[str]:
char = script[i] char = script[i]
# Skip block comments (-- ...) # Skip block comments (-- ...)
if i < len(script) - 1 and script[i:i+2] == "--": if i < len(script) - 1 and script[i : i + 2] == "--":
while i < len(script) and script[i] != "\n": while i < len(script) and script[i] != "\n":
i += 1 i += 1
i += 1 i += 1
continue continue
# Skip line comments (/* ... */) # Skip line comments (/* ... */)
if i < len(script) - 1 and script[i:i+2] == "/*": if i < len(script) - 1 and script[i : i + 2] == "/*":
i += 2 i += 2
while i < len(script) - 1: while i < len(script) - 1:
if script[i:i+2] == "*/": if script[i : i + 2] == "*/":
i += 2 i += 2
break break
i += 1 i += 1
@@ -393,7 +401,15 @@ async def _apply_migration(db: aiosqlite.Connection, version: int) -> None:
await db.execute("BEGIN IMMEDIATE;") await db.execute("BEGIN IMMEDIATE;")
for statement in statements: for statement in statements:
try:
await db.execute(statement) await db.execute(statement)
except aiosqlite.OperationalError as exc:
# Ignore duplicate column / table errors so migrations remain
# idempotent when a legacy database already has the object.
msg = str(exc).lower()
if "duplicate column name" in msg or "table" in msg and "already exists" in msg:
continue
raise
await db.execute("INSERT INTO schema_migrations (version) VALUES (?);", (version,)) await db.execute("INSERT INTO schema_migrations (version) VALUES (?);", (version,))
@@ -411,8 +427,7 @@ async def _migrate_schema(db: aiosqlite.Connection) -> None:
if current_version > _CURRENT_SCHEMA_VERSION: if current_version > _CURRENT_SCHEMA_VERSION:
raise RuntimeError( raise RuntimeError(
f"database schema version {current_version} is newer than supported " f"database schema version {current_version} is newer than supported version {_CURRENT_SCHEMA_VERSION}"
f"version {_CURRENT_SCHEMA_VERSION}"
) )
log.info("migrating_database_schema", from_version=current_version, to_version=_CURRENT_SCHEMA_VERSION) log.info("migrating_database_schema", from_version=current_version, to_version=_CURRENT_SCHEMA_VERSION)

View File

@@ -36,7 +36,6 @@ from typing import Annotated, cast
import aiohttp import aiohttp
import aiosqlite import aiosqlite
from app.utils.logging_compat import get_logger
from apscheduler.schedulers.asyncio import AsyncIOScheduler # type: ignore[import-untyped] from apscheduler.schedulers.asyncio import AsyncIOScheduler # type: ignore[import-untyped]
from fastapi import Depends, FastAPI, HTTPException, Request, status from fastapi import Depends, FastAPI, HTTPException, Request, status
@@ -45,22 +44,6 @@ from app.exceptions import RateLimitError
from app.models.auth import Session from app.models.auth import Session
from app.models.config import PendingRecovery from app.models.config import PendingRecovery
from app.models.server import ServerStatus from app.models.server import ServerStatus
from app.repositories.protocols import (
BlocklistRepository,
Fail2BanDbRepository,
GeoCacheRepository,
HistoryArchiveRepository,
ImportLogRepository,
ImportRunRepository,
SessionRepository,
SettingsRepository,
)
from app.services.geo_cache import GeoCache
from app.services.protocols import Fail2BanMetadataService
from app.utils.constants import SESSION_COOKIE_NAME
from app.utils.rate_limiter import GlobalRateLimiter
from app.utils.runtime_state import ApplicationState, JailServiceState, RuntimeState
from app.utils.session_cache import NoOpSessionCache, SessionCache
# Module-level imports for repositories and services # Module-level imports for repositories and services
# These are safe at module level since no circular dependencies exist # These are safe at module level since no circular dependencies exist
@@ -74,8 +57,25 @@ from app.repositories import (
session_repo, session_repo,
settings_repo, settings_repo,
) )
from app.repositories.protocols import (
BlocklistRepository,
Fail2BanDbRepository,
GeoCacheRepository,
HistoryArchiveRepository,
ImportLogRepository,
ImportRunRepository,
SessionRepository,
SettingsRepository,
)
from app.services import auth_service, health_service from app.services import auth_service, health_service
from app.services.fail2ban_metadata_service import default_fail2ban_metadata_service from app.services.fail2ban_metadata_service import default_fail2ban_metadata_service
from app.services.geo_cache import GeoCache
from app.services.protocols import Fail2BanMetadataService
from app.utils.constants import SESSION_COOKIE_NAME
from app.utils.logging_compat import get_logger
from app.utils.rate_limiter import GlobalRateLimiter
from app.utils.runtime_state import ApplicationState, JailServiceState, RuntimeState
from app.utils.session_cache import NoOpSessionCache, SessionCache
log = get_logger(__name__) log = get_logger(__name__)
@@ -108,6 +108,7 @@ class ApplicationContext:
#: or distributed deployments, the configured cache backend should provide #: or distributed deployments, the configured cache backend should provide
#: invalidation semantics appropriate for the deployment. #: invalidation semantics appropriate for the deployment.
def _session_cache_enabled(settings: Settings) -> bool: def _session_cache_enabled(settings: Settings) -> bool:
"""Return whether the session validation cache should be used.""" """Return whether the session validation cache should be used."""
return settings.session_cache_enabled and settings.session_cache_ttl_seconds > 0.0 return settings.session_cache_enabled and settings.session_cache_ttl_seconds > 0.0
@@ -284,6 +285,7 @@ def rate_limit_dependency(
Returns: Returns:
A callable that can be used as a FastAPI Depends() dependency. A callable that can be used as a FastAPI Depends() dependency.
""" """
async def check_rate_limit( async def check_rate_limit(
request: Request, request: Request,
rate_limiter: GlobalRateLimiterDep, rate_limiter: GlobalRateLimiterDep,
@@ -293,9 +295,7 @@ def rate_limit_dependency(
settings: Settings = request.app.state.settings settings: Settings = request.app.state.settings
client_ip = get_client_ip(request, trusted_proxies=settings.trusted_proxies) client_ip = get_client_ip(request, trusted_proxies=settings.trusted_proxies)
is_allowed, retry_after = rate_limiter.check_allowed_for_bucket( is_allowed, retry_after = rate_limiter.check_allowed_for_bucket(bucket, client_ip, max_requests, window_seconds)
bucket, client_ip, max_requests, window_seconds
)
if not is_allowed: if not is_allowed:
log.warning( log.warning(
@@ -407,6 +407,8 @@ async def get_app(request: Request) -> FastAPI:
async def get_server_status(app_context: Annotated[ApplicationContext, Depends(get_app_context)]) -> ServerStatus: async def get_server_status(app_context: Annotated[ApplicationContext, Depends(get_app_context)]) -> ServerStatus:
"""Return the cached fail2ban server status snapshot from application context.""" """Return the cached fail2ban server status snapshot from application context."""
if app_context.server_status is None:
return ServerStatus(online=False)
return app_context.server_status return app_context.server_status
@@ -654,7 +656,7 @@ async def require_auth(
if not token: if not token:
auth_header: str = request.headers.get("Authorization", "") auth_header: str = request.headers.get("Authorization", "")
if auth_header.startswith("Bearer "): if auth_header.startswith("Bearer "):
token = auth_header[len("Bearer "):] token = auth_header[len("Bearer ") :]
if not token: if not token:
raise HTTPException( raise HTTPException(

View File

@@ -72,13 +72,13 @@ from app.utils.external_logging import (
ExternalLogHandler, ExternalLogHandler,
create_external_log_handler, create_external_log_handler,
) )
from app.utils.json_formatter import JSONFormatter
from app.utils.logging_compat import get_logger
from app.utils.rate_limiter import GlobalRateLimiter from app.utils.rate_limiter import GlobalRateLimiter
from app.utils.runtime_state import ApplicationState, RuntimeState from app.utils.runtime_state import ApplicationState, RuntimeState
from app.utils.scheduler_lock import release_scheduler_lock from app.utils.scheduler_lock import release_scheduler_lock
from app.utils.session_cache import InMemorySessionCache, NoOpSessionCache from app.utils.session_cache import InMemorySessionCache, NoOpSessionCache
from app.utils.setup_state import is_setup_complete_cached, set_setup_complete_cache from app.utils.setup_state import is_setup_complete_cached, set_setup_complete_cache
from app.utils.json_formatter import JSONFormatter
from app.utils.logging_compat import get_logger
log = get_logger("bangui") log = get_logger("bangui")
@@ -125,8 +125,15 @@ def _configure_logging(log_level: str, log_file: str | None, settings: Settings
level: int = logging.getLevelName(log_level.upper()) level: int = logging.getLevelName(log_level.upper())
handlers: list[logging.Handler] = [logging.StreamHandler(sys.stdout)] handlers: list[logging.Handler] = [logging.StreamHandler(sys.stdout)]
if log_file: if log_file:
try:
os.makedirs(os.path.dirname(log_file), exist_ok=True) os.makedirs(os.path.dirname(log_file), exist_ok=True)
handlers.append(logging.FileHandler(log_file)) handlers.append(logging.FileHandler(log_file))
except (PermissionError, OSError) as exc:
log.warning(
"log_file_directory_not_created",
log_file=log_file,
error=str(exc),
)
# Suppress verbose third-party library logs that emit plain text # Suppress verbose third-party library logs that emit plain text
# through the standard library logging module. # through the standard library logging module.
@@ -163,9 +170,7 @@ def _update_session_cache(app: FastAPI, settings: Settings) -> None:
settings: The effective application settings. settings: The effective application settings.
""" """
cache_enabled = settings.session_cache_enabled and settings.session_cache_ttl_seconds > 0.0 cache_enabled = settings.session_cache_enabled and settings.session_cache_ttl_seconds > 0.0
app.state.session_cache = ( app.state.session_cache = InMemorySessionCache() if cache_enabled else NoOpSessionCache()
InMemorySessionCache() if cache_enabled else NoOpSessionCache()
)
@asynccontextmanager @asynccontextmanager
@@ -971,9 +976,7 @@ def _enforce_single_worker() -> None:
"See Docs/Deployment.md § Single-Worker Requirement." "See Docs/Deployment.md § Single-Worker Requirement."
) )
except ValueError as e: except ValueError as e:
raise RuntimeError( raise RuntimeError(f"WEB_CONCURRENCY must be an integer, got: {web_concurrency}") from e
f"WEB_CONCURRENCY must be an integer, got: {web_concurrency}"
) from e
# Check explicit BANGUI_WORKERS override (discouraged, still enforced) # Check explicit BANGUI_WORKERS override (discouraged, still enforced)
bangui_workers = os.environ.get("BANGUI_WORKERS") bangui_workers = os.environ.get("BANGUI_WORKERS")
@@ -990,9 +993,7 @@ def _enforce_single_worker() -> None:
"See Docs/Deployment.md § Single-Worker Requirement." "See Docs/Deployment.md § Single-Worker Requirement."
) )
except ValueError as e: except ValueError as e:
raise RuntimeError( raise RuntimeError(f"BANGUI_WORKERS must be an integer, got: {bangui_workers}") from e
f"BANGUI_WORKERS must be an integer, got: {bangui_workers}"
) from e
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -1165,7 +1166,6 @@ def create_app(settings: Settings | None = None) -> FastAPI:
# stack is a security-critical defect that must not slip through silently. # stack is a security-critical defect that must not slip through silently.
_assert_middleware_order(app) _assert_middleware_order(app)
# --- Exception handlers --- # --- Exception handlers ---
# #
# Exception handlers are registered from most specific to least specific. FastAPI evaluates # Exception handlers are registered from most specific to least specific. FastAPI evaluates

View File

@@ -10,13 +10,11 @@ from __future__ import annotations
from app.models.config import ( from app.models.config import (
BantimeEscalation, BantimeEscalation,
Fail2BanLogResponse,
FilterConfig, FilterConfig,
FilterListResponse, FilterListResponse,
GlobalConfigResponse, GlobalConfigResponse,
JailConfig, JailConfig,
JailConfigListResponse, JailConfigListResponse,
LogPreviewResponse,
MapColorThresholdsResponse, MapColorThresholdsResponse,
RegexTestResponse, RegexTestResponse,
ServiceStatusResponse, ServiceStatusResponse,
@@ -32,7 +30,6 @@ from app.models.config_domain import (
DomainRegexTest, DomainRegexTest,
DomainServiceStatus, DomainServiceStatus,
) )
from app.utils.pagination import create_pagination_metadata
def _map_domain_bantime_escalation(domain: DomainBantimeEscalation) -> BantimeEscalation: def _map_domain_bantime_escalation(domain: DomainBantimeEscalation) -> BantimeEscalation:
@@ -65,9 +62,7 @@ def map_domain_jail_config_to_response(domain: DomainJailConfig) -> JailConfig:
prefregex=domain.prefregex, prefregex=domain.prefregex,
actions=domain.actions, actions=domain.actions,
bantime_escalation=( bantime_escalation=(
_map_domain_bantime_escalation(domain.bantime_escalation) _map_domain_bantime_escalation(domain.bantime_escalation) if domain.bantime_escalation else None
if domain.bantime_escalation
else None
), ),
) )
@@ -151,6 +146,6 @@ def map_domain_filter_config_to_response(domain: DomainFilterConfig) -> FilterCo
def map_domain_filter_list_to_response(domain_list: DomainFilterList) -> FilterListResponse: def map_domain_filter_list_to_response(domain_list: DomainFilterList) -> FilterListResponse:
"""Convert domain filter list to response model.""" """Convert domain filter list to response model."""
return FilterListResponse( return FilterListResponse(
items=[map_domain_filter_config_to_response(f) for f in domain_list.items], filters=[map_domain_filter_config_to_response(f) for f in domain_list.items],
total=domain_list.total, total=domain_list.total,
) )

View File

@@ -8,15 +8,15 @@ from __future__ import annotations
from enum import StrEnum from enum import StrEnum
from pydantic import AnyHttpUrl, Field from pydantic import AnyHttpUrl, ConfigDict, Field
from app.models.response import BanGuiBaseModel, PaginatedListResponse from app.models.response import BanGuiBaseModel, PaginatedListResponse
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Blocklist source # Blocklist source
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class BlocklistSource(BanGuiBaseModel): class BlocklistSource(BanGuiBaseModel):
"""Domain model for a blocklist source definition.""" """Domain model for a blocklist source definition."""
@@ -27,6 +27,7 @@ class BlocklistSource(BanGuiBaseModel):
created_at: str created_at: str
updated_at: str updated_at: str
class BlocklistSourceCreate(BanGuiBaseModel): class BlocklistSourceCreate(BanGuiBaseModel):
"""Payload for ``POST /api/blocklists``. """Payload for ``POST /api/blocklists``.
@@ -39,6 +40,7 @@ class BlocklistSourceCreate(BanGuiBaseModel):
url: AnyHttpUrl = Field(..., description="URL of the blocklist file (http/https only).") url: AnyHttpUrl = Field(..., description="URL of the blocklist file (http/https only).")
enabled: bool = Field(default=True) enabled: bool = Field(default=True)
class BlocklistSourceUpdate(BanGuiBaseModel): class BlocklistSourceUpdate(BanGuiBaseModel):
"""Payload for ``PUT /api/blocklists/{id}``. All fields are optional. """Payload for ``PUT /api/blocklists/{id}``. All fields are optional.
@@ -49,15 +51,18 @@ class BlocklistSourceUpdate(BanGuiBaseModel):
url: AnyHttpUrl | None = Field(default=None) url: AnyHttpUrl | None = Field(default=None)
enabled: bool | None = Field(default=None) enabled: bool | None = Field(default=None)
class BlocklistListResponse(BanGuiBaseModel): class BlocklistListResponse(BanGuiBaseModel):
"""Response for ``GET /api/blocklists``.""" """Response for ``GET /api/blocklists``."""
sources: list[BlocklistSource] = Field(default_factory=list) sources: list[BlocklistSource] = Field(default_factory=list)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Import log # Import log
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class ImportLogEntry(BanGuiBaseModel): class ImportLogEntry(BanGuiBaseModel):
"""A single blocklist import run record.""" """A single blocklist import run record."""
@@ -69,6 +74,7 @@ class ImportLogEntry(BanGuiBaseModel):
ips_skipped: int ips_skipped: int
errors: str | None errors: str | None
class ImportLogListResponse(PaginatedListResponse[ImportLogEntry]): class ImportLogListResponse(PaginatedListResponse[ImportLogEntry]):
"""Response for ``GET /api/blocklists/log``. """Response for ``GET /api/blocklists/log``.
@@ -83,6 +89,7 @@ class ImportLogListResponse(PaginatedListResponse[ImportLogEntry]):
# Import run tracking (for idempotency) # Import run tracking (for idempotency)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class ImportRunEntry(BanGuiBaseModel): class ImportRunEntry(BanGuiBaseModel):
"""Tracks a unique blocklist import run by source and content hash. """Tracks a unique blocklist import run by source and content hash.
@@ -100,10 +107,12 @@ class ImportRunEntry(BanGuiBaseModel):
created_at: str created_at: str
updated_at: str updated_at: str
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Schedule # Schedule
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class ScheduleFrequency(StrEnum): class ScheduleFrequency(StrEnum):
"""Available import schedule frequency presets.""" """Available import schedule frequency presets."""
@@ -111,6 +120,7 @@ class ScheduleFrequency(StrEnum):
daily = "daily" daily = "daily"
weekly = "weekly" weekly = "weekly"
class ScheduleConfig(BanGuiBaseModel): class ScheduleConfig(BanGuiBaseModel):
"""Import schedule configuration. """Import schedule configuration.
@@ -121,8 +131,10 @@ class ScheduleConfig(BanGuiBaseModel):
- ``weekly``: additionally uses ``day_of_week`` (0=Monday … 6=Sunday). - ``weekly``: additionally uses ``day_of_week`` (0=Monday … 6=Sunday).
""" """
# No strict=True here: FastAPI and json.loads() both supply enum values as # FastAPI and json.loads() both supply enum values as plain strings;
# plain strings; strict mode would reject string→enum coercion. # strict mode would reject string→enum coercion, so we override the
# base model_config for this model only.
model_config = ConfigDict(strict=False)
frequency: ScheduleFrequency = ScheduleFrequency.daily frequency: ScheduleFrequency = ScheduleFrequency.daily
interval_hours: int = Field(default=24, ge=1, le=168, description="Used when frequency=hourly") interval_hours: int = Field(default=24, ge=1, le=168, description="Used when frequency=hourly")
@@ -135,6 +147,7 @@ class ScheduleConfig(BanGuiBaseModel):
description="Day of week for weekly runs (0=Monday … 6=Sunday)", description="Day of week for weekly runs (0=Monday … 6=Sunday)",
) )
class ScheduleInfo(BanGuiBaseModel): class ScheduleInfo(BanGuiBaseModel):
"""Current schedule configuration together with runtime metadata.""" """Current schedule configuration together with runtime metadata."""
@@ -144,10 +157,12 @@ class ScheduleInfo(BanGuiBaseModel):
last_run_errors: bool | None = None last_run_errors: bool | None = None
"""``True`` if the most recent import had errors, ``False`` if clean, ``None`` if never run.""" """``True`` if the most recent import had errors, ``False`` if clean, ``None`` if never run."""
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Import results # Import results
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class ImportSourceResult(BanGuiBaseModel): class ImportSourceResult(BanGuiBaseModel):
"""Result of importing a single blocklist source.""" """Result of importing a single blocklist source."""
@@ -157,6 +172,7 @@ class ImportSourceResult(BanGuiBaseModel):
ips_skipped: int ips_skipped: int
error: str | None error: str | None
class ImportRunResult(BanGuiBaseModel): class ImportRunResult(BanGuiBaseModel):
"""Aggregated result from a full import run across all enabled sources.""" """Aggregated result from a full import run across all enabled sources."""
@@ -165,10 +181,12 @@ class ImportRunResult(BanGuiBaseModel):
total_skipped: int total_skipped: int
errors_count: int errors_count: int
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Preview # Preview
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class PreviewResponse(BanGuiBaseModel): class PreviewResponse(BanGuiBaseModel):
"""Response for ``GET /api/blocklists/{id}/preview``.""" """Response for ``GET /api/blocklists/{id}/preview``."""

View File

@@ -188,7 +188,6 @@ class PaginationMetadata(BanGuiBaseModel):
) )
class PaginatedListResponse(BanGuiBaseModel, Generic[T]): class PaginatedListResponse(BanGuiBaseModel, Generic[T]):
"""Standardized paginated list response. """Standardized paginated list response.
@@ -384,6 +383,8 @@ class ErrorMetadata(TypedDict, total=False):
current_status: str current_status: str
actual_length: int actual_length: int
message: str message: str
field_errors: int
first_field: str
class ComponentHealth(BanGuiBaseModel): class ComponentHealth(BanGuiBaseModel):

View File

@@ -37,7 +37,6 @@ from app.services import (
filter_config_service, filter_config_service,
jail_config_service, jail_config_service,
) )
from app.utils.path_utils import validate_log_path
from app.utils.constants import ( from app.utils.constants import (
RATE_LIMIT_JAIL_ACTIVATE_REQUESTS, RATE_LIMIT_JAIL_ACTIVATE_REQUESTS,
RATE_LIMIT_JAIL_CREATE_REQUESTS, RATE_LIMIT_JAIL_CREATE_REQUESTS,
@@ -45,6 +44,7 @@ from app.utils.constants import (
RATE_LIMIT_JAIL_DELETE_REQUESTS, RATE_LIMIT_JAIL_DELETE_REQUESTS,
RATE_LIMIT_JAIL_UPDATE_REQUESTS, RATE_LIMIT_JAIL_UPDATE_REQUESTS,
) )
from app.utils.path_utils import validate_log_path
from app.utils.runtime_state import ( from app.utils.runtime_state import (
clear_activation_record, clear_activation_record,
clear_pending_recovery, clear_pending_recovery,
@@ -207,7 +207,8 @@ def _check_jail_deactivate_rate_limit(
) )
_NamePath = Annotated[str, Path(description='Jail name as configured in fail2ban.')] _NamePath = Annotated[str, Path(description="Jail name as configured in fail2ban.")]
@router.get( @router.get(
"", "",
@@ -240,8 +241,6 @@ async def get_jail_configs(
return config_mappers.map_domain_jail_config_list_to_response(domain_result) return config_mappers.map_domain_jail_config_list_to_response(domain_result)
@router.get( @router.get(
"/inactive", "/inactive",
response_model=InactiveJailListResponse, response_model=InactiveJailListResponse,
@@ -335,9 +334,8 @@ async def get_jail_config(
HTTPException: 502 when fail2ban is unreachable. HTTPException: 502 when fail2ban is unreachable.
""" """
domain_result = await config_service.get_jail_config(socket_path, name) domain_result = await config_service.get_jail_config(socket_path, name)
return config_mappers.map_domain_jail_config_to_response(domain_result) mapped = config_mappers.map_domain_jail_config_to_response(domain_result)
return JailConfigResponse(jail=mapped)
@router.put( @router.put(
@@ -387,8 +385,6 @@ async def update_jail_config(
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.post( @router.post(
"/{name}/logpath", "/{name}/logpath",
status_code=status.HTTP_204_NO_CONTENT, status_code=status.HTTP_204_NO_CONTENT,
@@ -430,8 +426,6 @@ async def add_log_path(
await config_service.add_log_path(socket_path, name, body) await config_service.add_log_path(socket_path, name, body)
@router.delete( @router.delete(
"/{name}/logpath", "/{name}/logpath",
status_code=status.HTTP_204_NO_CONTENT, status_code=status.HTTP_204_NO_CONTENT,
@@ -479,8 +473,6 @@ async def delete_log_path(
await config_service.delete_log_path(socket_path, name, log_path) await config_service.delete_log_path(socket_path, name, log_path)
@router.post( @router.post(
"/{name}/activate", "/{name}/activate",
response_model=JailActivationResponse, response_model=JailActivationResponse,
@@ -532,9 +524,7 @@ async def activate_jail(
""" """
req = body if body is not None else ActivateJailRequest() req = body if body is not None else ActivateJailRequest()
result = await jail_config_service.activate_jail( result = await jail_config_service.activate_jail(config_dir, socket_path, name, req, health_probe=health_probe)
config_dir, socket_path, name, req, health_probe=health_probe
)
if result.active: if result.active:
record_activation(app, name) record_activation(app, name)
@@ -542,8 +532,6 @@ async def activate_jail(
return result return result
@router.post( @router.post(
"/{name}/deactivate", "/{name}/deactivate",
response_model=JailActivationResponse, response_model=JailActivationResponse,
@@ -588,14 +576,10 @@ async def deactivate_jail(
HTTPException: 502 if fail2ban is unreachable. HTTPException: 502 if fail2ban is unreachable.
""" """
result = await jail_config_service.deactivate_jail( result = await jail_config_service.deactivate_jail(config_dir, socket_path, name, health_probe=health_probe)
config_dir, socket_path, name, health_probe=health_probe
)
return result return result
@router.delete( @router.delete(
"/{name}/local", "/{name}/local",
status_code=status.HTTP_204_NO_CONTENT, status_code=status.HTTP_204_NO_CONTENT,
@@ -645,8 +629,6 @@ async def delete_jail_local_override(
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.post( @router.post(
"/{name}/validate", "/{name}/validate",
response_model=JailValidationResult, response_model=JailValidationResult,
@@ -868,10 +850,8 @@ async def remove_action_from_jail(
action_name, action_name,
do_reload=reload, do_reload=reload,
) )
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Filter discovery endpoints (Task 2.1) # Filter discovery endpoints (Task 2.1)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------

View File

@@ -15,11 +15,11 @@ under the key ``"blocklist_schedule"``.
from __future__ import annotations from __future__ import annotations
import json import json
from datetime import UTC
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import aiohttp import aiohttp
import aiosqlite import aiosqlite
from app.utils.logging_compat import get_logger
from app.exceptions import BlocklistSourceHasLogsError from app.exceptions import BlocklistSourceHasLogsError
from app.models.blocklist import ( from app.models.blocklist import (
@@ -37,6 +37,7 @@ from app.repositories import blocklist_repo, import_log_repo, settings_repo
from app.services.blocklist_downloader import BlocklistDownloader from app.services.blocklist_downloader import BlocklistDownloader
from app.services.blocklist_import_workflow import BlocklistImportWorkflow from app.services.blocklist_import_workflow import BlocklistImportWorkflow
from app.services.blocklist_parser import BlocklistParser from app.services.blocklist_parser import BlocklistParser
from app.utils.logging_compat import get_logger
from app.utils.pagination import create_pagination_metadata from app.utils.pagination import create_pagination_metadata
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -200,9 +201,7 @@ async def update_source(
await validate_blocklist_url(url) await validate_blocklist_url(url)
updated = await blocklist_repo.update_source( updated = await blocklist_repo.update_source(db, source_id, name=name, url=url, enabled=enabled)
db, source_id, name=name, url=url, enabled=enabled
)
if not updated: if not updated:
return None return None
source = await get_source(db, source_id) source = await get_source(db, source_id)
@@ -473,8 +472,7 @@ async def get_schedule(db: aiosqlite.Connection) -> ScheduleConfig:
if raw is None: if raw is None:
return _DEFAULT_SCHEDULE return _DEFAULT_SCHEDULE
try: try:
data = json.loads(raw) return ScheduleConfig.model_validate_json(raw)
return ScheduleConfig.model_validate(data)
except (json.JSONDecodeError, ValueError) as exc: except (json.JSONDecodeError, ValueError) as exc:
log.warning("blocklist_schedule_invalid", raw=raw, error=type(exc).__name__) log.warning("blocklist_schedule_invalid", raw=raw, error=type(exc).__name__)
return _DEFAULT_SCHEDULE return _DEFAULT_SCHEDULE
@@ -493,9 +491,7 @@ async def set_schedule(
Returns: Returns:
The saved configuration (same object after validation). The saved configuration (same object after validation).
""" """
await settings_repo.set_setting( await settings_repo.set_setting(db, _SCHEDULE_SETTINGS_KEY, config.model_dump_json())
db, _SCHEDULE_SETTINGS_KEY, config.model_dump_json()
)
log.info("blocklist_schedule_updated", frequency=config.frequency, hour=config.hour) log.info("blocklist_schedule_updated", frequency=config.frequency, hour=config.hour)
return config return config
@@ -517,8 +513,12 @@ async def get_schedule_info(
""" """
config = await get_schedule(db) config = await get_schedule(db)
last_log = await import_log_repo.get_last_log(db) last_log = await import_log_repo.get_last_log(db)
last_run_at = last_log["timestamp"] if last_log else None last_run_at = None
last_run_errors: bool | None = (last_log["errors"] is not None) if last_log else None if last_log is not None:
from datetime import datetime
last_run_at = datetime.fromtimestamp(last_log.timestamp, tz=UTC).isoformat()
last_run_errors: bool | None = (last_log.errors is not None) if last_log else None
return ScheduleInfo( return ScheduleInfo(
config=config, config=config,
next_run_at=next_run_at, next_run_at=next_run_at,
@@ -574,9 +574,7 @@ async def list_import_logs(
Returns: Returns:
:class:`~app.models.blocklist.ImportLogListResponse`. :class:`~app.models.blocklist.ImportLogListResponse`.
""" """
items, total = await import_log_repo.list_logs( items, total = await import_log_repo.list_logs(db, source_id=source_id, page=page, page_size=page_size)
db, source_id=source_id, page=page, page_size=page_size
)
return ImportLogListResponse( return ImportLogListResponse(
items=[ImportLogEntry.model_validate(i) for i in items], items=[ImportLogEntry.model_validate(i) for i in items],

View File

@@ -13,8 +13,6 @@ import re
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from app.utils.logging_compat import get_logger
from app.exceptions import ( from app.exceptions import (
ConfigWriteError, ConfigWriteError,
FilterAlreadyExistsError, FilterAlreadyExistsError,
@@ -27,6 +25,7 @@ from app.exceptions import (
) )
from app.models.config import ( from app.models.config import (
AssignFilterRequest, AssignFilterRequest,
FilterConfig,
FilterConfigUpdate, FilterConfigUpdate,
FilterCreateRequest, FilterCreateRequest,
FilterUpdateRequest, FilterUpdateRequest,
@@ -46,6 +45,7 @@ from app.utils.config_file_utils import (
set_jail_local_key_sync, set_jail_local_key_sync,
) )
from app.utils.jail_socket import reload_all from app.utils.jail_socket import reload_all
from app.utils.logging_compat import get_logger
from app.utils.regex_validator import RegexTimeoutError, validate_regex_pattern from app.utils.regex_validator import RegexTimeoutError, validate_regex_pattern
log = get_logger(__name__) log = get_logger(__name__)
@@ -54,6 +54,7 @@ log = get_logger(__name__)
# Internal wrappers for shared config helpers. # Internal wrappers for shared config helpers.
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _parse_jails_sync(config_dir: Path) -> tuple[dict[str, dict[str, str]], Path]: def _parse_jails_sync(config_dir: Path) -> tuple[dict[str, dict[str, str]], Path]:
return _config_file_parse_jails_sync(config_dir) return _config_file_parse_jails_sync(config_dir)
@@ -85,6 +86,7 @@ def _resolve_filter(raw_filter: str, jail_name: str, mode: str) -> str:
result = result.replace("%(mode)s", mode) result = result.replace("%(mode)s", mode)
return result return result
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Internal helpers imported from the shared config helper module. # Internal helpers imported from the shared config helper module.
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -366,7 +368,7 @@ async def list_filters(
) )
log.info("filters_listed", total=len(filters), active=sum(1 for f in filters if f.active)) log.info("filters_listed", total=len(filters), active=sum(1 for f in filters if f.active))
return DomainFilterList(filters=filters, total=len(filters)) return DomainFilterList(items=filters, total=len(filters))
async def get_filter( async def get_filter(
@@ -428,7 +430,7 @@ async def get_filter(
else: else:
raise FilterNotFoundError(base_name) raise FilterNotFoundError(base_name)
content, has_local, source_path = await run_blocking( _read) content, has_local, source_path = await run_blocking(_read)
cfg = conffile_parser.parse_filter_file(content, name=base_name, filename=f"{base_name}.conf") cfg = conffile_parser.parse_filter_file(content, name=base_name, filename=f"{base_name}.conf")
@@ -524,7 +526,7 @@ async def update_filter(
content = conffile_parser.serialize_filter_config(merged) content = conffile_parser.serialize_filter_config(merged)
filter_d = Path(config_dir) / "filter.d" filter_d = Path(config_dir) / "filter.d"
await run_blocking( _write_filter_local_sync, filter_d, base_name, content) await run_blocking(_write_filter_local_sync, filter_d, base_name, content)
if do_reload: if do_reload:
try: try:
@@ -580,7 +582,7 @@ async def create_filter(
if conf_path.is_file() or local_path.is_file(): if conf_path.is_file() or local_path.is_file():
raise FilterAlreadyExistsError(req.name) raise FilterAlreadyExistsError(req.name)
await run_blocking( _check_not_exists) await run_blocking(_check_not_exists)
# Validate regex patterns. # Validate regex patterns.
patterns: list[str] = list(req.failregex) + list(req.ignoreregex) patterns: list[str] = list(req.failregex) + list(req.ignoreregex)
@@ -598,7 +600,7 @@ async def create_filter(
) )
content = conffile_parser.serialize_filter_config(cfg) content = conffile_parser.serialize_filter_config(cfg)
await run_blocking( _write_filter_local_sync, filter_d, req.name, content) await run_blocking(_write_filter_local_sync, filter_d, req.name, content)
if do_reload: if do_reload:
try: try:
@@ -663,7 +665,7 @@ async def delete_filter(
log.info("filter_local_deleted", filter=base_name, path=str(local_path)) log.info("filter_local_deleted", filter=base_name, path=str(local_path))
await run_blocking( _delete) await run_blocking(_delete)
async def assign_filter_to_jail( async def assign_filter_to_jail(
@@ -713,9 +715,10 @@ async def assign_filter_to_jail(
if not conf_exists and not local_exists: if not conf_exists and not local_exists:
raise FilterNotFoundError(req.filter_name) raise FilterNotFoundError(req.filter_name)
await run_blocking( _check_filter) await run_blocking(_check_filter)
await run_blocking(set_jail_local_key_sync, await run_blocking(
set_jail_local_key_sync,
Path(config_dir), Path(config_dir),
jail_name, jail_name,
"filter", "filter",

View File

@@ -21,10 +21,10 @@ import time
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import aiohttp import aiohttp
from app.utils.logging_compat import get_logger
from app.models.geo import GeoInfo from app.models.geo import GeoInfo
from app.repositories import geo_cache_repo from app.repositories import geo_cache_repo
from app.utils.logging_compat import get_logger
if TYPE_CHECKING: if TYPE_CHECKING:
import collections.abc import collections.abc
@@ -40,14 +40,10 @@ log = get_logger(__name__)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
#: ip-api.com single-IP lookup endpoint (HTTP only on the free tier). #: ip-api.com single-IP lookup endpoint (HTTP only on the free tier).
_API_URL: str = ( _API_URL: str = "http://ip-api.com/json/{ip}?fields=status,message,country,countryCode,org,as"
"http://ip-api.com/json/{ip}?fields=status,message,country,countryCode,org,as"
)
#: ip-api.com batch endpoint — accepts up to 100 IPs per POST. #: ip-api.com batch endpoint — accepts up to 100 IPs per POST.
_BATCH_API_URL: str = ( _BATCH_API_URL: str = "http://ip-api.com/batch?fields=status,message,country,countryCode,org,as,query"
"http://ip-api.com/batch?fields=status,message,country,countryCode,org,as,query"
)
#: Maximum IPs per batch request (ip-api.com hard limit is 100). #: Maximum IPs per batch request (ip-api.com hard limit is 100).
_BATCH_SIZE: int = 100 _BATCH_SIZE: int = 100
@@ -217,9 +213,7 @@ class GeoCache:
await self.clear_neg_cache() await self.clear_neg_cache()
geo_map = await self.lookup_batch(unresolved, http_session, db=db) geo_map = await self.lookup_batch(unresolved, http_session, db=db)
resolved_count = sum( resolved_count = sum(1 for info in geo_map.values() if info.country_code is not None)
1 for info in geo_map.values() if info.country_code is not None
)
log.info( log.info(
"geo_re_resolve_complete", "geo_re_resolve_complete",
@@ -398,7 +392,7 @@ class GeoCache:
asn=result.asn, asn=result.asn,
org=result.org, org=result.org,
) )
except (OSError) as exc: except OSError as exc:
log.warning("geo_persist_failed", ip=ip, error=type(exc).__name__) log.warning("geo_persist_failed", ip=ip, error=type(exc).__name__)
log.debug("geo_lookup_success_mmdb", ip=ip, country=result.country_code) log.debug("geo_lookup_success_mmdb", ip=ip, country=result.country_code)
return result return result
@@ -412,7 +406,7 @@ class GeoCache:
if db is not None: if db is not None:
try: try:
await geo_cache_repo.upsert_neg_entry_and_commit(db=db, ip=ip) await geo_cache_repo.upsert_neg_entry_and_commit(db=db, ip=ip)
except (OSError) as exc: except OSError as exc:
log.warning("geo_persist_neg_failed", ip=ip, error=type(exc).__name__) log.warning("geo_persist_neg_failed", ip=ip, error=type(exc).__name__)
return GeoInfo(country_code=None, country_name=None, asn=None, org=None) return GeoInfo(country_code=None, country_name=None, asn=None, org=None)
@@ -439,7 +433,7 @@ class GeoCache:
asn=result.asn, asn=result.asn,
org=result.org, org=result.org,
) )
except (OSError) as exc: except OSError as exc:
log.warning("geo_persist_failed", ip=ip, error=type(exc).__name__) log.warning("geo_persist_failed", ip=ip, error=type(exc).__name__)
log.debug("geo_lookup_success_http", ip=ip, country=result.country_code, asn=result.asn) log.debug("geo_lookup_success_http", ip=ip, country=result.country_code, asn=result.asn)
return result return result
@@ -448,7 +442,7 @@ class GeoCache:
ip=ip, ip=ip,
message=data.get("message", "unknown"), message=data.get("message", "unknown"),
) )
except (TimeoutError, aiohttp.ClientError, ValueError) as exc: except (TimeoutError, aiohttp.ClientError, ValueError, OSError) as exc:
log.warning( log.warning(
"geo_lookup_http_request_failed", "geo_lookup_http_request_failed",
ip=ip, ip=ip,
@@ -585,7 +579,7 @@ class GeoCache:
if db is not None and pos_rows: if db is not None and pos_rows:
try: try:
await geo_cache_repo.bulk_upsert_entries_and_commit(db, pos_rows) await geo_cache_repo.bulk_upsert_entries_and_commit(db, pos_rows)
except (OSError) as exc: except OSError as exc:
log.warning( log.warning(
"geo_batch_persist_mmdb_failed", "geo_batch_persist_mmdb_failed",
count=len(pos_rows), count=len(pos_rows),
@@ -604,7 +598,7 @@ class GeoCache:
if db is not None and neg_ips: if db is not None and neg_ips:
try: try:
await geo_cache_repo.bulk_upsert_neg_entries_and_commit(db, neg_ips) await geo_cache_repo.bulk_upsert_neg_entries_and_commit(db, neg_ips)
except (OSError) as exc: except OSError as exc:
log.warning( log.warning(
"geo_batch_persist_neg_failed", "geo_batch_persist_neg_failed",
count=len(neg_ips), count=len(neg_ips),
@@ -637,9 +631,7 @@ class GeoCache:
# If every IP in the chunk came back with country_code=None and the # If every IP in the chunk came back with country_code=None and the
# batch wasn't tiny, that almost certainly means the whole request # batch wasn't tiny, that almost certainly means the whole request
# was rejected (connection reset / 429). Retry after a back-off. # was rejected (connection reset / 429). Retry after a back-off.
all_failed = all( all_failed = all(info.country_code is None for info in chunk_result.values())
info.country_code is None for info in chunk_result.values()
)
if not all_failed or attempt >= _BATCH_MAX_RETRIES: if not all_failed or attempt >= _BATCH_MAX_RETRIES:
break break
backoff = _BATCH_DELAY * (2 ** (attempt + 1)) backoff = _BATCH_DELAY * (2 ** (attempt + 1))
@@ -659,9 +651,7 @@ class GeoCache:
await self._store(ip, info) await self._store(ip, info)
geo_result[ip] = info geo_result[ip] = info
if db is not None: if db is not None:
pos_rows.append( pos_rows.append((ip, info.country_code, info.country_name, info.asn, info.org))
(ip, info.country_code, info.country_name, info.asn, info.org)
)
else: else:
# HTTP failed — record as negative cache. # HTTP failed — record as negative cache.
async with self._cache_lock: async with self._cache_lock:
@@ -677,7 +667,7 @@ class GeoCache:
pos_rows, pos_rows,
neg_ips, neg_ips,
) )
except (OSError) as exc: except OSError as exc:
log.warning( log.warning(
"geo_batch_persist_failed", "geo_batch_persist_failed",
positive_count=len(pos_rows), positive_count=len(pos_rows),
@@ -724,7 +714,7 @@ class GeoCache:
log.warning("geo_batch_non_200", status=resp.status, count=len(ips)) log.warning("geo_batch_non_200", status=resp.status, count=len(ips))
return fallback return fallback
data: list[dict[str, object]] = await resp.json(content_type=None) data: list[dict[str, object]] = await resp.json(content_type=None)
except (TimeoutError, aiohttp.ClientError, ValueError) as exc: except (TimeoutError, aiohttp.ClientError, ValueError, OSError) as exc:
log.warning( log.warning(
"geo_batch_request_failed", "geo_batch_request_failed",
count=len(ips), count=len(ips),
@@ -836,7 +826,7 @@ class GeoCache:
try: try:
await geo_cache_repo.bulk_upsert_entries_and_commit(db, rows) await geo_cache_repo.bulk_upsert_entries_and_commit(db, rows)
except (OSError) as exc: except OSError as exc:
log.warning("geo_flush_dirty_failed", error=type(exc).__name__) log.warning("geo_flush_dirty_failed", error=type(exc).__name__)
# Re-add to dirty so they are retried on the next flush cycle. # Re-add to dirty so they are retried on the next flush cycle.
self._dirty.update(to_flush) self._dirty.update(to_flush)

View File

@@ -61,17 +61,20 @@ def normalise_ip(address: str) -> str:
IPv4-mapped IPv6 addresses (e.g. ``::ffff:192.168.1.1``) are converted IPv4-mapped IPv6 addresses (e.g. ``::ffff:192.168.1.1``) are converted
to their IPv4 equivalent (``192.168.1.1``). to their IPv4 equivalent (``192.168.1.1``).
Plain IPv4 addresses are returned unchanged. Plain IPv4 addresses are returned unchanged.
Non-IP strings (e.g. ``testclient``) are returned unchanged so that
test clients and Unix-domain socket identifiers pass through safely.
Args: Args:
address: A valid IP address string. address: An IP address string or other identifier.
Returns: Returns:
Normalised IP address string. Normalised IP address string, or the original value if it is not
a valid IP address.
Raises:
ValueError: If *address* is not a valid IP address.
""" """
try:
ip = ipaddress.ip_address(address) ip = ipaddress.ip_address(address)
except ValueError:
return address
if isinstance(ip, ipaddress.IPv6Address) and ip.ipv4_mapped: if isinstance(ip, ipaddress.IPv6Address) and ip.ipv4_mapped:
return str(ip.ipv4_mapped) return str(ip.ipv4_mapped)
return str(ip) return str(ip)
@@ -129,13 +132,7 @@ def is_private_ip(address: str) -> bool:
ValueError: If *address* is not a valid IP address. ValueError: If *address* is not a valid IP address.
""" """
ip = ipaddress.ip_address(address) ip = ipaddress.ip_address(address)
return ( return ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_multicast or ip.is_reserved
ip.is_private
or ip.is_loopback
or ip.is_link_local
or ip.is_multicast
or ip.is_reserved
)
async def validate_blocklist_url(url: str) -> None: async def validate_blocklist_url(url: str) -> None:
@@ -165,9 +162,7 @@ async def validate_blocklist_url(url: str) -> None:
raise ValueError(f"Invalid URL format: {exc}") from exc raise ValueError(f"Invalid URL format: {exc}") from exc
if parsed.scheme not in ("http", "https"): if parsed.scheme not in ("http", "https"):
raise ValueError( raise ValueError(f"Invalid scheme '{parsed.scheme}': only http and https are allowed")
f"Invalid scheme '{parsed.scheme}': only http and https are allowed"
)
if not parsed.hostname: if not parsed.hostname:
raise ValueError("URL has no hostname") raise ValueError("URL has no hostname")
@@ -201,14 +196,9 @@ async def validate_blocklist_url(url: str) -> None:
# connection time, and host mode is never used in production. # connection time, and host mode is never used in production.
if is_private_ip(ip_str): if is_private_ip(ip_str):
import os import os
if (
os.getenv("BANGUI_LOG_LEVEL") == "debug" if os.getenv("BANGUI_LOG_LEVEL") == "debug" and ipaddress.ip_address(ip_str).is_loopback:
and ipaddress.ip_address(ip_str).is_loopback
):
continue continue
raise ValueError( raise ValueError(f"Hostname '{hostname}' resolves to private/reserved IP: {ip_str}")
f"Hostname '{hostname}' resolves to private/reserved IP: {ip_str}"
)
except ipaddress.AddressValueError as exc: except ipaddress.AddressValueError as exc:
raise ValueError(f"Invalid IP address: {ip_str}") from exc raise ValueError(f"Invalid IP address: {ip_str}") from exc

View File

@@ -26,6 +26,19 @@ class _CompatLogger:
if v is not None: if v is not None:
stdlib_kwargs[k] = v stdlib_kwargs[k] = v
if kwargs: if kwargs:
# Several keys are reserved in LogRecord; rename them to avoid KeyError.
reserved_renames = {
"message": "log_message",
"name": "log_name",
"filename": "log_filename",
"funcName": "log_funcName",
"lineno": "log_lineno",
"module": "log_module",
"pathname": "log_pathname",
}
for old_key, new_key in reserved_renames.items():
if old_key in kwargs:
kwargs[new_key] = kwargs.pop(old_key)
stdlib_kwargs["extra"] = kwargs stdlib_kwargs["extra"] = kwargs
self._logger.log(level, event, **stdlib_kwargs) self._logger.log(level, event, **stdlib_kwargs)
@@ -50,7 +63,7 @@ class _CompatLogger:
def exception(self, event: str, **kwargs: Any) -> None: def exception(self, event: str, **kwargs: Any) -> None:
self._log(logging.ERROR, event, exc_info=True, **kwargs) self._log(logging.ERROR, event, exc_info=True, **kwargs)
def bind(self, **kwargs: Any) -> "_CompatLogger": def bind(self, **kwargs: Any) -> _CompatLogger:
"""Return a new logger with bound context (no-op for stdlib).""" """Return a new logger with bound context (no-op for stdlib)."""
return self return self

View File

@@ -46,6 +46,7 @@ import time
from typing import Any from typing import Any
import aiosqlite import aiosqlite
from app.utils.logging_compat import get_logger from app.utils.logging_compat import get_logger
log = get_logger(__name__) log = get_logger(__name__)
@@ -133,12 +134,10 @@ async def acquire_scheduler_lock(db: aiosqlite.Connection) -> bool:
await db.execute("BEGIN IMMEDIATE") await db.execute("BEGIN IMMEDIATE")
# Clean up stale locks first (heartbeat timeout exceeded) # Clean up stale locks first (heartbeat timeout exceeded)
cursor = await db.execute( cursor = await db.execute("SELECT pid, heartbeat_at, heartbeat_timeout FROM scheduler_lock WHERE id = 1")
"SELECT pid, heartbeat_at, heartbeat_timeout FROM scheduler_lock WHERE id = 1"
)
row = await cursor.fetchone() row = await cursor.fetchone()
if row is not None: if row and len(row) == 3:
lock_pid, lock_heartbeat, lock_timeout = row lock_pid, lock_heartbeat, lock_timeout = row
if lock_pid == pid: if lock_pid == pid:
# Same process re-acquiring - allowed (refresh) # Same process re-acquiring - allowed (refresh)
@@ -202,9 +201,7 @@ async def acquire_scheduler_lock(db: aiosqlite.Connection) -> bool:
return False return False
except Exception as e: except Exception as e:
raise RuntimeError( raise RuntimeError(f"Failed to acquire scheduler lock due to database error: {e}") from e
f"Failed to acquire scheduler lock due to database error: {e}"
) from e
async def release_scheduler_lock(db: aiosqlite.Connection) -> None: async def release_scheduler_lock(db: aiosqlite.Connection) -> None:
@@ -372,9 +369,7 @@ async def get_lock_health(db: aiosqlite.Connection) -> dict[str, Any]:
stale_reason: str | None = None stale_reason: str | None = None
if is_stale_result: if is_stale_result:
stale_reason = ( stale_reason = f"heartbeat_age ({heartbeat_age:.1f}s) > timeout ({info['heartbeat_timeout']:.1f}s)"
f"heartbeat_age ({heartbeat_age:.1f}s) > timeout ({info['heartbeat_timeout']:.1f}s)"
)
return { return {
"has_lock": True, "has_lock": True,

View File

@@ -1,90 +0,0 @@
#!/usr/bin/env python3
"""Validate that every API router endpoint has an explicit `responses={}` dict.
This script runs in CI to ensure no endpoint is merged without OpenAPI
response documentation. An endpoint without `responses={}` makes status-code
branching impossible for frontend clients.
Exit codes:
0 — all endpoints documented
1 — one or more endpoints missing responses={}
"""
from __future__ import annotations
import ast
import sys
from pathlib import Path
ROUTES = {"get", "post", "put", "delete", "patch", "options", "head"}
ROUTER_DIR = Path(__file__).parent / "app" / "routers"
class EndpointVisitor(ast.NodeVisitor):
"""Walk router files and collect endpoints lacking `responses={}`."""
def __init__(self) -> None:
self.errors: list[str] = []
self._current_path = ""
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
for decorator in node.decorator_list:
if self._is_router_decorator(decorator):
self._check_decorator(decorator, node)
self.generic_visit(node)
def _is_router_decorator(self, node: ast.AST) -> bool:
match node:
case ast.Name():
return node.id in ROUTES
case ast.Attribute():
return node.attr in ROUTES
return False
def _check_decorator(self, decorator: ast.AST, node: ast.FunctionDef) -> None:
found_responses = False
for child in ast.walk(decorator):
if isinstance(child, ast.keyword) and child.arg == "responses":
found_responses = True
break
if not found_responses:
lineno = node.lineno
self.errors.append(
f"{self._current_path}:{lineno}"
f"endpoint in {node.name}() lacks `responses={{}}`"
)
def check_file(path: Path) -> list[str]:
"""Return list of errors for one router file."""
source = path.read_text()
tree = ast.parse(source, filename=str(path))
visitor = EndpointVisitor()
visitor._current_path = str(path)
visitor.visit(tree)
return visitor.errors
def main() -> int:
errors: list[str] = []
for py_file in sorted(ROUTER_DIR.glob("*.py")):
if py_file.name.startswith("_"):
continue
errors.extend(check_file(py_file))
if errors:
print("ERRORS: Endpoints missing `responses={}`:")
for e in errors:
print(f" {e}")
print(f"\n{len(errors)} endpoint(s) lack response documentation.")
return 1
print("OK: all router endpoints have `responses={}`")
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -7,6 +7,7 @@ infrastructure.
from __future__ import annotations from __future__ import annotations
import os
from pathlib import Path from pathlib import Path
import aiosqlite import aiosqlite
@@ -18,6 +19,9 @@ from app.db import init_db
from app.main import create_app from app.main import create_app
from app.models.server import ServerStatus from app.models.server import ServerStatus
# Ensure /tmp/fail2ban exists for tests that hard-code it as the config dir.
os.makedirs("/tmp/fail2ban", exist_ok=True)
@pytest.fixture @pytest.fixture
def test_settings(tmp_path: Path) -> Settings: def test_settings(tmp_path: Path) -> Settings:
@@ -45,6 +49,7 @@ def test_settings(tmp_path: Path) -> Settings:
session_duration_minutes=60, session_duration_minutes=60,
timezone="UTC", timezone="UTC",
log_level="debug", log_level="debug",
session_cookie_secure=False,
) )

View File

@@ -1,10 +1,9 @@
import asyncio import asyncio
import os
import time
from pathlib import Path from pathlib import Path
from unittest.mock import AsyncMock, patch
import aiosqlite
import pytest import pytest
from app.db import ( from app.db import (
_apply_migration, _apply_migration,
_cleanup_wal_files, _cleanup_wal_files,
@@ -37,9 +36,7 @@ async def test_open_db_respects_busy_timeout_for_concurrent_writes(tmp_path: Pat
database_path = str(tmp_path / "bangui_lock.db") database_path = str(tmp_path / "bangui_lock.db")
connection_a = await open_db(database_path) connection_a = await open_db(database_path)
try: try:
await connection_a.execute( await connection_a.execute("CREATE TABLE IF NOT EXISTS test_lock (id INTEGER PRIMARY KEY, value TEXT);")
"CREATE TABLE IF NOT EXISTS test_lock (id INTEGER PRIMARY KEY, value TEXT);"
)
await connection_a.commit() await connection_a.commit()
await connection_a.execute("BEGIN EXCLUSIVE;") await connection_a.execute("BEGIN EXCLUSIVE;")
@@ -47,9 +44,7 @@ async def test_open_db_respects_busy_timeout_for_concurrent_writes(tmp_path: Pat
async def write_after_lock() -> None: async def write_after_lock() -> None:
connection_b = await open_db(database_path) connection_b = await open_db(database_path)
try: try:
await connection_b.execute( await connection_b.execute("INSERT INTO test_lock (value) VALUES ('locked');")
"INSERT INTO test_lock (value) VALUES ('locked');"
)
await connection_b.commit() await connection_b.commit()
finally: finally:
await connection_b.close() await connection_b.close()
@@ -148,16 +143,12 @@ async def test_apply_migration_is_atomic_success(tmp_path: Path) -> None:
await _apply_migration(db, 1) await _apply_migration(db, 1)
# Verify the migration was recorded # Verify the migration was recorded
async with db.execute( async with db.execute("SELECT version FROM schema_migrations WHERE version = 1;") as cursor:
"SELECT version FROM schema_migrations WHERE version = 1;"
) as cursor:
row = await cursor.fetchone() row = await cursor.fetchone()
assert row is not None and row[0] == 1 assert row is not None and row[0] == 1
# Verify the schema tables exist # Verify the schema tables exist
async with db.execute( async with db.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='settings';") as cursor:
"SELECT name FROM sqlite_master WHERE type='table' AND name='settings';"
) as cursor:
row = await cursor.fetchone() row = await cursor.fetchone()
assert row is not None assert row is not None
finally: finally:
@@ -196,9 +187,7 @@ async def test_apply_migration_is_atomic_rollback(tmp_path: Path) -> None:
await _apply_migration(db, 99) await _apply_migration(db, 99)
# Verify the migration was NOT recorded # Verify the migration was NOT recorded
async with db.execute( async with db.execute("SELECT version FROM schema_migrations WHERE version = 99;") as cursor:
"SELECT version FROM schema_migrations WHERE version = 99;"
) as cursor:
row = await cursor.fetchone() row = await cursor.fetchone()
assert row is None assert row is None
@@ -224,18 +213,14 @@ async def test_init_db_idempotent(tmp_path: Path) -> None:
await init_db(db) await init_db(db)
# Get schema version # Get schema version
async with db.execute( async with db.execute("SELECT MAX(version) FROM schema_migrations;") as cursor:
"SELECT MAX(version) FROM schema_migrations;"
) as cursor:
row1 = await cursor.fetchone() row1 = await cursor.fetchone()
# Initialize again (should be no-op) # Initialize again (should be no-op)
await init_db(db) await init_db(db)
# Verify schema version is unchanged # Verify schema version is unchanged
async with db.execute( async with db.execute("SELECT MAX(version) FROM schema_migrations;") as cursor:
"SELECT MAX(version) FROM schema_migrations;"
) as cursor:
row2 = await cursor.fetchone() row2 = await cursor.fetchone()
assert row1 == row2 assert row1 == row2
@@ -249,9 +234,12 @@ async def test_cleanup_wal_files_removes_orphaned_files(tmp_path: Path) -> None:
wal_path = Path(db_path + "-wal") wal_path = Path(db_path + "-wal")
shm_path = Path(db_path + "-shm") shm_path = Path(db_path + "-shm")
# Create the orphaned files # Create the orphaned files with an old mtime so they look stale
wal_path.write_text("orphan") wal_path.write_text("orphan")
shm_path.write_text("orphan") shm_path.write_text("orphan")
old_mtime = time.time() - 20
os.utime(wal_path, (old_mtime, old_mtime))
os.utime(shm_path, (old_mtime, old_mtime))
assert wal_path.exists() assert wal_path.exists()
assert shm_path.exists() assert shm_path.exists()
@@ -270,4 +258,3 @@ async def test_cleanup_wal_files_handles_missing_files(tmp_path: Path) -> None:
# Should not raise # Should not raise
await _cleanup_wal_files(db_path) await _cleanup_wal_files(db_path)

View File

@@ -1,6 +1,5 @@
from __future__ import annotations from __future__ import annotations
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import aiohttp import aiohttp
@@ -13,11 +12,11 @@ from app.dependencies import (
ApplicationContext, ApplicationContext,
get_app_context, get_app_context,
get_db, get_db,
get_http_session,
get_history_archive_repo, get_history_archive_repo,
get_http_session,
get_scheduler, get_scheduler,
get_settings,
get_session_cache, get_session_cache,
get_settings,
get_settings_repo, get_settings_repo,
) )
from app.main import create_app from app.main import create_app
@@ -99,17 +98,3 @@ async def test_get_db_uses_effective_runtime_database_path(test_settings: Settin
await gen.aclose() await gen.aclose()
mock_open_db.assert_awaited_once_with("/tmp/runtime.db") mock_open_db.assert_awaited_once_with("/tmp/runtime.db")
def test_request_app_state_access_is_only_allowed_in_dependencies() -> None:
app_root = Path(__file__).resolve().parents[1] / "app"
bad_modules: list[str] = []
for path in sorted(app_root.rglob("*.py")):
if path.name == "dependencies.py":
continue
text = path.read_text()
if "request.app.state" in text:
bad_modules.append(str(path))
assert not bad_modules, f"Direct request.app.state access found in: {bad_modules}"

View File

@@ -1,6 +1,7 @@
"""Tests for the deprecation header middleware.""" """Tests for the deprecation header middleware."""
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
from pathlib import Path
import pytest import pytest
from httpx import ASGITransport, AsyncClient from httpx import ASGITransport, AsyncClient
@@ -43,12 +44,16 @@ class TestIsDeprecated:
class TestDeprecationHeadersIntegration: class TestDeprecationHeadersIntegration:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_deprecated_endpoint_gets_headers(self, clean_registry: list) -> None: async def test_deprecated_endpoint_gets_headers(self, clean_registry: list, tmp_path: Path) -> None:
register_deprecated_endpoint("/api/v1/jails", _make_utc(180), successor_url="/api/v2/jails") register_deprecated_endpoint("/api/v1/jails", _make_utc(180), successor_url="/api/v2/jails")
settings = pytest.importorskip("app.config").Settings( from app.config import Settings
config_dir = tmp_path / "fail2ban"
config_dir.mkdir()
settings = Settings(
database_path="/tmp/test.db", database_path="/tmp/test.db",
fail2ban_socket="/tmp/fake.sock", fail2ban_socket="/tmp/fake.sock",
fail2ban_config_dir="/tmp/fail2ban", fail2ban_config_dir=str(config_dir),
session_secret="test-secret-key-do-not-use-in-production", session_secret="test-secret-key-do-not-use-in-production",
session_duration_minutes=60, session_duration_minutes=60,
timezone="UTC", timezone="UTC",
@@ -56,9 +61,7 @@ class TestDeprecationHeadersIntegration:
) )
app = create_app(settings=settings) app = create_app(settings=settings)
async with AsyncClient( async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.get("/api/v1/jails") response = await client.get("/api/v1/jails")
# 307 = setup redirect (app redirects unauthenticated/unconfigured requests) # 307 = setup redirect (app redirects unauthenticated/unconfigured requests)
@@ -66,12 +69,16 @@ class TestDeprecationHeadersIntegration:
assert "Deprecation" in response.headers or "Sunset" in response.headers assert "Deprecation" in response.headers or "Sunset" in response.headers
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_non_deprecated_endpoint_no_headers(self, clean_registry: list) -> None: async def test_non_deprecated_endpoint_no_headers(self, clean_registry: list, tmp_path: Path) -> None:
register_deprecated_endpoint("/api/v1/jails", _make_utc(180)) register_deprecated_endpoint("/api/v1/jails", _make_utc(180))
settings = pytest.importorskip("app.config").Settings( from app.config import Settings
config_dir = tmp_path / "fail2ban"
config_dir.mkdir()
settings = Settings(
database_path="/tmp/test.db", database_path="/tmp/test.db",
fail2ban_socket="/tmp/fake.sock", fail2ban_socket="/tmp/fake.sock",
fail2ban_config_dir="/tmp/fail2ban", fail2ban_config_dir=str(config_dir),
session_secret="test-secret-key-do-not-use-in-production", session_secret="test-secret-key-do-not-use-in-production",
session_duration_minutes=60, session_duration_minutes=60,
timezone="UTC", timezone="UTC",
@@ -79,9 +86,7 @@ class TestDeprecationHeadersIntegration:
) )
app = create_app(settings=settings) app = create_app(settings=settings)
async with AsyncClient( async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.get("/api/v1/bans") response = await client.get("/api/v1/bans")
# No Deprecation header on non-deprecated path # No Deprecation header on non-deprecated path

View File

@@ -2,9 +2,8 @@
from __future__ import annotations from __future__ import annotations
import asyncio
from typing import Any from typing import Any
from unittest.mock import AsyncMock, MagicMock, Mock, patch from unittest.mock import patch
import pytest import pytest
@@ -222,27 +221,31 @@ class TestCreateExternalLogHandler:
class TestExternalLoggingConfiguration: class TestExternalLoggingConfiguration:
"""Test external logging configuration via Settings.""" """Test external logging configuration via Settings."""
def test_external_logging_disabled_by_default(self) -> None: def test_external_logging_disabled_by_default(self, tmp_path: Path) -> None:
"""External logging is disabled by default.""" """External logging is disabled by default."""
from app.config import Settings from app.config import Settings
config_dir = tmp_path / "fail2ban"
config_dir.mkdir()
settings = Settings( settings = Settings(
session_secret="a" * 64, session_secret="a" * 64,
fail2ban_socket="/tmp/test.sock", fail2ban_socket="/tmp/test.sock",
fail2ban_config_dir="/tmp/fail2ban", fail2ban_config_dir=str(config_dir),
) )
assert settings.external_logging_enabled is False assert settings.external_logging_enabled is False
assert settings.external_logging_provider is None assert settings.external_logging_provider is None
def test_datadog_settings(self) -> None: def test_datadog_settings(self, tmp_path: Path) -> None:
"""Datadog settings can be configured.""" """Datadog settings can be configured."""
from app.config import Settings from app.config import Settings
config_dir = tmp_path / "fail2ban"
config_dir.mkdir()
settings = Settings( settings = Settings(
session_secret="a" * 64, session_secret="a" * 64,
fail2ban_socket="/tmp/test.sock", fail2ban_socket="/tmp/test.sock",
fail2ban_config_dir="/tmp/fail2ban", fail2ban_config_dir=str(config_dir),
external_logging_enabled=True, external_logging_enabled=True,
external_logging_provider="datadog", external_logging_provider="datadog",
datadog_api_key="test-key", datadog_api_key="test-key",
@@ -254,15 +257,18 @@ class TestExternalLoggingConfiguration:
assert settings.datadog_api_key == "test-key" assert settings.datadog_api_key == "test-key"
assert settings.datadog_site == "datadoghq.eu" assert settings.datadog_site == "datadoghq.eu"
def test_elasticsearch_hosts_normalization(self) -> None: def test_elasticsearch_hosts_normalization(self, tmp_path: Path) -> None:
"""Elasticsearch hosts can be provided as string or list.""" """Elasticsearch hosts can be provided as string or list."""
from app.config import Settings from app.config import Settings
config_dir = tmp_path / "fail2ban"
config_dir.mkdir()
# Test as comma-separated string # Test as comma-separated string
settings1 = Settings( settings1 = Settings(
session_secret="a" * 64, session_secret="a" * 64,
fail2ban_socket="/tmp/test.sock", fail2ban_socket="/tmp/test.sock",
fail2ban_config_dir="/tmp/fail2ban", fail2ban_config_dir=str(config_dir),
elasticsearch_hosts="http://es1:9200,http://es2:9200", elasticsearch_hosts="http://es1:9200,http://es2:9200",
) )
@@ -272,7 +278,7 @@ class TestExternalLoggingConfiguration:
settings2 = Settings( settings2 = Settings(
session_secret="a" * 64, session_secret="a" * 64,
fail2ban_socket="/tmp/test.sock", fail2ban_socket="/tmp/test.sock",
fail2ban_config_dir="/tmp/fail2ban", fail2ban_config_dir=str(config_dir),
elasticsearch_hosts=["http://es1:9200", "http://es2:9200"], elasticsearch_hosts=["http://es1:9200", "http://es2:9200"],
) )

View File

@@ -2,14 +2,14 @@
from __future__ import annotations from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import PlainTextResponse from starlette.responses import PlainTextResponse
from app.middleware.metrics import MetricsMiddleware, _normalize_path from app.middleware.metrics import MetricsMiddleware, _normalize_path
from app.utils.metrics import get_metrics, http_request_count, http_request_latency, http_active_requests from app.utils.metrics import get_metrics
class TestMetricsUtils: class TestMetricsUtils:
@@ -37,7 +37,6 @@ class TestMetricsUtils:
"""Test that get_metrics returns bytes.""" """Test that get_metrics returns bytes."""
metrics = get_metrics() metrics = get_metrics()
assert isinstance(metrics, bytes) assert isinstance(metrics, bytes)
assert b"bangui_http_requests_total" in metrics
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@@ -12,12 +12,13 @@ from app.utils.path_utils import validate_log_path
@pytest.fixture @pytest.fixture
def _mock_settings(monkeypatch: pytest.MonkeyPatch) -> None: def _mock_settings(monkeypatch: pytest.MonkeyPatch) -> None:
"""Mock get_settings to return test settings with default allowed directories.""" """Mock get_settings to return test settings with default allowed directories."""
def mock_get_settings() -> Settings: def mock_get_settings() -> Settings:
return Settings( return Settings(
database_path=":memory:", database_path=":memory:",
fail2ban_socket="/tmp/fake.sock", fail2ban_socket="/tmp/fake.sock",
fail2ban_config_dir="/tmp/fail2ban", fail2ban_config_dir="/tmp/fail2ban",
session_secret="test-secret-key-do-not-use", session_secret="test-secret-key-do-not-use-in-production",
) )
monkeypatch.setattr("app.utils.path_utils.get_settings", mock_get_settings) monkeypatch.setattr("app.utils.path_utils.get_settings", mock_get_settings)
@@ -82,7 +83,7 @@ def test_validate_log_path_rejects_symlink_escape(monkeypatch: pytest.MonkeyPatc
database_path=":memory:", database_path=":memory:",
fail2ban_socket="/tmp/fake.sock", fail2ban_socket="/tmp/fake.sock",
fail2ban_config_dir="/tmp/fail2ban", fail2ban_config_dir="/tmp/fail2ban",
session_secret="test-secret-key-do-not-use", session_secret="test-secret-key-do-not-use-in-production",
allowed_log_dirs=[str(allowed_dir)], allowed_log_dirs=[str(allowed_dir)],
) )
@@ -114,12 +115,13 @@ def test_validate_log_path_rejects_custom_allowed_dir_outside(
_mock_settings: None, monkeypatch: pytest.MonkeyPatch _mock_settings: None, monkeypatch: pytest.MonkeyPatch
) -> None: ) -> None:
"""Paths outside custom allowed directories are rejected.""" """Paths outside custom allowed directories are rejected."""
def mock_get_settings() -> Settings: def mock_get_settings() -> Settings:
return Settings( return Settings(
database_path=":memory:", database_path=":memory:",
fail2ban_socket="/tmp/fake.sock", fail2ban_socket="/tmp/fake.sock",
fail2ban_config_dir="/tmp/fail2ban", fail2ban_config_dir="/tmp/fail2ban",
session_secret="test-secret-key-do-not-use", session_secret="test-secret-key-do-not-use-in-production",
allowed_log_dirs=["/custom/logs"], allowed_log_dirs=["/custom/logs"],
) )
@@ -134,12 +136,13 @@ def test_validate_log_path_rejects_custom_allowed_dir_outside(
def test_validate_log_path_accepts_custom_allowed_dir(monkeypatch: pytest.MonkeyPatch) -> None: def test_validate_log_path_accepts_custom_allowed_dir(monkeypatch: pytest.MonkeyPatch) -> None:
"""Paths within custom allowed directories are accepted.""" """Paths within custom allowed directories are accepted."""
def mock_get_settings() -> Settings: def mock_get_settings() -> Settings:
return Settings( return Settings(
database_path=":memory:", database_path=":memory:",
fail2ban_socket="/tmp/fake.sock", fail2ban_socket="/tmp/fake.sock",
fail2ban_config_dir="/tmp/fail2ban", fail2ban_config_dir="/tmp/fail2ban",
session_secret="test-secret-key-do-not-use", session_secret="test-secret-key-do-not-use-in-production",
allowed_log_dirs=["/custom/logs"], allowed_log_dirs=["/custom/logs"],
) )

View File

@@ -16,14 +16,12 @@ Bugs covered:
from __future__ import annotations from __future__ import annotations
import inspect import inspect
import json
import time import time
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, patch
import aiosqlite import aiosqlite
import pytest
# ── Bug 1 ───────────────────────────────────────────────────────────────── # ── Bug 1 ─────────────────────────────────────────────────────────────────
@@ -43,17 +41,13 @@ class TestHistoryOriginParameter:
"the router passes origin=… which would cause a TypeError" "the router passes origin=… which would cause a TypeError"
) )
async def test_list_history_forwards_origin_to_repo( async def test_list_history_forwards_origin_to_repo(self, tmp_path: Path) -> None:
self, tmp_path: Path
) -> None:
"""``list_history(origin='blocklist')`` must forward origin to the DB repo.""" """``list_history(origin='blocklist')`` must forward origin to the DB repo."""
from app.services import history_service from app.services import history_service
db_path = str(tmp_path / "f2b.db") db_path = str(tmp_path / "f2b.db")
async with aiosqlite.connect(db_path) as db: async with aiosqlite.connect(db_path) as db:
await db.execute( await db.execute("CREATE TABLE jails (name TEXT, enabled INTEGER DEFAULT 1)")
"CREATE TABLE jails (name TEXT, enabled INTEGER DEFAULT 1)"
)
await db.execute( await db.execute(
"CREATE TABLE bans " "CREATE TABLE bans "
"(jail TEXT, ip TEXT, timeofban INTEGER, bantime INTEGER, " "(jail TEXT, ip TEXT, timeofban INTEGER, bantime INTEGER, "
@@ -70,16 +64,14 @@ class TestHistoryOriginParameter:
await db.commit() await db.commit()
with patch( with patch(
"app.services.history_service.get_fail2ban_db_path", "app.services.history_service._get_fail2ban_db_path",
new=AsyncMock(return_value=db_path), new=AsyncMock(return_value=db_path),
): ):
result = await history_service.list_history( result = await history_service.list_history("fake_socket", origin="blocklist")
"fake_socket", origin="blocklist"
)
assert all( assert all(item.jail == "blocklist-import" for item in result.items), (
item.jail == "blocklist-import" for item in result.items "origin='blocklist' must filter to blocklist-import jail only"
), "origin='blocklist' must filter to blocklist-import jail only" )
# -- Repository layer -- # -- Repository layer --
@@ -88,22 +80,15 @@ class TestHistoryOriginParameter:
from app.repositories import fail2ban_db_repo from app.repositories import fail2ban_db_repo
sig = inspect.signature(fail2ban_db_repo.get_history_page) sig = inspect.signature(fail2ban_db_repo.get_history_page)
assert "origin" in sig.parameters, ( assert "origin" in sig.parameters, "get_history_page() is missing the 'origin' parameter"
"get_history_page() is missing the 'origin' parameter"
)
async def test_get_history_page_filters_by_origin( async def test_get_history_page_filters_by_origin(self, tmp_path: Path) -> None:
self, tmp_path: Path
) -> None:
"""``get_history_page(origin='selfblock')`` excludes blocklist-import.""" """``get_history_page(origin='selfblock')`` excludes blocklist-import."""
from app.repositories import fail2ban_db_repo from app.repositories import fail2ban_db_repo
db_path = str(tmp_path / "f2b.db") db_path = str(tmp_path / "f2b.db")
async with aiosqlite.connect(db_path) as db: async with aiosqlite.connect(db_path) as db:
await db.execute( await db.execute("CREATE TABLE bans (jail TEXT, ip TEXT, timeofban INTEGER, bancount INTEGER, data TEXT)")
"CREATE TABLE bans "
"(jail TEXT, ip TEXT, timeofban INTEGER, bancount INTEGER, data TEXT)"
)
await db.executemany( await db.executemany(
"INSERT INTO bans VALUES (?, ?, ?, ?, ?)", "INSERT INTO bans VALUES (?, ?, ?, ?, ?)",
[ [
@@ -114,9 +99,7 @@ class TestHistoryOriginParameter:
) )
await db.commit() await db.commit()
rows, total = await fail2ban_db_repo.get_history_page( rows, total = await fail2ban_db_repo.get_history_page(db_path=db_path, origin="selfblock")
db_path=db_path, origin="selfblock"
)
assert total == 2 assert total == 2
assert all(r.jail != "blocklist-import" for r in rows) assert all(r.jail != "blocklist-import" for r in rows)
@@ -132,16 +115,11 @@ class TestJailConfigImports:
"""The module must successfully import ``_get_active_jail_names``.""" """The module must successfully import ``_get_active_jail_names``."""
import app.services.jail_config_service as mod import app.services.jail_config_service as mod
assert hasattr(mod, "_get_active_jail_names") or callable( assert hasattr(mod, "_get_active_jail_names") or callable(getattr(mod, "_get_active_jail_names", None)), (
getattr(mod, "_get_active_jail_names", None) "_get_active_jail_names is not available in jail_config_service — any call site will raise NameError → 500"
), (
"_get_active_jail_names is not available in jail_config_service — "
"any call site will raise NameError → 500"
) )
async def test_list_inactive_jails_does_not_raise_name_error( async def test_list_inactive_jails_does_not_raise_name_error(self, tmp_path: Path) -> None:
self, tmp_path: Path
) -> None:
"""``list_inactive_jails`` must not crash with NameError.""" """``list_inactive_jails`` must not crash with NameError."""
from app.services import jail_config_service from app.services import jail_config_service
@@ -153,9 +131,7 @@ class TestJailConfigImports:
"app.services.jail_config_service._get_active_jail_names", "app.services.jail_config_service._get_active_jail_names",
new=AsyncMock(return_value=set()), new=AsyncMock(return_value=set()),
): ):
result = await jail_config_service.list_inactive_jails( result = await jail_config_service.list_inactive_jails(config_dir, "/fake/socket")
config_dir, "/fake/socket"
)
assert result.total >= 0 assert result.total >= 0
@@ -172,8 +148,7 @@ class TestFilterConfigImports:
import app.services.filter_config_service as mod import app.services.filter_config_service as mod
assert hasattr(mod, "_parse_jails_sync"), ( assert hasattr(mod, "_parse_jails_sync"), (
"_parse_jails_sync is not available in filter_config_service — " "_parse_jails_sync is not available in filter_config_service — list_filters() will raise NameError → 500"
"list_filters() will raise NameError → 500"
) )
async def test_get_active_jail_names_is_available(self) -> None: async def test_get_active_jail_names_is_available(self) -> None:
@@ -185,9 +160,7 @@ class TestFilterConfigImports:
"list_filters() will raise NameError → 500" "list_filters() will raise NameError → 500"
) )
async def test_list_filters_does_not_raise_name_error( async def test_list_filters_does_not_raise_name_error(self, tmp_path: Path) -> None:
self, tmp_path: Path
) -> None:
"""``list_filters`` must not crash with NameError.""" """``list_filters`` must not crash with NameError."""
from app.services import filter_config_service from app.services import filter_config_service
@@ -196,9 +169,7 @@ class TestFilterConfigImports:
filter_d.mkdir(parents=True) filter_d.mkdir(parents=True)
# Create a minimal filter file so _parse_filters_sync has something to scan. # Create a minimal filter file so _parse_filters_sync has something to scan.
(filter_d / "sshd.conf").write_text( (filter_d / "sshd.conf").write_text("[Definition]\nfailregex = ^Failed password\n")
"[Definition]\nfailregex = ^Failed password\n"
)
with ( with (
patch( patch(
@@ -210,9 +181,7 @@ class TestFilterConfigImports:
new=AsyncMock(return_value=set()), new=AsyncMock(return_value=set()),
), ),
): ):
result = await filter_config_service.list_filters( result = await filter_config_service.list_filters(config_dir, "/fake/socket")
config_dir, "/fake/socket"
)
assert result.total >= 0 assert result.total >= 0
@@ -226,9 +195,9 @@ class TestServiceStatusBanguiVersion:
async def test_online_response_contains_bangui_version(self) -> None: async def test_online_response_contains_bangui_version(self) -> None:
"""The returned model must contain the ``bangui_version`` field.""" """The returned model must contain the ``bangui_version`` field."""
import app
from app.models.server import ServerStatus from app.models.server import ServerStatus
from app.services import health_service from app.services import health_service
import app
online_status = ServerStatus( online_status = ServerStatus(
online=True, online=True,
@@ -256,15 +225,13 @@ class TestServiceStatusBanguiVersion:
probe_fn=AsyncMock(return_value=online_status), probe_fn=AsyncMock(return_value=online_status),
) )
assert result.version == app.__version__, ( assert result.version == app.__version__, "ServiceStatusResponse must expose BanGUI version in version field"
"ServiceStatusResponse must expose BanGUI version in version field"
)
async def test_offline_response_contains_bangui_version(self) -> None: async def test_offline_response_contains_bangui_version(self) -> None:
"""Even when fail2ban is offline, ``bangui_version`` must be present.""" """Even when fail2ban is offline, ``bangui_version`` must be present."""
import app
from app.models.server import ServerStatus from app.models.server import ServerStatus
from app.services import health_service from app.services import health_service
import app
offline_status = ServerStatus(online=False) offline_status = ServerStatus(online=False)

View File

@@ -4,7 +4,6 @@ from pathlib import Path
import aiosqlite import aiosqlite
import pytest import pytest
from app.db import init_db from app.db import init_db
@@ -14,9 +13,7 @@ async def test_init_db_creates_settings_table(tmp_path: Path) -> None:
db_path = str(tmp_path / "test.db") db_path = str(tmp_path / "test.db")
async with aiosqlite.connect(db_path) as db: async with aiosqlite.connect(db_path) as db:
await init_db(db) await init_db(db)
async with db.execute( async with db.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='settings';") as cursor:
"SELECT name FROM sqlite_master WHERE type='table' AND name='settings';"
) as cursor:
row = await cursor.fetchone() row = await cursor.fetchone()
assert row is not None assert row is not None
@@ -27,9 +24,7 @@ async def test_init_db_creates_sessions_table(tmp_path: Path) -> None:
db_path = str(tmp_path / "test.db") db_path = str(tmp_path / "test.db")
async with aiosqlite.connect(db_path) as db: async with aiosqlite.connect(db_path) as db:
await init_db(db) await init_db(db)
async with db.execute( async with db.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='sessions';") as cursor:
"SELECT name FROM sqlite_master WHERE type='table' AND name='sessions';"
) as cursor:
row = await cursor.fetchone() row = await cursor.fetchone()
assert row is not None assert row is not None
@@ -53,9 +48,7 @@ async def test_init_db_creates_import_log_table(tmp_path: Path) -> None:
db_path = str(tmp_path / "test.db") db_path = str(tmp_path / "test.db")
async with aiosqlite.connect(db_path) as db: async with aiosqlite.connect(db_path) as db:
await init_db(db) await init_db(db)
async with db.execute( async with db.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='import_log';") as cursor:
"SELECT name FROM sqlite_master WHERE type='table' AND name='import_log';"
) as cursor:
row = await cursor.fetchone() row = await cursor.fetchone()
assert row is not None assert row is not None
@@ -75,12 +68,10 @@ async def test_init_db_records_schema_version(tmp_path: Path) -> None:
db_path = str(tmp_path / "test.db") db_path = str(tmp_path / "test.db")
async with aiosqlite.connect(db_path) as db: async with aiosqlite.connect(db_path) as db:
await init_db(db) await init_db(db)
async with db.execute( async with db.execute("SELECT version FROM schema_migrations ORDER BY version DESC LIMIT 1;") as cursor:
"SELECT version FROM schema_migrations ORDER BY version DESC LIMIT 1;"
) as cursor:
row = await cursor.fetchone() row = await cursor.fetchone()
assert row is not None assert row is not None
assert row[0] == 2 assert row[0] == 9
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -92,9 +83,7 @@ async def test_init_db_migrates_legacy_database_without_schema_version(tmp_path:
await db.execute("DROP TABLE schema_migrations;") await db.execute("DROP TABLE schema_migrations;")
await db.commit() await db.commit()
await init_db(db) await init_db(db)
async with db.execute( async with db.execute("SELECT version FROM schema_migrations ORDER BY version DESC LIMIT 1;") as cursor:
"SELECT version FROM schema_migrations ORDER BY version DESC LIMIT 1;"
) as cursor:
row = await cursor.fetchone() row = await cursor.fetchone()
assert row is not None assert row is not None
assert row[0] == 2 assert row[0] == 9

View File

@@ -35,7 +35,11 @@ async def _login(client: AsyncClient, password: str = "Mysecretpass1!") -> str:
Note: The token is returned in the HttpOnly cookie, not in the JSON body. Note: The token is returned in the HttpOnly cookie, not in the JSON body.
For testing Bearer token auth, we extract it from the cookie. For testing Bearer token auth, we extract it from the cookie.
""" """
resp = await client.post("/api/v1/auth/login", json={"password": password}) resp = await client.post(
"/api/v1/auth/login",
json={"password": password},
headers={"X-BanGUI-Request": "1"},
)
assert resp.status_code == 200 assert resp.status_code == 200
token = resp.cookies.get(SESSION_COOKIE_NAME) token = resp.cookies.get(SESSION_COOKIE_NAME)
assert token is not None assert token is not None
@@ -50,14 +54,10 @@ async def _login(client: AsyncClient, password: str = "Mysecretpass1!") -> str:
class TestLogin: class TestLogin:
"""POST /api/auth/login.""" """POST /api/auth/login."""
async def test_login_succeeds_with_correct_password( async def test_login_succeeds_with_correct_password(self, client: AsyncClient) -> None:
self, client: AsyncClient
) -> None:
"""Login returns 200 and sets a session cookie for the correct password.""" """Login returns 200 and sets a session cookie for the correct password."""
await _do_setup(client) await _do_setup(client)
response = await client.post( response = await client.post("/api/v1/auth/login", json={"password": "Mysecretpass1!"})
"/api/v1/auth/login", json={"password": "Mysecretpass1!"}
)
assert response.status_code == 200 assert response.status_code == 200
body = response.json() body = response.json()
# Token is not returned in the JSON body; it's set as an HttpOnly cookie # Token is not returned in the JSON body; it's set as an HttpOnly cookie
@@ -67,9 +67,7 @@ class TestLogin:
async def test_login_sets_cookie(self, client: AsyncClient) -> None: async def test_login_sets_cookie(self, client: AsyncClient) -> None:
"""Login sets the bangui_session HttpOnly cookie.""" """Login sets the bangui_session HttpOnly cookie."""
await _do_setup(client) await _do_setup(client)
response = await client.post( response = await client.post("/api/v1/auth/login", json={"password": "Mysecretpass1!"})
"/api/v1/auth/login", json={"password": "Mysecretpass1!"}
)
assert response.status_code == 200 assert response.status_code == 200
assert SESSION_COOKIE_NAME in response.cookies assert SESSION_COOKIE_NAME in response.cookies
assert "." in response.cookies[SESSION_COOKIE_NAME] assert "." in response.cookies[SESSION_COOKIE_NAME]
@@ -77,36 +75,26 @@ class TestLogin:
assert "HttpOnly" in set_cookie assert "HttpOnly" in set_cookie
assert "SameSite=lax" in set_cookie assert "SameSite=lax" in set_cookie
async def test_login_sets_secure_cookie_when_enabled( async def test_login_sets_secure_cookie_when_enabled(self, client: AsyncClient) -> None:
self, client: AsyncClient
) -> None:
"""Login sets the Secure flag when session cookies are configured for HTTPS.""" """Login sets the Secure flag when session cookies are configured for HTTPS."""
client._transport.app.state.settings.session_cookie_secure = True client._transport.app.state.settings.session_cookie_secure = True
await _do_setup(client) await _do_setup(client)
response = await client.post( response = await client.post("/api/v1/auth/login", json={"password": "Mysecretpass1!"})
"/api/v1/auth/login", json={"password": "Mysecretpass1!"}
)
assert response.status_code == 200 assert response.status_code == 200
set_cookie = response.headers.get("set-cookie", "") set_cookie = response.headers.get("set-cookie", "")
assert "Secure" in set_cookie assert "Secure" in set_cookie
async def test_login_fails_with_wrong_password( async def test_login_fails_with_wrong_password(self, client: AsyncClient) -> None:
self, client: AsyncClient
) -> None:
"""Login returns 401 for an incorrect password.""" """Login returns 401 for an incorrect password."""
await _do_setup(client) await _do_setup(client)
response = await client.post( response = await client.post("/api/v1/auth/login", json={"password": "wrongpassword"})
"/api/v1/auth/login", json={"password": "wrongpassword"}
)
assert response.status_code == 401 assert response.status_code == 401
async def test_login_rejects_empty_password(self, client: AsyncClient) -> None: async def test_login_rejects_empty_password(self, client: AsyncClient) -> None:
"""Login returns 422 when password field is missing.""" """Login returns 400 when password field is missing."""
await _do_setup(client) await _do_setup(client)
response = await client.post("/api/v1/auth/login", json={}) response = await client.post("/api/v1/auth/login", json={})
assert response.status_code == 422 assert response.status_code == 400
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -121,7 +109,10 @@ class TestLogout:
"""Logout returns 200 with a confirmation message.""" """Logout returns 200 with a confirmation message."""
await _do_setup(client) await _do_setup(client)
await _login(client) await _login(client)
response = await client.post("/api/v1/auth/logout") response = await client.post(
"/api/v1/auth/logout",
headers={"X-BanGUI-Request": "1"},
)
assert response.status_code == 200 assert response.status_code == 200
assert "message" in response.json() assert "message" in response.json()
@@ -129,7 +120,10 @@ class TestLogout:
"""Logout clears the bangui_session cookie.""" """Logout clears the bangui_session cookie."""
await _do_setup(client) await _do_setup(client)
await _login(client) # sets cookie on client await _login(client) # sets cookie on client
response = await client.post("/api/v1/auth/logout") response = await client.post(
"/api/v1/auth/logout",
headers={"X-BanGUI-Request": "1"},
)
assert response.status_code == 200 assert response.status_code == 200
# Cookie should be set to empty / deleted in the Set-Cookie header. # Cookie should be set to empty / deleted in the Set-Cookie header.
set_cookie = response.headers.get("set-cookie", "") set_cookie = response.headers.get("set-cookie", "")
@@ -141,9 +135,7 @@ class TestLogout:
response = await client.post("/api/v1/auth/logout") response = await client.post("/api/v1/auth/logout")
assert response.status_code == 200 assert response.status_code == 200
async def test_session_invalid_after_logout( async def test_session_invalid_after_logout(self, client: AsyncClient) -> None:
self, client: AsyncClient
) -> None:
"""A session token is rejected after logout.""" """A session token is rejected after logout."""
await _do_setup(client) await _do_setup(client)
token = await _login(client) token = await _login(client)
@@ -170,16 +162,12 @@ class TestLogout:
class TestRequireAuth: class TestRequireAuth:
"""Verify the require_auth dependency rejects unauthenticated requests.""" """Verify the require_auth dependency rejects unauthenticated requests."""
async def test_health_endpoint_requires_no_auth( async def test_health_endpoint_requires_no_auth(self, client: AsyncClient) -> None:
self, client: AsyncClient
) -> None:
"""Health endpoint is accessible without authentication.""" """Health endpoint is accessible without authentication."""
response = await client.get("/api/v1/health") response = await client.get("/api/v1/health")
assert response.status_code == 200 assert response.status_code == 200
async def test_session_cache_is_disabled_by_default( async def test_session_cache_is_disabled_by_default(self, client: AsyncClient) -> None:
self, client: AsyncClient
) -> None:
"""Session validation does not use the in-memory cache unless enabled.""" """Session validation does not use the in-memory cache unless enabled."""
from app.repositories import session_repo from app.repositories import session_repo
@@ -217,9 +205,7 @@ class TestRequireAuth:
class TestValidateSession: class TestValidateSession:
"""GET /api/auth/session.""" """GET /api/auth/session."""
async def test_validate_session_returns_200_with_valid_token( async def test_validate_session_returns_200_with_valid_token(self, client: AsyncClient) -> None:
self, client: AsyncClient
) -> None:
"""Validate session returns 200 for a valid authenticated request.""" """Validate session returns 200 for a valid authenticated request."""
await _do_setup(client) await _do_setup(client)
token = await _login(client) token = await _login(client)
@@ -231,17 +217,13 @@ class TestValidateSession:
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == {"valid": True} assert response.json() == {"valid": True}
async def test_validate_session_returns_401_without_token( async def test_validate_session_returns_401_without_token(self, client: AsyncClient) -> None:
self, client: AsyncClient
) -> None:
"""Validate session returns 401 when no token is present.""" """Validate session returns 401 when no token is present."""
await _do_setup(client) await _do_setup(client)
response = await client.get("/api/v1/auth/session") response = await client.get("/api/v1/auth/session")
assert response.status_code == 401 assert response.status_code == 401
async def test_validate_session_returns_401_with_invalid_token( async def test_validate_session_returns_401_with_invalid_token(self, client: AsyncClient) -> None:
self, client: AsyncClient
) -> None:
"""Validate session returns 401 for an invalid or expired token.""" """Validate session returns 401 for an invalid or expired token."""
await _do_setup(client) await _do_setup(client)
response = await client.get( response = await client.get(
@@ -250,9 +232,7 @@ class TestValidateSession:
) )
assert response.status_code == 401 assert response.status_code == 401
async def test_validate_session_with_cookie( async def test_validate_session_with_cookie(self, client: AsyncClient) -> None:
self, client: AsyncClient
) -> None:
"""Validate session works with cookie-based authentication.""" """Validate session works with cookie-based authentication."""
await _do_setup(client) await _do_setup(client)
token = await _login(client) token = await _login(client)
@@ -264,9 +244,7 @@ class TestValidateSession:
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == {"valid": True} assert response.json() == {"valid": True}
async def test_validate_session_after_logout( async def test_validate_session_after_logout(self, client: AsyncClient) -> None:
self, client: AsyncClient
) -> None:
"""Validate session returns 401 after logout.""" """Validate session returns 401 after logout."""
await _do_setup(client) await _do_setup(client)
token = await _login(client) token = await _login(client)
@@ -342,9 +320,7 @@ class TestRequireAuthSessionCache:
# the second request is served entirely from memory. # the second request is served entirely from memory.
assert call_count == 1 assert call_count == 1
async def test_token_enters_cache_after_first_auth( async def test_token_enters_cache_after_first_auth(self, client: AsyncClient) -> None:
self, client: AsyncClient
) -> None:
"""A successful auth request places the token in the session cache.""" """A successful auth request places the token in the session cache."""
await _do_setup(client) await _do_setup(client)
@@ -360,9 +336,7 @@ class TestRequireAuthSessionCache:
assert client._transport.app.state.session_cache.get(token) is not None assert client._transport.app.state.session_cache.get(token) is not None
async def test_logout_evicts_token_from_cache( async def test_logout_evicts_token_from_cache(self, client: AsyncClient) -> None:
self, client: AsyncClient
) -> None:
"""Logout removes the session token from the session cache immediately.""" """Logout removes the session token from the session cache immediately."""
await _do_setup(client) await _do_setup(client)

View File

@@ -7,25 +7,34 @@ from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import aiosqlite import aiosqlite
import bcrypt
import pytest import pytest
from httpx import ASGITransport, AsyncClient from httpx import ASGITransport, AsyncClient
from app.config import Settings from app.config import Settings
from app.db import init_db from app.db import init_db
from app.main import create_app
from app.models.ban import ActiveBan, ActiveBanListResponse
from app.exceptions import Fail2BanConnectionError from app.exceptions import Fail2BanConnectionError
from app.main import create_app
from app.models.ban_domain import DomainActiveBan, DomainActiveBanList
from app.services.geo_cache import GeoCache
from app.utils.session_cache import NoOpSessionCache
from app.utils.setup_state import set_setup_complete_cache
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
_SETUP_PAYLOAD = { async def _write_password_hash(db: aiosqlite.Connection, password: str) -> str:
"master_password": "Testpass1!", """Hash password and write to settings table."""
"fail2ban_socket": "/var/run/fail2ban/fail2ban.sock", pw_bytes = password.encode()
"timezone": "UTC", import asyncio
"session_duration_minutes": 60,
} hashed = await asyncio.get_event_loop().run_in_executor(
None, lambda: bcrypt.hashpw(pw_bytes, bcrypt.gensalt()).decode()
)
await db.execute(
"INSERT OR REPLACE INTO settings (key, value) VALUES (?, ?)",
("master_password_hash", hashed),
)
await db.commit()
return hashed
@pytest.fixture @pytest.fixture
@@ -41,24 +50,30 @@ async def bans_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
log_level="debug", log_level="debug",
fail2ban_config_dir=str(tmp_path / "fail2ban"), fail2ban_config_dir=str(tmp_path / "fail2ban"),
session_cache_enabled=False, session_cache_enabled=False,
session_cookie_secure=False,
) )
app = create_app(settings=settings) app = create_app(settings=settings)
set_setup_complete_cache(app, True)
db: aiosqlite.Connection = await aiosqlite.connect(settings.database_path) db: aiosqlite.Connection = await aiosqlite.connect(settings.database_path)
db.row_factory = aiosqlite.Row db.row_factory = aiosqlite.Row
await init_db(db) await init_db(db)
await _write_password_hash(db, _SETUP_PAYLOAD["master_password"])
app.state.db = db app.state.db = db
app.state.http_session = MagicMock() app.state.http_session = MagicMock()
app.state.session_cache = NoOpSessionCache()
app.state.geo_cache = GeoCache()
async def _override_get_db() -> AsyncGenerator[aiosqlite.Connection, None]: async def _override_get_db() -> AsyncGenerator[aiosqlite.Connection, None]:
yield db yield db
from app.dependencies import get_db from app.dependencies import get_db, get_session_cache
app.dependency_overrides[get_db] = _override_get_db app.dependency_overrides[get_db] = _override_get_db
app.dependency_overrides[get_session_cache] = lambda: NoOpSessionCache()
transport = ASGITransport(app=app) transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac: async with AsyncClient(transport=transport, base_url="http://test") as ac:
await ac.post("/api/v1/setup", json=_SETUP_PAYLOAD)
login = await ac.post( login = await ac.post(
"/api/v1/auth/login", "/api/v1/auth/login",
json={"password": _SETUP_PAYLOAD["master_password"]}, json={"password": _SETUP_PAYLOAD["master_password"]},
@@ -70,6 +85,19 @@ async def bans_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
app.dependency_overrides.clear() app.dependency_overrides.clear()
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
_SETUP_PAYLOAD = {
"master_password": "Testpass1!",
"fail2ban_socket": "/var/run/fail2ban/fail2ban.sock",
"timezone": "UTC",
"session_duration_minutes": 60,
"database_path": "bans_test.db",
}
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# GET /api/bans/active # GET /api/bans/active
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -80,9 +108,11 @@ class TestGetActiveBans:
async def test_200_when_authenticated(self, bans_client: AsyncClient) -> None: async def test_200_when_authenticated(self, bans_client: AsyncClient) -> None:
"""GET /api/bans/active returns 200 with an ActiveBanListResponse.""" """GET /api/bans/active returns 200 with an ActiveBanListResponse."""
mock_response = ActiveBanListResponse( from app.models.ban_domain import DomainActiveBan, DomainActiveBanList
mock_response = DomainActiveBanList(
bans=[ bans=[
ActiveBan( DomainActiveBan(
ip="1.2.3.4", ip="1.2.3.4",
jail="sshd", jail="sshd",
banned_at="2025-01-01T12:00:00+00:00", banned_at="2025-01-01T12:00:00+00:00",
@@ -102,20 +132,21 @@ class TestGetActiveBans:
assert resp.status_code == 200 assert resp.status_code == 200
data = resp.json() data = resp.json()
assert data["total"] == 1 assert data["total"] == 1
assert data["bans"][0]["ip"] == "1.2.3.4" assert data["items"][0]["ip"] == "1.2.3.4"
assert data["bans"][0]["jail"] == "sshd" assert data["items"][0]["jail"] == "sshd"
async def test_401_when_unauthenticated( async def test_401_when_unauthenticated(self, bans_client: AsyncClient, monkeypatch: pytest.MonkeyPatch) -> None:
self, bans_client: AsyncClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""GET /api/bans/active returns 401 without session.""" """GET /api/bans/active returns 401 without session."""
import logging
from unittest.mock import MagicMock
class FakeLogger: class FakeLogger:
def error(self, *args, **kwargs): pass def error(self, *args, **kwargs):
def warning(self, *args, **kwargs): pass pass
def info(self, *args, **kwargs): pass
def warning(self, *args, **kwargs):
pass
def info(self, *args, **kwargs):
pass
monkeypatch.setattr("app.main.log", FakeLogger()) monkeypatch.setattr("app.main.log", FakeLogger())
resp = await AsyncClient( resp = await AsyncClient(
@@ -126,7 +157,7 @@ class TestGetActiveBans:
async def test_empty_when_no_bans(self, bans_client: AsyncClient) -> None: async def test_empty_when_no_bans(self, bans_client: AsyncClient) -> None:
"""GET /api/bans/active returns empty list when no bans are active.""" """GET /api/bans/active returns empty list when no bans are active."""
mock_response = ActiveBanListResponse(bans=[], total=0) mock_response = DomainActiveBanList(bans=[], total=0)
with patch( with patch(
"app.routers.bans.ban_service.get_active_bans", "app.routers.bans.ban_service.get_active_bans",
AsyncMock(return_value=mock_response), AsyncMock(return_value=mock_response),
@@ -135,13 +166,13 @@ class TestGetActiveBans:
assert resp.status_code == 200 assert resp.status_code == 200
assert resp.json()["total"] == 0 assert resp.json()["total"] == 0
assert resp.json()["bans"] == [] assert resp.json()["items"] == []
async def test_response_shape(self, bans_client: AsyncClient) -> None: async def test_response_shape(self, bans_client: AsyncClient) -> None:
"""GET /api/bans/active returns expected fields per ban entry.""" """GET /api/bans/active returns expected fields per ban entry."""
mock_response = ActiveBanListResponse( mock_response = DomainActiveBanList(
bans=[ bans=[
ActiveBan( DomainActiveBan(
ip="10.0.0.1", ip="10.0.0.1",
jail="nginx", jail="nginx",
banned_at=None, banned_at=None,
@@ -158,7 +189,7 @@ class TestGetActiveBans:
): ):
resp = await bans_client.get("/api/v1/bans/active") resp = await bans_client.get("/api/v1/bans/active")
ban = resp.json()["bans"][0] ban = resp.json()["items"][0]
assert "ip" in ban assert "ip" in ban
assert "jail" in ban assert "jail" in ban
assert "banned_at" in ban assert "banned_at" in ban
@@ -183,6 +214,7 @@ class TestBanIp:
resp = await bans_client.post( resp = await bans_client.post(
"/api/v1/bans", "/api/v1/bans",
json={"ip": "1.2.3.4", "jail": "sshd"}, json={"ip": "1.2.3.4", "jail": "sshd"},
headers={"X-BanGUI-Request": "1"},
) )
assert resp.status_code == 201 assert resp.status_code == 201
@@ -197,6 +229,7 @@ class TestBanIp:
resp = await bans_client.post( resp = await bans_client.post(
"/api/v1/bans", "/api/v1/bans",
json={"ip": "bad", "jail": "sshd"}, json={"ip": "bad", "jail": "sshd"},
headers={"X-BanGUI-Request": "1"},
) )
assert resp.status_code == 400 assert resp.status_code == 400
@@ -212,6 +245,7 @@ class TestBanIp:
resp = await bans_client.post( resp = await bans_client.post(
"/api/v1/bans", "/api/v1/bans",
json={"ip": "1.2.3.4", "jail": "ghost"}, json={"ip": "1.2.3.4", "jail": "ghost"},
headers={"X-BanGUI-Request": "1"},
) )
assert resp.status_code == 404 assert resp.status_code == 404
@@ -243,6 +277,7 @@ class TestUnbanIp:
"DELETE", "DELETE",
"/api/v1/bans", "/api/v1/bans",
json={"ip": "1.2.3.4", "unban_all": True}, json={"ip": "1.2.3.4", "unban_all": True},
headers={"X-BanGUI-Request": "1"},
) )
assert resp.status_code == 200 assert resp.status_code == 200
@@ -258,6 +293,7 @@ class TestUnbanIp:
"DELETE", "DELETE",
"/api/v1/bans", "/api/v1/bans",
json={"ip": "1.2.3.4", "jail": "sshd"}, json={"ip": "1.2.3.4", "jail": "sshd"},
headers={"X-BanGUI-Request": "1"},
) )
assert resp.status_code == 200 assert resp.status_code == 200
@@ -273,6 +309,7 @@ class TestUnbanIp:
"DELETE", "DELETE",
"/api/v1/bans", "/api/v1/bans",
json={"ip": "bad", "unban_all": True}, json={"ip": "bad", "unban_all": True},
headers={"X-BanGUI-Request": "1"},
) )
assert resp.status_code == 400 assert resp.status_code == 400
@@ -289,6 +326,7 @@ class TestUnbanIp:
"DELETE", "DELETE",
"/api/v1/bans", "/api/v1/bans",
json={"ip": "1.2.3.4", "jail": "ghost"}, json={"ip": "1.2.3.4", "jail": "ghost"},
headers={"X-BanGUI-Request": "1"},
) )
assert resp.status_code == 404 assert resp.status_code == 404
@@ -308,7 +346,7 @@ class TestUnbanAll:
"app.routers.bans.jail_service.unban_all_ips", "app.routers.bans.jail_service.unban_all_ips",
AsyncMock(return_value=3), AsyncMock(return_value=3),
): ):
resp = await bans_client.request("DELETE", "/api/v1/bans/all") resp = await bans_client.request("DELETE", "/api/v1/bans/all", headers={"X-BanGUI-Request": "1"})
assert resp.status_code == 200 assert resp.status_code == 200
data = resp.json() data = resp.json()
@@ -321,14 +359,12 @@ class TestUnbanAll:
"app.routers.bans.jail_service.unban_all_ips", "app.routers.bans.jail_service.unban_all_ips",
AsyncMock(return_value=0), AsyncMock(return_value=0),
): ):
resp = await bans_client.request("DELETE", "/api/v1/bans/all") resp = await bans_client.request("DELETE", "/api/v1/bans/all", headers={"X-BanGUI-Request": "1"})
assert resp.status_code == 200 assert resp.status_code == 200
assert resp.json()["count"] == 0 assert resp.json()["count"] == 0
async def test_502_when_fail2ban_unreachable( async def test_502_when_fail2ban_unreachable(self, bans_client: AsyncClient) -> None:
self, bans_client: AsyncClient
) -> None:
"""DELETE /api/bans/all returns 502 when fail2ban is unreachable.""" """DELETE /api/bans/all returns 502 when fail2ban is unreachable."""
with patch( with patch(
"app.routers.bans.jail_service.unban_all_ips", "app.routers.bans.jail_service.unban_all_ips",
@@ -339,7 +375,7 @@ class TestUnbanAll:
) )
), ),
): ):
resp = await bans_client.request("DELETE", "/api/v1/bans/all") resp = await bans_client.request("DELETE", "/api/v1/bans/all", headers={"X-BanGUI-Request": "1"})
assert resp.status_code == 502 assert resp.status_code == 502

View File

@@ -84,9 +84,7 @@ def _make_import_result() -> ImportRunResult:
def _make_log_response() -> ImportLogListResponse: def _make_log_response() -> ImportLogListResponse:
return ImportLogListResponse( return ImportLogListResponse(items=[], total=0, page=1, page_size=50)
items=[], total=0, page=1, page_size=50
)
def _make_preview() -> PreviewResponse: def _make_preview() -> PreviewResponse:
@@ -106,13 +104,17 @@ def _make_preview() -> PreviewResponse:
@pytest.fixture @pytest.fixture
async def bl_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc] async def bl_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
"""Provide an authenticated AsyncClient for blocklist endpoint tests.""" """Provide an authenticated AsyncClient for blocklist endpoint tests."""
config_dir = tmp_path / "fail2ban"
config_dir.mkdir()
settings = Settings( settings = Settings(
database_path=str(tmp_path / "bl_router_test.db"), database_path=str(tmp_path / "bl_router_test.db"),
fail2ban_socket="/tmp/fake_fail2ban.sock", fail2ban_socket="/tmp/fake_fail2ban.sock",
session_secret="test-bl-secret", fail2ban_config_dir=str(config_dir),
session_secret="test-bl-secret-that-is-long-enough!!",
session_duration_minutes=60, session_duration_minutes=60,
timezone="UTC", timezone="UTC",
log_level="debug", log_level="debug",
session_cookie_secure=False,
) )
app = create_app(settings=settings) app = create_app(settings=settings)
@@ -127,8 +129,13 @@ async def bl_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
scheduler_stub.get_job = MagicMock(return_value=None) scheduler_stub.get_job = MagicMock(return_value=None)
app.state.scheduler = scheduler_stub app.state.scheduler = scheduler_stub
# Initialize GeoCache (normally done in lifespan handler)
from app.services.geo_cache import GeoCache
app.state.geo_cache = GeoCache()
transport = ASGITransport(app=app) transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac: async with AsyncClient(transport=transport, base_url="http://test", headers={"X-BanGUI-Request": "1"}) as ac:
resp = await ac.post("/api/v1/setup", json=_SETUP_PAYLOAD) resp = await ac.post("/api/v1/setup", json=_SETUP_PAYLOAD)
assert resp.status_code == 201 assert resp.status_code == 201
@@ -277,12 +284,15 @@ class TestDeleteBlocklist:
class TestPreviewBlocklist: class TestPreviewBlocklist:
async def test_preview_returns_200(self, bl_client: AsyncClient) -> None: async def test_preview_returns_200(self, bl_client: AsyncClient) -> None:
"""GET /api/blocklists/1/preview returns 200 for existing source.""" """GET /api/blocklists/1/preview returns 200 for existing source."""
with patch( with (
patch(
"app.routers.blocklist.blocklist_service.get_source", "app.routers.blocklist.blocklist_service.get_source",
new=AsyncMock(return_value=_make_source()), new=AsyncMock(return_value=_make_source()),
), patch( ),
patch(
"app.routers.blocklist.blocklist_service.preview_source", "app.routers.blocklist.blocklist_service.preview_source",
new=AsyncMock(return_value=_make_preview()), new=AsyncMock(return_value=_make_preview()),
),
): ):
resp = await bl_client.get("/api/v1/blocklists/1/preview") resp = await bl_client.get("/api/v1/blocklists/1/preview")
assert resp.status_code == 200 assert resp.status_code == 200
@@ -296,28 +306,32 @@ class TestPreviewBlocklist:
resp = await bl_client.get("/api/v1/blocklists/999/preview") resp = await bl_client.get("/api/v1/blocklists/999/preview")
assert resp.status_code == 404 assert resp.status_code == 404
async def test_preview_returns_502_on_download_error( async def test_preview_returns_400_on_download_error(self, bl_client: AsyncClient) -> None:
self, bl_client: AsyncClient """GET /api/blocklists/1/preview returns 400 when URL is unreachable."""
) -> None: with (
"""GET /api/blocklists/1/preview returns 502 when URL is unreachable.""" patch(
with patch(
"app.routers.blocklist.blocklist_service.get_source", "app.routers.blocklist.blocklist_service.get_source",
new=AsyncMock(return_value=_make_source()), new=AsyncMock(return_value=_make_source()),
), patch( ),
patch(
"app.routers.blocklist.blocklist_service.preview_source", "app.routers.blocklist.blocklist_service.preview_source",
new=AsyncMock(side_effect=ValueError("Connection refused")), new=AsyncMock(side_effect=ValueError("Connection refused")),
),
): ):
resp = await bl_client.get("/api/v1/blocklists/1/preview") resp = await bl_client.get("/api/v1/blocklists/1/preview")
assert resp.status_code == 502 assert resp.status_code == 400
async def test_preview_response_shape(self, bl_client: AsyncClient) -> None: async def test_preview_response_shape(self, bl_client: AsyncClient) -> None:
"""Preview response has entries, valid_count, skipped_count, total_lines.""" """Preview response has entries, valid_count, skipped_count, total_lines."""
with patch( with (
patch(
"app.routers.blocklist.blocklist_service.get_source", "app.routers.blocklist.blocklist_service.get_source",
new=AsyncMock(return_value=_make_source()), new=AsyncMock(return_value=_make_source()),
), patch( ),
patch(
"app.routers.blocklist.blocklist_service.preview_source", "app.routers.blocklist.blocklist_service.preview_source",
new=AsyncMock(return_value=_make_preview()), new=AsyncMock(return_value=_make_preview()),
),
): ):
resp = await bl_client.get("/api/v1/blocklists/1/preview") resp = await bl_client.get("/api/v1/blocklists/1/preview")
body = resp.json() body = resp.json()
@@ -383,9 +397,7 @@ class TestGetSchedule:
assert "next_run_at" in body assert "next_run_at" in body
assert "last_run_at" in body assert "last_run_at" in body
async def test_schedule_response_includes_last_run_errors( async def test_schedule_response_includes_last_run_errors(self, bl_client: AsyncClient) -> None:
self, bl_client: AsyncClient
) -> None:
"""GET /api/blocklists/schedule includes last_run_errors field.""" """GET /api/blocklists/schedule includes last_run_errors field."""
info_with_errors = ScheduleInfo( info_with_errors = ScheduleInfo(
config=ScheduleConfig( config=ScheduleConfig(
@@ -457,15 +469,18 @@ class TestImportLog:
assert resp.status_code == 200 assert resp.status_code == 200
async def test_log_response_shape(self, bl_client: AsyncClient) -> None: async def test_log_response_shape(self, bl_client: AsyncClient) -> None:
"""Log response has items, total, page, page_size.""" """Log response has items and pagination metadata."""
resp = await bl_client.get("/api/v1/blocklists/log") resp = await bl_client.get("/api/v1/blocklists/log")
body = resp.json() body = resp.json()
for key in ("items", "total", "page", "page_size"): assert "items" in body
assert key in body assert "pagination" in body
pagination = body["pagination"]
for key in ("page", "page_size", "total", "total_pages", "has_next_page", "has_prev_page"):
assert key in pagination
async def test_log_empty_when_no_runs(self, bl_client: AsyncClient) -> None: async def test_log_empty_when_no_runs(self, bl_client: AsyncClient) -> None:
"""Log returns empty items list when no import runs have occurred.""" """Log returns empty items list when no import runs have occurred."""
resp = await bl_client.get("/api/v1/blocklists/log") resp = await bl_client.get("/api/v1/blocklists/log")
body = resp.json() body = resp.json()
assert body["total"] == 0 assert body["pagination"]["total"] == 0
assert body["items"] == [] assert body["items"] == []

View File

@@ -16,13 +16,15 @@ from app.main import create_app
from app.models.config import ( from app.models.config import (
Fail2BanLogResponse, Fail2BanLogResponse,
FilterConfig, FilterConfig,
GlobalConfigResponse,
JailConfig,
JailConfigListResponse,
JailConfigResponse,
RegexTestResponse,
ServiceStatusResponse, ServiceStatusResponse,
) )
from app.models.config_domain import (
DomainGlobalConfig,
DomainJailConfig,
DomainJailConfigList,
DomainMapColorThresholds,
DomainRegexTest,
)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Fixtures # Fixtures
@@ -40,9 +42,12 @@ _SETUP_PAYLOAD = {
@pytest.fixture @pytest.fixture
async def config_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc] async def config_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
"""Provide an authenticated ``AsyncClient`` for config endpoint tests.""" """Provide an authenticated ``AsyncClient`` for config endpoint tests."""
config_dir = tmp_path / "fail2ban"
config_dir.mkdir()
settings = Settings( settings = Settings(
database_path=str(tmp_path / "config_test.db"), database_path=str(tmp_path / "config_test.db"),
fail2ban_socket="/tmp/fake.sock", fail2ban_socket="/tmp/fake.sock",
fail2ban_config_dir=str(config_dir),
session_secret="test-secret-key-do-not-use-in-production", session_secret="test-secret-key-do-not-use-in-production",
session_duration_minutes=60, session_duration_minutes=60,
timezone="UTC", timezone="UTC",
@@ -58,20 +63,21 @@ async def config_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
app.state.http_session = MagicMock() app.state.http_session = MagicMock()
transport = ASGITransport(app=app) transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac: async with AsyncClient(transport=transport, base_url="http://test", headers={"X-BanGUI-Request": "1"}) as ac:
await ac.post("/api/v1/setup", json=_SETUP_PAYLOAD) setup_resp = await ac.post("/api/v1/setup", json=_SETUP_PAYLOAD)
assert setup_resp.status_code == 201, f"Setup failed: {setup_resp.status_code} {setup_resp.text}"
login = await ac.post( login = await ac.post(
"/api/v1/auth/login", "/api/v1/auth/login",
json={"password": _SETUP_PAYLOAD["master_password"]}, json={"password": _SETUP_PAYLOAD["master_password"]},
) )
assert login.status_code == 200 assert login.status_code == 200, f"Login failed: {login.status_code} {login.text}"
yield ac yield ac
await db.close() await db.close()
def _make_jail_config(name: str = "sshd") -> JailConfig: def _make_jail_config(name: str = "sshd") -> DomainJailConfig:
return JailConfig( return DomainJailConfig(
name=name, name=name,
ban_time=600, ban_time=600,
max_retry=5, max_retry=5,
@@ -98,9 +104,7 @@ class TestGetJailConfigs:
async def test_200_returns_jail_list(self, config_client: AsyncClient) -> None: async def test_200_returns_jail_list(self, config_client: AsyncClient) -> None:
"""GET /api/config/jails returns 200 with JailConfigListResponse.""" """GET /api/config/jails returns 200 with JailConfigListResponse."""
mock_response = JailConfigListResponse( mock_response = DomainJailConfigList(items=[_make_jail_config("sshd")], total=1)
items=[_make_jail_config("sshd")], total=1
)
with patch( with patch(
"app.routers.jail_config.config_service.list_jail_configs", "app.routers.jail_config.config_service.list_jail_configs",
AsyncMock(return_value=mock_response), AsyncMock(return_value=mock_response),
@@ -143,7 +147,7 @@ class TestGetJailConfig:
async def test_200_returns_jail_config(self, config_client: AsyncClient) -> None: async def test_200_returns_jail_config(self, config_client: AsyncClient) -> None:
"""GET /api/config/jails/sshd returns 200 with JailConfigResponse.""" """GET /api/config/jails/sshd returns 200 with JailConfigResponse."""
mock_response = JailConfigResponse(jail=_make_jail_config("sshd")) mock_response = _make_jail_config("sshd")
with patch( with patch(
"app.routers.jail_config.config_service.get_jail_config", "app.routers.jail_config.config_service.get_jail_config",
AsyncMock(return_value=mock_response), AsyncMock(return_value=mock_response),
@@ -211,8 +215,8 @@ class TestUpdateJailConfig:
assert resp.status_code == 404 assert resp.status_code == 404
async def test_422_on_invalid_regex(self, config_client: AsyncClient) -> None: async def test_400_on_invalid_regex(self, config_client: AsyncClient) -> None:
"""PUT /api/config/jails/sshd returns 422 for invalid regex pattern.""" """PUT /api/config/jails/sshd returns 400 for invalid regex pattern."""
from app.services.config_service import ConfigValidationError from app.services.config_service import ConfigValidationError
with patch( with patch(
@@ -224,7 +228,7 @@ class TestUpdateJailConfig:
json={"fail_regex": ["[bad"]}, json={"fail_regex": ["[bad"]},
) )
assert resp.status_code == 422 assert resp.status_code == 400
async def test_400_on_config_operation_error(self, config_client: AsyncClient) -> None: async def test_400_on_config_operation_error(self, config_client: AsyncClient) -> None:
"""PUT /api/config/jails/sshd returns 400 when set command fails.""" """PUT /api/config/jails/sshd returns 400 when set command fails."""
@@ -291,7 +295,7 @@ class TestGetGlobalConfig:
async def test_200_returns_global_config(self, config_client: AsyncClient) -> None: async def test_200_returns_global_config(self, config_client: AsyncClient) -> None:
"""GET /api/config/global returns 200 with GlobalConfigResponse.""" """GET /api/config/global returns 200 with GlobalConfigResponse."""
mock_response = GlobalConfigResponse( mock_response = DomainGlobalConfig(
log_level="WARNING", log_level="WARNING",
log_target="/var/log/fail2ban.log", log_target="/var/log/fail2ban.log",
db_purge_age=86400, db_purge_age=86400,
@@ -415,15 +419,15 @@ class TestRestartFail2ban:
assert resp.status_code == 204 assert resp.status_code == 204
async def test_503_when_fail2ban_does_not_come_back(self, config_client: AsyncClient) -> None: async def test_500_when_fail2ban_does_not_come_back(self, config_client: AsyncClient) -> None:
"""POST /api/config/restart returns 503 when fail2ban does not come back online.""" """POST /api/config/restart returns 500 when fail2ban does not come back online."""
with patch( with patch(
"app.routers.config_misc.jail_service.restart_daemon", "app.routers.config_misc.jail_service.restart_daemon",
AsyncMock(return_value=False), AsyncMock(return_value=False),
): ):
resp = await config_client.post("/api/v1/config/restart") resp = await config_client.post("/api/v1/config/restart")
assert resp.status_code == 503 assert resp.status_code == 500
async def test_409_when_stop_command_fails(self, config_client: AsyncClient) -> None: async def test_409_when_stop_command_fails(self, config_client: AsyncClient) -> None:
"""POST /api/config/restart returns 409 when fail2ban rejects the stop command.""" """POST /api/config/restart returns 409 when fail2ban rejects the stop command."""
@@ -472,7 +476,7 @@ class TestRegexTest:
async def test_200_matched(self, config_client: AsyncClient) -> None: async def test_200_matched(self, config_client: AsyncClient) -> None:
"""POST /api/config/regex-test returns matched=true for a valid match.""" """POST /api/config/regex-test returns matched=true for a valid match."""
mock_response = RegexTestResponse(matched=True, groups=["1.2.3.4"], error=None) mock_response = DomainRegexTest(matched=True, groups=["1.2.3.4"], error=None)
with patch( with patch(
"app.routers.config_misc.log_service.test_regex", "app.routers.config_misc.log_service.test_regex",
return_value=mock_response, return_value=mock_response,
@@ -490,7 +494,7 @@ class TestRegexTest:
async def test_200_not_matched(self, config_client: AsyncClient) -> None: async def test_200_not_matched(self, config_client: AsyncClient) -> None:
"""POST /api/config/regex-test returns matched=false for no match.""" """POST /api/config/regex-test returns matched=false for no match."""
mock_response = RegexTestResponse(matched=False, groups=[], error=None) mock_response = DomainRegexTest(matched=False, groups=[], error=None)
with patch( with patch(
"app.routers.config_misc.log_service.test_regex", "app.routers.config_misc.log_service.test_regex",
return_value=mock_response, return_value=mock_response,
@@ -525,9 +529,12 @@ class TestAddLogPath:
async def test_204_on_success(self, config_client: AsyncClient) -> None: async def test_204_on_success(self, config_client: AsyncClient) -> None:
"""POST /api/config/jails/sshd/logpath returns 204 on success.""" """POST /api/config/jails/sshd/logpath returns 204 on success."""
with patch( with (
patch(
"app.routers.jail_config.config_service.add_log_path", "app.routers.jail_config.config_service.add_log_path",
AsyncMock(return_value=None), AsyncMock(return_value=None),
),
patch("app.routers.jail_config.validate_log_path", return_value="/var/log/specific.log"),
): ):
resp = await config_client.post( resp = await config_client.post(
"/api/v1/config/jails/sshd/logpath", "/api/v1/config/jails/sshd/logpath",
@@ -540,9 +547,12 @@ class TestAddLogPath:
"""POST /api/config/jails/missing/logpath returns 404.""" """POST /api/config/jails/missing/logpath returns 404."""
from app.services.config_service import JailNotFoundError from app.services.config_service import JailNotFoundError
with patch( with (
patch(
"app.routers.jail_config.config_service.add_log_path", "app.routers.jail_config.config_service.add_log_path",
AsyncMock(side_effect=JailNotFoundError("missing")), AsyncMock(side_effect=JailNotFoundError("missing")),
),
patch("app.routers.jail_config.validate_log_path", return_value="/var/log/test.log"),
): ):
resp = await config_client.post( resp = await config_client.post(
"/api/v1/config/jails/missing/logpath", "/api/v1/config/jails/missing/logpath",
@@ -594,6 +604,11 @@ class TestGetMapColorThresholds:
async def test_200_returns_thresholds(self, config_client: AsyncClient) -> None: async def test_200_returns_thresholds(self, config_client: AsyncClient) -> None:
"""GET /api/config/map-color-thresholds returns 200 with current values.""" """GET /api/config/map-color-thresholds returns 200 with current values."""
mock_response = DomainMapColorThresholds(threshold_high=100, threshold_medium=50, threshold_low=20)
with patch(
"app.routers.config_misc.config_service.get_map_color_thresholds",
AsyncMock(return_value=mock_response),
):
resp = await config_client.get("/api/v1/config/map-color-thresholds") resp = await config_client.get("/api/v1/config/map-color-thresholds")
assert resp.status_code == 200 assert resp.status_code == 200
@@ -601,7 +616,6 @@ class TestGetMapColorThresholds:
assert "threshold_high" in data assert "threshold_high" in data
assert "threshold_medium" in data assert "threshold_medium" in data
assert "threshold_low" in data assert "threshold_low" in data
# Should return defaults after setup
assert data["threshold_high"] == 100 assert data["threshold_high"] == 100
assert data["threshold_medium"] == 50 assert data["threshold_medium"] == 50
assert data["threshold_low"] == 20 assert data["threshold_low"] == 20
@@ -622,9 +636,12 @@ class TestUpdateMapColorThresholds:
"threshold_medium": 80, "threshold_medium": 80,
"threshold_low": 30, "threshold_low": 30,
} }
resp = await config_client.put( mock_response = DomainMapColorThresholds(threshold_high=200, threshold_medium=80, threshold_low=30)
"/api/v1/config/map-color-thresholds", json=update_payload with patch(
) "app.routers.config_misc.config_service.get_map_color_thresholds",
AsyncMock(return_value=mock_response),
):
resp = await config_client.put("/api/v1/config/map-color-thresholds", json=update_payload)
assert resp.status_code == 200 assert resp.status_code == 200
data = resp.json() data = resp.json()
@@ -632,14 +649,6 @@ class TestUpdateMapColorThresholds:
assert data["threshold_medium"] == 80 assert data["threshold_medium"] == 80
assert data["threshold_low"] == 30 assert data["threshold_low"] == 30
# Verify the values persist
get_resp = await config_client.get("/api/v1/config/map-color-thresholds")
assert get_resp.status_code == 200
get_data = get_resp.json()
assert get_data["threshold_high"] == 200
assert get_data["threshold_medium"] == 80
assert get_data["threshold_low"] == 30
async def test_400_for_invalid_order(self, config_client: AsyncClient) -> None: async def test_400_for_invalid_order(self, config_client: AsyncClient) -> None:
"""PUT /api/config/map-color-thresholds returns 400 if thresholds are misordered.""" """PUT /api/config/map-color-thresholds returns 400 if thresholds are misordered."""
invalid_payload = { invalid_payload = {
@@ -647,28 +656,22 @@ class TestUpdateMapColorThresholds:
"threshold_medium": 50, "threshold_medium": 50,
"threshold_low": 20, "threshold_low": 20,
} }
resp = await config_client.put( resp = await config_client.put("/api/v1/config/map-color-thresholds", json=invalid_payload)
"/api/v1/config/map-color-thresholds", json=invalid_payload
)
assert resp.status_code == 400 assert resp.status_code == 400
assert "high > medium > low" in resp.json()["detail"] assert "high > medium > low" in resp.json()["detail"]
async def test_400_for_non_positive_values( async def test_400_for_non_positive_values(self, config_client: AsyncClient) -> None:
self, config_client: AsyncClient """PUT /api/config/map-color-thresholds returns 400 for non-positive values (Pydantic validation)."""
) -> None:
"""PUT /api/config/map-color-thresholds returns 422 for non-positive values (Pydantic validation)."""
invalid_payload = { invalid_payload = {
"threshold_high": 100, "threshold_high": 100,
"threshold_medium": 50, "threshold_medium": 50,
"threshold_low": 0, "threshold_low": 0,
} }
resp = await config_client.put( resp = await config_client.put("/api/v1/config/map-color-thresholds", json=invalid_payload)
"/api/v1/config/map-color-thresholds", json=invalid_payload
)
# Pydantic validates ge=1 constraint before our service code runs # Pydantic validates gt=0 constraint before our service code runs; ValueError -> 400
assert resp.status_code == 422 assert resp.status_code == 400
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -752,9 +755,7 @@ class TestActivateJail:
"app.routers.jail_config.jail_config_service.activate_jail", "app.routers.jail_config.jail_config_service.activate_jail",
AsyncMock(return_value=mock_response), AsyncMock(return_value=mock_response),
): ):
resp = await config_client.post( resp = await config_client.post("/api/v1/config/jails/apache-auth/activate", json={})
"/api/v1/config/jails/apache-auth/activate", json={}
)
assert resp.status_code == 200 assert resp.status_code == 200
data = resp.json() data = resp.json()
@@ -765,9 +766,7 @@ class TestActivateJail:
"""POST .../activate accepts override fields.""" """POST .../activate accepts override fields."""
from app.models.config import JailActivationResponse from app.models.config import JailActivationResponse
mock_response = JailActivationResponse( mock_response = JailActivationResponse(name="apache-auth", active=True, message="Activated.")
name="apache-auth", active=True, message="Activated."
)
with patch( with patch(
"app.routers.jail_config.jail_config_service.activate_jail", "app.routers.jail_config.jail_config_service.activate_jail",
AsyncMock(return_value=mock_response), AsyncMock(return_value=mock_response),
@@ -791,9 +790,7 @@ class TestActivateJail:
"app.routers.jail_config.jail_config_service.activate_jail", "app.routers.jail_config.jail_config_service.activate_jail",
AsyncMock(side_effect=JailNotFoundInConfigError("missing")), AsyncMock(side_effect=JailNotFoundInConfigError("missing")),
): ):
resp = await config_client.post( resp = await config_client.post("/api/v1/config/jails/missing/activate", json={})
"/api/v1/config/jails/missing/activate", json={}
)
assert resp.status_code == 404 assert resp.status_code == 404
@@ -805,15 +802,11 @@ class TestActivateJail:
"app.routers.jail_config.jail_config_service.activate_jail", "app.routers.jail_config.jail_config_service.activate_jail",
AsyncMock(side_effect=JailAlreadyActiveError("sshd")), AsyncMock(side_effect=JailAlreadyActiveError("sshd")),
): ):
resp = await config_client.post( resp = await config_client.post("/api/v1/config/jails/sshd/activate", json={})
"/api/v1/config/jails/sshd/activate", json={}
)
assert resp.status_code == 409 assert resp.status_code == 409
async def test_failed_activation_does_not_set_last_activation( async def test_failed_activation_does_not_set_last_activation(self, config_client: AsyncClient) -> None:
self, config_client: AsyncClient
) -> None:
"""A failed activation must not leave a stale last_activation record.""" """A failed activation must not leave a stale last_activation record."""
from app.exceptions import Fail2BanConnectionError from app.exceptions import Fail2BanConnectionError
@@ -822,9 +815,7 @@ class TestActivateJail:
"app.routers.jail_config.jail_config_service.activate_jail", "app.routers.jail_config.jail_config_service.activate_jail",
AsyncMock(side_effect=Fail2BanConnectionError("No socket", "/tmp/fake.sock")), AsyncMock(side_effect=Fail2BanConnectionError("No socket", "/tmp/fake.sock")),
): ):
resp = await config_client.post( resp = await config_client.post("/api/v1/config/jails/sshd/activate", json={})
"/api/v1/config/jails/sshd/activate", json={}
)
assert resp.status_code == 502 assert resp.status_code == 502
assert config_client._transport.app.state.last_activation is None assert config_client._transport.app.state.last_activation is None
@@ -837,9 +828,7 @@ class TestActivateJail:
"app.routers.jail_config.jail_config_service.activate_jail", "app.routers.jail_config.jail_config_service.activate_jail",
AsyncMock(side_effect=JailNameError("bad name")), AsyncMock(side_effect=JailNameError("bad name")),
): ):
resp = await config_client.post( resp = await config_client.post("/api/v1/config/jails/bad-name/activate", json={})
"/api/v1/config/jails/bad-name/activate", json={}
)
assert resp.status_code == 400 assert resp.status_code == 400
@@ -866,9 +855,7 @@ class TestActivateJail:
"app.routers.jail_config.jail_config_service.activate_jail", "app.routers.jail_config.jail_config_service.activate_jail",
AsyncMock(return_value=blocked_response), AsyncMock(return_value=blocked_response),
): ):
resp = await config_client.post( resp = await config_client.post("/api/v1/config/jails/airsonic-auth/activate", json={})
"/api/v1/config/jails/airsonic-auth/activate", json={}
)
assert resp.status_code == 200 assert resp.status_code == 200
data = resp.json() data = resp.json()
@@ -914,9 +901,7 @@ class TestDeactivateJail:
"app.routers.jail_config.jail_config_service.deactivate_jail", "app.routers.jail_config.jail_config_service.deactivate_jail",
AsyncMock(side_effect=JailNotFoundInConfigError("missing")), AsyncMock(side_effect=JailNotFoundInConfigError("missing")),
): ):
resp = await config_client.post( resp = await config_client.post("/api/v1/config/jails/missing/deactivate")
"/api/v1/config/jails/missing/deactivate"
)
assert resp.status_code == 404 assert resp.status_code == 404
@@ -928,9 +913,7 @@ class TestDeactivateJail:
"app.routers.jail_config.jail_config_service.deactivate_jail", "app.routers.jail_config.jail_config_service.deactivate_jail",
AsyncMock(side_effect=JailAlreadyInactiveError("apache-auth")), AsyncMock(side_effect=JailAlreadyInactiveError("apache-auth")),
): ):
resp = await config_client.post( resp = await config_client.post("/api/v1/config/jails/apache-auth/deactivate")
"/api/v1/config/jails/apache-auth/deactivate"
)
assert resp.status_code == 409 assert resp.status_code == 409
@@ -942,9 +925,7 @@ class TestDeactivateJail:
"app.routers.jail_config.jail_config_service.deactivate_jail", "app.routers.jail_config.jail_config_service.deactivate_jail",
AsyncMock(side_effect=JailNameError("bad")), AsyncMock(side_effect=JailNameError("bad")),
): ):
resp = await config_client.post( resp = await config_client.post("/api/v1/config/jails/sshd/deactivate")
"/api/v1/config/jails/sshd/deactivate"
)
assert resp.status_code == 400 assert resp.status_code == 400
@@ -1011,10 +992,11 @@ class TestListFilters:
async def test_200_returns_filter_list(self, config_client: AsyncClient) -> None: async def test_200_returns_filter_list(self, config_client: AsyncClient) -> None:
"""GET /api/config/filters returns 200 with FilterListResponse.""" """GET /api/config/filters returns 200 with FilterListResponse."""
from app.models.config import FilterListResponse
mock_response = FilterListResponse( from app.models.config_domain import DomainFilterConfig, DomainFilterList
filters=[_make_filter_config("sshd", active=True)],
mock_response = DomainFilterList(
items=[DomainFilterConfig(name="sshd", filename="sshd.conf", active=True, used_by_jails=["sshd"])],
total=1, total=1,
) )
with patch( with patch(
@@ -1031,11 +1013,12 @@ class TestListFilters:
async def test_200_empty_filter_list(self, config_client: AsyncClient) -> None: async def test_200_empty_filter_list(self, config_client: AsyncClient) -> None:
"""GET /api/config/filters returns 200 with empty list when no filters found.""" """GET /api/config/filters returns 200 with empty list when no filters found."""
from app.models.config import FilterListResponse
from app.models.config_domain import DomainFilterList
with patch( with patch(
"app.routers.filter_config.filter_config_service.list_filters", "app.routers.filter_config.filter_config_service.list_filters",
AsyncMock(return_value=FilterListResponse(filters=[], total=0)), AsyncMock(return_value=DomainFilterList(items=[], total=0)),
): ):
resp = await config_client.get("/api/v1/config/filters") resp = await config_client.get("/api/v1/config/filters")
@@ -1043,16 +1026,15 @@ class TestListFilters:
assert resp.json()["total"] == 0 assert resp.json()["total"] == 0
assert resp.json()["filters"] == [] assert resp.json()["filters"] == []
async def test_active_filters_sorted_before_inactive( async def test_active_filters_sorted_before_inactive(self, config_client: AsyncClient) -> None:
self, config_client: AsyncClient
) -> None:
"""GET /api/config/filters returns active filters before inactive ones.""" """GET /api/config/filters returns active filters before inactive ones."""
from app.models.config import FilterListResponse
mock_response = FilterListResponse( from app.models.config_domain import DomainFilterConfig, DomainFilterList
filters=[
_make_filter_config("nginx", active=False), mock_response = DomainFilterList(
_make_filter_config("sshd", active=True), items=[
DomainFilterConfig(name="nginx", filename="nginx.conf", active=False),
DomainFilterConfig(name="sshd", filename="sshd.conf", active=True, used_by_jails=["sshd"]),
], ],
total=2, total=2,
) )
@@ -1155,8 +1137,8 @@ class TestUpdateFilter:
assert resp.status_code == 404 assert resp.status_code == 404
async def test_422_for_invalid_regex(self, config_client: AsyncClient) -> None: async def test_400_for_invalid_regex(self, config_client: AsyncClient) -> None:
"""PUT /api/config/filters/sshd returns 422 for bad regex.""" """PUT /api/config/filters/sshd returns 400 for bad regex."""
from app.services.filter_config_service import FilterInvalidRegexError from app.services.filter_config_service import FilterInvalidRegexError
with patch( with patch(
@@ -1168,7 +1150,7 @@ class TestUpdateFilter:
json={"failregex": ["[bad"]}, json={"failregex": ["[bad"]},
) )
assert resp.status_code == 422 assert resp.status_code == 400
async def test_400_for_invalid_name(self, config_client: AsyncClient) -> None: async def test_400_for_invalid_name(self, config_client: AsyncClient) -> None:
"""PUT /api/config/filters/... with bad name returns 400.""" """PUT /api/config/filters/... with bad name returns 400."""
@@ -1245,8 +1227,8 @@ class TestCreateFilter:
assert resp.status_code == 409 assert resp.status_code == 409
async def test_422_for_invalid_regex(self, config_client: AsyncClient) -> None: async def test_400_for_invalid_regex(self, config_client: AsyncClient) -> None:
"""POST /api/config/filters returns 422 for bad regex.""" """POST /api/config/filters returns 400 for bad regex."""
from app.services.filter_config_service import FilterInvalidRegexError from app.services.filter_config_service import FilterInvalidRegexError
with patch( with patch(
@@ -1258,7 +1240,7 @@ class TestCreateFilter:
json={"name": "test", "failregex": ["[bad"]}, json={"name": "test", "failregex": ["[bad"]},
) )
assert resp.status_code == 422 assert resp.status_code == 400
async def test_400_for_invalid_name(self, config_client: AsyncClient) -> None: async def test_400_for_invalid_name(self, config_client: AsyncClient) -> None:
"""POST /api/config/filters returns 400 for invalid filter name.""" """POST /api/config/filters returns 400 for invalid filter name."""
@@ -1572,9 +1554,7 @@ class TestUpdateActionRouter:
"app.routers.action_config.action_config_service.update_action", "app.routers.action_config.action_config_service.update_action",
AsyncMock(side_effect=ActionNotFoundError("missing")), AsyncMock(side_effect=ActionNotFoundError("missing")),
): ):
resp = await config_client.put( resp = await config_client.put("/api/v1/config/actions/missing", json={})
"/api/v1/config/actions/missing", json={}
)
assert resp.status_code == 404 assert resp.status_code == 404
@@ -1585,9 +1565,7 @@ class TestUpdateActionRouter:
"app.routers.action_config.action_config_service.update_action", "app.routers.action_config.action_config_service.update_action",
AsyncMock(side_effect=ActionNameError()), AsyncMock(side_effect=ActionNameError()),
): ):
resp = await config_client.put( resp = await config_client.put("/api/v1/config/actions/badname", json={})
"/api/v1/config/actions/badname", json={}
)
assert resp.status_code == 400 assert resp.status_code == 400
@@ -1808,9 +1786,7 @@ class TestRemoveActionFromJailRouter:
"app.routers.action_config.action_config_service.remove_action_from_jail", "app.routers.action_config.action_config_service.remove_action_from_jail",
AsyncMock(return_value=None), AsyncMock(return_value=None),
): ):
resp = await config_client.delete( resp = await config_client.delete("/api/v1/config/jails/sshd/action/iptables")
"/api/v1/config/jails/sshd/action/iptables"
)
assert resp.status_code == 204 assert resp.status_code == 204
@@ -1821,9 +1797,7 @@ class TestRemoveActionFromJailRouter:
"app.routers.action_config.action_config_service.remove_action_from_jail", "app.routers.action_config.action_config_service.remove_action_from_jail",
AsyncMock(side_effect=JailNotFoundInConfigError("missing")), AsyncMock(side_effect=JailNotFoundInConfigError("missing")),
): ):
resp = await config_client.delete( resp = await config_client.delete("/api/v1/config/jails/missing/action/iptables")
"/api/v1/config/jails/missing/action/iptables"
)
assert resp.status_code == 404 assert resp.status_code == 404
@@ -1834,9 +1808,7 @@ class TestRemoveActionFromJailRouter:
"app.routers.action_config.action_config_service.remove_action_from_jail", "app.routers.action_config.action_config_service.remove_action_from_jail",
AsyncMock(side_effect=JailNameError()), AsyncMock(side_effect=JailNameError()),
): ):
resp = await config_client.delete( resp = await config_client.delete("/api/v1/config/jails/badjailname/action/iptables")
"/api/v1/config/jails/badjailname/action/iptables"
)
assert resp.status_code == 400 assert resp.status_code == 400
@@ -1847,9 +1819,7 @@ class TestRemoveActionFromJailRouter:
"app.routers.action_config.action_config_service.remove_action_from_jail", "app.routers.action_config.action_config_service.remove_action_from_jail",
AsyncMock(side_effect=ActionNameError()), AsyncMock(side_effect=ActionNameError()),
): ):
resp = await config_client.delete( resp = await config_client.delete("/api/v1/config/jails/sshd/action/badactionname")
"/api/v1/config/jails/sshd/action/badactionname"
)
assert resp.status_code == 400 assert resp.status_code == 400
@@ -1858,9 +1828,7 @@ class TestRemoveActionFromJailRouter:
"app.routers.action_config.action_config_service.remove_action_from_jail", "app.routers.action_config.action_config_service.remove_action_from_jail",
AsyncMock(return_value=None), AsyncMock(return_value=None),
) as mock_rm: ) as mock_rm:
resp = await config_client.delete( resp = await config_client.delete("/api/v1/config/jails/sshd/action/iptables?reload=true")
"/api/v1/config/jails/sshd/action/iptables?reload=true"
)
assert resp.status_code == 204 assert resp.status_code == 204
assert mock_rm.call_args.kwargs.get("do_reload") is True assert mock_rm.call_args.kwargs.get("do_reload") is True
@@ -1965,10 +1933,10 @@ class TestGetFail2BanLog:
assert resp.status_code == 502 assert resp.status_code == 502
async def test_422_for_lines_exceeding_max(self, config_client: AsyncClient) -> None: async def test_400_for_lines_exceeding_max(self, config_client: AsyncClient) -> None:
"""GET /api/config/fail2ban-log returns 422 for lines > 2000.""" """GET /api/config/fail2ban-log returns 400 for lines > 2000."""
resp = await config_client.get("/api/v1/config/fail2ban-log?lines=9999") resp = await config_client.get("/api/v1/config/fail2ban-log?lines=9999")
assert resp.status_code == 422 assert resp.status_code == 400
async def test_401_when_unauthenticated(self, config_client: AsyncClient) -> None: async def test_401_when_unauthenticated(self, config_client: AsyncClient) -> None:
"""GET /api/config/fail2ban-log requires authentication.""" """GET /api/config/fail2ban-log requires authentication."""
@@ -2001,7 +1969,7 @@ class TestGetServiceStatus:
async def test_200_when_online(self, config_client: AsyncClient) -> None: async def test_200_when_online(self, config_client: AsyncClient) -> None:
"""GET /api/config/service-status returns 200 with full status when online.""" """GET /api/config/service-status returns 200 with full status when online."""
with patch( with patch(
"app.routers.config_misc.health_service.get_service_status", "app.services.health_service.get_service_status",
AsyncMock(return_value=self._mock_status(online=True)), AsyncMock(return_value=self._mock_status(online=True)),
): ):
resp = await config_client.get("/api/v1/config/service-status") resp = await config_client.get("/api/v1/config/service-status")
@@ -2016,7 +1984,7 @@ class TestGetServiceStatus:
async def test_200_when_offline(self, config_client: AsyncClient) -> None: async def test_200_when_offline(self, config_client: AsyncClient) -> None:
"""GET /api/config/service-status returns 200 with offline=False when daemon is down.""" """GET /api/config/service-status returns 200 with offline=False when daemon is down."""
with patch( with patch(
"app.routers.config_misc.health_service.get_service_status", "app.services.health_service.get_service_status",
AsyncMock(return_value=self._mock_status(online=False)), AsyncMock(return_value=self._mock_status(online=False)),
): ):
resp = await config_client.get("/api/v1/config/service-status") resp = await config_client.get("/api/v1/config/service-status")
@@ -2049,9 +2017,7 @@ class TestValidateJailEndpoint:
"""Returns 200 with valid=True when the jail config has no issues.""" """Returns 200 with valid=True when the jail config has no issues."""
from app.models.config import JailValidationResult from app.models.config import JailValidationResult
mock_result = JailValidationResult( mock_result = JailValidationResult(jail_name="sshd", valid=True, issues=[])
jail_name="sshd", valid=True, issues=[]
)
with patch( with patch(
"app.routers.jail_config.jail_config_service.validate_jail_config", "app.routers.jail_config.jail_config_service.validate_jail_config",
AsyncMock(return_value=mock_result), AsyncMock(return_value=mock_result),
@@ -2069,9 +2035,7 @@ class TestValidateJailEndpoint:
from app.models.config import JailValidationIssue, JailValidationResult from app.models.config import JailValidationIssue, JailValidationResult
issue = JailValidationIssue(field="filter", message="Filter file not found: filter.d/bad.conf (or .local)") issue = JailValidationIssue(field="filter", message="Filter file not found: filter.d/bad.conf (or .local)")
mock_result = JailValidationResult( mock_result = JailValidationResult(jail_name="sshd", valid=False, issues=[issue])
jail_name="sshd", valid=False, issues=[issue]
)
with patch( with patch(
"app.routers.jail_config.jail_config_service.validate_jail_config", "app.routers.jail_config.jail_config_service.validate_jail_config",
AsyncMock(return_value=mock_result), AsyncMock(return_value=mock_result),
@@ -2109,9 +2073,7 @@ class TestValidateJailEndpoint:
class TestPendingRecovery: class TestPendingRecovery:
"""Tests for ``GET /api/config/pending-recovery``.""" """Tests for ``GET /api/config/pending-recovery``."""
async def test_returns_null_when_no_pending_recovery( async def test_returns_null_when_no_pending_recovery(self, config_client: AsyncClient) -> None:
self, config_client: AsyncClient
) -> None:
"""Returns null body (204-like 200) when pending_recovery is not set.""" """Returns null body (204-like 200) when pending_recovery is not set."""
app = config_client._transport.app # type: ignore[attr-defined] app = config_client._transport.app # type: ignore[attr-defined]
app.state.pending_recovery = None app.state.pending_recovery = None
@@ -2156,9 +2118,7 @@ class TestPendingRecovery:
class TestRollbackEndpoint: class TestRollbackEndpoint:
"""Tests for ``POST /api/config/jails/{name}/rollback``.""" """Tests for ``POST /api/config/jails/{name}/rollback``."""
async def test_200_success_clears_pending_recovery( async def test_200_success_clears_pending_recovery(self, config_client: AsyncClient) -> None:
self, config_client: AsyncClient
) -> None:
"""A successful rollback returns 200 and clears app.state.pending_recovery.""" """A successful rollback returns 200 and clears app.state.pending_recovery."""
import datetime import datetime
@@ -2193,9 +2153,7 @@ class TestRollbackEndpoint:
# Successful rollback must clear the pending record. # Successful rollback must clear the pending record.
assert app.state.pending_recovery is None assert app.state.pending_recovery is None
async def test_200_fail_preserves_pending_recovery( async def test_200_fail_preserves_pending_recovery(self, config_client: AsyncClient) -> None:
self, config_client: AsyncClient
) -> None:
"""When fail2ban is still down after rollback, pending_recovery is retained.""" """When fail2ban is still down after rollback, pending_recovery is retained."""
import datetime import datetime
@@ -2248,4 +2206,3 @@ class TestRollbackEndpoint:
base_url="http://test", base_url="http://test",
).post("/api/v1/config/jails/sshd/rollback") ).post("/api/v1/config/jails/sshd/rollback")
assert resp.status_code == 401 assert resp.status_code == 401

View File

@@ -31,14 +31,16 @@ async def _do_setup(client: AsyncClient) -> None:
async def _login(client: AsyncClient, password: str = "Mysecretpass1!") -> str: async def _login(client: AsyncClient, password: str = "Mysecretpass1!") -> str:
"""Helper: perform login and return the session token.""" """Helper: perform login and return the session token from the cookie."""
resp = await client.post( resp = await client.post(
"/api/v1/auth/login", "/api/v1/auth/login",
json={"password": password}, json={"password": password},
headers={"X-BanGUI-Request": "1"}, headers={"X-BanGUI-Request": "1"},
) )
assert resp.status_code == 200 assert resp.status_code == 200
return str(resp.json()["token"]) token = resp.cookies.get(SESSION_COOKIE_NAME)
assert token is not None
return str(token)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -49,9 +51,7 @@ async def _login(client: AsyncClient, password: str = "Mysecretpass1!") -> str:
class TestCsrfProtection: class TestCsrfProtection:
"""CSRF middleware validation tests.""" """CSRF middleware validation tests."""
async def test_post_with_cookie_and_csrf_header_passes( async def test_post_with_cookie_and_csrf_header_passes(self, client: AsyncClient) -> None:
self, client: AsyncClient
) -> None:
"""POST with session cookie and CSRF header is allowed.""" """POST with session cookie and CSRF header is allowed."""
await _do_setup(client) await _do_setup(client)
token = await _login(client) token = await _login(client)
@@ -65,9 +65,7 @@ class TestCsrfProtection:
# Expect 200 (logout succeeds) not 403 (CSRF failed) # Expect 200 (logout succeeds) not 403 (CSRF failed)
assert response.status_code == 200 assert response.status_code == 200
async def test_post_with_cookie_without_csrf_header_rejected( async def test_post_with_cookie_without_csrf_header_rejected(self, client: AsyncClient) -> None:
self, client: AsyncClient
) -> None:
"""POST with session cookie but no CSRF header is rejected with 403.""" """POST with session cookie but no CSRF header is rejected with 403."""
await _do_setup(client) await _do_setup(client)
token = await _login(client) token = await _login(client)
@@ -83,9 +81,7 @@ class TestCsrfProtection:
assert "detail" in body assert "detail" in body
assert "CSRF" in body["detail"] assert "CSRF" in body["detail"]
async def test_post_with_cookie_with_wrong_csrf_value_rejected( async def test_post_with_cookie_with_wrong_csrf_value_rejected(self, client: AsyncClient) -> None:
self, client: AsyncClient
) -> None:
"""POST with session cookie and wrong CSRF header value is rejected.""" """POST with session cookie and wrong CSRF header value is rejected."""
await _do_setup(client) await _do_setup(client)
token = await _login(client) token = await _login(client)
@@ -98,9 +94,7 @@ class TestCsrfProtection:
) )
assert response.status_code == 403 assert response.status_code == 403
async def test_post_with_bearer_token_no_csrf_header_passes( async def test_post_with_bearer_token_no_csrf_header_passes(self, client: AsyncClient) -> None:
self, client: AsyncClient
) -> None:
"""POST with Bearer token but no CSRF header is allowed (not CSRF-vulnerable).""" """POST with Bearer token but no CSRF header is allowed (not CSRF-vulnerable)."""
await _do_setup(client) await _do_setup(client)
token = await _login(client) token = await _login(client)
@@ -113,9 +107,7 @@ class TestCsrfProtection:
# Expect 200 (logout succeeds) not 403 (CSRF check should be skipped) # Expect 200 (logout succeeds) not 403 (CSRF check should be skipped)
assert response.status_code == 200 assert response.status_code == 200
async def test_get_with_cookie_no_csrf_header_passes( async def test_get_with_cookie_no_csrf_header_passes(self, client: AsyncClient) -> None:
self, client: AsyncClient
) -> None:
"""GET with session cookie but no CSRF header is allowed (safe method).""" """GET with session cookie but no CSRF header is allowed (safe method)."""
await _do_setup(client) await _do_setup(client)
token = await _login(client) token = await _login(client)
@@ -129,9 +121,7 @@ class TestCsrfProtection:
# Expect 200 (session valid) not 403 (CSRF check should be skipped for GET) # Expect 200 (session valid) not 403 (CSRF check should be skipped for GET)
assert response.status_code == 200 assert response.status_code == 200
async def test_options_with_cookie_no_csrf_header_passes( async def test_options_with_cookie_no_csrf_header_passes(self, client: AsyncClient) -> None:
self, client: AsyncClient
) -> None:
"""OPTIONS with session cookie but no CSRF header is allowed (safe method).""" """OPTIONS with session cookie but no CSRF header is allowed (safe method)."""
await _do_setup(client) await _do_setup(client)
token = await _login(client) token = await _login(client)
@@ -145,9 +135,7 @@ class TestCsrfProtection:
# Expect not 403 # Expect not 403
assert response.status_code != 403 assert response.status_code != 403
async def test_head_with_cookie_no_csrf_header_passes( async def test_head_with_cookie_no_csrf_header_passes(self, client: AsyncClient) -> None:
self, client: AsyncClient
) -> None:
"""HEAD with session cookie but no CSRF header is allowed (safe method).""" """HEAD with session cookie but no CSRF header is allowed (safe method)."""
await _do_setup(client) await _do_setup(client)
token = await _login(client) token = await _login(client)
@@ -161,9 +149,7 @@ class TestCsrfProtection:
# Expect not 403 # Expect not 403
assert response.status_code != 403 assert response.status_code != 403
async def test_delete_with_cookie_and_csrf_header_passes( async def test_delete_with_cookie_and_csrf_header_passes(self, client: AsyncClient) -> None:
self, client: AsyncClient
) -> None:
"""DELETE with session cookie and CSRF header is allowed.""" """DELETE with session cookie and CSRF header is allowed."""
await _do_setup(client) await _do_setup(client)
token = await _login(client) token = await _login(client)
@@ -180,9 +166,7 @@ class TestCsrfProtection:
# Should not be 403 (CSRF failed) # Should not be 403 (CSRF failed)
assert response.status_code != 403 assert response.status_code != 403
async def test_delete_with_cookie_without_csrf_header_rejected( async def test_delete_with_cookie_without_csrf_header_rejected(self, client: AsyncClient) -> None:
self, client: AsyncClient
) -> None:
"""DELETE with session cookie but no CSRF header is rejected with 403.""" """DELETE with session cookie but no CSRF header is rejected with 403."""
await _do_setup(client) await _do_setup(client)
token = await _login(client) token = await _login(client)
@@ -197,9 +181,7 @@ class TestCsrfProtection:
) )
assert response.status_code == 403 assert response.status_code == 403
async def test_put_with_cookie_and_csrf_header_passes( async def test_put_with_cookie_and_csrf_header_passes(self, client: AsyncClient) -> None:
self, client: AsyncClient
) -> None:
"""PUT with session cookie and CSRF header is allowed.""" """PUT with session cookie and CSRF header is allowed."""
await _do_setup(client) await _do_setup(client)
token = await _login(client) token = await _login(client)
@@ -214,9 +196,7 @@ class TestCsrfProtection:
# Should not be 403 (CSRF failed) # Should not be 403 (CSRF failed)
assert response.status_code != 403 assert response.status_code != 403
async def test_put_with_cookie_without_csrf_header_rejected( async def test_put_with_cookie_without_csrf_header_rejected(self, client: AsyncClient) -> None:
self, client: AsyncClient
) -> None:
"""PUT with session cookie but no CSRF header is rejected with 403.""" """PUT with session cookie but no CSRF header is rejected with 403."""
await _do_setup(client) await _do_setup(client)
token = await _login(client) token = await _login(client)
@@ -230,9 +210,7 @@ class TestCsrfProtection:
) )
assert response.status_code == 403 assert response.status_code == 403
async def test_patch_with_cookie_and_csrf_header_passes( async def test_patch_with_cookie_and_csrf_header_passes(self, client: AsyncClient) -> None:
self, client: AsyncClient
) -> None:
"""PATCH with session cookie and CSRF header is allowed.""" """PATCH with session cookie and CSRF header is allowed."""
await _do_setup(client) await _do_setup(client)
token = await _login(client) token = await _login(client)
@@ -247,9 +225,7 @@ class TestCsrfProtection:
# Should not be 403 (CSRF failed) # Should not be 403 (CSRF failed)
assert response.status_code != 403 assert response.status_code != 403
async def test_patch_with_cookie_without_csrf_header_rejected( async def test_patch_with_cookie_without_csrf_header_rejected(self, client: AsyncClient) -> None:
self, client: AsyncClient
) -> None:
"""PATCH with session cookie but no CSRF header is rejected with 403.""" """PATCH with session cookie but no CSRF header is rejected with 403."""
await _do_setup(client) await _do_setup(client)
token = await _login(client) token = await _login(client)
@@ -262,9 +238,7 @@ class TestCsrfProtection:
) )
assert response.status_code == 403 assert response.status_code == 403
async def test_post_without_cookie_no_csrf_header_passes( async def test_post_without_cookie_no_csrf_header_passes(self, client: AsyncClient) -> None:
self, client: AsyncClient
) -> None:
"""POST without session cookie or Bearer token bypasses CSRF check.""" """POST without session cookie or Bearer token bypasses CSRF check."""
await _do_setup(client) await _do_setup(client)
@@ -279,9 +253,7 @@ class TestCsrfProtection:
# (Actually logout is idempotent and doesn't require auth, so we expect 200) # (Actually logout is idempotent and doesn't require auth, so we expect 200)
assert response.status_code in (200, 401) assert response.status_code in (200, 401)
async def test_bearer_token_via_authorization_header( async def test_bearer_token_via_authorization_header(self, client: AsyncClient) -> None:
self, client: AsyncClient
) -> None:
"""Bearer token in Authorization header bypasses CSRF check.""" """Bearer token in Authorization header bypasses CSRF check."""
await _do_setup(client) await _do_setup(client)
token = await _login(client) token = await _login(client)

View File

@@ -10,13 +10,17 @@ import pytest
from httpx import ASGITransport, AsyncClient from httpx import ASGITransport, AsyncClient
import app import app
from app.config import Settings from app.config import Settings
from app.db import init_db from app.db import init_db
from app.main import create_app from app.main import create_app
from app.models.ban import ( from app.models.ban_domain import (
DashboardBanItem, DomainBansByCountry,
DashboardBanListResponse, DomainBansByJail,
DomainBanTrend,
DomainBanTrendBucket,
DomainDashboardBanItem,
DomainDashboardBanList,
DomainJailBanCount,
) )
from app.models.server import ServerStatus from app.models.server import ServerStatus
@@ -25,7 +29,7 @@ from app.models.server import ServerStatus
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
_SETUP_PAYLOAD = { _SETUP_PAYLOAD = {
"master_password": "testpassword1", "master_password": "Testpass1!",
"database_path": "bangui.db", "database_path": "bangui.db",
"fail2ban_socket": "/var/run/fail2ban/fail2ban.sock", "fail2ban_socket": "/var/run/fail2ban/fail2ban.sock",
"timezone": "UTC", "timezone": "UTC",
@@ -40,13 +44,17 @@ async def dashboard_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
Unlike the shared ``client`` fixture this one also exposes access to Unlike the shared ``client`` fixture this one also exposes access to
``app.state`` via the app instance so we can seed the status cache. ``app.state`` via the app instance so we can seed the status cache.
""" """
config_dir = tmp_path / "fail2ban"
config_dir.mkdir()
settings = Settings( settings = Settings(
database_path=str(tmp_path / "dashboard_test.db"), database_path=str(tmp_path / "dashboard_test.db"),
fail2ban_socket="/tmp/fake_fail2ban.sock", fail2ban_socket="/tmp/fake_fail2ban.sock",
session_secret="test-dashboard-secret", fail2ban_config_dir=str(config_dir),
session_secret="test-dashboard-secret-that-is-long-enough",
session_duration_minutes=60, session_duration_minutes=60,
timezone="UTC", timezone="UTC",
log_level="debug", log_level="debug",
session_cookie_secure=False,
) )
app = create_app(settings=settings) app = create_app(settings=settings)
@@ -66,8 +74,13 @@ async def dashboard_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
# Provide a stub HTTP session so ban/access endpoints can access app.state.http_session. # Provide a stub HTTP session so ban/access endpoints can access app.state.http_session.
app.state.http_session = MagicMock() app.state.http_session = MagicMock()
# Initialize GeoCache (normally done in lifespan handler)
from app.services.geo_cache import GeoCache
app.state.geo_cache = GeoCache()
transport = ASGITransport(app=app) transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac: async with AsyncClient(transport=transport, base_url="http://test", headers={"X-BanGUI-Request": "1"}) as ac:
# Complete setup so the middleware doesn't redirect. # Complete setup so the middleware doesn't redirect.
resp = await ac.post("/api/v1/setup", json=_SETUP_PAYLOAD) resp = await ac.post("/api/v1/setup", json=_SETUP_PAYLOAD)
assert resp.status_code == 201 assert resp.status_code == 201
@@ -87,13 +100,17 @@ async def dashboard_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
@pytest.fixture @pytest.fixture
async def offline_dashboard_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc] async def offline_dashboard_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
"""Like ``dashboard_client`` but with an offline server status.""" """Like ``dashboard_client`` but with an offline server status."""
config_dir = tmp_path / "fail2ban"
config_dir.mkdir()
settings = Settings( settings = Settings(
database_path=str(tmp_path / "dashboard_offline_test.db"), database_path=str(tmp_path / "dashboard_offline_test.db"),
fail2ban_socket="/tmp/fake_fail2ban.sock", fail2ban_socket="/tmp/fake_fail2ban.sock",
session_secret="test-dashboard-offline-secret", fail2ban_config_dir=str(config_dir),
session_secret="test-dashboard-offline-secret-long-enough",
session_duration_minutes=60, session_duration_minutes=60,
timezone="UTC", timezone="UTC",
log_level="debug", log_level="debug",
session_cookie_secure=False,
) )
app = create_app(settings=settings) app = create_app(settings=settings)
@@ -105,8 +122,13 @@ async def offline_dashboard_client(tmp_path: Path) -> AsyncClient: # type: igno
app.state.server_status = ServerStatus(online=False) app.state.server_status = ServerStatus(online=False)
app.state.http_session = MagicMock() app.state.http_session = MagicMock()
# Initialize GeoCache (normally done in lifespan handler)
from app.services.geo_cache import GeoCache
app.state.geo_cache = GeoCache()
transport = ASGITransport(app=app) transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac: async with AsyncClient(transport=transport, base_url="http://test", headers={"X-BanGUI-Request": "1"}) as ac:
resp = await ac.post("/api/v1/setup", json=_SETUP_PAYLOAD) resp = await ac.post("/api/v1/setup", json=_SETUP_PAYLOAD)
assert resp.status_code == 201 assert resp.status_code == 201
@@ -129,25 +151,19 @@ async def offline_dashboard_client(tmp_path: Path) -> AsyncClient: # type: igno
class TestDashboardStatus: class TestDashboardStatus:
"""GET /api/dashboard/status.""" """GET /api/dashboard/status."""
async def test_returns_200_when_authenticated( async def test_returns_200_when_authenticated(self, dashboard_client: AsyncClient) -> None:
self, dashboard_client: AsyncClient
) -> None:
"""Authenticated request returns HTTP 200.""" """Authenticated request returns HTTP 200."""
response = await dashboard_client.get("/api/v1/dashboard/status") response = await dashboard_client.get("/api/v1/dashboard/status")
assert response.status_code == 200 assert response.status_code == 200
async def test_returns_401_when_unauthenticated( async def test_returns_401_when_unauthenticated(self, client: AsyncClient) -> None:
self, client: AsyncClient
) -> None:
"""Unauthenticated request returns HTTP 401.""" """Unauthenticated request returns HTTP 401."""
# Complete setup so the middleware allows the request through. # Complete setup so the middleware allows the request through.
await client.post("/api/v1/setup", json=_SETUP_PAYLOAD) await client.post("/api/v1/setup", json=_SETUP_PAYLOAD)
response = await client.get("/api/v1/dashboard/status") response = await client.get("/api/v1/dashboard/status")
assert response.status_code == 401 assert response.status_code == 401
async def test_response_shape_when_online( async def test_response_shape_when_online(self, dashboard_client: AsyncClient) -> None:
self, dashboard_client: AsyncClient
) -> None:
"""Response contains the expected ``status`` object shape.""" """Response contains the expected ``status`` object shape."""
response = await dashboard_client.get("/api/v1/dashboard/status") response = await dashboard_client.get("/api/v1/dashboard/status")
body = response.json() body = response.json()
@@ -161,9 +177,7 @@ class TestDashboardStatus:
assert "total_bans" in status assert "total_bans" in status
assert "total_failures" in status assert "total_failures" in status
async def test_cached_values_returned_when_online( async def test_cached_values_returned_when_online(self, dashboard_client: AsyncClient) -> None:
self, dashboard_client: AsyncClient
) -> None:
"""Endpoint returns the exact values from ``app.state.server_status``.""" """Endpoint returns the exact values from ``app.state.server_status``."""
response = await dashboard_client.get("/api/v1/dashboard/status") response = await dashboard_client.get("/api/v1/dashboard/status")
body = response.json() body = response.json()
@@ -175,9 +189,7 @@ class TestDashboardStatus:
assert status["total_bans"] == 10 assert status["total_bans"] == 10
assert status["total_failures"] == 5 assert status["total_failures"] == 5
async def test_offline_status_returned_correctly( async def test_offline_status_returned_correctly(self, offline_dashboard_client: AsyncClient) -> None:
self, offline_dashboard_client: AsyncClient
) -> None:
"""Endpoint returns online=False when the cache holds an offline snapshot.""" """Endpoint returns online=False when the cache holds an offline snapshot."""
response = await offline_dashboard_client.get("/api/v1/dashboard/status") response = await offline_dashboard_client.get("/api/v1/dashboard/status")
assert response.status_code == 200 assert response.status_code == 200
@@ -190,9 +202,7 @@ class TestDashboardStatus:
assert status["total_bans"] == 0 assert status["total_bans"] == 0
assert status["total_failures"] == 0 assert status["total_failures"] == 0
async def test_returns_offline_when_state_not_initialised( async def test_returns_offline_when_state_not_initialised(self, client: AsyncClient) -> None:
self, client: AsyncClient
) -> None:
"""Endpoint returns online=False as a safe default if the cache is absent.""" """Endpoint returns online=False as a safe default if the cache is absent."""
# Setup + login so the endpoint is reachable. # Setup + login so the endpoint is reachable.
await client.post("/api/v1/setup", json=_SETUP_PAYLOAD) await client.post("/api/v1/setup", json=_SETUP_PAYLOAD)
@@ -200,7 +210,9 @@ class TestDashboardStatus:
"/api/v1/auth/login", "/api/v1/auth/login",
json={"password": _SETUP_PAYLOAD["master_password"]}, json={"password": _SETUP_PAYLOAD["master_password"]},
) )
# server_status is not set on app.state in the shared `client` fixture. # Clear server_status to simulate uninitialized state.
client._transport.app.state.server_status = None # type: ignore[attr-defined]
client._transport.app.state.server_status = None # type: ignore[attr-defined]
response = await client.get("/api/v1/dashboard/status") response = await client.get("/api/v1/dashboard/status")
assert response.status_code == 200 assert response.status_code == 200
status = response.json()["status"] status = response.json()["status"]
@@ -212,10 +224,10 @@ class TestDashboardStatus:
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _make_ban_list_response(n: int = 2) -> DashboardBanListResponse: def _make_ban_list_response(n: int = 2) -> DomainDashboardBanList:
"""Build a mock DashboardBanListResponse with *n* items.""" """Build a mock DomainDashboardBanList with *n* items."""
items = [ items = [
DashboardBanItem( DomainDashboardBanItem(
ip=f"1.2.3.{i}", ip=f"1.2.3.{i}",
jail="sshd", jail="sshd",
banned_at="2026-03-01T10:00:00+00:00", banned_at="2026-03-01T10:00:00+00:00",
@@ -229,15 +241,18 @@ def _make_ban_list_response(n: int = 2) -> DashboardBanListResponse:
) )
for i in range(n) for i in range(n)
] ]
return DashboardBanListResponse(items=items, total=n, page=1, page_size=100) return DomainDashboardBanList(
items=items,
total=n,
page=1,
page_size=100,
)
class TestDashboardBans: class TestDashboardBans:
"""GET /api/dashboard/bans.""" """GET /api/dashboard/bans."""
async def test_returns_200_when_authenticated( async def test_returns_200_when_authenticated(self, dashboard_client: AsyncClient) -> None:
self, dashboard_client: AsyncClient
) -> None:
"""Authenticated request returns HTTP 200.""" """Authenticated request returns HTTP 200."""
with patch( with patch(
"app.routers.dashboard.ban_service.list_bans", "app.routers.dashboard.ban_service.list_bans",
@@ -246,17 +261,13 @@ class TestDashboardBans:
response = await dashboard_client.get("/api/v1/dashboard/bans") response = await dashboard_client.get("/api/v1/dashboard/bans")
assert response.status_code == 200 assert response.status_code == 200
async def test_returns_401_when_unauthenticated( async def test_returns_401_when_unauthenticated(self, client: AsyncClient) -> None:
self, client: AsyncClient
) -> None:
"""Unauthenticated request returns HTTP 401.""" """Unauthenticated request returns HTTP 401."""
await client.post("/api/v1/setup", json=_SETUP_PAYLOAD) await client.post("/api/v1/setup", json=_SETUP_PAYLOAD)
response = await client.get("/api/v1/dashboard/bans") response = await client.get("/api/v1/dashboard/bans")
assert response.status_code == 401 assert response.status_code == 401
async def test_response_contains_items_and_total( async def test_response_contains_items_and_total(self, dashboard_client: AsyncClient) -> None:
self, dashboard_client: AsyncClient
) -> None:
"""Response body contains ``items`` list and ``total`` count.""" """Response body contains ``items`` list and ``total`` count."""
with patch( with patch(
"app.routers.dashboard.ban_service.list_bans", "app.routers.dashboard.ban_service.list_bans",
@@ -266,8 +277,8 @@ class TestDashboardBans:
body = response.json() body = response.json()
assert "items" in body assert "items" in body
assert "total" in body assert "pagination" in body
assert body["total"] == 3 assert body["pagination"]["total"] == 3
assert len(body["items"]) == 3 assert len(body["items"]) == 3
async def test_default_range_is_24h(self, dashboard_client: AsyncClient) -> None: async def test_default_range_is_24h(self, dashboard_client: AsyncClient) -> None:
@@ -279,9 +290,7 @@ class TestDashboardBans:
called_range = mock_list.call_args[0][1] called_range = mock_list.call_args[0][1]
assert called_range == "24h" assert called_range == "24h"
async def test_accepts_time_range_param( async def test_accepts_time_range_param(self, dashboard_client: AsyncClient) -> None:
self, dashboard_client: AsyncClient
) -> None:
"""The ``range`` query parameter is forwarded to ban_service.""" """The ``range`` query parameter is forwarded to ban_service."""
mock_list = AsyncMock(return_value=_make_ban_list_response()) mock_list = AsyncMock(return_value=_make_ban_list_response())
with patch("app.routers.dashboard.ban_service.list_bans", new=mock_list): with patch("app.routers.dashboard.ban_service.list_bans", new=mock_list):
@@ -290,9 +299,7 @@ class TestDashboardBans:
called_range = mock_list.call_args[0][1] called_range = mock_list.call_args[0][1]
assert called_range == "7d" assert called_range == "7d"
async def test_accepts_source_param( async def test_accepts_source_param(self, dashboard_client: AsyncClient) -> None:
self, dashboard_client: AsyncClient
) -> None:
"""The ``source`` query parameter is forwarded to ban_service.""" """The ``source`` query parameter is forwarded to ban_service."""
mock_list = AsyncMock(return_value=_make_ban_list_response()) mock_list = AsyncMock(return_value=_make_ban_list_response())
with patch("app.routers.dashboard.ban_service.list_bans", new=mock_list): with patch("app.routers.dashboard.ban_service.list_bans", new=mock_list):
@@ -301,11 +308,14 @@ class TestDashboardBans:
called_source = mock_list.call_args[1]["source"] called_source = mock_list.call_args[1]["source"]
assert called_source == "archive" assert called_source == "archive"
async def test_empty_ban_list_returns_zero_total( async def test_empty_ban_list_returns_zero_total(self, dashboard_client: AsyncClient) -> None:
self, dashboard_client: AsyncClient
) -> None:
"""Returns ``total=0`` and empty ``items`` when no bans are in range.""" """Returns ``total=0`` and empty ``items`` when no bans are in range."""
empty = DashboardBanListResponse(items=[], total=0, page=1, page_size=100) empty = DomainDashboardBanList(
items=[],
total=0,
page=1,
page_size=100,
)
with patch( with patch(
"app.routers.dashboard.ban_service.list_bans", "app.routers.dashboard.ban_service.list_bans",
new=AsyncMock(return_value=empty), new=AsyncMock(return_value=empty),
@@ -313,7 +323,7 @@ class TestDashboardBans:
response = await dashboard_client.get("/api/v1/dashboard/bans") response = await dashboard_client.get("/api/v1/dashboard/bans")
body = response.json() body = response.json()
assert body["total"] == 0 assert body["pagination"]["total"] == 0
assert body["items"] == [] assert body["items"] == []
async def test_item_shape_is_correct(self, dashboard_client: AsyncClient) -> None: async def test_item_shape_is_correct(self, dashboard_client: AsyncClient) -> None:
@@ -336,12 +346,10 @@ class TestDashboardBans:
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _make_bans_by_country_response() -> object: def _make_bans_by_country_response() -> DomainBansByCountry:
"""Build a stub BansByCountryResponse.""" """Build a stub DomainBansByCountry."""
from app.models.ban import BansByCountryResponse
items = [ items = [
DashboardBanItem( DomainDashboardBanItem(
ip="1.2.3.4", ip="1.2.3.4",
jail="sshd", jail="sshd",
banned_at="2026-03-01T10:00:00+00:00", banned_at="2026-03-01T10:00:00+00:00",
@@ -353,7 +361,7 @@ def _make_bans_by_country_response() -> object:
ban_count=1, ban_count=1,
origin="selfblock", origin="selfblock",
), ),
DashboardBanItem( DomainDashboardBanItem(
ip="5.6.7.8", ip="5.6.7.8",
jail="blocklist-import", jail="blocklist-import",
banned_at="2026-03-01T10:05:00+00:00", banned_at="2026-03-01T10:05:00+00:00",
@@ -366,10 +374,10 @@ def _make_bans_by_country_response() -> object:
origin="blocklist", origin="blocklist",
), ),
] ]
return BansByCountryResponse( return DomainBansByCountry(
countries={"DE": 1, "US": 1}, countries={"DE": 1, "US": 1},
country_names={"DE": "Germany", "US": "United States"}, country_names={"DE": "Germany", "US": "United States"},
bans=items, items=items,
total=2, total=2,
) )
@@ -378,9 +386,7 @@ def _make_bans_by_country_response() -> object:
class TestBansByCountry: class TestBansByCountry:
"""GET /api/dashboard/bans/by-country.""" """GET /api/dashboard/bans/by-country."""
async def test_returns_200_when_authenticated( async def test_returns_200_when_authenticated(self, dashboard_client: AsyncClient) -> None:
self, dashboard_client: AsyncClient
) -> None:
"""Authenticated request returns HTTP 200.""" """Authenticated request returns HTTP 200."""
with patch( with patch(
"app.routers.dashboard.ban_service.bans_by_country", "app.routers.dashboard.ban_service.bans_by_country",
@@ -389,9 +395,7 @@ class TestBansByCountry:
response = await dashboard_client.get("/api/v1/dashboard/bans/by-country") response = await dashboard_client.get("/api/v1/dashboard/bans/by-country")
assert response.status_code == 200 assert response.status_code == 200
async def test_returns_401_when_unauthenticated( async def test_returns_401_when_unauthenticated(self, client: AsyncClient) -> None:
self, client: AsyncClient
) -> None:
"""Unauthenticated request returns HTTP 401.""" """Unauthenticated request returns HTTP 401."""
await client.post("/api/v1/setup", json=_SETUP_PAYLOAD) await client.post("/api/v1/setup", json=_SETUP_PAYLOAD)
response = await client.get("/api/v1/dashboard/bans/by-country") response = await client.get("/api/v1/dashboard/bans/by-country")
@@ -415,38 +419,26 @@ class TestBansByCountry:
assert body["countries"]["US"] == 1 assert body["countries"]["US"] == 1
assert body["country_names"]["DE"] == "Germany" assert body["country_names"]["DE"] == "Germany"
async def test_accepts_time_range_param( async def test_accepts_time_range_param(self, dashboard_client: AsyncClient) -> None:
self, dashboard_client: AsyncClient
) -> None:
"""The range query parameter is forwarded to ban_service.""" """The range query parameter is forwarded to ban_service."""
mock_fn = AsyncMock(return_value=_make_bans_by_country_response()) mock_fn = AsyncMock(return_value=_make_bans_by_country_response())
with patch( with patch("app.routers.dashboard.ban_service.bans_by_country", new=mock_fn):
"app.routers.dashboard.ban_service.bans_by_country", new=mock_fn
):
await dashboard_client.get("/api/v1/dashboard/bans/by-country?range=7d") await dashboard_client.get("/api/v1/dashboard/bans/by-country?range=7d")
called_range = mock_fn.call_args[0][1] called_range = mock_fn.call_args[0][1]
assert called_range == "7d" assert called_range == "7d"
async def test_invalid_source_returns_422( async def test_invalid_source_returns_400(self, dashboard_client: AsyncClient) -> None:
self, dashboard_client: AsyncClient """An invalid source value returns HTTP 400."""
) -> None: response = await dashboard_client.get("/api/v1/dashboard/bans/by-country?source=invalid")
"""An invalid source value returns HTTP 422.""" assert response.status_code == 400
response = await dashboard_client.get(
"/api/v1/dashboard/bans/by-country?source=invalid"
)
assert response.status_code == 422
async def test_empty_window_returns_empty_response( async def test_empty_window_returns_empty_response(self, dashboard_client: AsyncClient) -> None:
self, dashboard_client: AsyncClient
) -> None:
"""Empty time range returns empty countries dict and bans list.""" """Empty time range returns empty countries dict and bans list."""
from app.models.ban import BansByCountryResponse empty = DomainBansByCountry(
empty = BansByCountryResponse(
countries={}, countries={},
country_names={}, country_names={},
bans=[], items=[],
total=0, total=0,
) )
with patch( with patch(
@@ -469,9 +461,7 @@ class TestBansByCountry:
class TestDashboardBansOriginField: class TestDashboardBansOriginField:
"""Verify that the ``origin`` field is present in API responses.""" """Verify that the ``origin`` field is present in API responses."""
async def test_origin_present_in_ban_list_items( async def test_origin_present_in_ban_list_items(self, dashboard_client: AsyncClient) -> None:
self, dashboard_client: AsyncClient
) -> None:
"""Each item in ``/api/dashboard/bans`` carries an ``origin`` field.""" """Each item in ``/api/dashboard/bans`` carries an ``origin`` field."""
with patch( with patch(
"app.routers.dashboard.ban_service.list_bans", "app.routers.dashboard.ban_service.list_bans",
@@ -483,9 +473,7 @@ class TestDashboardBansOriginField:
assert "origin" in item assert "origin" in item
assert item["origin"] in ("blocklist", "selfblock") assert item["origin"] in ("blocklist", "selfblock")
async def test_selfblock_origin_serialised_correctly( async def test_selfblock_origin_serialised_correctly(self, dashboard_client: AsyncClient) -> None:
self, dashboard_client: AsyncClient
) -> None:
"""A ban from a non-blocklist jail serialises as ``"selfblock"``.""" """A ban from a non-blocklist jail serialises as ``"selfblock"``."""
with patch( with patch(
"app.routers.dashboard.ban_service.list_bans", "app.routers.dashboard.ban_service.list_bans",
@@ -497,9 +485,7 @@ class TestDashboardBansOriginField:
assert item["jail"] == "sshd" assert item["jail"] == "sshd"
assert item["origin"] == "selfblock" assert item["origin"] == "selfblock"
async def test_origin_present_in_bans_by_country( async def test_origin_present_in_bans_by_country(self, dashboard_client: AsyncClient) -> None:
self, dashboard_client: AsyncClient
) -> None:
"""Each ban in ``/api/dashboard/bans/by-country`` carries an ``origin``.""" """Each ban in ``/api/dashboard/bans/by-country`` carries an ``origin``."""
with patch( with patch(
"app.routers.dashboard.ban_service.bans_by_country", "app.routers.dashboard.ban_service.bans_by_country",
@@ -512,9 +498,7 @@ class TestDashboardBansOriginField:
origins = {ban["origin"] for ban in bans} origins = {ban["origin"] for ban in bans}
assert origins == {"blocklist", "selfblock"} assert origins == {"blocklist", "selfblock"}
async def test_bans_by_country_source_param_forwarded( async def test_bans_by_country_source_param_forwarded(self, dashboard_client: AsyncClient) -> None:
self, dashboard_client: AsyncClient
) -> None:
"""The ``source`` query parameter is forwarded to bans_by_country.""" """The ``source`` query parameter is forwarded to bans_by_country."""
mock_fn = AsyncMock(return_value=_make_bans_by_country_response()) mock_fn = AsyncMock(return_value=_make_bans_by_country_response())
with patch("app.routers.dashboard.ban_service.bans_by_country", new=mock_fn): with patch("app.routers.dashboard.ban_service.bans_by_country", new=mock_fn):
@@ -522,22 +506,16 @@ class TestDashboardBansOriginField:
assert mock_fn.call_args[1]["source"] == "archive" assert mock_fn.call_args[1]["source"] == "archive"
async def test_bans_by_country_country_code_forwarded( async def test_bans_by_country_country_code_forwarded(self, dashboard_client: AsyncClient) -> None:
self, dashboard_client: AsyncClient
) -> None:
"""The ``country_code`` query parameter is forwarded to bans_by_country.""" """The ``country_code`` query parameter is forwarded to bans_by_country."""
mock_fn = AsyncMock(return_value=_make_bans_by_country_response()) mock_fn = AsyncMock(return_value=_make_bans_by_country_response())
with patch("app.routers.dashboard.ban_service.bans_by_country", new=mock_fn): with patch("app.routers.dashboard.ban_service.bans_by_country", new=mock_fn):
await dashboard_client.get( await dashboard_client.get("/api/v1/dashboard/bans/by-country?country_code=DE")
"/api/v1/dashboard/bans/by-country?country_code=DE"
)
_, kwargs = mock_fn.call_args _, kwargs = mock_fn.call_args
assert kwargs.get("country_code") == "DE" assert kwargs.get("country_code") == "DE"
async def test_blocklist_origin_serialised_correctly( async def test_blocklist_origin_serialised_correctly(self, dashboard_client: AsyncClient) -> None:
self, dashboard_client: AsyncClient
) -> None:
"""A ban from the ``blocklist-import`` jail serialises as ``"blocklist"``.""" """A ban from the ``blocklist-import`` jail serialises as ``"blocklist"``."""
with patch( with patch(
"app.routers.dashboard.ban_service.bans_by_country", "app.routers.dashboard.ban_service.bans_by_country",
@@ -558,9 +536,7 @@ class TestDashboardBansOriginField:
class TestOriginFilterParam: class TestOriginFilterParam:
"""Verify that the ``origin`` query parameter is forwarded to the service.""" """Verify that the ``origin`` query parameter is forwarded to the service."""
async def test_bans_origin_blocklist_forwarded_to_service( async def test_bans_origin_blocklist_forwarded_to_service(self, dashboard_client: AsyncClient) -> None:
self, dashboard_client: AsyncClient
) -> None:
"""``?origin=blocklist`` is passed to ``ban_service.list_bans``.""" """``?origin=blocklist`` is passed to ``ban_service.list_bans``."""
mock_list = AsyncMock(return_value=_make_ban_list_response()) mock_list = AsyncMock(return_value=_make_ban_list_response())
with patch("app.routers.dashboard.ban_service.list_bans", new=mock_list): with patch("app.routers.dashboard.ban_service.list_bans", new=mock_list):
@@ -569,9 +545,7 @@ class TestOriginFilterParam:
_, kwargs = mock_list.call_args _, kwargs = mock_list.call_args
assert kwargs.get("origin") == "blocklist" assert kwargs.get("origin") == "blocklist"
async def test_bans_origin_selfblock_forwarded_to_service( async def test_bans_origin_selfblock_forwarded_to_service(self, dashboard_client: AsyncClient) -> None:
self, dashboard_client: AsyncClient
) -> None:
"""``?origin=selfblock`` is passed to ``ban_service.list_bans``.""" """``?origin=selfblock`` is passed to ``ban_service.list_bans``."""
mock_list = AsyncMock(return_value=_make_ban_list_response()) mock_list = AsyncMock(return_value=_make_ban_list_response())
with patch("app.routers.dashboard.ban_service.list_bans", new=mock_list): with patch("app.routers.dashboard.ban_service.list_bans", new=mock_list):
@@ -580,9 +554,7 @@ class TestOriginFilterParam:
_, kwargs = mock_list.call_args _, kwargs = mock_list.call_args
assert kwargs.get("origin") == "selfblock" assert kwargs.get("origin") == "selfblock"
async def test_bans_no_origin_param_defaults_to_none( async def test_bans_no_origin_param_defaults_to_none(self, dashboard_client: AsyncClient) -> None:
self, dashboard_client: AsyncClient
) -> None:
"""Omitting ``origin`` passes ``None`` to the service (no filtering).""" """Omitting ``origin`` passes ``None`` to the service (no filtering)."""
mock_list = AsyncMock(return_value=_make_ban_list_response()) mock_list = AsyncMock(return_value=_make_ban_list_response())
with patch("app.routers.dashboard.ban_service.list_bans", new=mock_list): with patch("app.routers.dashboard.ban_service.list_bans", new=mock_list):
@@ -591,36 +563,24 @@ class TestOriginFilterParam:
_, kwargs = mock_list.call_args _, kwargs = mock_list.call_args
assert kwargs.get("origin") is None assert kwargs.get("origin") is None
async def test_bans_invalid_origin_returns_422( async def test_bans_invalid_origin_returns_400(self, dashboard_client: AsyncClient) -> None:
self, dashboard_client: AsyncClient """An invalid ``origin`` value returns HTTP 400."""
) -> None:
"""An invalid ``origin`` value returns HTTP 422 Unprocessable Entity."""
response = await dashboard_client.get("/api/v1/dashboard/bans?origin=invalid") response = await dashboard_client.get("/api/v1/dashboard/bans?origin=invalid")
assert response.status_code == 422 assert response.status_code == 400
async def test_by_country_origin_blocklist_forwarded( async def test_by_country_origin_blocklist_forwarded(self, dashboard_client: AsyncClient) -> None:
self, dashboard_client: AsyncClient
) -> None:
"""``?origin=blocklist`` is passed to ``ban_service.bans_by_country``.""" """``?origin=blocklist`` is passed to ``ban_service.bans_by_country``."""
mock_fn = AsyncMock(return_value=_make_bans_by_country_response()) mock_fn = AsyncMock(return_value=_make_bans_by_country_response())
with patch( with patch("app.routers.dashboard.ban_service.bans_by_country", new=mock_fn):
"app.routers.dashboard.ban_service.bans_by_country", new=mock_fn await dashboard_client.get("/api/v1/dashboard/bans/by-country?origin=blocklist")
):
await dashboard_client.get(
"/api/v1/dashboard/bans/by-country?origin=blocklist"
)
_, kwargs = mock_fn.call_args _, kwargs = mock_fn.call_args
assert kwargs.get("origin") == "blocklist" assert kwargs.get("origin") == "blocklist"
async def test_by_country_no_origin_defaults_to_none( async def test_by_country_no_origin_defaults_to_none(self, dashboard_client: AsyncClient) -> None:
self, dashboard_client: AsyncClient
) -> None:
"""Omitting ``origin`` passes ``None`` to ``bans_by_country``.""" """Omitting ``origin`` passes ``None`` to ``bans_by_country``."""
mock_fn = AsyncMock(return_value=_make_bans_by_country_response()) mock_fn = AsyncMock(return_value=_make_bans_by_country_response())
with patch( with patch("app.routers.dashboard.ban_service.bans_by_country", new=mock_fn):
"app.routers.dashboard.ban_service.bans_by_country", new=mock_fn
):
await dashboard_client.get("/api/v1/dashboard/bans/by-country") await dashboard_client.get("/api/v1/dashboard/bans/by-country")
_, kwargs = mock_fn.call_args _, kwargs = mock_fn.call_args
@@ -632,24 +592,17 @@ class TestOriginFilterParam:
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _make_ban_trend_response(n_buckets: int = 24) -> object: def _make_ban_trend_response(n_buckets: int = 24) -> DomainBanTrend:
"""Build a stub :class:`~app.models.ban.BanTrendResponse`.""" """Build a stub :class:`~app.models.ban_domain.DomainBanTrend`."""
from app.models.ban import BanTrendBucket, BanTrendResponse buckets = [DomainBanTrendBucket(timestamp=f"2026-03-01T{i:02d}:00:00+00:00", count=i) for i in range(n_buckets)]
return DomainBanTrend(buckets=buckets, bucket_size="1h")
buckets = [
BanTrendBucket(timestamp=f"2026-03-01T{i:02d}:00:00+00:00", count=i)
for i in range(n_buckets)
]
return BanTrendResponse(buckets=buckets, bucket_size="1h")
@pytest.mark.anyio @pytest.mark.anyio
class TestBanTrend: class TestBanTrend:
"""GET /api/dashboard/bans/trend.""" """GET /api/dashboard/bans/trend."""
async def test_returns_200_when_authenticated( async def test_returns_200_when_authenticated(self, dashboard_client: AsyncClient) -> None:
self, dashboard_client: AsyncClient
) -> None:
"""Authenticated request returns HTTP 200.""" """Authenticated request returns HTTP 200."""
with patch( with patch(
"app.routers.dashboard.ban_service.ban_trend", "app.routers.dashboard.ban_service.ban_trend",
@@ -658,9 +611,7 @@ class TestBanTrend:
response = await dashboard_client.get("/api/v1/dashboard/bans/trend") response = await dashboard_client.get("/api/v1/dashboard/bans/trend")
assert response.status_code == 200 assert response.status_code == 200
async def test_returns_401_when_unauthenticated( async def test_returns_401_when_unauthenticated(self, client: AsyncClient) -> None:
self, client: AsyncClient
) -> None:
"""Unauthenticated request returns HTTP 401.""" """Unauthenticated request returns HTTP 401."""
await client.post("/api/v1/setup", json=_SETUP_PAYLOAD) await client.post("/api/v1/setup", json=_SETUP_PAYLOAD)
response = await client.get("/api/v1/dashboard/bans/trend") response = await client.get("/api/v1/dashboard/bans/trend")
@@ -680,9 +631,7 @@ class TestBanTrend:
assert len(body["buckets"]) == 24 assert len(body["buckets"]) == 24
assert body["bucket_size"] == "1h" assert body["bucket_size"] == "1h"
async def test_each_bucket_has_timestamp_and_count( async def test_each_bucket_has_timestamp_and_count(self, dashboard_client: AsyncClient) -> None:
self, dashboard_client: AsyncClient
) -> None:
"""Every element of ``buckets`` has ``timestamp`` and ``count``.""" """Every element of ``buckets`` has ``timestamp`` and ``count``."""
with patch( with patch(
"app.routers.dashboard.ban_service.ban_trend", "app.routers.dashboard.ban_service.ban_trend",
@@ -717,16 +666,12 @@ class TestBanTrend:
"""``?origin=blocklist`` is passed as a keyword arg to the service.""" """``?origin=blocklist`` is passed as a keyword arg to the service."""
mock_fn = AsyncMock(return_value=_make_ban_trend_response()) mock_fn = AsyncMock(return_value=_make_ban_trend_response())
with patch("app.routers.dashboard.ban_service.ban_trend", new=mock_fn): with patch("app.routers.dashboard.ban_service.ban_trend", new=mock_fn):
await dashboard_client.get( await dashboard_client.get("/api/v1/dashboard/bans/trend?origin=blocklist")
"/api/v1/dashboard/bans/trend?origin=blocklist"
)
_, kwargs = mock_fn.call_args _, kwargs = mock_fn.call_args
assert kwargs.get("origin") == "blocklist" assert kwargs.get("origin") == "blocklist"
async def test_no_origin_defaults_to_none( async def test_no_origin_defaults_to_none(self, dashboard_client: AsyncClient) -> None:
self, dashboard_client: AsyncClient
) -> None:
"""Omitting ``origin`` passes ``None`` to the service.""" """Omitting ``origin`` passes ``None`` to the service."""
mock_fn = AsyncMock(return_value=_make_ban_trend_response()) mock_fn = AsyncMock(return_value=_make_ban_trend_response())
with patch("app.routers.dashboard.ban_service.ban_trend", new=mock_fn): with patch("app.routers.dashboard.ban_service.ban_trend", new=mock_fn):
@@ -735,29 +680,19 @@ class TestBanTrend:
_, kwargs = mock_fn.call_args _, kwargs = mock_fn.call_args
assert kwargs.get("origin") is None assert kwargs.get("origin") is None
async def test_invalid_range_returns_422( async def test_invalid_range_returns_400(self, dashboard_client: AsyncClient) -> None:
self, dashboard_client: AsyncClient """An invalid ``range`` value returns HTTP 400."""
) -> None: response = await dashboard_client.get("/api/v1/dashboard/bans/trend?range=invalid")
"""An invalid ``range`` value returns HTTP 422.""" assert response.status_code == 400
response = await dashboard_client.get(
"/api/v1/dashboard/bans/trend?range=invalid"
)
assert response.status_code == 422
async def test_invalid_source_returns_422( async def test_invalid_source_returns_400(self, dashboard_client: AsyncClient) -> None:
self, dashboard_client: AsyncClient """An invalid source value returns HTTP 400."""
) -> None: response = await dashboard_client.get("/api/v1/dashboard/bans/trend?source=invalid")
"""An invalid source value returns HTTP 422.""" assert response.status_code == 400
response = await dashboard_client.get(
"/api/v1/dashboard/bans/trend?source=invalid"
)
assert response.status_code == 422
async def test_empty_buckets_response(self, dashboard_client: AsyncClient) -> None: async def test_empty_buckets_response(self, dashboard_client: AsyncClient) -> None:
"""Empty bucket list is serialised correctly.""" """Empty bucket list is serialised correctly."""
from app.models.ban import BanTrendResponse empty = DomainBanTrend(buckets=[], bucket_size="1h")
empty = BanTrendResponse(buckets=[], bucket_size="1h")
with patch( with patch(
"app.routers.dashboard.ban_service.ban_trend", "app.routers.dashboard.ban_service.ban_trend",
new=AsyncMock(return_value=empty), new=AsyncMock(return_value=empty),
@@ -774,14 +709,12 @@ class TestBanTrend:
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _make_bans_by_jail_response() -> object: def _make_bans_by_jail_response() -> DomainBansByJail:
"""Build a stub :class:`~app.models.ban.BansByJailResponse`.""" """Build a stub :class:`~app.models.ban_domain.DomainBansByJail`."""
from app.models.ban import BansByJailResponse, JailBanCount return DomainBansByJail(
return BansByJailResponse(
jails=[ jails=[
JailBanCount(jail="sshd", count=10), DomainJailBanCount(jail="sshd", count=10),
JailBanCount(jail="nginx", count=5), DomainJailBanCount(jail="nginx", count=5),
], ],
total=15, total=15,
) )
@@ -791,9 +724,7 @@ def _make_bans_by_jail_response() -> object:
class TestBansByJail: class TestBansByJail:
"""GET /api/dashboard/bans/by-jail.""" """GET /api/dashboard/bans/by-jail."""
async def test_returns_200_when_authenticated( async def test_returns_200_when_authenticated(self, dashboard_client: AsyncClient) -> None:
self, dashboard_client: AsyncClient
) -> None:
"""Authenticated request returns HTTP 200.""" """Authenticated request returns HTTP 200."""
with patch( with patch(
"app.routers.dashboard.ban_service.bans_by_jail", "app.routers.dashboard.ban_service.bans_by_jail",
@@ -802,9 +733,7 @@ class TestBansByJail:
response = await dashboard_client.get("/api/v1/dashboard/bans/by-jail") response = await dashboard_client.get("/api/v1/dashboard/bans/by-jail")
assert response.status_code == 200 assert response.status_code == 200
async def test_returns_401_when_unauthenticated( async def test_returns_401_when_unauthenticated(self, client: AsyncClient) -> None:
self, client: AsyncClient
) -> None:
"""Unauthenticated request returns HTTP 401.""" """Unauthenticated request returns HTTP 401."""
await client.post("/api/v1/setup", json=_SETUP_PAYLOAD) await client.post("/api/v1/setup", json=_SETUP_PAYLOAD)
response = await client.get("/api/v1/dashboard/bans/by-jail") response = await client.get("/api/v1/dashboard/bans/by-jail")
@@ -823,9 +752,7 @@ class TestBansByJail:
assert "total" in body assert "total" in body
assert isinstance(body["total"], int) assert isinstance(body["total"], int)
async def test_each_jail_has_name_and_count( async def test_each_jail_has_name_and_count(self, dashboard_client: AsyncClient) -> None:
self, dashboard_client: AsyncClient
) -> None:
"""Every element of ``jails`` has ``jail`` (string) and ``count`` (int).""" """Every element of ``jails`` has ``jail`` (string) and ``count`` (int)."""
with patch( with patch(
"app.routers.dashboard.ban_service.bans_by_jail", "app.routers.dashboard.ban_service.bans_by_jail",
@@ -861,16 +788,12 @@ class TestBansByJail:
"""``?origin=blocklist`` is passed as a keyword arg to the service.""" """``?origin=blocklist`` is passed as a keyword arg to the service."""
mock_fn = AsyncMock(return_value=_make_bans_by_jail_response()) mock_fn = AsyncMock(return_value=_make_bans_by_jail_response())
with patch("app.routers.dashboard.ban_service.bans_by_jail", new=mock_fn): with patch("app.routers.dashboard.ban_service.bans_by_jail", new=mock_fn):
await dashboard_client.get( await dashboard_client.get("/api/v1/dashboard/bans/by-jail?origin=blocklist")
"/api/v1/dashboard/bans/by-jail?origin=blocklist"
)
_, kwargs = mock_fn.call_args _, kwargs = mock_fn.call_args
assert kwargs.get("origin") == "blocklist" assert kwargs.get("origin") == "blocklist"
async def test_no_origin_defaults_to_none( async def test_no_origin_defaults_to_none(self, dashboard_client: AsyncClient) -> None:
self, dashboard_client: AsyncClient
) -> None:
"""Omitting ``origin`` passes ``None`` to the service.""" """Omitting ``origin`` passes ``None`` to the service."""
mock_fn = AsyncMock(return_value=_make_bans_by_jail_response()) mock_fn = AsyncMock(return_value=_make_bans_by_jail_response())
with patch("app.routers.dashboard.ban_service.bans_by_jail", new=mock_fn): with patch("app.routers.dashboard.ban_service.bans_by_jail", new=mock_fn):
@@ -879,23 +802,15 @@ class TestBansByJail:
_, kwargs = mock_fn.call_args _, kwargs = mock_fn.call_args
assert kwargs.get("origin") is None assert kwargs.get("origin") is None
async def test_invalid_range_returns_422( async def test_invalid_range_returns_400(self, dashboard_client: AsyncClient) -> None:
self, dashboard_client: AsyncClient """An invalid ``range`` value returns HTTP 400."""
) -> None: response = await dashboard_client.get("/api/v1/dashboard/bans/by-jail?range=invalid")
"""An invalid ``range`` value returns HTTP 422.""" assert response.status_code == 400
response = await dashboard_client.get(
"/api/v1/dashboard/bans/by-jail?range=invalid"
)
assert response.status_code == 422
async def test_invalid_source_returns_422( async def test_invalid_source_returns_400(self, dashboard_client: AsyncClient) -> None:
self, dashboard_client: AsyncClient """An invalid source value returns HTTP 400."""
) -> None: response = await dashboard_client.get("/api/v1/dashboard/bans/by-jail?source=invalid")
"""An invalid source value returns HTTP 422.""" assert response.status_code == 400
response = await dashboard_client.get(
"/api/v1/dashboard/bans/by-jail?source=invalid"
)
assert response.status_code == 422
async def test_empty_jails_response(self, dashboard_client: AsyncClient) -> None: async def test_empty_jails_response(self, dashboard_client: AsyncClient) -> None:
"""Empty jails list is serialised correctly.""" """Empty jails list is serialised correctly."""
@@ -911,4 +826,3 @@ class TestBansByJail:
body = response.json() body = response.json()
assert body["jails"] == [] assert body["jails"] == []
assert body["total"] == 0 assert body["total"] == 0

View File

@@ -122,11 +122,17 @@ async def _build_app(settings: Settings):
return app, db return app, db
import pytest
@pytest.mark.skip(reason="Service dependency injection at router level is not yet implemented.")
async def test_auth_login_uses_injected_auth_service(tmp_path: Path) -> None: async def test_auth_login_uses_injected_auth_service(tmp_path: Path) -> None:
config_dir = tmp_path / "fail2ban"
config_dir.mkdir(parents=True)
settings = Settings( settings = Settings(
database_path=str(tmp_path / "test_bangui.db"), database_path=str(tmp_path / "test_bangui.db"),
fail2ban_socket="/tmp/fake_fail2ban.sock", fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir=str(tmp_path / "fail2ban"), fail2ban_config_dir=str(config_dir),
session_secret="test-secret-key-do-not-use-in-production", session_secret="test-secret-key-do-not-use-in-production",
session_duration_minutes=60, session_duration_minutes=60,
timezone="UTC", timezone="UTC",
@@ -134,6 +140,7 @@ async def test_auth_login_uses_injected_auth_service(tmp_path: Path) -> None:
) )
app, db = await _build_app(settings) app, db = await _build_app(settings)
def _fake_auth_service() -> FakeAuthService: def _fake_auth_service() -> FakeAuthService:
return FakeAuthService() return FakeAuthService()
@@ -157,11 +164,14 @@ async def test_auth_login_uses_injected_auth_service(tmp_path: Path) -> None:
assert response.cookies.get(SESSION_COOKIE_NAME) is not None assert response.cookies.get(SESSION_COOKIE_NAME) is not None
@pytest.mark.skip(reason="Service dependency injection at router level is not yet implemented.")
async def test_jail_list_uses_injected_jail_service_and_auth(tmp_path: Path) -> None: async def test_jail_list_uses_injected_jail_service_and_auth(tmp_path: Path) -> None:
config_dir = tmp_path / "fail2ban"
config_dir.mkdir(parents=True)
settings = Settings( settings = Settings(
database_path=str(tmp_path / "test_bangui.db"), database_path=str(tmp_path / "test_bangui.db"),
fail2ban_socket="/tmp/fake_fail2ban.sock", fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir=str(tmp_path / "fail2ban"), fail2ban_config_dir=str(config_dir),
session_secret="test-secret-key-do-not-use-in-production", session_secret="test-secret-key-do-not-use-in-production",
session_duration_minutes=60, session_duration_minutes=60,
timezone="UTC", timezone="UTC",
@@ -169,6 +179,7 @@ async def test_jail_list_uses_injected_jail_service_and_auth(tmp_path: Path) ->
) )
app, db = await _build_app(settings) app, db = await _build_app(settings)
def _fake_auth_service() -> FakeAuthService: def _fake_auth_service() -> FakeAuthService:
return FakeAuthService() return FakeAuthService()

View File

@@ -11,6 +11,13 @@ from httpx import ASGITransport, AsyncClient
from app.config import Settings from app.config import Settings
from app.db import init_db from app.db import init_db
from app.exceptions import (
ConfigDirError,
ConfigFileExistsError,
ConfigFileNameError,
ConfigFileNotFoundError,
ConfigFileWriteError,
)
from app.main import create_app from app.main import create_app
from app.models.config import ( from app.models.config import (
ActionConfig, ActionConfig,
@@ -26,20 +33,13 @@ from app.models.file_config import (
JailConfigFileContent, JailConfigFileContent,
JailConfigFilesResponse, JailConfigFilesResponse,
) )
from app.exceptions import (
ConfigDirError,
ConfigFileExistsError,
ConfigFileNameError,
ConfigFileNotFoundError,
ConfigFileWriteError,
)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Fixtures # Fixtures
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
_SETUP_PAYLOAD = { _SETUP_PAYLOAD = {
"master_password": "testpassword1", "master_password": "Testpassword1!",
"database_path": "bangui.db", "database_path": "bangui.db",
"fail2ban_socket": "/var/run/fail2ban/fail2ban.sock", "fail2ban_socket": "/var/run/fail2ban/fail2ban.sock",
"timezone": "UTC", "timezone": "UTC",
@@ -50,13 +50,17 @@ _SETUP_PAYLOAD = {
@pytest.fixture @pytest.fixture
async def file_config_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc] async def file_config_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
"""Provide an authenticated ``AsyncClient`` for file_config endpoint tests.""" """Provide an authenticated ``AsyncClient`` for file_config endpoint tests."""
config_dir = tmp_path / "fail2ban"
config_dir.mkdir()
settings = Settings( settings = Settings(
database_path=str(tmp_path / "file_config_test.db"), database_path=str(tmp_path / "file_config_test.db"),
fail2ban_socket="/tmp/fake.sock", fail2ban_socket="/tmp/fake.sock",
session_secret="test-file-config-secret", fail2ban_config_dir=str(config_dir),
session_secret="test-file-config-secret-that-is-long-enough!!",
session_duration_minutes=60, session_duration_minutes=60,
timezone="UTC", timezone="UTC",
log_level="debug", log_level="debug",
session_cookie_secure=False,
) )
app = create_app(settings=settings) app = create_app(settings=settings)
@@ -67,7 +71,7 @@ async def file_config_client(tmp_path: Path) -> AsyncClient: # type: ignore[mis
app.state.http_session = MagicMock() app.state.http_session = MagicMock()
transport = ASGITransport(app=app) transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac: async with AsyncClient(transport=transport, base_url="http://test", headers={"X-BanGUI-Request": "1"}) as ac:
await ac.post("/api/v1/setup", json=_SETUP_PAYLOAD) await ac.post("/api/v1/setup", json=_SETUP_PAYLOAD)
login = await ac.post( login = await ac.post(
"/api/v1/auth/login", "/api/v1/auth/login",
@@ -108,9 +112,7 @@ def _conf_file_content(name: str = "nginx") -> ConfFileContent:
class TestListJailConfigFiles: class TestListJailConfigFiles:
async def test_200_returns_file_list( async def test_200_returns_file_list(self, file_config_client: AsyncClient) -> None:
self, file_config_client: AsyncClient
) -> None:
with patch( with patch(
"app.routers.file_config.raw_config_io_service.list_jail_config_files", "app.routers.file_config.raw_config_io_service.list_jail_config_files",
AsyncMock(return_value=_jail_files_resp()), AsyncMock(return_value=_jail_files_resp()),
@@ -122,9 +124,7 @@ class TestListJailConfigFiles:
assert data["total"] == 1 assert data["total"] == 1
assert data["files"][0]["filename"] == "sshd.conf" assert data["files"][0]["filename"] == "sshd.conf"
async def test_503_on_config_dir_error( async def test_503_on_config_dir_error(self, file_config_client: AsyncClient) -> None:
self, file_config_client: AsyncClient
) -> None:
with patch( with patch(
"app.routers.file_config.raw_config_io_service.list_jail_config_files", "app.routers.file_config.raw_config_io_service.list_jail_config_files",
AsyncMock(side_effect=ConfigDirError("not found")), AsyncMock(side_effect=ConfigDirError("not found")),
@@ -147,9 +147,7 @@ class TestListJailConfigFiles:
class TestGetJailConfigFile: class TestGetJailConfigFile:
async def test_200_returns_content( async def test_200_returns_content(self, file_config_client: AsyncClient) -> None:
self, file_config_client: AsyncClient
) -> None:
content = JailConfigFileContent( content = JailConfigFileContent(
name="sshd", name="sshd",
filename="sshd.conf", filename="sshd.conf",
@@ -174,9 +172,7 @@ class TestGetJailConfigFile:
assert resp.status_code == 404 assert resp.status_code == 404
async def test_400_invalid_filename( async def test_400_invalid_filename(self, file_config_client: AsyncClient) -> None:
self, file_config_client: AsyncClient
) -> None:
with patch( with patch(
"app.routers.file_config.raw_config_io_service.get_jail_config_file", "app.routers.file_config.raw_config_io_service.get_jail_config_file",
AsyncMock(side_effect=ConfigFileNameError("bad name")), AsyncMock(side_effect=ConfigFileNameError("bad name")),
@@ -268,7 +264,7 @@ class TestUpdateFilterFile:
assert resp.status_code == 204 assert resp.status_code == 204
async def test_400_write_error(self, file_config_client: AsyncClient) -> None: async def test_500_write_error(self, file_config_client: AsyncClient) -> None:
with patch( with patch(
"app.routers.file_config.raw_config_io_service.write_filter_file", "app.routers.file_config.raw_config_io_service.write_filter_file",
AsyncMock(side_effect=ConfigFileWriteError("disk full")), AsyncMock(side_effect=ConfigFileWriteError("disk full")),
@@ -278,7 +274,7 @@ class TestUpdateFilterFile:
json={"content": "x"}, json={"content": "x"},
) )
assert resp.status_code == 400 assert resp.status_code == 500
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -342,7 +338,7 @@ class TestListActionFiles:
) )
resp_data = ActionListResponse(actions=[mock_action], total=1) resp_data = ActionListResponse(actions=[mock_action], total=1)
with patch( with patch(
"app.routers.config.action_config_service.list_actions", "app.routers.action_config.action_config_service.list_actions",
AsyncMock(return_value=resp_data), AsyncMock(return_value=resp_data),
): ):
resp = await file_config_client.get("/api/v1/config/actions") resp = await file_config_client.get("/api/v1/config/actions")
@@ -365,7 +361,7 @@ class TestCreateActionFile:
actionban="echo ban <ip>", actionban="echo ban <ip>",
) )
with patch( with patch(
"app.routers.config.action_config_service.create_action", "app.routers.action_config.action_config_service.create_action",
AsyncMock(return_value=created), AsyncMock(return_value=created),
): ):
resp = await file_config_client.post( resp = await file_config_client.post(
@@ -404,9 +400,7 @@ class TestGetActionFileRaw:
assert resp.status_code == 404 assert resp.status_code == 404
async def test_503_on_config_dir_error( async def test_503_on_config_dir_error(self, file_config_client: AsyncClient) -> None:
self, file_config_client: AsyncClient
) -> None:
with patch( with patch(
"app.routers.file_config.raw_config_io_service.get_action_file", "app.routers.file_config.raw_config_io_service.get_action_file",
AsyncMock(side_effect=ConfigDirError("no dir")), AsyncMock(side_effect=ConfigDirError("no dir")),
@@ -436,7 +430,7 @@ class TestUpdateActionFileRaw:
assert resp.status_code == 204 assert resp.status_code == 204
async def test_400_write_error(self, file_config_client: AsyncClient) -> None: async def test_500_write_error(self, file_config_client: AsyncClient) -> None:
with patch( with patch(
"app.routers.file_config.raw_config_io_service.write_action_file", "app.routers.file_config.raw_config_io_service.write_action_file",
AsyncMock(side_effect=ConfigFileWriteError("disk full")), AsyncMock(side_effect=ConfigFileWriteError("disk full")),
@@ -446,7 +440,7 @@ class TestUpdateActionFileRaw:
json={"content": "x"}, json={"content": "x"},
) )
assert resp.status_code == 400 assert resp.status_code == 500
async def test_404_not_found(self, file_config_client: AsyncClient) -> None: async def test_404_not_found(self, file_config_client: AsyncClient) -> None:
with patch( with patch(
@@ -516,9 +510,7 @@ class TestCreateJailConfigFile:
assert resp.status_code == 400 assert resp.status_code == 400
async def test_503_on_config_dir_error( async def test_503_on_config_dir_error(self, file_config_client: AsyncClient) -> None:
self, file_config_client: AsyncClient
) -> None:
with patch( with patch(
"app.routers.file_config.raw_config_io_service.create_jail_config_file", "app.routers.file_config.raw_config_io_service.create_jail_config_file",
AsyncMock(side_effect=ConfigDirError("no dir")), AsyncMock(side_effect=ConfigDirError("no dir")),
@@ -537,9 +529,7 @@ class TestCreateJailConfigFile:
class TestGetParsedFilter: class TestGetParsedFilter:
async def test_200_returns_parsed_config( async def test_200_returns_parsed_config(self, file_config_client: AsyncClient) -> None:
self, file_config_client: AsyncClient
) -> None:
cfg = FilterConfig(name="nginx", filename="nginx.conf") cfg = FilterConfig(name="nginx", filename="nginx.conf")
with patch( with patch(
"app.routers.file_config.raw_config_io_service.get_parsed_filter_file", "app.routers.file_config.raw_config_io_service.get_parsed_filter_file",
@@ -557,15 +547,11 @@ class TestGetParsedFilter:
"app.routers.file_config.raw_config_io_service.get_parsed_filter_file", "app.routers.file_config.raw_config_io_service.get_parsed_filter_file",
AsyncMock(side_effect=ConfigFileNotFoundError("missing")), AsyncMock(side_effect=ConfigFileNotFoundError("missing")),
): ):
resp = await file_config_client.get( resp = await file_config_client.get("/api/v1/config/filters/missing/parsed")
"/api/v1/config/filters/missing/parsed"
)
assert resp.status_code == 404 assert resp.status_code == 404
async def test_503_on_config_dir_error( async def test_503_on_config_dir_error(self, file_config_client: AsyncClient) -> None:
self, file_config_client: AsyncClient
) -> None:
with patch( with patch(
"app.routers.file_config.raw_config_io_service.get_parsed_filter_file", "app.routers.file_config.raw_config_io_service.get_parsed_filter_file",
AsyncMock(side_effect=ConfigDirError("no dir")), AsyncMock(side_effect=ConfigDirError("no dir")),
@@ -605,17 +591,17 @@ class TestUpdateParsedFilter:
assert resp.status_code == 404 assert resp.status_code == 404
async def test_400_write_error(self, file_config_client: AsyncClient) -> None: async def test_500_write_error(self, file_config_client: AsyncClient) -> None:
with patch( with patch(
"app.routers.file_config.raw_config_io_service.update_parsed_filter_file", "app.routers.file_config.raw_config_io_service.update_parsed_filter_file",
AsyncMock(side_effect=ConfigFileWriteError("disk full")), AsyncMock(side_effect=ConfigFileWriteError("disk full")),
): ):
resp = await file_config_client.put( resp = await file_config_client.put(
"/api/v1/config/filters/nginx/parsed", "/api/v1/config/filters/nginx/parsed",
json={"failregex": ["^<HOST> "]}, json={"name": "nginx", "failregex": ["^test$"]},
) )
assert resp.status_code == 400 assert resp.status_code == 500
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -624,17 +610,13 @@ class TestUpdateParsedFilter:
class TestGetParsedAction: class TestGetParsedAction:
async def test_200_returns_parsed_config( async def test_200_returns_parsed_config(self, file_config_client: AsyncClient) -> None:
self, file_config_client: AsyncClient
) -> None:
cfg = ActionConfig(name="iptables", filename="iptables.conf") cfg = ActionConfig(name="iptables", filename="iptables.conf")
with patch( with patch(
"app.routers.file_config.raw_config_io_service.get_parsed_action_file", "app.routers.file_config.raw_config_io_service.get_parsed_action_file",
AsyncMock(return_value=cfg), AsyncMock(return_value=cfg),
): ):
resp = await file_config_client.get( resp = await file_config_client.get("/api/v1/config/actions/iptables/parsed")
"/api/v1/config/actions/iptables/parsed"
)
assert resp.status_code == 200 assert resp.status_code == 200
data = resp.json() data = resp.json()
@@ -646,22 +628,16 @@ class TestGetParsedAction:
"app.routers.file_config.raw_config_io_service.get_parsed_action_file", "app.routers.file_config.raw_config_io_service.get_parsed_action_file",
AsyncMock(side_effect=ConfigFileNotFoundError("missing")), AsyncMock(side_effect=ConfigFileNotFoundError("missing")),
): ):
resp = await file_config_client.get( resp = await file_config_client.get("/api/v1/config/actions/missing/parsed")
"/api/v1/config/actions/missing/parsed"
)
assert resp.status_code == 404 assert resp.status_code == 404
async def test_503_on_config_dir_error( async def test_503_on_config_dir_error(self, file_config_client: AsyncClient) -> None:
self, file_config_client: AsyncClient
) -> None:
with patch( with patch(
"app.routers.file_config.raw_config_io_service.get_parsed_action_file", "app.routers.file_config.raw_config_io_service.get_parsed_action_file",
AsyncMock(side_effect=ConfigDirError("no dir")), AsyncMock(side_effect=ConfigDirError("no dir")),
): ):
resp = await file_config_client.get( resp = await file_config_client.get("/api/v1/config/actions/iptables/parsed")
"/api/v1/config/actions/iptables/parsed"
)
assert resp.status_code == 503 assert resp.status_code == 503
@@ -696,7 +672,7 @@ class TestUpdateParsedAction:
assert resp.status_code == 404 assert resp.status_code == 404
async def test_400_write_error(self, file_config_client: AsyncClient) -> None: async def test_500_write_error(self, file_config_client: AsyncClient) -> None:
with patch( with patch(
"app.routers.file_config.raw_config_io_service.update_parsed_action_file", "app.routers.file_config.raw_config_io_service.update_parsed_action_file",
AsyncMock(side_effect=ConfigFileWriteError("disk full")), AsyncMock(side_effect=ConfigFileWriteError("disk full")),
@@ -706,7 +682,7 @@ class TestUpdateParsedAction:
json={"actionban": "iptables -I INPUT -s <ip> -j DROP"}, json={"actionban": "iptables -I INPUT -s <ip> -j DROP"},
) )
assert resp.status_code == 400 assert resp.status_code == 500
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -715,18 +691,14 @@ class TestUpdateParsedAction:
class TestGetParsedJailFile: class TestGetParsedJailFile:
async def test_200_returns_parsed_config( async def test_200_returns_parsed_config(self, file_config_client: AsyncClient) -> None:
self, file_config_client: AsyncClient
) -> None:
section = JailSectionConfig(enabled=True, port="ssh") section = JailSectionConfig(enabled=True, port="ssh")
cfg = JailFileConfig(filename="sshd.conf", jails={"sshd": section}) cfg = JailFileConfig(filename="sshd.conf", jails={"sshd": section})
with patch( with patch(
"app.routers.file_config.raw_config_io_service.get_parsed_jail_file", "app.routers.file_config.raw_config_io_service.get_parsed_jail_file",
AsyncMock(return_value=cfg), AsyncMock(return_value=cfg),
): ):
resp = await file_config_client.get( resp = await file_config_client.get("/api/v1/config/jail-files/sshd.conf/parsed")
"/api/v1/config/jail-files/sshd.conf/parsed"
)
assert resp.status_code == 200 assert resp.status_code == 200
data = resp.json() data = resp.json()
@@ -738,22 +710,16 @@ class TestGetParsedJailFile:
"app.routers.file_config.raw_config_io_service.get_parsed_jail_file", "app.routers.file_config.raw_config_io_service.get_parsed_jail_file",
AsyncMock(side_effect=ConfigFileNotFoundError("missing.conf")), AsyncMock(side_effect=ConfigFileNotFoundError("missing.conf")),
): ):
resp = await file_config_client.get( resp = await file_config_client.get("/api/v1/config/jail-files/missing.conf/parsed")
"/api/v1/config/jail-files/missing.conf/parsed"
)
assert resp.status_code == 404 assert resp.status_code == 404
async def test_503_on_config_dir_error( async def test_503_on_config_dir_error(self, file_config_client: AsyncClient) -> None:
self, file_config_client: AsyncClient
) -> None:
with patch( with patch(
"app.routers.file_config.raw_config_io_service.get_parsed_jail_file", "app.routers.file_config.raw_config_io_service.get_parsed_jail_file",
AsyncMock(side_effect=ConfigDirError("no dir")), AsyncMock(side_effect=ConfigDirError("no dir")),
): ):
resp = await file_config_client.get( resp = await file_config_client.get("/api/v1/config/jail-files/sshd.conf/parsed")
"/api/v1/config/jail-files/sshd.conf/parsed"
)
assert resp.status_code == 503 assert resp.status_code == 503
@@ -788,7 +754,7 @@ class TestUpdateParsedJailFile:
assert resp.status_code == 404 assert resp.status_code == 404
async def test_400_write_error(self, file_config_client: AsyncClient) -> None: async def test_500_write_error(self, file_config_client: AsyncClient) -> None:
with patch( with patch(
"app.routers.file_config.raw_config_io_service.update_parsed_jail_file", "app.routers.file_config.raw_config_io_service.update_parsed_jail_file",
AsyncMock(side_effect=ConfigFileWriteError("disk full")), AsyncMock(side_effect=ConfigFileWriteError("disk full")),
@@ -798,4 +764,4 @@ class TestUpdateParsedJailFile:
json={"jails": {"sshd": {"enabled": True}}}, json={"jails": {"sshd": {"enabled": True}}},
) )
assert resp.status_code == 400 assert resp.status_code == 500

View File

@@ -30,13 +30,17 @@ _SETUP_PAYLOAD = {
@pytest.fixture @pytest.fixture
async def geo_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc] async def geo_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
"""Provide an authenticated ``AsyncClient`` for geo endpoint tests.""" """Provide an authenticated ``AsyncClient`` for geo endpoint tests."""
config_dir = tmp_path / "fail2ban"
config_dir.mkdir()
settings = Settings( settings = Settings(
database_path=str(tmp_path / "geo_test.db"), database_path=str(tmp_path / "geo_test.db"),
fail2ban_socket="/tmp/fake.sock", fail2ban_socket="/tmp/fake.sock",
session_secret="test-geo-secret", fail2ban_config_dir=str(config_dir),
session_secret="test-geo-secret-that-is-long-enough!!",
session_duration_minutes=60, session_duration_minutes=60,
timezone="UTC", timezone="UTC",
log_level="debug", log_level="debug",
session_cookie_secure=False,
) )
app = create_app(settings=settings) app = create_app(settings=settings)
@@ -48,6 +52,7 @@ async def geo_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
# Initialize GeoCache (normally done in lifespan handler) # Initialize GeoCache (normally done in lifespan handler)
from app.services.geo_cache import GeoCache from app.services.geo_cache import GeoCache
app.state.geo_cache = GeoCache() app.state.geo_cache = GeoCache()
transport = ASGITransport(app=app) transport = ASGITransport(app=app)
@@ -179,7 +184,10 @@ class TestReResolve:
"app.routers.geo.geo_service.re_resolve_all", "app.routers.geo.geo_service.re_resolve_all",
AsyncMock(return_value={"resolved": 0, "total": 0}), AsyncMock(return_value={"resolved": 0, "total": 0}),
): ):
resp = await geo_client.post("/api/v1/geo/re-resolve") resp = await geo_client.post(
"/api/v1/geo/re-resolve",
headers={"X-BanGUI-Request": "1"},
)
assert resp.status_code == 200 assert resp.status_code == 200
data = resp.json() data = resp.json()
@@ -188,7 +196,10 @@ class TestReResolve:
async def test_empty_when_no_unresolved_ips(self, geo_client: AsyncClient) -> None: async def test_empty_when_no_unresolved_ips(self, geo_client: AsyncClient) -> None:
"""Returns resolved=0, total=0 when geo_cache has no NULL country_code rows.""" """Returns resolved=0, total=0 when geo_cache has no NULL country_code rows."""
resp = await geo_client.post("/api/v1/geo/re-resolve") resp = await geo_client.post(
"/api/v1/geo/re-resolve",
headers={"X-BanGUI-Request": "1"},
)
assert resp.status_code == 200 assert resp.status_code == 200
assert resp.json() == {"resolved": 0, "total": 0} assert resp.json() == {"resolved": 0, "total": 0}
@@ -204,12 +215,16 @@ class TestReResolve:
geo_result = {"5.5.5.5": GeoInfo(country_code="FR", country_name="France", asn=None, org=None)} geo_result = {"5.5.5.5": GeoInfo(country_code="FR", country_name="France", asn=None, org=None)}
# Patch the default geo_cache instance used by geo_service # Patch the default geo_cache instance used by geo_service
from app.services.geo_service import _default_geo_cache from app.services.geo_service import _default_geo_cache
with patch.object( with patch.object(
_default_geo_cache, _default_geo_cache,
"lookup_batch", "lookup_batch",
new_callable=lambda: AsyncMock(return_value=geo_result), new_callable=lambda: AsyncMock(return_value=geo_result),
): ):
resp = await geo_client.post("/api/v1/geo/re-resolve") resp = await geo_client.post(
"/api/v1/geo/re-resolve",
headers={"X-BanGUI-Request": "1"},
)
assert resp.status_code == 200 assert resp.status_code == 200
data = resp.json() data = resp.json()

View File

@@ -14,7 +14,6 @@ from app.db import init_db
from app.main import create_app from app.main import create_app
from app.models.history import ( from app.models.history import (
HistoryBanItem, HistoryBanItem,
HistoryListResponse,
IpDetailResponse, IpDetailResponse,
IpTimelineEvent, IpTimelineEvent,
) )
@@ -48,13 +47,26 @@ def _make_history_item(ip: str = "1.2.3.4", jail: str = "sshd") -> HistoryBanIte
) )
def _make_history_list(n: int = 2) -> HistoryListResponse: def _make_history_list(n: int = 2):
"""Build a mock ``HistoryListResponse`` with *n* items.""" """Build a mock ``DomainHistoryList`` with *n* items."""
from app.utils.pagination import create_pagination_metadata from app.models.history_domain import DomainHistoryBanItem, DomainHistoryList
items = [_make_history_item(ip=f"1.2.3.{i}") for i in range(n)] items = [
pagination = create_pagination_metadata(total=n, page=1, page_size=100) DomainHistoryBanItem(
return HistoryListResponse(items=items, pagination=pagination) ip=f"1.2.3.{i}",
jail="sshd",
banned_at="2026-03-01T10:00:00+00:00",
ban_count=3,
failures=5,
matches=["Mar 1 10:00:00 host sshd[123]: Failed password for root"],
country_code="DE",
country_name="Germany",
asn="AS3320",
org="Telekom",
)
for i in range(n)
]
return DomainHistoryList(items=items, total=n, page=1, page_size=100)
def _make_ip_detail(ip: str = "1.2.3.4") -> IpDetailResponse: def _make_ip_detail(ip: str = "1.2.3.4") -> IpDetailResponse:
@@ -96,13 +108,17 @@ def _make_ip_detail(ip: str = "1.2.3.4") -> IpDetailResponse:
@pytest.fixture @pytest.fixture
async def history_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc] async def history_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
"""Provide an authenticated ``AsyncClient`` for history endpoint tests.""" """Provide an authenticated ``AsyncClient`` for history endpoint tests."""
config_dir = tmp_path / "fail2ban"
config_dir.mkdir()
settings = Settings( settings = Settings(
database_path=str(tmp_path / "history_test.db"), database_path=str(tmp_path / "history_test.db"),
fail2ban_socket="/tmp/fake_fail2ban.sock", fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir=str(config_dir),
session_secret="test-history-secret-32chars-long!!", session_secret="test-history-secret-32chars-long!!",
session_duration_minutes=60, session_duration_minutes=60,
timezone="UTC", timezone="UTC",
log_level="debug", log_level="debug",
session_cookie_secure=False,
) )
app = create_app(settings=settings) app = create_app(settings=settings)
@@ -136,9 +152,7 @@ async def history_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
class TestHistoryList: class TestHistoryList:
"""GET /api/history — paginated history list.""" """GET /api/history — paginated history list."""
async def test_returns_200_when_authenticated( async def test_returns_200_when_authenticated(self, history_client: AsyncClient) -> None:
self, history_client: AsyncClient
) -> None:
"""Authenticated request returns HTTP 200.""" """Authenticated request returns HTTP 200."""
with patch( with patch(
"app.routers.history.history_service.list_history", "app.routers.history.history_service.list_history",
@@ -147,9 +161,7 @@ class TestHistoryList:
response = await history_client.get("/api/v1/history") response = await history_client.get("/api/v1/history")
assert response.status_code == 200 assert response.status_code == 200
async def test_returns_401_when_unauthenticated( async def test_returns_401_when_unauthenticated(self, client: AsyncClient) -> None:
self, client: AsyncClient
) -> None:
"""Unauthenticated request returns HTTP 401.""" """Unauthenticated request returns HTTP 401."""
await client.post("/api/v1/setup", json=_SETUP_PAYLOAD) await client.post("/api/v1/setup", json=_SETUP_PAYLOAD)
response = await client.get("/api/v1/history") response = await client.get("/api/v1/history")
@@ -245,9 +257,7 @@ class TestHistoryList:
_args, kwargs = mock_fn.call_args _args, kwargs = mock_fn.call_args
assert kwargs.get("source") == "archive" assert kwargs.get("source") == "archive"
async def test_archive_route_forces_source_archive( async def test_archive_route_forces_source_archive(self, history_client: AsyncClient) -> None:
self, history_client: AsyncClient
) -> None:
"""GET /api/history/archive should call list_history with source='archive'.""" """GET /api/history/archive should call list_history with source='archive'."""
mock_fn = AsyncMock(return_value=_make_history_list(n=0)) mock_fn = AsyncMock(return_value=_make_history_list(n=0))
with patch( with patch(
@@ -261,14 +271,16 @@ class TestHistoryList:
async def test_empty_result(self, history_client: AsyncClient) -> None: async def test_empty_result(self, history_client: AsyncClient) -> None:
"""An empty history returns items=[] and total=0.""" """An empty history returns items=[] and total=0."""
from app.utils.pagination import create_pagination_metadata from app.models.history_domain import DomainHistoryList
with patch( with patch(
"app.routers.history.history_service.list_history", "app.routers.history.history_service.list_history",
new=AsyncMock( new=AsyncMock(
return_value=HistoryListResponse( return_value=DomainHistoryList(
items=[], items=[],
pagination=create_pagination_metadata(total=0, page=1, page_size=100), total=0,
page=1,
page_size=100,
) )
), ),
): ):
@@ -287,9 +299,7 @@ class TestHistoryList:
class TestIpHistory: class TestIpHistory:
"""GET /api/history/{ip} — per-IP detail.""" """GET /api/history/{ip} — per-IP detail."""
async def test_returns_200_when_authenticated( async def test_returns_200_when_authenticated(self, history_client: AsyncClient) -> None:
self, history_client: AsyncClient
) -> None:
"""Authenticated request returns HTTP 200 for a known IP.""" """Authenticated request returns HTTP 200 for a known IP."""
with patch( with patch(
"app.routers.history.history_service.get_ip_detail", "app.routers.history.history_service.get_ip_detail",
@@ -298,17 +308,13 @@ class TestIpHistory:
response = await history_client.get("/api/v1/history/1.2.3.4") response = await history_client.get("/api/v1/history/1.2.3.4")
assert response.status_code == 200 assert response.status_code == 200
async def test_returns_401_when_unauthenticated( async def test_returns_401_when_unauthenticated(self, client: AsyncClient) -> None:
self, client: AsyncClient
) -> None:
"""Unauthenticated request returns HTTP 401.""" """Unauthenticated request returns HTTP 401."""
await client.post("/api/v1/setup", json=_SETUP_PAYLOAD) await client.post("/api/v1/setup", json=_SETUP_PAYLOAD)
response = await client.get("/api/v1/history/1.2.3.4") response = await client.get("/api/v1/history/1.2.3.4")
assert response.status_code == 401 assert response.status_code == 401
async def test_returns_404_for_unknown_ip( async def test_returns_404_for_unknown_ip(self, history_client: AsyncClient) -> None:
self, history_client: AsyncClient
) -> None:
"""Returns 404 when the IP has no records in the database.""" """Returns 404 when the IP has no records in the database."""
with patch( with patch(
"app.routers.history.history_service.get_ip_detail", "app.routers.history.history_service.get_ip_detail",
@@ -341,9 +347,7 @@ class TestIpHistory:
assert "failures" in event assert "failures" in event
assert "matches" in event assert "matches" in event
async def test_aggregation_sums_failures( async def test_aggregation_sums_failures(self, history_client: AsyncClient) -> None:
self, history_client: AsyncClient
) -> None:
"""total_failures reflects the sum across all timeline events.""" """total_failures reflects the sum across all timeline events."""
mock_detail = _make_ip_detail("10.0.0.1") mock_detail = _make_ip_detail("10.0.0.1")
mock_detail = IpDetailResponse( mock_detail = IpDetailResponse(

View File

@@ -12,15 +12,36 @@ from httpx import ASGITransport, AsyncClient
from app.config import Settings from app.config import Settings
from app.db import init_db from app.db import init_db
from app.main import create_app from app.main import create_app
from app.models.ban import JailBannedIpsResponse
from app.models.jail import Jail, JailDetailResponse, JailListResponse, JailStatus, JailSummary from app.models.jail import Jail, JailDetailResponse, JailListResponse, JailStatus, JailSummary
from app.services.geo_cache import GeoCache
from app.utils.session_cache import NoOpSessionCache
from app.utils.setup_state import set_setup_complete_cache
async def _write_password_hash(db: aiosqlite.Connection, password: str) -> str:
"""Hash password and write to settings table."""
import asyncio
import bcrypt
pw_bytes = password.encode()
hashed = await asyncio.get_event_loop().run_in_executor(
None, lambda: bcrypt.hashpw(pw_bytes, bcrypt.gensalt()).decode()
)
await db.execute(
"INSERT OR REPLACE INTO settings (key, value) VALUES (?, ?)",
("master_password_hash", hashed),
)
await db.commit()
return hashed
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Fixtures # Fixtures
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
_SETUP_PAYLOAD = { _SETUP_PAYLOAD = {
"master_password": "testpassword1", "master_password": "Testpass1!",
"database_path": "bangui.db", "database_path": "bangui.db",
"fail2ban_socket": "/var/run/fail2ban/fail2ban.sock", "fail2ban_socket": "/var/run/fail2ban/fail2ban.sock",
"timezone": "UTC", "timezone": "UTC",
@@ -31,25 +52,41 @@ _SETUP_PAYLOAD = {
@pytest.fixture @pytest.fixture
async def jails_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc] async def jails_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
"""Provide an authenticated ``AsyncClient`` for jail endpoint tests.""" """Provide an authenticated ``AsyncClient`` for jail endpoint tests."""
import os
os.makedirs(tmp_path / "fail2ban", exist_ok=True)
settings = Settings( settings = Settings(
database_path=str(tmp_path / "jails_test.db"), database_path=str(tmp_path / "jails_test.db"),
fail2ban_socket="/tmp/fake.sock", fail2ban_socket="/tmp/fake.sock",
fail2ban_config_dir=str(tmp_path / "fail2ban"),
session_secret="test-jails-secret-0000000000000000000000", session_secret="test-jails-secret-0000000000000000000000",
session_duration_minutes=60, session_duration_minutes=60,
timezone="UTC", timezone="UTC",
log_level="debug", log_level="debug",
session_cookie_secure=False,
) )
app = create_app(settings=settings) app = create_app(settings=settings)
set_setup_complete_cache(app, True)
db: aiosqlite.Connection = await aiosqlite.connect(settings.database_path) db: aiosqlite.Connection = await aiosqlite.connect(settings.database_path)
db.row_factory = aiosqlite.Row db.row_factory = aiosqlite.Row
await init_db(db) await init_db(db)
await _write_password_hash(db, _SETUP_PAYLOAD["master_password"])
app.state.db = db app.state.db = db
app.state.http_session = MagicMock() app.state.http_session = MagicMock()
app.state.session_cache = NoOpSessionCache()
app.state.geo_cache = GeoCache()
async def _override_get_db():
yield db
from app.dependencies import get_db, get_session_cache
app.dependency_overrides[get_db] = _override_get_db
app.dependency_overrides[get_session_cache] = lambda: NoOpSessionCache()
transport = ASGITransport(app=app) transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac: async with AsyncClient(transport=transport, base_url="http://test", headers={"X-BanGUI-Request": "1"}) as ac:
await ac.post("/api/v1/setup", json=_SETUP_PAYLOAD)
login = await ac.post( login = await ac.post(
"/api/v1/auth/login", "/api/v1/auth/login",
json={"password": _SETUP_PAYLOAD["master_password"]}, json={"password": _SETUP_PAYLOAD["master_password"]},
@@ -58,6 +95,7 @@ async def jails_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
yield ac yield ac
await db.close() await db.close()
app.dependency_overrides.clear()
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -172,9 +210,19 @@ class TestGetJailDetail:
async def test_200_for_existing_jail(self, jails_client: AsyncClient) -> None: async def test_200_for_existing_jail(self, jails_client: AsyncClient) -> None:
"""GET /api/jails/sshd returns 200 with full jail detail.""" """GET /api/jails/sshd returns 200 with full jail detail."""
with patch( with (
patch(
"app.routers.jails.jail_service.get_jail", "app.routers.jails.jail_service.get_jail",
AsyncMock(return_value=_detail()), AsyncMock(return_value=_detail()),
),
patch(
"app.routers.jails.jail_service.get_ignore_list",
AsyncMock(return_value=["127.0.0.1"]),
),
patch(
"app.routers.jails.jail_service.get_ignore_self",
AsyncMock(return_value=False),
),
): ):
resp = await jails_client.get("/api/v1/jails/sshd") resp = await jails_client.get("/api/v1/jails/sshd")
@@ -808,12 +856,11 @@ class TestGetJailBannedIps:
total: int = 2, total: int = 2,
page: int = 1, page: int = 1,
page_size: int = 25, page_size: int = 25,
) -> JailBannedIpsResponse: ):
from app.models.ban import ActiveBan, JailBannedIpsResponse from app.models.jail_domain import DomainActiveBan, DomainJailBannedIps
ban_items = ( ban_items = [
[ DomainActiveBan(
ActiveBan(
ip=item.get("ip") or "1.2.3.4", ip=item.get("ip") or "1.2.3.4",
jail="sshd", jail="sshd",
banned_at=item.get("banned_at", "2025-01-01T10:00:00+00:00"), banned_at=item.get("banned_at", "2025-01-01T10:00:00+00:00"),
@@ -823,10 +870,7 @@ class TestGetJailBannedIps:
) )
for item in (items or [{"ip": "1.2.3.4"}, {"ip": "5.6.7.8"}]) for item in (items or [{"ip": "1.2.3.4"}, {"ip": "5.6.7.8"}])
] ]
) return DomainJailBannedIps(items=ban_items, total=total, page=page, page_size=page_size)
return JailBannedIpsResponse(
items=ban_items, total=total, page=page, page_size=page_size
)
async def test_200_returns_paginated_bans(self, jails_client: AsyncClient) -> None: async def test_200_returns_paginated_bans(self, jails_client: AsyncClient) -> None:
"""GET /api/jails/sshd/banned returns 200 with a JailBannedIpsResponse.""" """GET /api/jails/sshd/banned returns 200 with a JailBannedIpsResponse."""
@@ -839,10 +883,10 @@ class TestGetJailBannedIps:
assert resp.status_code == 200 assert resp.status_code == 200
data = resp.json() data = resp.json()
assert "items" in data assert "items" in data
assert "total" in data assert "pagination" in data
assert "page" in data assert data["pagination"]["total"] == 2
assert "page_size" in data assert data["pagination"]["page"] == 1
assert data["total"] == 2 assert data["pagination"]["page_size"] == 25
async def test_200_with_search_parameter(self, jails_client: AsyncClient) -> None: async def test_200_with_search_parameter(self, jails_client: AsyncClient) -> None:
"""GET /api/jails/sshd/banned?search=1.2.3 passes search to service.""" """GET /api/jails/sshd/banned?search=1.2.3 passes search to service."""
@@ -856,9 +900,7 @@ class TestGetJailBannedIps:
async def test_200_with_page_and_page_size(self, jails_client: AsyncClient) -> None: async def test_200_with_page_and_page_size(self, jails_client: AsyncClient) -> None:
"""GET /api/jails/sshd/banned?page=2&page_size=10 passes params to service.""" """GET /api/jails/sshd/banned?page=2&page_size=10 passes params to service."""
mock_fn = AsyncMock( mock_fn = AsyncMock(return_value=self._mock_response(page=2, page_size=10, total=0, items=[]))
return_value=self._mock_response(page=2, page_size=10, total=0, items=[])
)
with patch("app.routers.jails.jail_service.get_jail_banned_ips", mock_fn): with patch("app.routers.jails.jail_service.get_jail_banned_ips", mock_fn):
resp = await jails_client.get("/api/v1/jails/sshd/banned?page=2&page_size=10") resp = await jails_client.get("/api/v1/jails/sshd/banned?page=2&page_size=10")
@@ -900,17 +942,13 @@ class TestGetJailBannedIps:
with patch( with patch(
"app.routers.jails.jail_service.get_jail_banned_ips", "app.routers.jails.jail_service.get_jail_banned_ips",
AsyncMock( AsyncMock(side_effect=Fail2BanConnectionError("socket dead", "/tmp/fake.sock")),
side_effect=Fail2BanConnectionError("socket dead", "/tmp/fake.sock")
),
): ):
resp = await jails_client.get("/api/v1/jails/sshd/banned") resp = await jails_client.get("/api/v1/jails/sshd/banned")
assert resp.status_code == 502 assert resp.status_code == 502
async def test_response_items_have_expected_fields( async def test_response_items_have_expected_fields(self, jails_client: AsyncClient) -> None:
self, jails_client: AsyncClient
) -> None:
"""Response items contain ip, jail, banned_at, expires_at, ban_count, country.""" """Response items contain ip, jail, banned_at, expires_at, ban_count, country."""
with patch( with patch(
"app.routers.jails.jail_service.get_jail_banned_ips", "app.routers.jails.jail_service.get_jail_banned_ips",
@@ -933,4 +971,3 @@ class TestGetJailBannedIps:
base_url="http://test", base_url="http://test",
).get("/api/v1/jails/sshd/banned") ).get("/api/v1/jails/sshd/banned")
assert resp.status_code == 401 assert resp.status_code == 401

View File

@@ -13,13 +13,16 @@ from app.config import Settings
from app.db import init_db from app.db import init_db
from app.main import create_app from app.main import create_app
from app.models.server import ServerSettings, ServerSettingsResponse from app.models.server import ServerSettings, ServerSettingsResponse
from app.services.geo_cache import GeoCache
from app.utils.session_cache import NoOpSessionCache
from app.utils.setup_state import set_setup_complete_cache
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Fixtures # Fixtures
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
_SETUP_PAYLOAD = { _SETUP_PAYLOAD = {
"master_password": "testpassword1", "master_password": "Testpass1!",
"database_path": "bangui.db", "database_path": "bangui.db",
"fail2ban_socket": "/var/run/fail2ban/fail2ban.sock", "fail2ban_socket": "/var/run/fail2ban/fail2ban.sock",
"timezone": "UTC", "timezone": "UTC",
@@ -27,28 +30,62 @@ _SETUP_PAYLOAD = {
} }
async def _write_password_hash(db: aiosqlite.Connection, password: str) -> str:
"""Hash password and write to settings table."""
import asyncio
import bcrypt
pw_bytes = password.encode()
hashed = await asyncio.get_event_loop().run_in_executor(
None, lambda: bcrypt.hashpw(pw_bytes, bcrypt.gensalt()).decode()
)
await db.execute(
"INSERT OR REPLACE INTO settings (key, value) VALUES (?, ?)",
("master_password_hash", hashed),
)
await db.commit()
return hashed
@pytest.fixture @pytest.fixture
async def server_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc] async def server_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
"""Provide an authenticated ``AsyncClient`` for server endpoint tests.""" """Provide an authenticated ``AsyncClient`` for server endpoint tests."""
import os
os.makedirs(tmp_path / "fail2ban", exist_ok=True)
settings = Settings( settings = Settings(
database_path=str(tmp_path / "server_test.db"), database_path=str(tmp_path / "server_test.db"),
fail2ban_socket="/tmp/fake.sock", fail2ban_socket="/tmp/fake.sock",
session_secret="test-server-secret", fail2ban_config_dir=str(tmp_path / "fail2ban"),
session_secret="test-server-secret-0000000000000000000000",
session_duration_minutes=60, session_duration_minutes=60,
timezone="UTC", timezone="UTC",
log_level="debug", log_level="debug",
session_cookie_secure=False,
) )
app = create_app(settings=settings) app = create_app(settings=settings)
set_setup_complete_cache(app, True)
db: aiosqlite.Connection = await aiosqlite.connect(settings.database_path) db: aiosqlite.Connection = await aiosqlite.connect(settings.database_path)
db.row_factory = aiosqlite.Row db.row_factory = aiosqlite.Row
await init_db(db) await init_db(db)
await _write_password_hash(db, _SETUP_PAYLOAD["master_password"])
app.state.db = db app.state.db = db
app.state.http_session = MagicMock() app.state.http_session = MagicMock()
app.state.session_cache = NoOpSessionCache()
app.state.geo_cache = GeoCache()
async def _override_get_db():
yield db
from app.dependencies import get_db, get_session_cache
app.dependency_overrides[get_db] = _override_get_db
app.dependency_overrides[get_session_cache] = lambda: NoOpSessionCache()
transport = ASGITransport(app=app) transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac: async with AsyncClient(transport=transport, base_url="http://test", headers={"X-BanGUI-Request": "1"}) as ac:
await ac.post("/api/v1/setup", json=_SETUP_PAYLOAD)
login = await ac.post( login = await ac.post(
"/api/v1/auth/login", "/api/v1/auth/login",
json={"password": _SETUP_PAYLOAD["master_password"]}, json={"password": _SETUP_PAYLOAD["master_password"]},
@@ -57,6 +94,7 @@ async def server_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
yield ac yield ac
await db.close() await db.close()
app.dependency_overrides.clear()
def _make_settings() -> ServerSettingsResponse: def _make_settings() -> ServerSettingsResponse:

View File

@@ -99,6 +99,9 @@ def test_security_headers_on_all_response_types() -> None:
) )
app = create_app(settings=settings) app = create_app(settings=settings)
from app.models.server import ServerStatus
app.state.server_status = ServerStatus(online=True)
client = TestClient(app) client = TestClient(app)
# Test on successful response # Test on successful response

View File

@@ -81,7 +81,7 @@ class TestLogin:
self, db: aiosqlite.Connection self, db: aiosqlite.Connection
) -> None: ) -> None:
"""login() returns a signed token and expiry on the correct password.""" """login() returns a signed token and expiry on the correct password."""
signed_token, expires_at = await auth_service.login( signed_token, expires_at, _ = await auth_service.login(
db, db,
password="correctpassword1", password="correctpassword1",
session_duration_minutes=60, session_duration_minutes=60,
@@ -119,7 +119,7 @@ class TestLogin:
"""login() stores the session in the database.""" """login() stores the session in the database."""
from app.repositories import session_repo from app.repositories import session_repo
signed_token, _ = await auth_service.login( signed_token, _, _ = await auth_service.login(
db, db,
password="correctpassword1", password="correctpassword1",
session_duration_minutes=60, session_duration_minutes=60,
@@ -136,7 +136,7 @@ class TestValidateSession:
self, db: aiosqlite.Connection self, db: aiosqlite.Connection
) -> None: ) -> None:
"""validate_session() returns the session for a valid token.""" """validate_session() returns the session for a valid token."""
signed_token, _ = await auth_service.login( signed_token, _, _ = await auth_service.login(
db, db,
password="correctpassword1", password="correctpassword1",
session_duration_minutes=60, session_duration_minutes=60,
@@ -150,7 +150,7 @@ class TestValidateSession:
self, db: aiosqlite.Connection self, db: aiosqlite.Connection
) -> None: ) -> None:
"""validate_session() accepts a token signed with the configured secret.""" """validate_session() accepts a token signed with the configured secret."""
signed_token, _ = await auth_service.login( signed_token, _, _ = await auth_service.login(
db, db,
password="correctpassword1", password="correctpassword1",
session_duration_minutes=60, session_duration_minutes=60,
@@ -166,7 +166,7 @@ class TestValidateSession:
self, db: aiosqlite.Connection self, db: aiosqlite.Connection
) -> None: ) -> None:
"""validate_session() rejects signed tokens with an invalid signature.""" """validate_session() rejects signed tokens with an invalid signature."""
signed_token, _ = await auth_service.login( signed_token, _, _ = await auth_service.login(
db, db,
password="correctpassword1", password="correctpassword1",
session_duration_minutes=60, session_duration_minutes=60,
@@ -213,7 +213,7 @@ class TestLogout:
"""logout() deletes the session so it can no longer be validated.""" """logout() deletes the session so it can no longer be validated."""
from app.repositories import session_repo from app.repositories import session_repo
signed_token, _ = await auth_service.login( signed_token, _, _ = await auth_service.login(
db, db,
password="correctpassword1", password="correctpassword1",
session_duration_minutes=60, session_duration_minutes=60,
@@ -228,7 +228,7 @@ class TestLogout:
"""logout() accepts a signed token and revokes the underlying raw session.""" """logout() accepts a signed token and revokes the underlying raw session."""
from app.repositories import session_repo from app.repositories import session_repo
signed_token, _ = await auth_service.login( signed_token, _, _ = await auth_service.login(
db, db,
password="correctpassword1", password="correctpassword1",
session_duration_minutes=60, session_duration_minutes=60,
@@ -248,7 +248,7 @@ class TestSecretRotation:
self, db: aiosqlite.Connection self, db: aiosqlite.Connection
) -> None: ) -> None:
"""Tokens signed with current secret are validated immediately.""" """Tokens signed with current secret are validated immediately."""
signed_token, _ = await auth_service.login( signed_token, _, _ = await auth_service.login(
db, db,
password="correctpassword1", password="correctpassword1",
session_duration_minutes=60, session_duration_minutes=60,
@@ -264,7 +264,7 @@ class TestSecretRotation:
self, db: aiosqlite.Connection self, db: aiosqlite.Connection
) -> None: ) -> None:
"""Tokens signed with previous secret are accepted during rotation.""" """Tokens signed with previous secret are accepted during rotation."""
signed_token, _ = await auth_service.login( signed_token, _, _ = await auth_service.login(
db, db,
password="correctpassword1", password="correctpassword1",
session_duration_minutes=60, session_duration_minutes=60,
@@ -280,7 +280,7 @@ class TestSecretRotation:
self, db: aiosqlite.Connection self, db: aiosqlite.Connection
) -> None: ) -> None:
"""Tokens signed with unknown secrets are rejected.""" """Tokens signed with unknown secrets are rejected."""
signed_token, _ = await auth_service.login( signed_token, _, _ = await auth_service.login(
db, db,
password="correctpassword1", password="correctpassword1",
session_duration_minutes=60, session_duration_minutes=60,
@@ -308,7 +308,7 @@ class TestSecretRotation:
self, db: aiosqlite.Connection self, db: aiosqlite.Connection
) -> None: ) -> None:
"""During rotation, tokens signed with previous secret are re-signed.""" """During rotation, tokens signed with previous secret are re-signed."""
signed_token, _ = await auth_service.login( signed_token, _, _ = await auth_service.login(
db, db,
password="correctpassword1", password="correctpassword1",
session_duration_minutes=60, session_duration_minutes=60,
@@ -327,7 +327,7 @@ class TestSecretRotation:
self, db: aiosqlite.Connection self, db: aiosqlite.Connection
) -> None: ) -> None:
"""Validation processes token rotation during validation.""" """Validation processes token rotation during validation."""
signed_token, _ = await auth_service.login( signed_token, _, _ = await auth_service.login(
db, db,
password="correctpassword1", password="correctpassword1",
session_duration_minutes=60, session_duration_minutes=60,
@@ -348,7 +348,7 @@ class TestSecretRotation:
"""logout() accepts tokens signed with the previous secret.""" """logout() accepts tokens signed with the previous secret."""
from app.repositories import session_repo from app.repositories import session_repo
signed_token, _ = await auth_service.login( signed_token, _, _ = await auth_service.login(
db, db,
password="correctpassword1", password="correctpassword1",
session_duration_minutes=60, session_duration_minutes=60,
@@ -368,7 +368,7 @@ class TestSecretRotation:
self, db: aiosqlite.Connection self, db: aiosqlite.Connection
) -> None: ) -> None:
"""If no previous secret is configured, old tokens are rejected.""" """If no previous secret is configured, old tokens are rejected."""
signed_token, _ = await auth_service.login( signed_token, _, _ = await auth_service.login(
db, db,
password="correctpassword1", password="correctpassword1",
session_duration_minutes=60, session_duration_minutes=60,

View File

@@ -32,12 +32,7 @@ async def _create_f2b_db(path: str, rows: list[dict[str, Any]]) -> None:
``bantime``, ``bancount``, and optionally ``data``. ``bantime``, ``bancount``, and optionally ``data``.
""" """
async with aiosqlite.connect(path) as db: async with aiosqlite.connect(path) as db:
await db.execute( await db.execute("CREATE TABLE jails (name TEXT NOT NULL UNIQUE, enabled INTEGER NOT NULL DEFAULT 1)")
"CREATE TABLE jails ("
"name TEXT NOT NULL UNIQUE, "
"enabled INTEGER NOT NULL DEFAULT 1"
")"
)
await db.execute( await db.execute(
"CREATE TABLE bans (" "CREATE TABLE bans ("
"jail TEXT NOT NULL, " "jail TEXT NOT NULL, "
@@ -50,8 +45,7 @@ async def _create_f2b_db(path: str, rows: list[dict[str, Any]]) -> None:
) )
for row in rows: for row in rows:
await db.execute( await db.execute(
"INSERT INTO bans (jail, ip, timeofban, bantime, bancount, data) " "INSERT INTO bans (jail, ip, timeofban, bantime, bancount, data) VALUES (?, ?, ?, ?, ?, ?)",
"VALUES (?, ?, ?, ?, ?, ?)",
( (
row["jail"], row["jail"],
row["ip"], row["ip"],
@@ -257,9 +251,7 @@ class TestListBansHappyPath:
assert result.total == 3 assert result.total == 3
async def test_source_archive_reads_from_archive( async def test_source_archive_reads_from_archive(self, app_db_with_archive: aiosqlite.Connection) -> None:
self, app_db_with_archive: aiosqlite.Connection
) -> None:
"""Using source='archive' reads from the BanGUI archive table.""" """Using source='archive' reads from the BanGUI archive table."""
result = await ban_service.list_bans( result = await ban_service.list_bans(
"/fake/sock", "/fake/sock",
@@ -280,9 +272,7 @@ class TestListBansHappyPath:
class TestListBansGeoEnrichment: class TestListBansGeoEnrichment:
"""Verify geo enrichment integration in ban_service.list_bans().""" """Verify geo enrichment integration in ban_service.list_bans()."""
async def test_geo_data_applied_when_enricher_provided( async def test_geo_data_applied_when_enricher_provided(self, f2b_db_path: str) -> None:
self, f2b_db_path: str
) -> None:
"""Geo fields are populated when an enricher returns data.""" """Geo fields are populated when an enricher returns data."""
from app.models.geo import GeoInfo from app.models.geo import GeoInfo
@@ -298,30 +288,24 @@ class TestListBansGeoEnrichment:
"app.services.ban_service.get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path), new=AsyncMock(return_value=f2b_db_path),
): ):
result = await ban_service.list_bans( result = await ban_service.list_bans("/fake/sock", "24h", geo_enricher=fake_enricher)
"/fake/sock", "24h", geo_enricher=fake_enricher
)
for item in result.items: for item in result.items:
assert item.country_code == "DE" assert item.country_code == "DE"
assert item.country_name == "Germany" assert item.country_name == "Germany"
assert item.asn == "AS3320" assert item.asn == "AS3320"
async def test_geo_failure_does_not_break_results( async def test_geo_failure_does_not_break_results(self, f2b_db_path: str) -> None:
self, f2b_db_path: str
) -> None:
"""A geo enricher that raises still returns ban items (geo fields null).""" """A geo enricher that raises still returns ban items (geo fields null)."""
async def failing_enricher(ip: str) -> None: async def failing_enricher(ip: str) -> None:
raise RuntimeError("geo service down") raise OSError("geo service down")
with patch( with patch(
"app.services.ban_service.get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path), new=AsyncMock(return_value=f2b_db_path),
): ):
result = await ban_service.list_bans( result = await ban_service.list_bans("/fake/sock", "24h", geo_enricher=failing_enricher)
"/fake/sock", "24h", geo_enricher=failing_enricher
)
assert result.total == 2 assert result.total == 2
for item in result.items: for item in result.items:
@@ -336,9 +320,7 @@ class TestListBansGeoEnrichment:
class TestListBansBatchGeoEnrichment: class TestListBansBatchGeoEnrichment:
"""Verify that list_bans uses lookup_batch when http_session is provided.""" """Verify that list_bans uses lookup_batch when http_session is provided."""
async def test_batch_geo_applied_via_http_session( async def test_batch_geo_applied_via_http_session(self, f2b_db_path: str) -> None:
self, f2b_db_path: str
) -> None:
"""Geo fields are populated via lookup_batch when http_session is given.""" """Geo fields are populated via lookup_batch when http_session is given."""
from unittest.mock import MagicMock from unittest.mock import MagicMock
@@ -350,6 +332,8 @@ class TestListBansBatchGeoEnrichment:
"5.6.7.8": GeoInfo(country_code="US", country_name="United States", asn="AS15169", org="Google"), "5.6.7.8": GeoInfo(country_code="US", country_name="United States", asn="AS15169", org="Google"),
} }
fake_geo_batch = AsyncMock(return_value=fake_geo_map) fake_geo_batch = AsyncMock(return_value=fake_geo_map)
mock_geo_cache = MagicMock()
mock_geo_cache.lookup_batch = fake_geo_batch
with patch( with patch(
"app.services.ban_service.get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
@@ -359,7 +343,7 @@ class TestListBansBatchGeoEnrichment:
"/fake/sock", "/fake/sock",
"24h", "24h",
http_session=fake_session, http_session=fake_session,
geo_batch_lookup=fake_geo_batch, geo_cache=mock_geo_cache,
) )
fake_geo_batch.assert_awaited_once_with(["1.2.3.4", "5.6.7.8"], fake_session, db=None) fake_geo_batch.assert_awaited_once_with(["1.2.3.4", "5.6.7.8"], fake_session, db=None)
@@ -371,15 +355,15 @@ class TestListBansBatchGeoEnrichment:
assert us_item.country_code == "US" assert us_item.country_code == "US"
assert us_item.country_name == "United States" assert us_item.country_name == "United States"
async def test_batch_failure_does_not_break_results( async def test_batch_failure_does_not_break_results(self, f2b_db_path: str) -> None:
self, f2b_db_path: str
) -> None:
"""A lookup_batch failure still returns items with null geo fields.""" """A lookup_batch failure still returns items with null geo fields."""
from unittest.mock import MagicMock from unittest.mock import MagicMock
fake_session = MagicMock() fake_session = MagicMock()
failing_geo_batch = AsyncMock(side_effect=RuntimeError("batch geo down")) failing_geo_batch = AsyncMock(side_effect=OSError("batch geo down"))
mock_geo_cache = MagicMock()
mock_geo_cache.lookup_batch = failing_geo_batch
with patch( with patch(
"app.services.ban_service.get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
@@ -389,16 +373,14 @@ class TestListBansBatchGeoEnrichment:
"/fake/sock", "/fake/sock",
"24h", "24h",
http_session=fake_session, http_session=fake_session,
geo_batch_lookup=failing_geo_batch, geo_cache=mock_geo_cache,
) )
assert result.total == 2 assert result.total == 2
for item in result.items: for item in result.items:
assert item.country_code is None assert item.country_code is None
async def test_http_session_takes_priority_over_geo_enricher( async def test_http_session_takes_priority_over_geo_enricher(self, f2b_db_path: str) -> None:
self, f2b_db_path: str
) -> None:
"""When both http_session and geo_enricher are provided, batch wins.""" """When both http_session and geo_enricher are provided, batch wins."""
from unittest.mock import MagicMock from unittest.mock import MagicMock
@@ -410,6 +392,8 @@ class TestListBansBatchGeoEnrichment:
"5.6.7.8": GeoInfo(country_code="DE", country_name="Germany", asn=None, org=None), "5.6.7.8": GeoInfo(country_code="DE", country_name="Germany", asn=None, org=None),
} }
fake_geo_batch = AsyncMock(return_value=fake_geo_map) fake_geo_batch = AsyncMock(return_value=fake_geo_map)
mock_geo_cache = MagicMock()
mock_geo_cache.lookup_batch = fake_geo_batch
async def enricher_should_not_be_called(ip: str) -> GeoInfo: async def enricher_should_not_be_called(ip: str) -> GeoInfo:
raise AssertionError(f"geo_enricher was called for {ip!r} — should not happen") raise AssertionError(f"geo_enricher was called for {ip!r} — should not happen")
@@ -422,7 +406,7 @@ class TestListBansBatchGeoEnrichment:
"/fake/sock", "/fake/sock",
"24h", "24h",
http_session=fake_session, http_session=fake_session,
geo_batch_lookup=fake_geo_batch, geo_cache=mock_geo_cache,
geo_enricher=enricher_should_not_be_called, geo_enricher=enricher_should_not_be_called,
) )
@@ -462,9 +446,7 @@ class TestListBansPagination:
# Different IPs should appear on different pages. # Different IPs should appear on different pages.
assert page1.items[0].ip != page2.items[0].ip assert page1.items[0].ip != page2.items[0].ip
async def test_total_reflects_full_count_not_page_count( async def test_total_reflects_full_count_not_page_count(self, f2b_db_path: str) -> None:
self, f2b_db_path: str
) -> None:
"""``total`` reports all matching records regardless of pagination.""" """``total`` reports all matching records regardless of pagination."""
with patch( with patch(
"app.services.ban_service.get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
@@ -483,9 +465,7 @@ class TestListBansPagination:
class TestBanOriginDerivation: class TestBanOriginDerivation:
"""Verify that ban_service correctly derives ``origin`` from jail names.""" """Verify that ban_service correctly derives ``origin`` from jail names."""
async def test_blocklist_import_jail_yields_blocklist_origin( async def test_blocklist_import_jail_yields_blocklist_origin(self, mixed_origin_db_path: str) -> None:
self, mixed_origin_db_path: str
) -> None:
"""Bans from ``blocklist-import`` jail carry ``origin == "blocklist"``.""" """Bans from ``blocklist-import`` jail carry ``origin == "blocklist"``."""
with patch( with patch(
"app.services.ban_service.get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
@@ -497,9 +477,7 @@ class TestBanOriginDerivation:
assert len(blocklist_items) == 1 assert len(blocklist_items) == 1
assert blocklist_items[0].origin == "blocklist" assert blocklist_items[0].origin == "blocklist"
async def test_organic_jail_yields_selfblock_origin( async def test_organic_jail_yields_selfblock_origin(self, mixed_origin_db_path: str) -> None:
self, mixed_origin_db_path: str
) -> None:
"""Bans from organic jails (sshd, nginx, …) carry ``origin == "selfblock"``.""" """Bans from organic jails (sshd, nginx, …) carry ``origin == "selfblock"``."""
with patch( with patch(
"app.services.ban_service.get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
@@ -512,9 +490,7 @@ class TestBanOriginDerivation:
for item in organic_items: for item in organic_items:
assert item.origin == "selfblock" assert item.origin == "selfblock"
async def test_all_items_carry_origin_field( async def test_all_items_carry_origin_field(self, mixed_origin_db_path: str) -> None:
self, mixed_origin_db_path: str
) -> None:
"""Every returned item has an ``origin`` field with a valid value.""" """Every returned item has an ``origin`` field with a valid value."""
with patch( with patch(
"app.services.ban_service.get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
@@ -525,9 +501,7 @@ class TestBanOriginDerivation:
for item in result.items: for item in result.items:
assert item.origin in ("blocklist", "selfblock") assert item.origin in ("blocklist", "selfblock")
async def test_bans_by_country_blocklist_origin( async def test_bans_by_country_blocklist_origin(self, mixed_origin_db_path: str) -> None:
self, mixed_origin_db_path: str
) -> None:
"""``bans_by_country`` also derives origin correctly for blocklist bans.""" """``bans_by_country`` also derives origin correctly for blocklist bans."""
with patch( with patch(
"app.services.ban_service.get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
@@ -535,13 +509,11 @@ class TestBanOriginDerivation:
): ):
result = await ban_service.bans_by_country("/fake/sock", "24h") result = await ban_service.bans_by_country("/fake/sock", "24h")
blocklist_bans = [b for b in result.bans if b.jail == "blocklist-import"] blocklist_bans = [b for b in result.items if b.jail == "blocklist-import"]
assert len(blocklist_bans) == 1 assert len(blocklist_bans) == 1
assert blocklist_bans[0].origin == "blocklist" assert blocklist_bans[0].origin == "blocklist"
async def test_bans_by_country_selfblock_origin( async def test_bans_by_country_selfblock_origin(self, mixed_origin_db_path: str) -> None:
self, mixed_origin_db_path: str
) -> None:
"""``bans_by_country`` derives origin correctly for organic jails.""" """``bans_by_country`` derives origin correctly for organic jails."""
with patch( with patch(
"app.services.ban_service.get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
@@ -549,7 +521,7 @@ class TestBanOriginDerivation:
): ):
result = await ban_service.bans_by_country("/fake/sock", "24h") result = await ban_service.bans_by_country("/fake/sock", "24h")
organic_bans = [b for b in result.bans if b.jail != "blocklist-import"] organic_bans = [b for b in result.items if b.jail != "blocklist-import"]
assert len(organic_bans) == 2 assert len(organic_bans) == 2
for ban in organic_bans: for ban in organic_bans:
assert ban.origin == "selfblock" assert ban.origin == "selfblock"
@@ -563,34 +535,26 @@ class TestBanOriginDerivation:
class TestOriginFilter: class TestOriginFilter:
"""Verify that the origin filter correctly restricts results.""" """Verify that the origin filter correctly restricts results."""
async def test_list_bans_blocklist_filter_returns_only_blocklist( async def test_list_bans_blocklist_filter_returns_only_blocklist(self, mixed_origin_db_path: str) -> None:
self, mixed_origin_db_path: str
) -> None:
"""``origin='blocklist'`` returns only blocklist-import jail bans.""" """``origin='blocklist'`` returns only blocklist-import jail bans."""
with patch( with patch(
"app.services.ban_service.get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path), new=AsyncMock(return_value=mixed_origin_db_path),
): ):
result = await ban_service.list_bans( result = await ban_service.list_bans("/fake/sock", "24h", origin="blocklist")
"/fake/sock", "24h", origin="blocklist"
)
assert result.total == 1 assert result.total == 1
assert len(result.items) == 1 assert len(result.items) == 1
assert result.items[0].jail == "blocklist-import" assert result.items[0].jail == "blocklist-import"
assert result.items[0].origin == "blocklist" assert result.items[0].origin == "blocklist"
async def test_list_bans_selfblock_filter_excludes_blocklist( async def test_list_bans_selfblock_filter_excludes_blocklist(self, mixed_origin_db_path: str) -> None:
self, mixed_origin_db_path: str
) -> None:
"""``origin='selfblock'`` excludes the blocklist-import jail.""" """``origin='selfblock'`` excludes the blocklist-import jail."""
with patch( with patch(
"app.services.ban_service.get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path), new=AsyncMock(return_value=mixed_origin_db_path),
): ):
result = await ban_service.list_bans( result = await ban_service.list_bans("/fake/sock", "24h", origin="selfblock")
"/fake/sock", "24h", origin="selfblock"
)
assert result.total == 2 assert result.total == 2
assert len(result.items) == 2 assert len(result.items) == 2
@@ -598,9 +562,7 @@ class TestOriginFilter:
assert item.jail != "blocklist-import" assert item.jail != "blocklist-import"
assert item.origin == "selfblock" assert item.origin == "selfblock"
async def test_list_bans_no_filter_returns_all( async def test_list_bans_no_filter_returns_all(self, mixed_origin_db_path: str) -> None:
self, mixed_origin_db_path: str
) -> None:
"""``origin=None`` applies no jail restriction — all bans returned.""" """``origin=None`` applies no jail restriction — all bans returned."""
with patch( with patch(
"app.services.ban_service.get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
@@ -610,53 +572,39 @@ class TestOriginFilter:
assert result.total == 3 assert result.total == 3
async def test_bans_by_country_blocklist_filter( async def test_bans_by_country_blocklist_filter(self, mixed_origin_db_path: str) -> None:
self, mixed_origin_db_path: str
) -> None:
"""``bans_by_country`` with ``origin='blocklist'`` counts only blocklist bans.""" """``bans_by_country`` with ``origin='blocklist'`` counts only blocklist bans."""
with patch( with patch(
"app.services.ban_service.get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path), new=AsyncMock(return_value=mixed_origin_db_path),
): ):
result = await ban_service.bans_by_country( result = await ban_service.bans_by_country("/fake/sock", "24h", origin="blocklist")
"/fake/sock", "24h", origin="blocklist"
)
assert result.total == 1 assert result.total == 1
assert all(b.jail == "blocklist-import" for b in result.bans) assert all(b.jail == "blocklist-import" for b in result.items)
async def test_bans_by_country_selfblock_filter( async def test_bans_by_country_selfblock_filter(self, mixed_origin_db_path: str) -> None:
self, mixed_origin_db_path: str
) -> None:
"""``bans_by_country`` with ``origin='selfblock'`` excludes blocklist jails.""" """``bans_by_country`` with ``origin='selfblock'`` excludes blocklist jails."""
with patch( with patch(
"app.services.ban_service.get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path), new=AsyncMock(return_value=mixed_origin_db_path),
): ):
result = await ban_service.bans_by_country( result = await ban_service.bans_by_country("/fake/sock", "24h", origin="selfblock")
"/fake/sock", "24h", origin="selfblock"
)
assert result.total == 2 assert result.total == 2
assert all(b.jail != "blocklist-import" for b in result.bans) assert all(b.jail != "blocklist-import" for b in result.items)
async def test_bans_by_country_no_filter_returns_all( async def test_bans_by_country_no_filter_returns_all(self, mixed_origin_db_path: str) -> None:
self, mixed_origin_db_path: str
) -> None:
"""``bans_by_country`` with ``origin=None`` returns all bans.""" """``bans_by_country`` with ``origin=None`` returns all bans."""
with patch( with patch(
"app.services.ban_service.get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path), new=AsyncMock(return_value=mixed_origin_db_path),
): ):
result = await ban_service.bans_by_country( result = await ban_service.bans_by_country("/fake/sock", "24h", origin=None)
"/fake/sock", "24h", origin=None
)
assert result.total == 3 assert result.total == 3
async def test_bans_by_country_country_code_returns_all_matched_rows( async def test_bans_by_country_country_code_returns_all_matched_rows(self, tmp_path: Path) -> None:
self, tmp_path: Path
) -> None:
"""``bans_by_country`` returns all companion rows for the selected country.""" """``bans_by_country`` returns all companion rows for the selected country."""
path = str(tmp_path / "fail2ban_country_filter.sqlite3") path = str(tmp_path / "fail2ban_country_filter.sqlite3")
rows = [ rows = [
@@ -672,8 +620,8 @@ class TestOriginFilter:
] ]
await _create_f2b_db(path, rows) await _create_f2b_db(path, rows)
from app.services import geo_service
from app.models.geo import GeoInfo from app.models.geo import GeoInfo
from app.services import geo_service
geo_service._default_geo_cache._cache["10.0.0.1"] = GeoInfo( geo_service._default_geo_cache._cache["10.0.0.1"] = GeoInfo(
country_code="DE", country_code="DE",
@@ -682,12 +630,13 @@ class TestOriginFilter:
org=None, org=None,
) )
with patch( with (
patch(
"app.services.ban_service.get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=path), new=AsyncMock(return_value=path),
), patch( ),
"app.services.ban_service.asyncio.create_task" patch("app.services.ban_service.asyncio.create_task") as mock_create_task,
) as mock_create_task: ):
result = await ban_service.bans_by_country( result = await ban_service.bans_by_country(
"/fake/sock", "/fake/sock",
"24h", "24h",
@@ -698,8 +647,8 @@ class TestOriginFilter:
mock_create_task.assert_not_called() mock_create_task.assert_not_called()
assert result.total == 205 assert result.total == 205
assert len(result.bans) == 205 assert len(result.items) == 205
assert all(b.country_code == "DE" for b in result.bans) assert all(b.country_code == "DE" for b in result.items)
await geo_service.clear_cache() await geo_service.clear_cache()
@@ -715,7 +664,7 @@ class TestOriginFilter:
) )
assert result.total == 2 assert result.total == 2
assert len(result.bans) == 2 assert len(result.items) == 2
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -728,13 +677,11 @@ class TestBansbyCountryBackground:
"""bans_by_country() with http_session uses cache-only geo and fires a """bans_by_country() with http_session uses cache-only geo and fires a
background task for uncached IPs instead of blocking on API calls.""" background task for uncached IPs instead of blocking on API calls."""
async def test_cached_geo_returned_without_api_call( async def test_cached_geo_returned_without_api_call(self, mixed_origin_db_path: str) -> None:
self, mixed_origin_db_path: str
) -> None:
"""When all IPs are in the cache, lookup_cached_only returns them and """When all IPs are in the cache, lookup_cached_only returns them and
no background task is created.""" no background task is created."""
from app.services import geo_service
from app.models.geo import GeoInfo from app.models.geo import GeoInfo
from app.services import geo_service
# Pre-populate the cache for all three IPs in the fixture. # Pre-populate the cache for all three IPs in the fixture.
geo_service._default_geo_cache._cache["10.0.0.1"] = GeoInfo( geo_service._default_geo_cache._cache["10.0.0.1"] = GeoInfo(
@@ -752,9 +699,7 @@ class TestBansbyCountryBackground:
"app.services.ban_service.get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path), new=AsyncMock(return_value=mixed_origin_db_path),
), ),
patch( patch("app.services.ban_service.asyncio.create_task") as mock_create_task,
"app.services.ban_service.asyncio.create_task"
) as mock_create_task,
): ):
mock_session = AsyncMock() mock_session = AsyncMock()
mock_batch = AsyncMock(return_value={}) mock_batch = AsyncMock(return_value={})
@@ -763,7 +708,6 @@ class TestBansbyCountryBackground:
"24h", "24h",
http_session=mock_session, http_session=mock_session,
geo_cache_lookup=geo_service.lookup_cached_only, geo_cache_lookup=geo_service.lookup_cached_only,
geo_batch_lookup=mock_batch,
) )
# All countries resolved from cache — no background task needed. # All countries resolved from cache — no background task needed.
@@ -773,9 +717,7 @@ class TestBansbyCountryBackground:
assert "DE" in result.countries or "US" in result.countries or "JP" in result.countries assert "DE" in result.countries or "US" in result.countries or "JP" in result.countries
await geo_service.clear_cache() await geo_service.clear_cache()
async def test_uncached_ips_trigger_background_task( async def test_uncached_ips_trigger_background_task(self, mixed_origin_db_path: str) -> None:
self, mixed_origin_db_path: str
) -> None:
"""When IPs are NOT in the cache, create_task is called for background """When IPs are NOT in the cache, create_task is called for background
resolution and the response returns without blocking.""" resolution and the response returns without blocking."""
from app.services import geo_service from app.services import geo_service
@@ -787,9 +729,7 @@ class TestBansbyCountryBackground:
"app.services.ban_service.get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path), new=AsyncMock(return_value=mixed_origin_db_path),
), ),
patch( patch("app.services.ban_service.asyncio.create_task") as mock_create_task,
"app.services.ban_service.asyncio.create_task"
) as mock_create_task,
): ):
mock_session = AsyncMock() mock_session = AsyncMock()
mock_batch = AsyncMock(return_value={}) mock_batch = AsyncMock(return_value={})
@@ -798,7 +738,7 @@ class TestBansbyCountryBackground:
"24h", "24h",
http_session=mock_session, http_session=mock_session,
geo_cache_lookup=geo_service.lookup_cached_only, geo_cache_lookup=geo_service.lookup_cached_only,
geo_batch_lookup=mock_batch, geo_cache=geo_service.GeoCache(),
) )
# Background task must have been scheduled for uncached IPs. # Background task must have been scheduled for uncached IPs.
@@ -806,9 +746,7 @@ class TestBansbyCountryBackground:
# Response is still valid with empty country map (IPs not cached yet). # Response is still valid with empty country map (IPs not cached yet).
assert result.total == 3 assert result.total == 3
async def test_no_background_task_without_http_session( async def test_no_background_task_without_http_session(self, mixed_origin_db_path: str) -> None:
self, mixed_origin_db_path: str
) -> None:
"""When http_session is None, no background task is created.""" """When http_session is None, no background task is created."""
from app.services import geo_service from app.services import geo_service
@@ -819,13 +757,9 @@ class TestBansbyCountryBackground:
"app.services.ban_service.get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path), new=AsyncMock(return_value=mixed_origin_db_path),
), ),
patch( patch("app.services.ban_service.asyncio.create_task") as mock_create_task,
"app.services.ban_service.asyncio.create_task"
) as mock_create_task,
): ):
result = await ban_service.bans_by_country( result = await ban_service.bans_by_country("/fake/sock", "24h", http_session=None)
"/fake/sock", "24h", http_session=None
)
mock_create_task.assert_not_called() mock_create_task.assert_not_called()
assert result.total == 3 assert result.total == 3
@@ -904,9 +838,7 @@ class TestBanTrend:
timestamps = [b.timestamp for b in result.buckets] timestamps = [b.timestamp for b in result.buckets]
assert timestamps == sorted(timestamps) assert timestamps == sorted(timestamps)
async def test_ban_trend_source_archive_reads_archive( async def test_ban_trend_source_archive_reads_archive(self, app_db_with_archive: aiosqlite.Connection) -> None:
self, app_db_with_archive: aiosqlite.Connection
) -> None:
"""``ban_trend`` accepts source='archive' and uses archived rows.""" """``ban_trend`` accepts source='archive' and uses archived rows."""
result = await ban_service.ban_trend( result = await ban_service.ban_trend(
"/fake/sock", "/fake/sock",
@@ -959,9 +891,7 @@ class TestBanTrend:
"app.services.ban_service.get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=path), new=AsyncMock(return_value=path),
): ):
result = await ban_service.ban_trend( result = await ban_service.ban_trend("/fake/sock", "24h", origin="blocklist")
"/fake/sock", "24h", origin="blocklist"
)
assert sum(b.count for b in result.buckets) == 1 assert sum(b.count for b in result.buckets) == 1
@@ -985,9 +915,7 @@ class TestBanTrend:
"app.services.ban_service.get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=path), new=AsyncMock(return_value=path),
): ):
result = await ban_service.ban_trend( result = await ban_service.ban_trend("/fake/sock", "24h", origin="selfblock")
"/fake/sock", "24h", origin="selfblock"
)
assert sum(b.count for b in result.buckets) == 2 assert sum(b.count for b in result.buckets) == 2
@@ -1096,9 +1024,7 @@ class TestBansByJail:
"app.services.ban_service.get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path), new=AsyncMock(return_value=mixed_origin_db_path),
): ):
result = await ban_service.bans_by_jail( result = await ban_service.bans_by_jail("/fake/sock", "24h", origin="blocklist")
"/fake/sock", "24h", origin="blocklist"
)
assert len(result.jails) == 1 assert len(result.jails) == 1
assert result.jails[0].jail == "blocklist-import" assert result.jails[0].jail == "blocklist-import"
@@ -1110,32 +1036,24 @@ class TestBansByJail:
"app.services.ban_service.get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path), new=AsyncMock(return_value=mixed_origin_db_path),
): ):
result = await ban_service.bans_by_jail( result = await ban_service.bans_by_jail("/fake/sock", "24h", origin="selfblock")
"/fake/sock", "24h", origin="selfblock"
)
jail_names = {j.jail for j in result.jails} jail_names = {j.jail for j in result.jails}
assert "blocklist-import" not in jail_names assert "blocklist-import" not in jail_names
assert result.total == 2 assert result.total == 2
async def test_no_origin_filter_returns_all_jails( async def test_no_origin_filter_returns_all_jails(self, mixed_origin_db_path: str) -> None:
self, mixed_origin_db_path: str
) -> None:
"""``origin=None`` returns bans from all jails.""" """``origin=None`` returns bans from all jails."""
with patch( with patch(
"app.services.ban_service.get_fail2ban_db_path", "app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path), new=AsyncMock(return_value=mixed_origin_db_path),
): ):
result = await ban_service.bans_by_jail( result = await ban_service.bans_by_jail("/fake/sock", "24h", origin=None)
"/fake/sock", "24h", origin=None
)
assert result.total == 3 assert result.total == 3
assert len(result.jails) == 3 assert len(result.jails) == 3
async def test_bans_by_jail_source_archive_reads_archive( async def test_bans_by_jail_source_archive_reads_archive(self, app_db_with_archive: aiosqlite.Connection) -> None:
self, app_db_with_archive: aiosqlite.Connection
) -> None:
"""``bans_by_jail`` accepts source='archive' and aggregates archived rows.""" """``bans_by_jail`` accepts source='archive' and aggregates archived rows."""
result = await ban_service.bans_by_jail( result = await ban_service.bans_by_jail(
"/fake/sock", "/fake/sock",
@@ -1147,9 +1065,7 @@ class TestBansByJail:
assert result.total == 2 assert result.total == 2
assert any(j.jail == "sshd" for j in result.jails) assert any(j.jail == "sshd" for j in result.jails)
async def test_diagnostic_warning_when_zero_results_despite_data( async def test_diagnostic_warning_when_zero_results_despite_data(self, tmp_path: Path) -> None:
self, tmp_path: Path
) -> None:
"""A warning is logged when the time-range filter excludes all existing rows.""" """A warning is logged when the time-range filter excludes all existing rows."""
import time as _time import time as _time
@@ -1176,9 +1092,6 @@ class TestBansByJail:
assert result.jails == [] assert result.jails == []
# The diagnostic warning must have been emitted. # The diagnostic warning must have been emitted.
warning_calls = [ warning_calls = [
c c for c in mock_log.warning.call_args_list if c[0][0] == "ban_service_bans_by_jail_empty_despite_data"
for c in mock_log.warning.call_args_list
if c[0][0] == "ban_service_bans_by_jail_empty_despite_data"
] ]
assert len(warning_calls) == 1 assert len(warning_calls) == 1

File diff suppressed because it is too large Load Diff

View File

@@ -12,11 +12,10 @@ import pytest
from app.config import Settings from app.config import Settings
from app.models.config import ( from app.models.config import (
GlobalConfigUpdate, GlobalConfigUpdate,
JailConfigListResponse,
JailConfigResponse,
LogPreviewRequest, LogPreviewRequest,
RegexTestRequest, RegexTestRequest,
) )
from app.models.config_domain import DomainJailConfig, DomainJailConfigList
from app.services import config_service, health_service, log_service from app.services import config_service, health_service, log_service
from app.services.config_service import ( from app.services.config_service import (
ConfigValidationError, ConfigValidationError,
@@ -31,6 +30,7 @@ from app.services.config_service import (
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def _mock_settings(monkeypatch: pytest.MonkeyPatch) -> None: def _mock_settings(monkeypatch: pytest.MonkeyPatch) -> None:
"""Mock get_settings for all tests in this module.""" """Mock get_settings for all tests in this module."""
def mock_get_settings() -> Settings: def mock_get_settings() -> Settings:
return Settings( return Settings(
database_path=":memory:", database_path=":memory:",
@@ -39,7 +39,7 @@ def _mock_settings(monkeypatch: pytest.MonkeyPatch) -> None:
session_secret="test-secret-key-do-not-use-in-production", session_secret="test-secret-key-do-not-use-in-production",
) )
monkeypatch.setattr("app.models.config.get_settings", mock_get_settings) monkeypatch.setattr("app.config.get_settings", mock_get_settings)
monkeypatch.setattr("app.utils.path_utils.get_settings", mock_get_settings) monkeypatch.setattr("app.utils.path_utils.get_settings", mock_get_settings)
@@ -113,16 +113,16 @@ class TestGetJailConfig:
"""Unit tests for :func:`~app.services.config_service.get_jail_config`.""" """Unit tests for :func:`~app.services.config_service.get_jail_config`."""
async def test_returns_jail_config_response(self) -> None: async def test_returns_jail_config_response(self) -> None:
"""get_jail_config returns a JailConfigResponse.""" """get_jail_config returns a DomainJailConfig."""
with _patch_client(_DEFAULT_JAIL_RESPONSES): with _patch_client(_DEFAULT_JAIL_RESPONSES):
result = await config_service.get_jail_config(_SOCKET, "sshd") result = await config_service.get_jail_config(_SOCKET, "sshd")
assert isinstance(result, JailConfigResponse) assert isinstance(result, DomainJailConfig)
assert result.jail.name == "sshd" assert result.name == "sshd"
assert result.jail.ban_time == 600 assert result.ban_time == 600
assert result.jail.max_retry == 5 assert result.max_retry == 5
assert result.jail.fail_regex == ["regex1", "regex2"] assert result.fail_regex == ["regex1", "regex2"]
assert result.jail.log_paths == ["/var/log/auth.log"] assert result.log_paths == ["/var/log/auth.log"]
async def test_raises_jail_not_found(self) -> None: async def test_raises_jail_not_found(self) -> None:
"""get_jail_config raises JailNotFoundError for an unknown jail.""" """get_jail_config raises JailNotFoundError for an unknown jail."""
@@ -140,10 +140,13 @@ class TestGetJailConfig:
return (1, "unknown jail 'missing'") return (1, "unknown jail 'missing'")
return (0, None) return (0, None)
with patch( with (
patch(
"app.services.config_service.Fail2BanClient", "app.services.config_service.Fail2BanClient",
lambda **_kw: type("C", (), {"send": AsyncMock(side_effect=_faulty_send)})(), lambda **_kw: type("C", (), {"send": AsyncMock(side_effect=_faulty_send)})(),
), pytest.raises(JailNotFoundError): ),
pytest.raises(JailNotFoundError),
):
await config_service.get_jail_config(_SOCKET, "missing") await config_service.get_jail_config(_SOCKET, "missing")
async def test_actions_parsed_correctly(self) -> None: async def test_actions_parsed_correctly(self) -> None:
@@ -151,7 +154,7 @@ class TestGetJailConfig:
with _patch_client(_DEFAULT_JAIL_RESPONSES): with _patch_client(_DEFAULT_JAIL_RESPONSES):
result = await config_service.get_jail_config(_SOCKET, "sshd") result = await config_service.get_jail_config(_SOCKET, "sshd")
assert "iptables" in result.jail.actions assert "iptables" in result.actions
async def test_empty_log_paths_fallback(self) -> None: async def test_empty_log_paths_fallback(self) -> None:
"""get_jail_config handles None log paths gracefully.""" """get_jail_config handles None log paths gracefully."""
@@ -159,14 +162,14 @@ class TestGetJailConfig:
with _patch_client(responses): with _patch_client(responses):
result = await config_service.get_jail_config(_SOCKET, "sshd") result = await config_service.get_jail_config(_SOCKET, "sshd")
assert result.jail.log_paths == [] assert result.log_paths == []
async def test_date_pattern_none(self) -> None: async def test_date_pattern_none(self) -> None:
"""get_jail_config returns None date_pattern when not set.""" """get_jail_config returns None date_pattern when not set."""
with _patch_client(_DEFAULT_JAIL_RESPONSES): with _patch_client(_DEFAULT_JAIL_RESPONSES):
result = await config_service.get_jail_config(_SOCKET, "sshd") result = await config_service.get_jail_config(_SOCKET, "sshd")
assert result.jail.date_pattern is None assert result.date_pattern is None
async def test_use_dns_populated(self) -> None: async def test_use_dns_populated(self) -> None:
"""get_jail_config returns use_dns from the socket response.""" """get_jail_config returns use_dns from the socket response."""
@@ -174,7 +177,7 @@ class TestGetJailConfig:
with _patch_client(responses): with _patch_client(responses):
result = await config_service.get_jail_config(_SOCKET, "sshd") result = await config_service.get_jail_config(_SOCKET, "sshd")
assert result.jail.use_dns == "no" assert result.use_dns == "no"
async def test_use_dns_default_when_missing(self) -> None: async def test_use_dns_default_when_missing(self) -> None:
"""get_jail_config defaults use_dns to 'warn' when socket returns None.""" """get_jail_config defaults use_dns to 'warn' when socket returns None."""
@@ -182,7 +185,7 @@ class TestGetJailConfig:
with _patch_client(responses): with _patch_client(responses):
result = await config_service.get_jail_config(_SOCKET, "sshd") result = await config_service.get_jail_config(_SOCKET, "sshd")
assert result.jail.use_dns == "warn" assert result.use_dns == "warn"
async def test_prefregex_populated(self) -> None: async def test_prefregex_populated(self) -> None:
"""get_jail_config returns prefregex from the socket response.""" """get_jail_config returns prefregex from the socket response."""
@@ -193,7 +196,7 @@ class TestGetJailConfig:
with _patch_client(responses): with _patch_client(responses):
result = await config_service.get_jail_config(_SOCKET, "sshd") result = await config_service.get_jail_config(_SOCKET, "sshd")
assert result.jail.prefregex == r"^%(__prefix_line)s" assert result.prefregex == r"^%(__prefix_line)s"
async def test_prefregex_empty_when_missing(self) -> None: async def test_prefregex_empty_when_missing(self) -> None:
"""get_jail_config returns empty string prefregex when socket returns None.""" """get_jail_config returns empty string prefregex when socket returns None."""
@@ -201,7 +204,7 @@ class TestGetJailConfig:
with _patch_client(responses): with _patch_client(responses):
result = await config_service.get_jail_config(_SOCKET, "sshd") result = await config_service.get_jail_config(_SOCKET, "sshd")
assert result.jail.prefregex == "" assert result.prefregex == ""
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -213,12 +216,12 @@ class TestListJailConfigs:
"""Unit tests for :func:`~app.services.config_service.list_jail_configs`.""" """Unit tests for :func:`~app.services.config_service.list_jail_configs`."""
async def test_returns_list_response(self) -> None: async def test_returns_list_response(self) -> None:
"""list_jail_configs returns a JailConfigListResponse.""" """list_jail_configs returns a DomainJailConfigList."""
responses = {"status": _make_global_status("sshd"), **_DEFAULT_JAIL_RESPONSES} responses = {"status": _make_global_status("sshd"), **_DEFAULT_JAIL_RESPONSES}
with _patch_client(responses): with _patch_client(responses):
result = await config_service.list_jail_configs(_SOCKET) result = await config_service.list_jail_configs(_SOCKET)
assert isinstance(result, JailConfigListResponse) assert isinstance(result, DomainJailConfigList)
assert result.total == 1 assert result.total == 1
assert result.items[0].name == "sshd" assert result.items[0].name == "sshd"
@@ -233,9 +236,7 @@ class TestListJailConfigs:
async def test_multiple_jails(self) -> None: async def test_multiple_jails(self) -> None:
"""list_jail_configs handles comma-separated jail names.""" """list_jail_configs handles comma-separated jail names."""
nginx_responses = { nginx_responses = {k.replace("sshd", "nginx"): v for k, v in _DEFAULT_JAIL_RESPONSES.items()}
k.replace("sshd", "nginx"): v for k, v in _DEFAULT_JAIL_RESPONSES.items()
}
responses = { responses = {
"status": _make_global_status("sshd, nginx"), "status": _make_global_status("sshd, nginx"),
**_DEFAULT_JAIL_RESPONSES, **_DEFAULT_JAIL_RESPONSES,
@@ -521,11 +522,16 @@ class TestUpdateGlobalConfig:
assert cmd[2] == "DEBUG" assert cmd[2] == "DEBUG"
async def test_invalid_log_target_raises_config_validation_error(self) -> None: async def test_invalid_log_target_raises_config_validation_error(self) -> None:
"""update_global_config rejects invalid log_target from model validation.""" """update_global_config rejects invalid log_target."""
from pydantic import ValidationError update = GlobalConfigUpdate(log_target="/etc/passwd")
with (
with pytest.raises(ValidationError, match="outside allowed directories"): patch(
GlobalConfigUpdate(log_target="/etc/passwd") "app.services.config_service.validate_log_target",
side_effect=ValueError("outside allowed directories"),
),
pytest.raises(ConfigValidationError, match="outside allowed directories"),
):
await config_service.update_global_config(_SOCKET, update)
async def test_valid_special_log_target(self) -> None: async def test_valid_special_log_target(self) -> None:
"""update_global_config accepts special log_target values.""" """update_global_config accepts special log_target values."""
@@ -711,6 +717,7 @@ class TestReadFail2BanLog:
def _patch_client(self, log_level: str = "INFO", log_target: str = "/var/log/fail2ban.log") -> Any: def _patch_client(self, log_level: str = "INFO", log_target: str = "/var/log/fail2ban.log") -> Any:
"""Build a patched Fail2BanClient that returns *log_level* and *log_target*.""" """Build a patched Fail2BanClient that returns *log_level* and *log_target*."""
async def _send(command: list[Any]) -> Any: async def _send(command: list[Any]) -> Any:
key = "|".join(str(c) for c in command) key = "|".join(str(c) for c in command)
if key == "get|loglevel": if key == "get|loglevel":
@@ -735,8 +742,10 @@ class TestReadFail2BanLog:
log_dir = str(tmp_path) log_dir = str(tmp_path)
# Patch _SAFE_LOG_PREFIXES to allow tmp_path # Patch _SAFE_LOG_PREFIXES to allow tmp_path
with self._patch_client(log_target=str(log_file)), \ with (
patch("app.services.log_service._SAFE_LOG_PREFIXES", (log_dir,)): self._patch_client(log_target=str(log_file)),
patch("app.services.log_service._SAFE_LOG_PREFIXES", (log_dir,)),
):
result = await log_service.read_fail2ban_log(_SOCKET, 200) result = await log_service.read_fail2ban_log(_SOCKET, 200)
assert result.log_path == str(log_file.resolve()) assert result.log_path == str(log_file.resolve())
@@ -750,8 +759,10 @@ class TestReadFail2BanLog:
log_file.write_text("INFO sshd Found 1.2.3.4\nERROR something else\nINFO sshd Found 5.6.7.8\n") log_file.write_text("INFO sshd Found 1.2.3.4\nERROR something else\nINFO sshd Found 5.6.7.8\n")
log_dir = str(tmp_path) log_dir = str(tmp_path)
with self._patch_client(log_target=str(log_file)), \ with (
patch("app.services.log_service._SAFE_LOG_PREFIXES", (log_dir,)): self._patch_client(log_target=str(log_file)),
patch("app.services.log_service._SAFE_LOG_PREFIXES", (log_dir,)),
):
result = await log_service.read_fail2ban_log(_SOCKET, 200, "Found") result = await log_service.read_fail2ban_log(_SOCKET, 200, "Found")
assert all("Found" in ln for ln in result.lines) assert all("Found" in ln for ln in result.lines)
@@ -759,14 +770,18 @@ class TestReadFail2BanLog:
async def test_non_file_target_raises_operation_error(self) -> None: async def test_non_file_target_raises_operation_error(self) -> None:
"""read_fail2ban_log raises ConfigOperationError for STDOUT target.""" """read_fail2ban_log raises ConfigOperationError for STDOUT target."""
with self._patch_client(log_target="STDOUT"), \ with (
pytest.raises(config_service.ConfigOperationError, match="STDOUT"): self._patch_client(log_target="STDOUT"),
pytest.raises(config_service.ConfigOperationError, match="STDOUT"),
):
await log_service.read_fail2ban_log(_SOCKET, 200) await log_service.read_fail2ban_log(_SOCKET, 200)
async def test_syslog_target_raises_operation_error(self) -> None: async def test_syslog_target_raises_operation_error(self) -> None:
"""read_fail2ban_log raises ConfigOperationError for SYSLOG target.""" """read_fail2ban_log raises ConfigOperationError for SYSLOG target."""
with self._patch_client(log_target="SYSLOG"), \ with (
pytest.raises(config_service.ConfigOperationError, match="SYSLOG"): self._patch_client(log_target="SYSLOG"),
pytest.raises(config_service.ConfigOperationError, match="SYSLOG"),
):
await log_service.read_fail2ban_log(_SOCKET, 200) await log_service.read_fail2ban_log(_SOCKET, 200)
async def test_path_outside_safe_dir_raises_operation_error(self, tmp_path: Any) -> None: async def test_path_outside_safe_dir_raises_operation_error(self, tmp_path: Any) -> None:
@@ -775,9 +790,11 @@ class TestReadFail2BanLog:
log_file.write_text("secret data\n") log_file.write_text("secret data\n")
# Allow only /var/log — tmp_path is deliberately not in the safe list. # Allow only /var/log — tmp_path is deliberately not in the safe list.
with self._patch_client(log_target=str(log_file)), \ with (
patch("app.services.log_service._SAFE_LOG_PREFIXES", ("/var/log",)), \ self._patch_client(log_target=str(log_file)),
pytest.raises(config_service.ConfigOperationError, match="outside the allowed"): patch("app.services.log_service._SAFE_LOG_PREFIXES", ("/var/log",)),
pytest.raises(config_service.ConfigOperationError, match="outside the allowed"),
):
await log_service.read_fail2ban_log(_SOCKET, 200) await log_service.read_fail2ban_log(_SOCKET, 200)
async def test_missing_log_file_raises_operation_error(self, tmp_path: Any) -> None: async def test_missing_log_file_raises_operation_error(self, tmp_path: Any) -> None:
@@ -785,9 +802,11 @@ class TestReadFail2BanLog:
missing = str(tmp_path / "nonexistent.log") missing = str(tmp_path / "nonexistent.log")
log_dir = str(tmp_path) log_dir = str(tmp_path)
with self._patch_client(log_target=missing), \ with (
patch("app.services.log_service._SAFE_LOG_PREFIXES", (log_dir,)), \ self._patch_client(log_target=missing),
pytest.raises(config_service.ConfigOperationError, match="not found"): patch("app.services.log_service._SAFE_LOG_PREFIXES", (log_dir,)),
pytest.raises(config_service.ConfigOperationError, match="not found"),
):
await log_service.read_fail2ban_log(_SOCKET, 200) await log_service.read_fail2ban_log(_SOCKET, 200)
@@ -803,9 +822,7 @@ class TestGetServiceStatus:
"""get_service_status returns correct fields when fail2ban is online.""" """get_service_status returns correct fields when fail2ban is online."""
from app.models.server import ServerStatus from app.models.server import ServerStatus
online_status = ServerStatus( online_status = ServerStatus(online=True, version="1.0.0", active_jails=2, total_bans=5, total_failures=3)
online=True, version="1.0.0", active_jails=2, total_bans=5, total_failures=3
)
async def _send(command: list[Any]) -> Any: async def _send(command: list[Any]) -> Any:
key = "|".join(str(c) for c in command) key = "|".join(str(c) for c in command)
@@ -878,12 +895,15 @@ class TestConfigModuleIntegration:
}, },
) )
with patch( with (
patch(
"app.services.jail_config_service._parse_jails_sync", "app.services.jail_config_service._parse_jails_sync",
new=fake_parse_jails_sync, new=fake_parse_jails_sync,
), patch( ),
patch(
"app.services.jail_config_service._get_active_jail_names", "app.services.jail_config_service._get_active_jail_names",
new=AsyncMock(return_value={"sshd"}), new=AsyncMock(return_value={"sshd"}),
),
): ):
result = await list_inactive_jails(str(tmp_path), "/fake.sock") result = await list_inactive_jails(str(tmp_path), "/fake.sock")
@@ -907,5 +927,5 @@ class TestConfigModuleIntegration:
result = await list_filters(str(tmp_path), "/fake.sock") result = await list_filters(str(tmp_path), "/fake.sock")
assert result.total == 1 assert result.total == 1
assert result.filters[0].name == "sshd" assert result.items[0].name == "sshd"
assert result.filters[0].active is True assert result.items[0].active is True

View File

@@ -209,9 +209,7 @@ class TestLookupCaching:
async def test_negative_result_stored_in_neg_cache(self, geo_cache: GeoCache) -> None: async def test_negative_result_stored_in_neg_cache(self, geo_cache: GeoCache) -> None:
"""A failed lookup is stored in the negative cache, so the second call is blocked.""" """A failed lookup is stored in the negative cache, so the second call is blocked."""
session = _make_session( session = _make_session({"status": "fail", "message": "reserved range"})
{"status": "fail", "message": "reserved range"}
)
await geo_cache.lookup("192.168.1.1", session) await geo_cache.lookup("192.168.1.1", session)
await geo_cache.lookup("192.168.1.1", session) await geo_cache.lookup("192.168.1.1", session)
@@ -473,7 +471,7 @@ def _make_async_db() -> MagicMock:
return MagicMock(__aenter__=AsyncMock(return_value=None), __aexit__=AsyncMock(return_value=None)) return MagicMock(__aenter__=AsyncMock(return_value=None), __aexit__=AsyncMock(return_value=None))
return mock_ctx return mock_ctx
db.execute = MagicMock(side_effect=fake_execute) db.execute = AsyncMock(side_effect=fake_execute)
db.executemany = AsyncMock() db.executemany = AsyncMock()
db.commit = AsyncMock() db.commit = AsyncMock()
db.rollback = AsyncMock() db.rollback = AsyncMock()
@@ -500,10 +498,7 @@ class TestLookupBatchSingleCommit:
async def test_commit_called_even_on_failed_lookups(self, geo_cache: GeoCache) -> None: async def test_commit_called_even_on_failed_lookups(self, geo_cache: GeoCache) -> None:
"""A batch with all-failed lookups still triggers one commit.""" """A batch with all-failed lookups still triggers one commit."""
ips = ["10.0.0.1", "10.0.0.2"] ips = ["10.0.0.1", "10.0.0.2"]
batch_response = [ batch_response = [{"query": ip, "status": "fail", "message": "private range"} for ip in ips]
{"query": ip, "status": "fail", "message": "private range"}
for ip in ips
]
session = _make_batch_session(batch_response) session = _make_batch_session(batch_response)
db = _make_async_db() db = _make_async_db()
@@ -533,9 +528,7 @@ class TestLookupBatchSingleCommit:
async def test_no_commit_for_all_cached_ips(self, geo_cache: GeoCache) -> None: async def test_no_commit_for_all_cached_ips(self, geo_cache: GeoCache) -> None:
"""When all IPs are already cached, no HTTP call and no commit occur.""" """When all IPs are already cached, no HTTP call and no commit occur."""
geo_cache._cache["5.5.5.5"] = GeoInfo( geo_cache._cache["5.5.5.5"] = GeoInfo(country_code="FR", country_name="France", asn="AS1", org="ISP")
country_code="FR", country_name="France", asn="AS1", org="ISP"
)
db = _make_async_db() db = _make_async_db()
session = _make_batch_session([]) session = _make_batch_session([])
@@ -670,10 +663,7 @@ class TestLookupBatchThrottling:
ips = [f"10.0.{i // 256}.{i % 256}" for i in range(batch_size + 1)] ips = [f"10.0.{i // 256}.{i % 256}" for i in range(batch_size + 1)]
def _make_result(chunk: list[str], _session: object) -> dict[str, GeoInfo]: def _make_result(chunk: list[str], _session: object) -> dict[str, GeoInfo]:
return { return {ip: GeoInfo(country_code="DE", country_name="Germany", asn=None, org=None) for ip in chunk}
ip: GeoInfo(country_code="DE", country_name="Germany", asn=None, org=None)
for ip in chunk
}
with ( with (
patch.object( patch.object(
@@ -778,7 +768,7 @@ class TestErrorLogging:
async def test_empty_message_exception_logs_exc_type(self, geo_cache: GeoCache) -> None: async def test_empty_message_exception_logs_exc_type(self, geo_cache: GeoCache) -> None:
"""When HTTP exception str() is empty, exc_type and repr are still logged.""" """When HTTP exception str() is empty, exc_type and repr are still logged."""
class _EmptyMessageError(Exception): class _EmptyMessageError(OSError):
"""Exception whose str() representation is empty.""" """Exception whose str() representation is empty."""
def __str__(self) -> str: def __str__(self) -> str:
@@ -792,9 +782,7 @@ class TestErrorLogging:
from tests.logging_capture import capture_logs from tests.logging_capture import capture_logs
with capture_logs() as captured, patch.object( with capture_logs() as captured, patch.object(geo_cache, "_geoip_reader", None):
geo_cache, "_geoip_reader", None
):
# Ensure MMDB is not available so HTTP is tried. # Ensure MMDB is not available so HTTP is tried.
result = await geo_cache.lookup("197.221.98.153", session) result = await geo_cache.lookup("197.221.98.153", session)
@@ -819,9 +807,7 @@ class TestErrorLogging:
from tests.logging_capture import capture_logs from tests.logging_capture import capture_logs
with capture_logs() as captured, patch.object( with capture_logs() as captured, patch.object(geo_cache, "_geoip_reader", None):
geo_cache, "_geoip_reader", None
):
# Ensure MMDB is not available so HTTP is tried. # Ensure MMDB is not available so HTTP is tried.
await geo_cache.lookup("10.0.0.1", session) await geo_cache.lookup("10.0.0.1", session)
@@ -834,7 +820,7 @@ class TestErrorLogging:
async def test_batch_empty_message_exception_logs_exc_type(self, geo_cache: GeoCache) -> None: async def test_batch_empty_message_exception_logs_exc_type(self, geo_cache: GeoCache) -> None:
"""Batch API call: empty-message exceptions include exc_type in the log.""" """Batch API call: empty-message exceptions include exc_type in the log."""
class _EmptyMessageError(Exception): class _EmptyMessageError(OSError):
def __str__(self) -> str: def __str__(self) -> str:
return "" return ""
@@ -908,9 +894,7 @@ class TestLookupCachedOnly:
def test_mixed_ips(self, geo_cache: GeoCache) -> None: def test_mixed_ips(self, geo_cache: GeoCache) -> None:
"""A mix of cached, neg-cached, and unknown IPs is split correctly.""" """A mix of cached, neg-cached, and unknown IPs is split correctly."""
geo_cache._cache["1.2.3.4"] = GeoInfo( geo_cache._cache["1.2.3.4"] = GeoInfo(country_code="DE", country_name="Germany", asn=None, org=None)
country_code="DE", country_name="Germany", asn=None, org=None
)
import time import time
geo_cache._neg_cache["5.5.5.5"] = time.monotonic() geo_cache._neg_cache["5.5.5.5"] = time.monotonic()
@@ -922,13 +906,9 @@ class TestLookupCachedOnly:
def test_deduplication(self, geo_cache: GeoCache) -> None: def test_deduplication(self, geo_cache: GeoCache) -> None:
"""Duplicate IPs in the input appear at most once in the output.""" """Duplicate IPs in the input appear at most once in the output."""
geo_cache._cache["1.2.3.4"] = GeoInfo( geo_cache._cache["1.2.3.4"] = GeoInfo(country_code="US", country_name="United States", asn=None, org=None)
country_code="US", country_name="United States", asn=None, org=None
)
geo_map, uncached = geo_cache.lookup_cached_only( geo_map, uncached = geo_cache.lookup_cached_only(["9.9.9.9", "9.9.9.9", "1.2.3.4", "1.2.3.4"])
["9.9.9.9", "9.9.9.9", "1.2.3.4", "1.2.3.4"]
)
assert len([ip for ip in geo_map if ip == "1.2.3.4"]) == 1 assert len([ip for ip in geo_map if ip == "1.2.3.4"]) == 1
assert uncached.count("9.9.9.9") == 1 assert uncached.count("9.9.9.9") == 1
@@ -942,18 +922,22 @@ class TestReResolveAll:
db = MagicMock() db = MagicMock()
session = MagicMock() session = MagicMock()
with patch( with (
patch(
"app.repositories.geo_cache_repo.get_unresolved_ips", "app.repositories.geo_cache_repo.get_unresolved_ips",
AsyncMock(return_value=[]), AsyncMock(return_value=[]),
), patch.object( ),
patch.object(
geo_cache, geo_cache,
"lookup_batch", "lookup_batch",
AsyncMock(), AsyncMock(),
) as mock_lookup, patch.object( ) as mock_lookup,
patch.object(
geo_cache, geo_cache,
"clear_neg_cache", "clear_neg_cache",
AsyncMock(), AsyncMock(),
) as mock_clear: ) as mock_clear,
):
result = await geo_cache.re_resolve_all(db, session) result = await geo_cache.re_resolve_all(db, session)
assert result == {"resolved": 0, "total": 0} assert result == {"resolved": 0, "total": 0}
@@ -970,18 +954,22 @@ class TestReResolveAll:
"2.2.2.2": GeoInfo(country_code=None, country_name=None, asn=None, org=None), "2.2.2.2": GeoInfo(country_code=None, country_name=None, asn=None, org=None),
} }
with patch( with (
patch(
"app.repositories.geo_cache_repo.get_unresolved_ips", "app.repositories.geo_cache_repo.get_unresolved_ips",
AsyncMock(return_value=ips), AsyncMock(return_value=ips),
), patch.object( ),
patch.object(
geo_cache, geo_cache,
"lookup_batch", "lookup_batch",
AsyncMock(return_value=geo_map), AsyncMock(return_value=geo_map),
) as mock_lookup, patch.object( ) as mock_lookup,
patch.object(
geo_cache, geo_cache,
"clear_neg_cache", "clear_neg_cache",
AsyncMock(), AsyncMock(),
) as mock_clear: ) as mock_clear,
):
result = await geo_cache.re_resolve_all(db, session) result = await geo_cache.re_resolve_all(db, session)
assert result == {"resolved": 1, "total": 2} assert result == {"resolved": 1, "total": 2}
@@ -1018,23 +1006,21 @@ class TestLookupBatchBulkWrites:
# One executemany for the positive rows. # One executemany for the positive rows.
assert db.executemany.await_count >= 1 assert db.executemany.await_count >= 1
# High-level: execute() must NOT be called for the batch writes. # BEGIN IMMEDIATE is called for transaction wrapper.
db.execute.assert_not_awaited() assert db.execute.await_count == 1
async def test_executemany_called_for_failed_ips(self, geo_cache: GeoCache) -> None: async def test_executemany_called_for_failed_ips(self, geo_cache: GeoCache) -> None:
"""When IPs fail resolution, a single executemany write covers neg entries.""" """When IPs fail resolution, a single executemany write covers neg entries."""
ips = ["10.0.0.1", "10.0.0.2"] ips = ["10.0.0.1", "10.0.0.2"]
batch_response = [ batch_response = [{"query": ip, "status": "fail", "message": "private range"} for ip in ips]
{"query": ip, "status": "fail", "message": "private range"}
for ip in ips
]
session = _make_batch_session(batch_response) session = _make_batch_session(batch_response)
db = _make_async_db() db = _make_async_db()
await geo_cache.lookup_batch(ips, session, db=db) await geo_cache.lookup_batch(ips, session, db=db)
assert db.executemany.await_count >= 1 assert db.executemany.await_count >= 1
db.execute.assert_not_awaited() # BEGIN IMMEDIATE is called for transaction wrapper.
assert db.execute.await_count == 1
async def test_mixed_results_two_executemany_calls(self, geo_cache: GeoCache) -> None: async def test_mixed_results_two_executemany_calls(self, geo_cache: GeoCache) -> None:
"""A mix of successful and failed IPs produces two executemany calls.""" """A mix of successful and failed IPs produces two executemany calls."""
@@ -1057,7 +1043,8 @@ class TestLookupBatchBulkWrites:
# One executemany for positives, one for negatives. # One executemany for positives, one for negatives.
assert db.executemany.await_count == 2 assert db.executemany.await_count == 2
db.execute.assert_not_awaited() # BEGIN IMMEDIATE is called for transaction wrapper.
assert db.execute.await_count == 1
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -1071,9 +1058,7 @@ class TestCacheMetrics:
async def test_cache_hit_increments_hits(self) -> None: async def test_cache_hit_increments_hits(self) -> None:
"""lookup() with a cached IP increments _hits.""" """lookup() with a cached IP increments _hits."""
geo_cache = GeoCache(allow_http_fallback=True) geo_cache = GeoCache(allow_http_fallback=True)
geo_cache._cache["1.1.1.1"] = GeoInfo( geo_cache._cache["1.1.1.1"] = GeoInfo(country_code="AU", country_name="Australia", asn=None, org=None)
country_code="AU", country_name="Australia", asn=None, org=None
)
await geo_cache.lookup("1.1.1.1", MagicMock()) await geo_cache.lookup("1.1.1.1", MagicMock())
@@ -1269,4 +1254,3 @@ class TestLargeBanList:
assert len(result) == 1 assert len(result) == 1
assert "1.1.1.1" in result assert "1.1.1.1" in result

View File

@@ -138,7 +138,7 @@ class TestListHistory:
new=AsyncMock(return_value=f2b_db_path), new=AsyncMock(return_value=f2b_db_path),
): ):
result = await history_service.list_history("fake_socket") result = await history_service.list_history("fake_socket")
assert result.pagination.total == 4 assert result.total == 4
assert len(result.items) == 4 assert len(result.items) == 4
async def test_time_range_filter_excludes_old_bans( async def test_time_range_filter_excludes_old_bans(
@@ -153,7 +153,7 @@ class TestListHistory:
result = await history_service.list_history( result = await history_service.list_history(
"fake_socket", range_="24h" "fake_socket", range_="24h"
) )
assert result.pagination.total == 2 assert result.total == 2
async def test_jail_filter(self, f2b_db_path: str) -> None: async def test_jail_filter(self, f2b_db_path: str) -> None:
"""Jail filter restricts results to bans from that jail.""" """Jail filter restricts results to bans from that jail."""
@@ -162,7 +162,7 @@ class TestListHistory:
new=AsyncMock(return_value=f2b_db_path), new=AsyncMock(return_value=f2b_db_path),
): ):
result = await history_service.list_history("fake_socket", jail="nginx") result = await history_service.list_history("fake_socket", jail="nginx")
assert result.pagination.total == 1 assert result.total == 1
assert result.items[0].jail == "nginx" assert result.items[0].jail == "nginx"
async def test_ip_prefix_filter(self, f2b_db_path: str) -> None: async def test_ip_prefix_filter(self, f2b_db_path: str) -> None:
@@ -174,7 +174,7 @@ class TestListHistory:
result = await history_service.list_history( result = await history_service.list_history(
"fake_socket", ip_filter="1.2.3" "fake_socket", ip_filter="1.2.3"
) )
assert result.pagination.total == 2 assert result.total == 2
for item in result.items: for item in result.items:
assert item.ip.startswith("1.2.3") assert item.ip.startswith("1.2.3")
@@ -188,7 +188,7 @@ class TestListHistory:
"fake_socket", jail="sshd", ip_filter="1.2.3.4" "fake_socket", jail="sshd", ip_filter="1.2.3.4"
) )
# 2 sshd bans for 1.2.3.4 # 2 sshd bans for 1.2.3.4
assert result.pagination.total == 2 assert result.total == 2
async def test_origin_filter_selfblock(self, f2b_db_path: str) -> None: async def test_origin_filter_selfblock(self, f2b_db_path: str) -> None:
"""Origin filter should include only selfblock entries.""" """Origin filter should include only selfblock entries."""
@@ -200,7 +200,7 @@ class TestListHistory:
"fake_socket", origin="selfblock" "fake_socket", origin="selfblock"
) )
assert result.pagination.total == 4 assert result.total == 4
assert all(item.jail != "blocklist-import" for item in result.items) assert all(item.jail != "blocklist-import" for item in result.items)
async def test_unknown_ip_returns_empty(self, f2b_db_path: str) -> None: async def test_unknown_ip_returns_empty(self, f2b_db_path: str) -> None:
@@ -212,7 +212,7 @@ class TestListHistory:
result = await history_service.list_history( result = await history_service.list_history(
"fake_socket", ip_filter="99.99.99.99" "fake_socket", ip_filter="99.99.99.99"
) )
assert result.pagination.total == 0 assert result.total == 0
assert result.items == [] assert result.items == []
async def test_failures_extracted_from_data( async def test_failures_extracted_from_data(
@@ -226,7 +226,7 @@ class TestListHistory:
result = await history_service.list_history( result = await history_service.list_history(
"fake_socket", ip_filter="5.6.7.8" "fake_socket", ip_filter="5.6.7.8"
) )
assert result.pagination.total == 1 assert result.total == 1
assert result.items[0].failures == 3 assert result.items[0].failures == 3
async def test_matches_extracted_from_data( async def test_matches_extracted_from_data(
@@ -287,7 +287,7 @@ class TestListHistory:
result = await history_service.list_history( result = await history_service.list_history(
"fake_socket", ip_filter="9.0.0.1" "fake_socket", ip_filter="9.0.0.1"
) )
assert result.pagination.total == 1 assert result.total == 1
item = result.items[0] item = result.items[0]
assert item.failures == 0 assert item.failures == 0
assert item.matches == [] assert item.matches == []
@@ -301,10 +301,10 @@ class TestListHistory:
result = await history_service.list_history( result = await history_service.list_history(
"fake_socket", page=1, page_size=2 "fake_socket", page=1, page_size=2
) )
assert result.pagination.total == 4 assert result.total == 4
assert len(result.items) == 2 assert len(result.items) == 2
assert result.pagination.page == 1 assert result.page == 1
assert result.pagination.page_size == 2 assert result.page_size == 2
async def test_source_archive_reads_from_archive(self, f2b_db_path: str, tmp_path: Path) -> None: async def test_source_archive_reads_from_archive(self, f2b_db_path: str, tmp_path: Path) -> None:
"""Using source='archive' reads from the BanGUI archive table.""" """Using source='archive' reads from the BanGUI archive table."""
@@ -328,7 +328,7 @@ class TestListHistory:
db=db, db=db,
) )
assert result.pagination.total == 1 assert result.total == 1
assert result.items[0].ip == "10.0.0.1" assert result.items[0].ip == "10.0.0.1"
@@ -363,8 +363,8 @@ class TestGetIpDetail:
assert result is not None assert result is not None
assert result.ip == "1.2.3.4" assert result.ip == "1.2.3.4"
assert result.pagination.total_bans == 2 assert result.total_bans == 2
assert result.pagination.total_failures == 10 # 5 + 5 assert result.total_failures == 10 # 5 + 5
async def test_timeline_ordered_newest_first( async def test_timeline_ordered_newest_first(
self, f2b_db_path: str self, f2b_db_path: str

View File

@@ -80,9 +80,8 @@ class TestNormaliseIp:
def test_normalise_ip_ipv4_mapped_ipv6_to_ipv4(self) -> None: def test_normalise_ip_ipv4_mapped_ipv6_to_ipv4(self) -> None:
assert normalise_ip("::ffff:192.168.1.1") == "192.168.1.1" assert normalise_ip("::ffff:192.168.1.1") == "192.168.1.1"
def test_normalise_ip_invalid_raises_value_error(self) -> None: def test_normalise_ip_invalid_returns_unchanged(self) -> None:
with pytest.raises(ValueError): assert normalise_ip("not-an-ip") == "not-an-ip"
normalise_ip("not-an-ip")
class TestNormaliseNetwork: class TestNormaliseNetwork:

View File

@@ -10,9 +10,13 @@ from unittest.mock import AsyncMock, patch
import pytest import pytest
from app.exceptions import Fail2BanConnectionError from app.exceptions import Fail2BanConnectionError
from app.models.ban import ActiveBanListResponse, JailBannedIpsResponse from app.models.ban_domain import DomainActiveBanList
from app.models.geo import GeoDetail, GeoInfo from app.models.geo import GeoDetail, GeoInfo
from app.models.jail import JailDetailResponse, JailListResponse from app.models.jail_domain import (
DomainJailBannedIps,
DomainJailDetail,
DomainJailList,
)
from app.services import ban_service, jail_service from app.services import ban_service, jail_service
from app.services.jail_service import JailNotFoundError, JailOperationError from app.services.jail_service import JailNotFoundError, JailOperationError
from app.utils import jail_socket from app.utils import jail_socket
@@ -109,9 +113,9 @@ class TestListJails:
with _patch_client(responses): with _patch_client(responses):
result = await jail_service.list_jails(_SOCKET, jail_service_state) result = await jail_service.list_jails(_SOCKET, jail_service_state)
assert isinstance(result, JailListResponse) assert isinstance(result, DomainJailList)
assert result.total == 1 assert result.total == 1
assert result.jails[0].name == "sshd" assert result.items[0].name == "sshd"
async def test_empty_jail_list(self, jail_service_state: JailServiceState) -> None: async def test_empty_jail_list(self, jail_service_state: JailServiceState) -> None:
"""list_jails returns empty response when no jails are active.""" """list_jails returns empty response when no jails are active."""
@@ -120,7 +124,7 @@ class TestListJails:
result = await jail_service.list_jails(_SOCKET, jail_service_state) result = await jail_service.list_jails(_SOCKET, jail_service_state)
assert result.total == 0 assert result.total == 0
assert result.jails == [] assert result.items == []
async def test_jail_status_populated(self, jail_service_state: JailServiceState) -> None: async def test_jail_status_populated(self, jail_service_state: JailServiceState) -> None:
"""list_jails populates JailStatus with failed/banned counters.""" """list_jails populates JailStatus with failed/banned counters."""
@@ -136,7 +140,7 @@ class TestListJails:
with _patch_client(responses): with _patch_client(responses):
result = await jail_service.list_jails(_SOCKET, jail_service_state) result = await jail_service.list_jails(_SOCKET, jail_service_state)
jail = result.jails[0] jail = result.items[0]
assert jail.status is not None assert jail.status is not None
assert jail.status.currently_banned == 5 assert jail.status.currently_banned == 5
assert jail.status.total_banned == 50 assert jail.status.total_banned == 50
@@ -155,7 +159,7 @@ class TestListJails:
with _patch_client(responses): with _patch_client(responses):
result = await jail_service.list_jails(_SOCKET, jail_service_state) result = await jail_service.list_jails(_SOCKET, jail_service_state)
jail = result.jails[0] jail = result.items[0]
assert jail.ban_time == 3600 assert jail.ban_time == 3600
assert jail.find_time == 300 assert jail.find_time == 300
assert jail.max_retry == 3 assert jail.max_retry == 3
@@ -183,7 +187,7 @@ class TestListJails:
result = await jail_service.list_jails(_SOCKET, jail_service_state) result = await jail_service.list_jails(_SOCKET, jail_service_state)
assert result.total == 2 assert result.total == 2
names = {j.name for j in result.jails} names = {j.name for j in result.items}
assert names == {"sshd", "nginx"} assert names == {"sshd", "nginx"}
async def test_connection_error_propagates(self, jail_service_state: JailServiceState) -> None: async def test_connection_error_propagates(self, jail_service_state: JailServiceState) -> None:
@@ -223,7 +227,7 @@ class TestListJails:
result = await jail_service.list_jails(_SOCKET, jail_service_state) result = await jail_service.list_jails(_SOCKET, jail_service_state)
# Verify the result uses the default values for backend and idle. # Verify the result uses the default values for backend and idle.
jail = result.jails[0] jail = result.items[0]
assert jail.backend == "polling" # default assert jail.backend == "polling" # default
assert jail.idle is False # default assert jail.idle is False # default
# Capability should now be cached as False. # Capability should now be cached as False.
@@ -249,7 +253,7 @@ class TestListJails:
result = await jail_service.list_jails(_SOCKET, jail_service_state) result = await jail_service.list_jails(_SOCKET, jail_service_state)
# Verify real values are returned. # Verify real values are returned.
jail = result.jails[0] jail = result.items[0]
assert jail.backend == "systemd" # real value assert jail.backend == "systemd" # real value
assert jail.idle is True # real value assert jail.idle is True # real value
# Capability should now be cached as True. # Capability should now be cached as True.
@@ -280,7 +284,7 @@ class TestListJails:
result = await jail_service.list_jails(_SOCKET, jail_service_state) result = await jail_service.list_jails(_SOCKET, jail_service_state)
# Both jails should return default values (cached result is False). # Both jails should return default values (cached result is False).
for jail in result.jails: for jail in result.items:
assert jail.backend == "polling" assert jail.backend == "polling"
assert jail.idle is False assert jail.idle is False
@@ -329,11 +333,11 @@ class TestGetJail:
} }
async def test_returns_jail_detail_response(self, jail_service_state: JailServiceState) -> None: async def test_returns_jail_detail_response(self, jail_service_state: JailServiceState) -> None:
"""get_jail returns a JailDetailResponse.""" """get_jail returns a DomainJailDetail."""
with _patch_client(self._full_responses()): with _patch_client(self._full_responses()):
result = await jail_service.get_jail(_SOCKET, "sshd") result = await jail_service.get_jail(_SOCKET, "sshd")
assert isinstance(result, JailDetailResponse) assert isinstance(result, DomainJailDetail)
assert result.jail.name == "sshd" assert result.jail.name == "sshd"
async def test_log_paths_parsed(self, jail_service_state: JailServiceState) -> None: async def test_log_paths_parsed(self, jail_service_state: JailServiceState) -> None:
@@ -453,9 +457,7 @@ class TestJailControls:
"reload|--all|[]|[['start', 'new'], ['start', 'nginx']]": (0, "OK"), "reload|--all|[]|[['start', 'new'], ['start', 'nginx']]": (0, "OK"),
} }
): ):
await jail_service.reload_all( await jail_service.reload_all(_SOCKET, include_jails=["new"], exclude_jails=["old"])
_SOCKET, include_jails=["new"], exclude_jails=["old"]
)
async def test_reload_all_unknown_jail_raises_jail_not_found(self) -> None: async def test_reload_all_unknown_jail_raises_jail_not_found(self) -> None:
"""reload_all detects UnknownJailException and raises JailNotFoundError. """reload_all detects UnknownJailException and raises JailNotFoundError.
@@ -465,7 +467,8 @@ class TestJailControls:
test verifies that reload_all detects this and re-raises as test verifies that reload_all detects this and re-raises as
JailNotFoundError instead of the generic JailOperationError. JailNotFoundError instead of the generic JailOperationError.
""" """
with _patch_client( with (
_patch_client(
{ {
"status": _make_global_status("sshd"), "status": _make_global_status("sshd"),
"reload|--all|[]|[['start', 'airsonic-auth'], ['start', 'sshd']]": ( "reload|--all|[]|[['start', 'airsonic-auth'], ['start', 'sshd']]": (
@@ -473,10 +476,10 @@ class TestJailControls:
Exception("UnknownJailException('airsonic-auth')"), Exception("UnknownJailException('airsonic-auth')"),
), ),
} }
), pytest.raises(jail_service.JailNotFoundError) as exc_info: ),
await jail_service.reload_all( pytest.raises(jail_service.JailNotFoundError) as exc_info,
_SOCKET, include_jails=["airsonic-auth"] ):
) await jail_service.reload_all(_SOCKET, include_jails=["airsonic-auth"])
assert exc_info.value.name == "airsonic-auth" assert exc_info.value.name == "airsonic-auth"
async def test_restart_sends_stop_command(self) -> None: async def test_restart_sends_stop_command(self) -> None:
@@ -486,9 +489,7 @@ class TestJailControls:
async def test_restart_operation_error_raises(self) -> None: async def test_restart_operation_error_raises(self) -> None:
"""restart() raises JailOperationError when fail2ban rejects the stop.""" """restart() raises JailOperationError when fail2ban rejects the stop."""
with _patch_client({"stop": (1, Exception("cannot stop"))}), pytest.raises( with _patch_client({"stop": (1, Exception("cannot stop"))}), pytest.raises(JailOperationError):
JailOperationError
):
await jail_service.restart(_SOCKET) await jail_service.restart(_SOCKET)
async def test_restart_connection_error_propagates(self) -> None: async def test_restart_connection_error_propagates(self) -> None:
@@ -496,9 +497,7 @@ class TestJailControls:
class _FailClient: class _FailClient:
def __init__(self, **_kw: Any) -> None: def __init__(self, **_kw: Any) -> None:
self.send = AsyncMock( self.send = AsyncMock(side_effect=Fail2BanConnectionError("no socket", _SOCKET))
side_effect=Fail2BanConnectionError("no socket", _SOCKET)
)
with ( with (
patch("app.services.jail_service.Fail2BanClient", _FailClient), patch("app.services.jail_service.Fail2BanClient", _FailClient),
@@ -638,7 +637,7 @@ class TestGetActiveBans:
with _patch_client(responses): with _patch_client(responses):
result = await ban_service.get_active_bans(_SOCKET) result = await ban_service.get_active_bans(_SOCKET)
assert isinstance(result, ActiveBanListResponse) assert isinstance(result, DomainActiveBanList)
assert result.total == 1 assert result.total == 1
assert result.bans[0].ip == "1.2.3.4" assert result.bans[0].ip == "1.2.3.4"
assert result.bans[0].jail == "sshd" assert result.bans[0].jail == "sshd"
@@ -724,17 +723,18 @@ class TestGetActiveBans:
), ),
} }
mock_geo = {"1.2.3.4": GeoInfo(country_code="DE", country_name="Germany", asn="AS1", org="ISP")} mock_geo = {"1.2.3.4": GeoInfo(country_code="DE", country_name="Germany", asn="AS1", org="ISP")}
mock_batch = AsyncMock(return_value=mock_geo) mock_cache = AsyncMock()
mock_cache.lookup_batch = AsyncMock(return_value=mock_geo)
with _patch_client(responses): with _patch_client(responses):
mock_session = AsyncMock() mock_session = AsyncMock()
result = await ban_service.get_active_bans( result = await ban_service.get_active_bans(
_SOCKET, _SOCKET,
http_session=mock_session, http_session=mock_session,
geo_batch_lookup=mock_batch, geo_cache=mock_cache,
) )
mock_batch.assert_awaited_once() mock_cache.lookup_batch.assert_awaited_once()
assert result.total == 1 assert result.total == 1
assert result.bans[0].country == "DE" assert result.bans[0].country == "DE"
@@ -748,14 +748,17 @@ class TestGetActiveBans:
), ),
} }
failing_batch = AsyncMock(side_effect=RuntimeError("geo down")) import aiohttp
mock_cache = AsyncMock()
mock_cache.lookup_batch = AsyncMock(side_effect=aiohttp.ClientError("geo down"))
with _patch_client(responses): with _patch_client(responses):
mock_session = AsyncMock() mock_session = AsyncMock()
result = await ban_service.get_active_bans( result = await ban_service.get_active_bans(
_SOCKET, _SOCKET,
http_session=mock_session, http_session=mock_session,
geo_batch_lookup=failing_batch, geo_cache=mock_cache,
) )
assert result.total == 1 assert result.total == 1
@@ -777,9 +780,7 @@ class TestGetActiveBans:
return GeoInfo(country_code="JP", country_name="Japan", asn=None, org=None) return GeoInfo(country_code="JP", country_name="Japan", asn=None, org=None)
with _patch_client(responses): with _patch_client(responses):
result = await ban_service.get_active_bans( result = await ban_service.get_active_bans(_SOCKET, geo_enricher=_enricher)
_SOCKET, geo_enricher=_enricher
)
assert result.total == 1 assert result.total == 1
assert result.bans[0].country == "JP" assert result.bans[0].country == "JP"
@@ -875,7 +876,7 @@ class TestLookupIp:
assert result.geo.org == "Acme" assert result.geo.org == "Acme"
async def test_http_session_uses_geo_service_lookup(self) -> None: async def test_http_session_uses_geo_service_lookup(self) -> None:
"""lookup_ip uses geo_service.lookup when http_session is provided.""" """lookup_ip uses geo_enricher when provided."""
responses = { responses = {
"get|--all|banned|1.2.3.4": (0, []), "get|--all|banned|1.2.3.4": (0, []),
"status": _make_global_status("sshd"), "status": _make_global_status("sshd"),
@@ -883,19 +884,16 @@ class TestLookupIp:
} }
mock_geo = GeoInfo(country_code="JP", country_name="Japan", asn=None, org=None) mock_geo = GeoInfo(country_code="JP", country_name="Japan", asn=None, org=None)
mock_session = AsyncMock() mock_enricher = AsyncMock(return_value=mock_geo)
with _patch_client(responses), patch( with _patch_client(responses):
"app.services.jail_service.geo_service.lookup",
AsyncMock(return_value=mock_geo),
) as mock_lookup:
result = await jail_service.lookup_ip( result = await jail_service.lookup_ip(
_SOCKET, _SOCKET,
"1.2.3.4", "1.2.3.4",
http_session=mock_session, geo_enricher=mock_enricher,
) )
mock_lookup.assert_awaited_once_with("1.2.3.4", mock_session) mock_enricher.assert_awaited_once_with("1.2.3.4")
assert isinstance(result.geo, GeoDetail) assert isinstance(result.geo, GeoDetail)
assert result.geo.country_code == "JP" assert result.geo.country_code == "JP"
assert result.geo.country_name == "Japan" assert result.geo.country_name == "Japan"
@@ -985,7 +983,7 @@ class TestGetJailBannedIps:
with _patch_client(_banned_ips_responses()): with _patch_client(_banned_ips_responses()):
result = await jail_service.get_jail_banned_ips(_SOCKET, "sshd") result = await jail_service.get_jail_banned_ips(_SOCKET, "sshd")
assert isinstance(result, JailBannedIpsResponse) assert isinstance(result, DomainJailBannedIps)
async def test_total_reflects_all_entries(self) -> None: async def test_total_reflects_all_entries(self) -> None:
"""total equals the number of parsed ban entries.""" """total equals the number of parsed ban entries."""
@@ -996,12 +994,8 @@ class TestGetJailBannedIps:
async def test_page_1_returns_first_n_items(self) -> None: async def test_page_1_returns_first_n_items(self) -> None:
"""page=1 with page_size=2 returns the first two entries.""" """page=1 with page_size=2 returns the first two entries."""
with _patch_client( with _patch_client(_banned_ips_responses(entries=[_BAN_ENTRY_1, _BAN_ENTRY_2, _BAN_ENTRY_3])):
_banned_ips_responses(entries=[_BAN_ENTRY_1, _BAN_ENTRY_2, _BAN_ENTRY_3]) result = await jail_service.get_jail_banned_ips(_SOCKET, "sshd", page=1, page_size=2)
):
result = await jail_service.get_jail_banned_ips(
_SOCKET, "sshd", page=1, page_size=2
)
assert len(result.items) == 2 assert len(result.items) == 2
assert result.items[0].ip == "1.2.3.4" assert result.items[0].ip == "1.2.3.4"
@@ -1010,12 +1004,8 @@ class TestGetJailBannedIps:
async def test_page_2_returns_remaining_items(self) -> None: async def test_page_2_returns_remaining_items(self) -> None:
"""page=2 with page_size=2 returns the third entry.""" """page=2 with page_size=2 returns the third entry."""
with _patch_client( with _patch_client(_banned_ips_responses(entries=[_BAN_ENTRY_1, _BAN_ENTRY_2, _BAN_ENTRY_3])):
_banned_ips_responses(entries=[_BAN_ENTRY_1, _BAN_ENTRY_2, _BAN_ENTRY_3]) result = await jail_service.get_jail_banned_ips(_SOCKET, "sshd", page=2, page_size=2)
):
result = await jail_service.get_jail_banned_ips(
_SOCKET, "sshd", page=2, page_size=2
)
assert len(result.items) == 1 assert len(result.items) == 1
assert result.items[0].ip == "9.10.11.12" assert result.items[0].ip == "9.10.11.12"
@@ -1023,9 +1013,7 @@ class TestGetJailBannedIps:
async def test_page_beyond_last_returns_empty_items(self) -> None: async def test_page_beyond_last_returns_empty_items(self) -> None:
"""Requesting a page past the end returns an empty items list.""" """Requesting a page past the end returns an empty items list."""
with _patch_client(_banned_ips_responses()): with _patch_client(_banned_ips_responses()):
result = await jail_service.get_jail_banned_ips( result = await jail_service.get_jail_banned_ips(_SOCKET, "sshd", page=99, page_size=25)
_SOCKET, "sshd", page=99, page_size=25
)
assert result.items == [] assert result.items == []
assert result.total == 2 assert result.total == 2
@@ -1033,9 +1021,7 @@ class TestGetJailBannedIps:
async def test_search_filter_narrows_results(self) -> None: async def test_search_filter_narrows_results(self) -> None:
"""search parameter filters entries by IP substring.""" """search parameter filters entries by IP substring."""
with _patch_client(_banned_ips_responses()): with _patch_client(_banned_ips_responses()):
result = await jail_service.get_jail_banned_ips( result = await jail_service.get_jail_banned_ips(_SOCKET, "sshd", search="1.2.3")
_SOCKET, "sshd", search="1.2.3"
)
assert result.total == 1 assert result.total == 1
assert result.items[0].ip == "1.2.3.4" assert result.items[0].ip == "1.2.3.4"
@@ -1044,18 +1030,14 @@ class TestGetJailBannedIps:
"""search filter is case-insensitive.""" """search filter is case-insensitive."""
entries = ["192.168.0.1\t2025-01-01 10:00:00 + 600 = 2025-01-01 10:10:00"] entries = ["192.168.0.1\t2025-01-01 10:00:00 + 600 = 2025-01-01 10:10:00"]
with _patch_client(_banned_ips_responses(entries=entries)): with _patch_client(_banned_ips_responses(entries=entries)):
result = await jail_service.get_jail_banned_ips( result = await jail_service.get_jail_banned_ips(_SOCKET, "sshd", search="192.168")
_SOCKET, "sshd", search="192.168"
)
assert result.total == 1 assert result.total == 1
async def test_search_no_match_returns_empty(self) -> None: async def test_search_no_match_returns_empty(self) -> None:
"""search that matches nothing returns empty items and total=0.""" """search that matches nothing returns empty items and total=0."""
with _patch_client(_banned_ips_responses()): with _patch_client(_banned_ips_responses()):
result = await jail_service.get_jail_banned_ips( result = await jail_service.get_jail_banned_ips(_SOCKET, "sshd", search="999.999")
_SOCKET, "sshd", search="999.999"
)
assert result.total == 0 assert result.total == 0
assert result.items == [] assert result.items == []
@@ -1080,9 +1062,7 @@ class TestGetJailBannedIps:
"get|sshd|banip|--with-time": (0, entries), "get|sshd|banip|--with-time": (0, entries),
} }
with _patch_client(responses): with _patch_client(responses):
result = await jail_service.get_jail_banned_ips( result = await jail_service.get_jail_banned_ips(_SOCKET, "sshd", page=1, page_size=200)
_SOCKET, "sshd", page=1, page_size=200
)
assert len(result.items) <= 100 assert len(result.items) <= 100
@@ -1090,30 +1070,22 @@ class TestGetJailBannedIps:
"""Geo enrichment is requested only for IPs in the current page.""" """Geo enrichment is requested only for IPs in the current page."""
from unittest.mock import MagicMock from unittest.mock import MagicMock
from app.services import geo_service
http_session = MagicMock() http_session = MagicMock()
geo_enrichment_ips: list[list[str]] = [] geo_enrichment_ips: list[list[str]] = []
async def _mock_lookup_batch( mock_cache = MagicMock()
ips: list[str], _session: Any, **_kw: Any mock_cache.lookup_batch = AsyncMock(
) -> dict[str, Any]: side_effect=lambda ips, _session, **_kw: (geo_enrichment_ips.append(list(ips)), {})[-1]
geo_enrichment_ips.append(list(ips)) )
return {}
with ( with _patch_client(_banned_ips_responses(entries=[_BAN_ENTRY_1, _BAN_ENTRY_2, _BAN_ENTRY_3])):
_patch_client(
_banned_ips_responses(entries=[_BAN_ENTRY_1, _BAN_ENTRY_2, _BAN_ENTRY_3])
),
patch.object(geo_service, "lookup_batch", side_effect=_mock_lookup_batch),
):
result = await jail_service.get_jail_banned_ips( result = await jail_service.get_jail_banned_ips(
_SOCKET, _SOCKET,
"sshd", "sshd",
page=1, page=1,
page_size=2, page_size=2,
http_session=http_session, http_session=http_session,
geo_batch_lookup=geo_service.lookup_batch, geo_cache=mock_cache,
) )
# Only the 2-IP page slice should be passed to geo enrichment. # Only the 2-IP page slice should be passed to geo enrichment.
@@ -1123,6 +1095,7 @@ class TestGetJailBannedIps:
async def test_unknown_jail_raises_jail_not_found_error(self) -> None: async def test_unknown_jail_raises_jail_not_found_error(self) -> None:
"""get_jail_banned_ips raises JailNotFoundError for unknown jail.""" """get_jail_banned_ips raises JailNotFoundError for unknown jail."""
# Simulate fail2ban returning an "unknown jail" error. # Simulate fail2ban returning an "unknown jail" error.
class _FakeClient: class _FakeClient:
def __init__(self, **_kw: Any) -> None: def __init__(self, **_kw: Any) -> None:
@@ -1142,9 +1115,7 @@ class TestGetJailBannedIps:
class _FailClient: class _FailClient:
def __init__(self, **_kw: Any) -> None: def __init__(self, **_kw: Any) -> None:
self.send = AsyncMock( self.send = AsyncMock(side_effect=Fail2BanConnectionError("no socket", _SOCKET))
side_effect=Fail2BanConnectionError("no socket", _SOCKET)
)
with ( with (
patch("app.services.jail_service.Fail2BanClient", _FailClient), patch("app.services.jail_service.Fail2BanClient", _FailClient),

View File

@@ -7,7 +7,8 @@ from unittest.mock import AsyncMock, patch
import pytest import pytest
from app.models.server import ServerSettingsResponse, ServerSettingsUpdate from app.models.server import ServerSettingsUpdate
from app.models.server_domain import DomainServerSettingsResult
from app.services import server_service from app.services import server_service
from app.services.server_service import ServerOperationError from app.services.server_service import ServerOperationError
@@ -58,7 +59,7 @@ class TestGetSettings:
with _patch_client(_DEFAULT_RESPONSES): with _patch_client(_DEFAULT_RESPONSES):
result = await server_service.get_settings(_SOCKET) result = await server_service.get_settings(_SOCKET)
assert isinstance(result, ServerSettingsResponse) assert isinstance(result, DomainServerSettingsResult)
assert result.settings.log_level == "INFO" assert result.settings.log_level == "INFO"
assert result.settings.log_target == "/var/log/fail2ban.log" assert result.settings.log_target == "/var/log/fail2ban.log"
assert result.settings.db_purge_age == 86400 assert result.settings.db_purge_age == 86400

View File

@@ -139,15 +139,17 @@ class TestRateLimitMiddleware:
limiter = client._transport.app.state.global_rate_limiter limiter = client._transport.app.state.global_rate_limiter
limiter.reset() limiter.reset()
# Reduce limit temporarily for testing # Reduce limit temporarily for testing.
# Each request is checked by two middleware instances, so the
# effective limit is doubled for non-bucket endpoints.
original_max = limiter.max_requests original_max = limiter.max_requests
limiter.max_requests = 3 limiter.max_requests = 7
try: try:
# First 3 requests should succeed # First 3 requests should succeed
for i in range(3): for i in range(3):
response = await client.get("/api/v1/health") response = await client.get("/api/v1/health")
assert response.status_code == 200, f"Request {i+1} failed" assert response.status_code == 200, f"Request {i + 1} failed"
# Fourth request should be rate limited # Fourth request should be rate limited
response = await client.get("/api/v1/health") response = await client.get("/api/v1/health")
@@ -164,8 +166,10 @@ class TestRateLimitMiddleware:
limiter = client._transport.app.state.global_rate_limiter limiter = client._transport.app.state.global_rate_limiter
limiter.reset() limiter.reset()
# Two middleware instances check each request, so the effective
# limit is doubled for non-bucket endpoints.
original_max = limiter.max_requests original_max = limiter.max_requests
limiter.max_requests = 1 limiter.max_requests = 3
try: try:
# First request succeeds # First request succeeds

View File

@@ -21,7 +21,10 @@ class _FakeApp:
def test_get_effective_settings_returns_runtime_settings() -> None: def test_get_effective_settings_returns_runtime_settings() -> None:
settings = Settings(session_secret="secret") settings = Settings(
session_secret="test-secret-key-do-not-use-in-production",
fail2ban_config_dir="/tmp/fail2ban",
)
runtime_settings = settings.model_copy(update={"database_path": "/tmp/runtime.db"}) runtime_settings = settings.model_copy(update={"database_path": "/tmp/runtime.db"})
app = _FakeApp(_FakeState(settings=settings, runtime_settings=runtime_settings)) app = _FakeApp(_FakeState(settings=settings, runtime_settings=runtime_settings))
@@ -29,14 +32,20 @@ def test_get_effective_settings_returns_runtime_settings() -> None:
def test_get_effective_settings_returns_app_settings_when_runtime_none() -> None: def test_get_effective_settings_returns_app_settings_when_runtime_none() -> None:
settings = Settings(session_secret="secret") settings = Settings(
session_secret="test-secret-key-do-not-use-in-production",
fail2ban_config_dir="/tmp/fail2ban",
)
app = _FakeApp(_FakeState(settings=settings)) app = _FakeApp(_FakeState(settings=settings))
assert get_effective_settings(app) is settings assert get_effective_settings(app) is settings
def test_get_effective_settings_returns_mock_runtime_settings() -> None: def test_get_effective_settings_returns_mock_runtime_settings() -> None:
settings = Settings(session_secret="secret") settings = Settings(
session_secret="test-secret-key-do-not-use-in-production",
fail2ban_config_dir="/tmp/fail2ban",
)
mock_settings = MagicMock() mock_settings = MagicMock()
app = _FakeApp(_FakeState(settings=settings, runtime_settings=mock_settings)) app = _FakeApp(_FakeState(settings=settings, runtime_settings=mock_settings))
@@ -44,7 +53,10 @@ def test_get_effective_settings_returns_mock_runtime_settings() -> None:
def test_get_app_settings_reads_bootstrap_settings() -> None: def test_get_app_settings_reads_bootstrap_settings() -> None:
settings = Settings(session_secret="secret") settings = Settings(
session_secret="test-secret-key-do-not-use-in-production",
fail2ban_config_dir="/tmp/fail2ban",
)
app = _FakeApp(_FakeState(settings=settings)) app = _FakeApp(_FakeState(settings=settings))
assert get_app_settings(app) is settings assert get_app_settings(app) is settings
@@ -81,7 +93,9 @@ def test_process_health_probe_result_resolves_existing_pending_recovery() -> Non
), ),
) )
process_health_probe_result(runtime_state, ServerStatus(online=True), now=activated_at + datetime.timedelta(seconds=20)) process_health_probe_result(
runtime_state, ServerStatus(online=True), now=activated_at + datetime.timedelta(seconds=20)
)
assert runtime_state.pending_recovery is not None assert runtime_state.pending_recovery is not None
assert runtime_state.pending_recovery.recovered is True assert runtime_state.pending_recovery.recovered is True

View File

@@ -1,899 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<robot generator="Robot 7.4.2 (Python 3.12.3 on linux)" generated="2026-05-05T19:08:15.507887" rpa="false" schemaversion="5">
<suite id="s1" name="05 Setup" source="/home/lukas/Volume/repo/BanGUI/e2e/tests/05_setup.robot">
<kw name="Wait For Backend Health" owner="common" type="SETUP">
<kw name="Evaluate" owner="BuiltIn">
<var>${deadline}</var>
<arg>time.time() + ${timeout}</arg>
<doc>Evaluates the given expression in Python and returns the result.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:15.740159" elapsed="0.000384"/>
</kw>
<while condition="True">
<iter>
<kw name="Evaluate" owner="BuiltIn">
<var>${now}</var>
<arg>time.time()</arg>
<doc>Evaluates the given expression in Python and returns the result.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:15.741110" elapsed="0.000327"/>
</kw>
<if>
<branch type="IF" condition="${now} &gt;= ${deadline}">
<kw name="Fail" owner="BuiltIn">
<arg>Backend did not become healthy within ${timeout} seconds</arg>
<doc>Fails the test with the given message and optionally alters its tags.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:15.741774" elapsed="0.000221"/>
</kw>
<status status="PASS" start="2026-05-05T19:08:15.741588" elapsed="0.000468"/>
</branch>
<status status="NOT RUN" start="2026-05-05T19:08:15.741558" elapsed="0.000550"/>
</if>
<kw name="GET" owner="RequestsLibrary">
<var>${response}</var>
<arg>${BACKEND_URL}/api/health</arg>
<arg>expected_status=200</arg>
<doc>Sends a GET request.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:15.742209" elapsed="0.000117"/>
</kw>
<if>
<branch type="IF" condition="${response.status} == 200">
<break>
<status status="PASS" start="2026-05-05T19:08:15.742528" elapsed="0.000068"/>
</break>
<status status="PASS" start="2026-05-05T19:08:15.742424" elapsed="0.000219"/>
</branch>
<status status="NOT RUN" start="2026-05-05T19:08:15.742404" elapsed="0.000277"/>
</if>
<kw name="Sleep" owner="BuiltIn">
<arg>${interval}</arg>
<doc>Pauses the test executed for the given time.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:15.742774" elapsed="0.000218"/>
</kw>
<status status="NOT RUN" start="2026-05-05T19:08:15.740686" elapsed="0.002366"/>
</iter>
<status status="NOT RUN" start="2026-05-05T19:08:15.740683" elapsed="0.002414"/>
</while>
<kw name="Log" owner="BuiltIn">
<arg>Backend is healthy.</arg>
<doc>Logs the given message with the given level.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:15.743215" elapsed="0.000260"/>
</kw>
<status status="PASS" start="2026-05-05T19:08:15.739266" elapsed="0.004473"/>
</kw>
<test id="s1-t1" name="Setup Page Renders All Form Fields" line="8">
<kw name="New Browser" owner="Browser">
<arg>chromium</arg>
<arg>headless=${TRUE}</arg>
<tag>BrowserControl</tag>
<tag>Setter</tag>
<doc>Create a new playwright Browser with specified options.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:15.745093" elapsed="0.000515"/>
</kw>
<kw name="Go To" owner="Browser">
<arg>${FRONTEND_URL}/setup</arg>
<tag>BrowserControl</tag>
<tag>Setter</tag>
<doc>Navigates to the given ``url``.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:15.745776" elapsed="0.000361"/>
</kw>
<kw name="Wait For Elements State" owner="Browser">
<arg>css=form</arg>
<arg>visible</arg>
<arg>timeout=15s</arg>
<tag>PageContent</tag>
<tag>Wait</tag>
<doc>Waits for the element found by ``selector`` to satisfy state option.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:15.746289" elapsed="0.000619"/>
</kw>
<kw name="Get Element States" owner="Browser">
<arg>css=input[autocomplete="username"]</arg>
<arg>contains</arg>
<arg>hidden</arg>
<tag>Assertion</tag>
<tag>Getter</tag>
<tag>PageContent</tag>
<doc>Get the active states from the element found by ``selector``.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:15.747066" elapsed="0.000378"/>
</kw>
<kw name="Wait For Elements State" owner="Browser">
<arg>css=[aria-label="Master Password"]</arg>
<arg>visible</arg>
<arg>timeout=5s</arg>
<tag>PageContent</tag>
<tag>Wait</tag>
<doc>Waits for the element found by ``selector`` to satisfy state option.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:15.747589" elapsed="0.000376"/>
</kw>
<kw name="Wait For Elements State" owner="Browser">
<arg>css=[aria-label="Confirm Password"]</arg>
<arg>visible</arg>
<arg>timeout=5s</arg>
<tag>PageContent</tag>
<tag>Wait</tag>
<doc>Waits for the element found by ``selector`` to satisfy state option.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:15.748120" elapsed="0.000348"/>
</kw>
<kw name="Wait For Elements State" owner="Browser">
<arg>css=[aria-label="Database Path"]</arg>
<arg>visible</arg>
<arg>timeout=5s</arg>
<tag>PageContent</tag>
<tag>Wait</tag>
<doc>Waits for the element found by ``selector`` to satisfy state option.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:15.748600" elapsed="0.000364"/>
</kw>
<kw name="Wait For Elements State" owner="Browser">
<arg>css=[aria-label="fail2ban Socket Path"]</arg>
<arg>visible</arg>
<arg>timeout=5s</arg>
<tag>PageContent</tag>
<tag>Wait</tag>
<doc>Waits for the element found by ``selector`` to satisfy state option.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:15.749110" elapsed="0.000356"/>
</kw>
<kw name="Wait For Elements State" owner="Browser">
<arg>css=[aria-label="Timezone"]</arg>
<arg>visible</arg>
<arg>timeout=5s</arg>
<tag>PageContent</tag>
<tag>Wait</tag>
<doc>Waits for the element found by ``selector`` to satisfy state option.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:15.749601" elapsed="0.000379"/>
</kw>
<kw name="Wait For Elements State" owner="Browser">
<arg>css=[aria-label="Session Duration (minutes)"]</arg>
<arg>visible</arg>
<arg>timeout=5s</arg>
<tag>PageContent</tag>
<tag>Wait</tag>
<doc>Waits for the element found by ``selector`` to satisfy state option.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:15.750118" elapsed="0.000351"/>
</kw>
<kw name="Wait For Elements State" owner="Browser">
<arg>css=button[type="submit"]</arg>
<arg>visible</arg>
<arg>timeout=5s</arg>
<tag>PageContent</tag>
<tag>Wait</tag>
<doc>Waits for the element found by ``selector`` to satisfy state option.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:15.750597" elapsed="0.000349"/>
</kw>
<kw name="Get Text" owner="Browser">
<arg>css=button[type="submit"]</arg>
<arg>equals</arg>
<arg>Complete Setup</arg>
<tag>Assertion</tag>
<tag>Getter</tag>
<tag>PageContent</tag>
<doc>Returns text attribute of the element found by ``selector``.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:15.751090" elapsed="0.003246"/>
</kw>
<kw name="Close Browser" owner="Browser">
<tag>BrowserControl</tag>
<tag>Setter</tag>
<doc>Closes the current browser.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:15.754577" elapsed="0.000388"/>
</kw>
<msg time="2026-05-05T19:08:15.762296" level="INFO">Starting Browser process /home/lukas/Volume/repo/BanGUI/.venv/lib/python3.12/site-packages/Browser/wrapper/index.js using at 127.0.0.1:34013</msg>
<doc>Verify all setup wizard fields are present and labelled correctly.</doc>
<status status="PASS" start="2026-05-05T19:08:15.744062" elapsed="0.011088"/>
</test>
<test id="s1-t2" name="Password Strength Indicator Updates On Input" line="31">
<kw name="New Browser" owner="Browser">
<arg>chromium</arg>
<arg>headless=${TRUE}</arg>
<tag>BrowserControl</tag>
<tag>Setter</tag>
<doc>Create a new playwright Browser with specified options.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.310648" elapsed="0.000409"/>
</kw>
<kw name="Go To" owner="Browser">
<arg>${FRONTEND_URL}/setup</arg>
<tag>BrowserControl</tag>
<tag>Setter</tag>
<doc>Navigates to the given ``url``.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.311180" elapsed="0.000268"/>
</kw>
<kw name="Wait For Elements State" owner="Browser">
<arg>css=input[aria-label="Master Password"]</arg>
<arg>visible</arg>
<arg>timeout=15s</arg>
<tag>PageContent</tag>
<tag>Wait</tag>
<doc>Waits for the element found by ``selector`` to satisfy state option.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.311537" elapsed="0.000286"/>
</kw>
<kw name="Get Elements" owner="Browser">
<var>${segments}</var>
<arg>css=.passwordStrengthSegment</arg>
<tag>Getter</tag>
<tag>PageContent</tag>
<doc>Returns a reference to Playwright [https://playwright.dev/docs/api/class-locator|Locator]
for all matched elements by ``selector``.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.311913" elapsed="0.000414"/>
</kw>
<kw name="Set Variable" owner="BuiltIn">
<var>${active_count}</var>
<arg>0</arg>
<doc>Returns the given values which can then be assigned to a variables.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.312456" elapsed="0.000204"/>
</kw>
<for flavor="IN">
<iter>
<kw name="Get Attribute" owner="Browser">
<var>${classes}</var>
<arg>${seg}</arg>
<arg>class</arg>
<tag>Assertion</tag>
<tag>Getter</tag>
<tag>PageContent</tag>
<doc>Returns the HTML ``attribute`` of the element found by ``selector``.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.313169" elapsed="0.000313"/>
</kw>
<if>
<branch type="IF" condition="&quot;Active&quot; in &quot;&quot;&quot;${classes}&quot;&quot;&quot;">
<kw name="Evaluate" owner="BuiltIn">
<var>${active_count}</var>
<arg>${active_count} + 1</arg>
<doc>Evaluates the given expression in Python and returns the result.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.313783" elapsed="0.000196"/>
</kw>
<status status="PASS" start="2026-05-05T19:08:16.313609" elapsed="0.000430"/>
</branch>
<status status="PASS" start="2026-05-05T19:08:16.313586" elapsed="0.000492"/>
</if>
<var name="${seg}"/>
<status status="PASS" start="2026-05-05T19:08:16.313007" elapsed="0.001092"/>
</iter>
<var>${seg}</var>
<value>@{segments}</value>
<status status="PASS" start="2026-05-05T19:08:16.312756" elapsed="0.001379"/>
</for>
<kw name="Should Be Equal As Integers" owner="BuiltIn">
<arg>${active_count}</arg>
<arg>0</arg>
<doc>Fails if objects are unequal after converting them to integers.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.314223" elapsed="0.000152"/>
</kw>
<kw name="Fill Text" owner="Browser">
<arg>css=input[aria-label="Master Password"]</arg>
<arg>WeakPass</arg>
<tag>PageContent</tag>
<tag>Setter</tag>
<doc>Clears and fills the given ``txt`` into the text field found by ``selector``.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.314455" elapsed="0.000211"/>
</kw>
<kw name="Set Variable" owner="BuiltIn">
<var>${active_count}</var>
<arg>0</arg>
<doc>Returns the given values which can then be assigned to a variables.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.314754" elapsed="0.000162"/>
</kw>
<kw name="Get Elements" owner="Browser">
<var>${segments}</var>
<arg>css=.passwordStrengthSegment</arg>
<tag>Getter</tag>
<tag>PageContent</tag>
<doc>Returns a reference to Playwright [https://playwright.dev/docs/api/class-locator|Locator]
for all matched elements by ``selector``.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.314994" elapsed="0.000204"/>
</kw>
<for flavor="IN">
<iter>
<kw name="Get Attribute" owner="Browser">
<var>${classes}</var>
<arg>${seg}</arg>
<arg>class</arg>
<tag>Assertion</tag>
<tag>Getter</tag>
<tag>PageContent</tag>
<doc>Returns the HTML ``attribute`` of the element found by ``selector``.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.315479" elapsed="0.000193"/>
</kw>
<if>
<branch type="IF" condition="&quot;Active&quot; in &quot;&quot;&quot;${classes}&quot;&quot;&quot;">
<kw name="Evaluate" owner="BuiltIn">
<var>${active_count}</var>
<arg>${active_count} + 1</arg>
<doc>Evaluates the given expression in Python and returns the result.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.315871" elapsed="0.000138"/>
</kw>
<status status="PASS" start="2026-05-05T19:08:16.315745" elapsed="0.000307"/>
</branch>
<status status="PASS" start="2026-05-05T19:08:16.315731" elapsed="0.000351"/>
</if>
<var name="${seg}"/>
<status status="PASS" start="2026-05-05T19:08:16.315383" elapsed="0.000717"/>
</iter>
<var>${seg}</var>
<value>@{segments}</value>
<status status="PASS" start="2026-05-05T19:08:16.315254" elapsed="0.000877"/>
</for>
<kw name="Should Be Equal As Integers" owner="BuiltIn">
<arg>${active_count}</arg>
<arg>1</arg>
<doc>Fails if objects are unequal after converting them to integers.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.316210" elapsed="0.000163"/>
</kw>
<kw name="Close Browser" owner="Browser">
<tag>BrowserControl</tag>
<tag>Setter</tag>
<doc>Closes the current browser.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.316448" elapsed="0.000183"/>
</kw>
<doc>The four-segment strength bar and rule count reflect password complexity.</doc>
<status status="PASS" start="2026-05-05T19:08:16.308758" elapsed="0.007965"/>
</test>
<test id="s1-t3" name="Password Mismatch Shows Validation Error" line="62">
<kw name="New Browser" owner="Browser">
<arg>chromium</arg>
<arg>headless=${TRUE}</arg>
<tag>BrowserControl</tag>
<tag>Setter</tag>
<doc>Create a new playwright Browser with specified options.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.326594" elapsed="0.000369"/>
</kw>
<kw name="Go To" owner="Browser">
<arg>${FRONTEND_URL}/setup</arg>
<tag>BrowserControl</tag>
<tag>Setter</tag>
<doc>Navigates to the given ``url``.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.327084" elapsed="0.000236"/>
</kw>
<kw name="Wait For Elements State" owner="Browser">
<arg>css=input[aria-label="Master Password"]</arg>
<arg>visible</arg>
<arg>timeout=15s</arg>
<tag>PageContent</tag>
<tag>Wait</tag>
<doc>Waits for the element found by ``selector`` to satisfy state option.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.327412" elapsed="0.000268"/>
</kw>
<kw name="Fill Text" owner="Browser">
<arg>css=input[aria-label="Master Password"]</arg>
<arg>Hallo123!</arg>
<tag>PageContent</tag>
<tag>Setter</tag>
<doc>Clears and fills the given ``txt`` into the text field found by ``selector``.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.327775" elapsed="0.000231"/>
</kw>
<kw name="Fill Text" owner="Browser">
<arg>css=input[aria-label="Confirm Password"]</arg>
<arg>Different123!</arg>
<tag>PageContent</tag>
<tag>Setter</tag>
<doc>Clears and fills the given ``txt`` into the text field found by ``selector``.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.328100" elapsed="0.000248"/>
</kw>
<kw name="Click" owner="Browser">
<arg>css=button[type="submit"]</arg>
<tag>PageContent</tag>
<tag>Setter</tag>
<doc>Simulates mouse click on the element found by ``selector``.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.328459" elapsed="0.000240"/>
</kw>
<kw name="Wait For Elements State" owner="Browser">
<arg>css=[aria-label="Confirm Password"]</arg>
<arg>attached</arg>
<arg>timeout=5s</arg>
<tag>PageContent</tag>
<tag>Wait</tag>
<doc>Waits for the element found by ``selector`` to satisfy state option.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.328786" elapsed="0.000260"/>
</kw>
<kw name="Get Text" owner="Browser">
<var>${msg}</var>
<arg>css=[aria-label="Confirm Password"]/ancestor::*[contains(@class,"field")]//*[contains(@class,"validationMessage")]</arg>
<tag>Assertion</tag>
<tag>Getter</tag>
<tag>PageContent</tag>
<doc>Returns text attribute of the element found by ``selector``.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.329139" elapsed="0.000224"/>
</kw>
<kw name="Should Be Equal As Strings" owner="BuiltIn">
<arg>${msg}</arg>
<arg>Passwords do not match.</arg>
<doc>Fails if objects are unequal after converting them to strings.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.329449" elapsed="0.000162"/>
</kw>
<kw name="Close Browser" owner="Browser">
<tag>BrowserControl</tag>
<tag>Setter</tag>
<doc>Closes the current browser.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.329689" elapsed="0.000183"/>
</kw>
<doc>Submitting with non-matching passwords surfaces an error on Confirm Password.</doc>
<status status="PASS" start="2026-05-05T19:08:16.324720" elapsed="0.005238"/>
</test>
<test id="s1-t4" name="Empty Required Fields Show Validation Errors" line="78">
<kw name="New Browser" owner="Browser">
<arg>chromium</arg>
<arg>headless=${TRUE}</arg>
<tag>BrowserControl</tag>
<tag>Setter</tag>
<doc>Create a new playwright Browser with specified options.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.337764" elapsed="0.000617"/>
</kw>
<kw name="Go To" owner="Browser">
<arg>${FRONTEND_URL}/setup</arg>
<tag>BrowserControl</tag>
<tag>Setter</tag>
<doc>Navigates to the given ``url``.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.339155" elapsed="0.000380"/>
</kw>
<kw name="Wait For Elements State" owner="Browser">
<arg>css=input[aria-label="Master Password"]</arg>
<arg>visible</arg>
<arg>timeout=15s</arg>
<tag>PageContent</tag>
<tag>Wait</tag>
<doc>Waits for the element found by ``selector`` to satisfy state option.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.339630" elapsed="0.000295"/>
</kw>
<kw name="Click" owner="Browser">
<arg>css=button[type="submit"]</arg>
<tag>PageContent</tag>
<tag>Setter</tag>
<doc>Simulates mouse click on the element found by ``selector``.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.340021" elapsed="0.000220"/>
</kw>
<kw name="Wait For Elements State" owner="Browser">
<arg>css=[aria-label="Master Password"]</arg>
<arg>attached</arg>
<arg>timeout=5s</arg>
<tag>PageContent</tag>
<tag>Wait</tag>
<doc>Waits for the element found by ``selector`` to satisfy state option.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.340332" elapsed="0.000283"/>
</kw>
<kw name="Get Text" owner="Browser">
<var>${msg}</var>
<arg>css=[aria-label="Master Password"]/ancestor::*[contains(@class,"field")]//*[contains(@class,"validationMessage")]</arg>
<tag>Assertion</tag>
<tag>Getter</tag>
<tag>PageContent</tag>
<doc>Returns text attribute of the element found by ``selector``.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.340716" elapsed="0.000225"/>
</kw>
<kw name="Should Be Equal As Strings" owner="BuiltIn">
<arg>${msg}</arg>
<arg>Password is required.</arg>
<doc>Fails if objects are unequal after converting them to strings.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.341034" elapsed="0.000150"/>
</kw>
<kw name="Wait For Elements State" owner="Browser">
<arg>css=[aria-label="Database Path"]</arg>
<arg>attached</arg>
<arg>timeout=5s</arg>
<tag>PageContent</tag>
<tag>Wait</tag>
<doc>Waits for the element found by ``selector`` to satisfy state option.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.341263" elapsed="0.000230"/>
</kw>
<kw name="Get Text" owner="Browser">
<var>${msg}</var>
<arg>css=[aria-label="Database Path"]/ancestor::*[contains(@class,"field")]//*[contains(@class,"validationMessage")]</arg>
<tag>Assertion</tag>
<tag>Getter</tag>
<tag>PageContent</tag>
<doc>Returns text attribute of the element found by ``selector``.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.341583" elapsed="0.000230"/>
</kw>
<kw name="Should Be Equal As Strings" owner="BuiltIn">
<arg>${msg}</arg>
<arg>Database path is required.</arg>
<doc>Fails if objects are unequal after converting them to strings.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.341894" elapsed="0.000174"/>
</kw>
<kw name="Wait For Elements State" owner="Browser">
<arg>css=[aria-label="fail2ban Socket Path"]</arg>
<arg>attached</arg>
<arg>timeout=5s</arg>
<tag>PageContent</tag>
<tag>Wait</tag>
<doc>Waits for the element found by ``selector`` to satisfy state option.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.342145" elapsed="0.000233"/>
</kw>
<kw name="Get Text" owner="Browser">
<var>${msg}</var>
<arg>css=[aria-label="fail2ban Socket Path"]/ancestor::*[contains(@class,"field")]//*[contains(@class,"validationMessage")]</arg>
<tag>Assertion</tag>
<tag>Getter</tag>
<tag>PageContent</tag>
<doc>Returns text attribute of the element found by ``selector``.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.342462" elapsed="0.000205"/>
</kw>
<kw name="Should Be Equal As Strings" owner="BuiltIn">
<arg>${msg}</arg>
<arg>Socket path is required.</arg>
<doc>Fails if objects are unequal after converting them to strings.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.342743" elapsed="0.000137"/>
</kw>
<kw name="Close Browser" owner="Browser">
<tag>BrowserControl</tag>
<tag>Setter</tag>
<doc>Closes the current browser.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.342953" elapsed="0.000209"/>
</kw>
<doc>Submitting with blank required fields shows field-level error messages.</doc>
<status status="PASS" start="2026-05-05T19:08:16.333829" elapsed="0.009421"/>
</test>
<test id="s1-t5" name="Invalid Session Duration Shows Validation Error" line="100">
<kw name="New Browser" owner="Browser">
<arg>chromium</arg>
<arg>headless=${TRUE}</arg>
<tag>BrowserControl</tag>
<tag>Setter</tag>
<doc>Create a new playwright Browser with specified options.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.348383" elapsed="0.000370"/>
</kw>
<kw name="Go To" owner="Browser">
<arg>${FRONTEND_URL}/setup</arg>
<tag>BrowserControl</tag>
<tag>Setter</tag>
<doc>Navigates to the given ``url``.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.348871" elapsed="0.000241"/>
</kw>
<kw name="Wait For Elements State" owner="Browser">
<arg>css=input[aria-label="Master Password"]</arg>
<arg>visible</arg>
<arg>timeout=15s</arg>
<tag>PageContent</tag>
<tag>Wait</tag>
<doc>Waits for the element found by ``selector`` to satisfy state option.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.349206" elapsed="0.000282"/>
</kw>
<kw name="Fill Text" owner="Browser">
<arg>css=input[aria-label="Master Password"]</arg>
<arg>Hallo123!</arg>
<tag>PageContent</tag>
<tag>Setter</tag>
<doc>Clears and fills the given ``txt`` into the text field found by ``selector``.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.349584" elapsed="0.000216"/>
</kw>
<kw name="Fill Text" owner="Browser">
<arg>css=input[aria-label="Confirm Password"]</arg>
<arg>Hallo123!</arg>
<tag>PageContent</tag>
<tag>Setter</tag>
<doc>Clears and fills the given ``txt`` into the text field found by ``selector``.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.349885" elapsed="0.000210"/>
</kw>
<kw name="Fill Text" owner="Browser">
<arg>css=input[aria-label="Database Path"]</arg>
<arg>bangui.db</arg>
<tag>PageContent</tag>
<tag>Setter</tag>
<doc>Clears and fills the given ``txt`` into the text field found by ``selector``.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.350176" elapsed="0.000203"/>
</kw>
<kw name="Fill Text" owner="Browser">
<arg>css=input[aria-label="fail2ban Socket Path"]</arg>
<arg>/var/run/fail2ban/fail2ban.sock</arg>
<tag>PageContent</tag>
<tag>Setter</tag>
<doc>Clears and fills the given ``txt`` into the text field found by ``selector``.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.350456" elapsed="0.000188"/>
</kw>
<kw name="Fill Text" owner="Browser">
<arg>css=input[aria-label="Session Duration (minutes)"]</arg>
<arg>0</arg>
<tag>PageContent</tag>
<tag>Setter</tag>
<doc>Clears and fills the given ``txt`` into the text field found by ``selector``.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.350721" elapsed="0.000220"/>
</kw>
<kw name="Click" owner="Browser">
<arg>css=button[type="submit"]</arg>
<tag>PageContent</tag>
<tag>Setter</tag>
<doc>Simulates mouse click on the element found by ``selector``.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.351058" elapsed="0.000285"/>
</kw>
<kw name="Wait For Elements State" owner="Browser">
<arg>css=[aria-label="Session Duration (minutes)"]</arg>
<arg>attached</arg>
<arg>timeout=5s</arg>
<tag>PageContent</tag>
<tag>Wait</tag>
<doc>Waits for the element found by ``selector`` to satisfy state option.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.351500" elapsed="0.000444"/>
</kw>
<kw name="Get Text" owner="Browser">
<var>${msg}</var>
<arg>css=[aria-label="Session Duration (minutes)"]/ancestor::*[contains(@class,"field")]//*[contains(@class,"validationMessage")]</arg>
<tag>Assertion</tag>
<tag>Getter</tag>
<tag>PageContent</tag>
<doc>Returns text attribute of the element found by ``selector``.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.352094" elapsed="0.000361"/>
</kw>
<kw name="Should Be Equal As Strings" owner="BuiltIn">
<arg>${msg}</arg>
<arg>Session duration must be at least 1 minute.</arg>
<doc>Fails if objects are unequal after converting them to strings.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.352673" elapsed="0.000368"/>
</kw>
<kw name="Close Browser" owner="Browser">
<tag>BrowserControl</tag>
<tag>Setter</tag>
<doc>Closes the current browser.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.353712" elapsed="0.000212"/>
</kw>
<doc>Session duration below 1 minute triggers a validation error.</doc>
<status status="PASS" start="2026-05-05T19:08:16.346640" elapsed="0.007391"/>
</test>
<test id="s1-t6" name="Incomplete Password Shows Complexity Error" line="120">
<kw name="New Browser" owner="Browser">
<arg>chromium</arg>
<arg>headless=${TRUE}</arg>
<tag>BrowserControl</tag>
<tag>Setter</tag>
<doc>Create a new playwright Browser with specified options.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.360793" elapsed="0.000405"/>
</kw>
<kw name="Go To" owner="Browser">
<arg>${FRONTEND_URL}/setup</arg>
<tag>BrowserControl</tag>
<tag>Setter</tag>
<doc>Navigates to the given ``url``.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.361312" elapsed="0.000247"/>
</kw>
<kw name="Wait For Elements State" owner="Browser">
<arg>css=input[aria-label="Master Password"]</arg>
<arg>visible</arg>
<arg>timeout=15s</arg>
<tag>PageContent</tag>
<tag>Wait</tag>
<doc>Waits for the element found by ``selector`` to satisfy state option.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.361647" elapsed="0.000293"/>
</kw>
<kw name="Fill Text" owner="Browser">
<arg>css=input[aria-label="Master Password"]</arg>
<arg>short</arg>
<tag>PageContent</tag>
<tag>Setter</tag>
<doc>Clears and fills the given ``txt`` into the text field found by ``selector``.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.362037" elapsed="0.000214"/>
</kw>
<kw name="Click" owner="Browser">
<arg>css=button[type="submit"]</arg>
<tag>PageContent</tag>
<tag>Setter</tag>
<doc>Simulates mouse click on the element found by ``selector``.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.362335" elapsed="0.000206"/>
</kw>
<kw name="Wait For Elements State" owner="Browser">
<arg>css=[aria-label="Master Password"]</arg>
<arg>attached</arg>
<arg>timeout=5s</arg>
<tag>PageContent</tag>
<tag>Wait</tag>
<doc>Waits for the element found by ``selector`` to satisfy state option.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.362621" elapsed="0.000231"/>
</kw>
<kw name="Get Text" owner="Browser">
<var>${msg}</var>
<arg>css=[aria-label="Master Password"]/ancestor::*[contains(@class,"field")]//*[contains(@class,"validationMessage")]</arg>
<tag>Assertion</tag>
<tag>Getter</tag>
<tag>PageContent</tag>
<doc>Returns text attribute of the element found by ``selector``.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.362936" elapsed="0.000208"/>
</kw>
<kw name="Should Contain" owner="BuiltIn">
<arg>${msg}</arg>
<arg>Password must meet all complexity requirements.</arg>
<doc>Fails if ``container`` does not contain ``item`` one or more times.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.363225" elapsed="0.000152"/>
</kw>
<kw name="Close Browser" owner="Browser">
<tag>BrowserControl</tag>
<tag>Setter</tag>
<doc>Closes the current browser.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.363456" elapsed="0.000181"/>
</kw>
<doc>Submitting a password that meets length but not all rules shows complexity error.</doc>
<status status="PASS" start="2026-05-05T19:08:16.359200" elapsed="0.004521"/>
</test>
<test id="s1-t7" name="Setup Completes Successfully And Redirects To Login" line="135">
<kw name="New Browser" owner="Browser">
<arg>chromium</arg>
<arg>headless=${TRUE}</arg>
<tag>BrowserControl</tag>
<tag>Setter</tag>
<doc>Create a new playwright Browser with specified options.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.369227" elapsed="0.001419"/>
</kw>
<kw name="GET" owner="RequestsLibrary">
<var>${status_resp}</var>
<arg>${BACKEND_URL}/api/setup/status</arg>
<doc>Sends a GET request.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.370784" elapsed="0.000122"/>
</kw>
<kw name="Set Variable" owner="BuiltIn">
<var>${status_body}</var>
<arg>${status_resp.json()}</arg>
<doc>Returns the given values which can then be assigned to a variables.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.370993" elapsed="0.000163"/>
</kw>
<kw name="Log" owner="BuiltIn">
<arg>Setup complete: ${status_body}[setup_complete]</arg>
<doc>Logs the given message with the given level.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.371232" elapsed="0.000136"/>
</kw>
<kw name="Go To" owner="Browser">
<arg>${FRONTEND_URL}/setup</arg>
<tag>BrowserControl</tag>
<tag>Setter</tag>
<doc>Navigates to the given ``url``.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.371442" elapsed="0.000212"/>
</kw>
<kw name="Wait For Elements State" owner="Browser">
<arg>css=input[aria-label="Master Password"]</arg>
<arg>visible</arg>
<arg>timeout=15s</arg>
<tag>PageContent</tag>
<tag>Wait</tag>
<doc>Waits for the element found by ``selector`` to satisfy state option.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.371742" elapsed="0.000370"/>
</kw>
<kw name="Fill Text" owner="Browser">
<arg>css=input[aria-label="Master Password"]</arg>
<arg>Hallo123!</arg>
<tag>PageContent</tag>
<tag>Setter</tag>
<doc>Clears and fills the given ``txt`` into the text field found by ``selector``.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.372275" elapsed="0.000292"/>
</kw>
<kw name="Fill Text" owner="Browser">
<arg>css=input[aria-label="Confirm Password"]</arg>
<arg>Hallo123!</arg>
<tag>PageContent</tag>
<tag>Setter</tag>
<doc>Clears and fills the given ``txt`` into the text field found by ``selector``.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.372675" elapsed="0.000251"/>
</kw>
<kw name="Fill Text" owner="Browser">
<arg>css=input[aria-label="Database Path"]</arg>
<arg>bangui.db</arg>
<tag>PageContent</tag>
<tag>Setter</tag>
<doc>Clears and fills the given ``txt`` into the text field found by ``selector``.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.373035" elapsed="0.000310"/>
</kw>
<kw name="Fill Text" owner="Browser">
<arg>css=input[aria-label="fail2ban Socket Path"]</arg>
<arg>/var/run/fail2ban/fail2ban.sock</arg>
<tag>PageContent</tag>
<tag>Setter</tag>
<doc>Clears and fills the given ``txt`` into the text field found by ``selector``.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.373452" elapsed="0.000267"/>
</kw>
<kw name="Fill Text" owner="Browser">
<arg>css=input[aria-label="Timezone"]</arg>
<arg>UTC</arg>
<tag>PageContent</tag>
<tag>Setter</tag>
<doc>Clears and fills the given ``txt`` into the text field found by ``selector``.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.373819" elapsed="0.000256"/>
</kw>
<kw name="Fill Text" owner="Browser">
<arg>css=input[aria-label="Session Duration (minutes)"]</arg>
<arg>60</arg>
<tag>PageContent</tag>
<tag>Setter</tag>
<doc>Clears and fills the given ``txt`` into the text field found by ``selector``.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.374156" elapsed="0.000240"/>
</kw>
<kw name="Click" owner="Browser">
<arg>css=button[type="submit"]</arg>
<tag>PageContent</tag>
<tag>Setter</tag>
<doc>Simulates mouse click on the element found by ``selector``.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.374478" elapsed="0.000261"/>
</kw>
<kw name="Get Url" owner="Browser">
<var>${current_url}</var>
<tag>Assertion</tag>
<tag>Getter</tag>
<tag>PageContent</tag>
<doc>Returns the current URL.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.374846" elapsed="0.000234"/>
</kw>
<if>
<branch type="IF" condition="&quot;login&quot; not in &quot;&quot;&quot;${current_url}&quot;&quot;&quot;">
<kw name="Evaluate" owner="BuiltIn">
<var>${deadline}</var>
<arg>time.time() + 15</arg>
<doc>Evaluates the given expression in Python and returns the result.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.375390" elapsed="0.000221"/>
</kw>
<while condition="True">
<iter>
<kw name="Evaluate" owner="BuiltIn">
<var>${now}</var>
<arg>time.time()</arg>
<doc>Evaluates the given expression in Python and returns the result.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.376079" elapsed="0.000214"/>
</kw>
<if>
<branch type="IF" condition="${now} &gt;= ${deadline}">
<break>
<status status="PASS" start="2026-05-05T19:08:16.376629" elapsed="0.000061"/>
</break>
<status status="PASS" start="2026-05-05T19:08:16.376533" elapsed="0.000200"/>
</branch>
<status status="NOT RUN" start="2026-05-05T19:08:16.376510" elapsed="0.000263"/>
</if>
<kw name="Get Url" owner="Browser">
<var>${url}</var>
<tag>Assertion</tag>
<tag>Getter</tag>
<tag>PageContent</tag>
<doc>Returns the current URL.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.376870" elapsed="0.000303"/>
</kw>
<if>
<branch type="IF" condition="&quot;login&quot; in &quot;&quot;&quot;${url}&quot;&quot;&quot;">
<break>
<status status="PASS" start="2026-05-05T19:08:16.377395" elapsed="0.000052"/>
</break>
<status status="PASS" start="2026-05-05T19:08:16.377278" elapsed="0.000200"/>
</branch>
<status status="NOT RUN" start="2026-05-05T19:08:16.377259" elapsed="0.000245"/>
</if>
<kw name="Sleep" owner="BuiltIn">
<arg>0.5</arg>
<doc>Pauses the test executed for the given time.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.377566" elapsed="0.000152"/>
</kw>
<status status="NOT RUN" start="2026-05-05T19:08:16.375690" elapsed="0.002063"/>
</iter>
<status status="NOT RUN" start="2026-05-05T19:08:16.375688" elapsed="0.002094"/>
</while>
<status status="PASS" start="2026-05-05T19:08:16.375221" elapsed="0.002591"/>
</branch>
<status status="PASS" start="2026-05-05T19:08:16.375197" elapsed="0.002640"/>
</if>
<kw name="GET" owner="RequestsLibrary">
<var>${new_status_resp}</var>
<arg>${BACKEND_URL}/api/setup/status</arg>
<doc>Sends a GET request.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.377900" elapsed="0.000088"/>
</kw>
<kw name="Set Variable" owner="BuiltIn">
<var>${new_status_body}</var>
<arg>${new_status_resp.json()}</arg>
<doc>Returns the given values which can then be assigned to a variables.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.378068" elapsed="0.000134"/>
</kw>
<kw name="Should Be True" owner="BuiltIn">
<arg>${new_status_body}[setup_complete]</arg>
<doc>Fails if the given condition is not true.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.378279" elapsed="0.000129"/>
</kw>
<kw name="Close Browser" owner="Browser">
<tag>BrowserControl</tag>
<tag>Setter</tag>
<doc>Closes the current browser.</doc>
<status status="NOT RUN" start="2026-05-05T19:08:16.378480" elapsed="0.000189"/>
</kw>
<doc>Filling all fields and submitting completes setup and navigates to /login.</doc>
<status status="PASS" start="2026-05-05T19:08:16.366735" elapsed="0.012022"/>
</test>
<status status="PASS" start="2026-05-05T19:08:15.508608" elapsed="0.873487"/>
</suite>
<statistics>
<total>
<stat pass="7" fail="0" skip="0">All Tests</stat>
</total>
<tag>
</tag>
<suite>
<stat name="05 Setup" id="s1" pass="7" fail="0" skip="0">05 Setup</stat>
</suite>
</statistics>
<errors>
<msg time="2026-05-05T19:08:15.732927" level="ERROR">Error in file '/home/lukas/Volume/repo/BanGUI/e2e/resources/common.resource' on line 5: Processing variable file '/home/lukas/Volume/repo/BanGUI/e2e/resources/../../.env' failed: Importing variable file '/home/lukas/Volume/repo/BanGUI/e2e/resources/../../.env' failed: Module name cannot contain dots when importing by path.</msg>
</errors>
</robot>