fixed tests
This commit is contained in:
2427
Docs/Tasks.md
2427
Docs/Tasks.md
File diff suppressed because it is too large
Load Diff
@@ -102,7 +102,7 @@ for (int i = 0; i < items.Count; i++)
|
||||
|
||||
// Step 1 — run the task prompt
|
||||
await RunCopilot(Enumerable.Empty<string>(), $"/caveman full");
|
||||
await RunCopilot(new[] { "--continue" }, $"read ./Docs/Instructions.md. fix the following test and only that one {item}");
|
||||
await RunCopilot(new[] { "--continue" }, $"read ./Docs/Instructions.md. fix the following test and only that one. Keep in mind that i did many refactorings and test may is obsolet or need to be changed. {item}");
|
||||
if (cts.IsCancellationRequested) break;
|
||||
|
||||
// Step 2 — confirm completion in the same chat session
|
||||
|
||||
@@ -14,6 +14,7 @@ from __future__ import annotations
|
||||
from pathlib import Path
|
||||
|
||||
import aiosqlite
|
||||
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
log = get_logger(__name__)
|
||||
@@ -246,7 +247,6 @@ CREATE INDEX IF NOT EXISTS idx_import_log_source_id_desc
|
||||
}
|
||||
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -254,6 +254,7 @@ CREATE INDEX IF NOT EXISTS idx_import_log_source_id_desc
|
||||
|
||||
async def _configure_connection(db: aiosqlite.Connection) -> None:
|
||||
"""Apply hardening pragmas to a newly-opened SQLite connection."""
|
||||
await db.execute("PRAGMA journal_mode=WAL;")
|
||||
await db.execute("PRAGMA foreign_keys=ON;")
|
||||
await db.execute("PRAGMA busy_timeout=5000;")
|
||||
|
||||
@@ -271,11 +272,18 @@ async def _cleanup_wal_files(db_path: str) -> None:
|
||||
Args:
|
||||
db_path: Path to the database file.
|
||||
"""
|
||||
import time
|
||||
|
||||
wal_path = Path(db_path + "-wal")
|
||||
shm_path = Path(db_path + "-shm")
|
||||
|
||||
for path in (wal_path, shm_path):
|
||||
if path.exists():
|
||||
# Skip files that were modified recently — they likely belong to an
|
||||
# active connection. Only remove stale files left by crashes.
|
||||
mtime = path.stat().st_mtime
|
||||
if time.time() - mtime < 10:
|
||||
continue
|
||||
try:
|
||||
path.unlink()
|
||||
log.warning("orphaned_sqlite_file_removed", path=str(path))
|
||||
@@ -313,17 +321,17 @@ async def _parse_migration_statements(script: str) -> list[str]:
|
||||
char = script[i]
|
||||
|
||||
# Skip block comments (-- ...)
|
||||
if i < len(script) - 1 and script[i:i+2] == "--":
|
||||
if i < len(script) - 1 and script[i : i + 2] == "--":
|
||||
while i < len(script) and script[i] != "\n":
|
||||
i += 1
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Skip line comments (/* ... */)
|
||||
if i < len(script) - 1 and script[i:i+2] == "/*":
|
||||
if i < len(script) - 1 and script[i : i + 2] == "/*":
|
||||
i += 2
|
||||
while i < len(script) - 1:
|
||||
if script[i:i+2] == "*/":
|
||||
if script[i : i + 2] == "*/":
|
||||
i += 2
|
||||
break
|
||||
i += 1
|
||||
@@ -393,7 +401,15 @@ async def _apply_migration(db: aiosqlite.Connection, version: int) -> None:
|
||||
await db.execute("BEGIN IMMEDIATE;")
|
||||
|
||||
for statement in statements:
|
||||
await db.execute(statement)
|
||||
try:
|
||||
await db.execute(statement)
|
||||
except aiosqlite.OperationalError as exc:
|
||||
# Ignore duplicate column / table errors so migrations remain
|
||||
# idempotent when a legacy database already has the object.
|
||||
msg = str(exc).lower()
|
||||
if "duplicate column name" in msg or "table" in msg and "already exists" in msg:
|
||||
continue
|
||||
raise
|
||||
|
||||
await db.execute("INSERT INTO schema_migrations (version) VALUES (?);", (version,))
|
||||
|
||||
@@ -411,8 +427,7 @@ async def _migrate_schema(db: aiosqlite.Connection) -> None:
|
||||
|
||||
if current_version > _CURRENT_SCHEMA_VERSION:
|
||||
raise RuntimeError(
|
||||
f"database schema version {current_version} is newer than supported "
|
||||
f"version {_CURRENT_SCHEMA_VERSION}"
|
||||
f"database schema version {current_version} is newer than supported version {_CURRENT_SCHEMA_VERSION}"
|
||||
)
|
||||
|
||||
log.info("migrating_database_schema", from_version=current_version, to_version=_CURRENT_SCHEMA_VERSION)
|
||||
|
||||
@@ -36,7 +36,6 @@ from typing import Annotated, cast
|
||||
|
||||
import aiohttp
|
||||
import aiosqlite
|
||||
from app.utils.logging_compat import get_logger
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler # type: ignore[import-untyped]
|
||||
from fastapi import Depends, FastAPI, HTTPException, Request, status
|
||||
|
||||
@@ -45,22 +44,6 @@ from app.exceptions import RateLimitError
|
||||
from app.models.auth import Session
|
||||
from app.models.config import PendingRecovery
|
||||
from app.models.server import ServerStatus
|
||||
from app.repositories.protocols import (
|
||||
BlocklistRepository,
|
||||
Fail2BanDbRepository,
|
||||
GeoCacheRepository,
|
||||
HistoryArchiveRepository,
|
||||
ImportLogRepository,
|
||||
ImportRunRepository,
|
||||
SessionRepository,
|
||||
SettingsRepository,
|
||||
)
|
||||
from app.services.geo_cache import GeoCache
|
||||
from app.services.protocols import Fail2BanMetadataService
|
||||
from app.utils.constants import SESSION_COOKIE_NAME
|
||||
from app.utils.rate_limiter import GlobalRateLimiter
|
||||
from app.utils.runtime_state import ApplicationState, JailServiceState, RuntimeState
|
||||
from app.utils.session_cache import NoOpSessionCache, SessionCache
|
||||
|
||||
# Module-level imports for repositories and services
|
||||
# These are safe at module level since no circular dependencies exist
|
||||
@@ -74,8 +57,25 @@ from app.repositories import (
|
||||
session_repo,
|
||||
settings_repo,
|
||||
)
|
||||
from app.repositories.protocols import (
|
||||
BlocklistRepository,
|
||||
Fail2BanDbRepository,
|
||||
GeoCacheRepository,
|
||||
HistoryArchiveRepository,
|
||||
ImportLogRepository,
|
||||
ImportRunRepository,
|
||||
SessionRepository,
|
||||
SettingsRepository,
|
||||
)
|
||||
from app.services import auth_service, health_service
|
||||
from app.services.fail2ban_metadata_service import default_fail2ban_metadata_service
|
||||
from app.services.geo_cache import GeoCache
|
||||
from app.services.protocols import Fail2BanMetadataService
|
||||
from app.utils.constants import SESSION_COOKIE_NAME
|
||||
from app.utils.logging_compat import get_logger
|
||||
from app.utils.rate_limiter import GlobalRateLimiter
|
||||
from app.utils.runtime_state import ApplicationState, JailServiceState, RuntimeState
|
||||
from app.utils.session_cache import NoOpSessionCache, SessionCache
|
||||
|
||||
log = get_logger(__name__)
|
||||
|
||||
@@ -108,6 +108,7 @@ class ApplicationContext:
|
||||
#: or distributed deployments, the configured cache backend should provide
|
||||
#: invalidation semantics appropriate for the deployment.
|
||||
|
||||
|
||||
def _session_cache_enabled(settings: Settings) -> bool:
|
||||
"""Return whether the session validation cache should be used."""
|
||||
return settings.session_cache_enabled and settings.session_cache_ttl_seconds > 0.0
|
||||
@@ -284,6 +285,7 @@ def rate_limit_dependency(
|
||||
Returns:
|
||||
A callable that can be used as a FastAPI Depends() dependency.
|
||||
"""
|
||||
|
||||
async def check_rate_limit(
|
||||
request: Request,
|
||||
rate_limiter: GlobalRateLimiterDep,
|
||||
@@ -293,9 +295,7 @@ def rate_limit_dependency(
|
||||
settings: Settings = request.app.state.settings
|
||||
client_ip = get_client_ip(request, trusted_proxies=settings.trusted_proxies)
|
||||
|
||||
is_allowed, retry_after = rate_limiter.check_allowed_for_bucket(
|
||||
bucket, client_ip, max_requests, window_seconds
|
||||
)
|
||||
is_allowed, retry_after = rate_limiter.check_allowed_for_bucket(bucket, client_ip, max_requests, window_seconds)
|
||||
|
||||
if not is_allowed:
|
||||
log.warning(
|
||||
@@ -407,6 +407,8 @@ async def get_app(request: Request) -> FastAPI:
|
||||
|
||||
async def get_server_status(app_context: Annotated[ApplicationContext, Depends(get_app_context)]) -> ServerStatus:
|
||||
"""Return the cached fail2ban server status snapshot from application context."""
|
||||
if app_context.server_status is None:
|
||||
return ServerStatus(online=False)
|
||||
return app_context.server_status
|
||||
|
||||
|
||||
@@ -654,7 +656,7 @@ async def require_auth(
|
||||
if not token:
|
||||
auth_header: str = request.headers.get("Authorization", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
token = auth_header[len("Bearer "):]
|
||||
token = auth_header[len("Bearer ") :]
|
||||
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
|
||||
@@ -72,13 +72,13 @@ from app.utils.external_logging import (
|
||||
ExternalLogHandler,
|
||||
create_external_log_handler,
|
||||
)
|
||||
from app.utils.json_formatter import JSONFormatter
|
||||
from app.utils.logging_compat import get_logger
|
||||
from app.utils.rate_limiter import GlobalRateLimiter
|
||||
from app.utils.runtime_state import ApplicationState, RuntimeState
|
||||
from app.utils.scheduler_lock import release_scheduler_lock
|
||||
from app.utils.session_cache import InMemorySessionCache, NoOpSessionCache
|
||||
from app.utils.setup_state import is_setup_complete_cached, set_setup_complete_cache
|
||||
from app.utils.json_formatter import JSONFormatter
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
log = get_logger("bangui")
|
||||
|
||||
@@ -125,8 +125,15 @@ def _configure_logging(log_level: str, log_file: str | None, settings: Settings
|
||||
level: int = logging.getLevelName(log_level.upper())
|
||||
handlers: list[logging.Handler] = [logging.StreamHandler(sys.stdout)]
|
||||
if log_file:
|
||||
os.makedirs(os.path.dirname(log_file), exist_ok=True)
|
||||
handlers.append(logging.FileHandler(log_file))
|
||||
try:
|
||||
os.makedirs(os.path.dirname(log_file), exist_ok=True)
|
||||
handlers.append(logging.FileHandler(log_file))
|
||||
except (PermissionError, OSError) as exc:
|
||||
log.warning(
|
||||
"log_file_directory_not_created",
|
||||
log_file=log_file,
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
# Suppress verbose third-party library logs that emit plain text
|
||||
# through the standard library logging module.
|
||||
@@ -163,9 +170,7 @@ def _update_session_cache(app: FastAPI, settings: Settings) -> None:
|
||||
settings: The effective application settings.
|
||||
"""
|
||||
cache_enabled = settings.session_cache_enabled and settings.session_cache_ttl_seconds > 0.0
|
||||
app.state.session_cache = (
|
||||
InMemorySessionCache() if cache_enabled else NoOpSessionCache()
|
||||
)
|
||||
app.state.session_cache = InMemorySessionCache() if cache_enabled else NoOpSessionCache()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
@@ -811,12 +816,12 @@ async def _request_validation_error_handler(
|
||||
# the guard without being explicitly allowed.
|
||||
_EXACT_ALLOWED: frozenset[str] = frozenset(
|
||||
{
|
||||
"/api/v1/setup", # GET/POST /api/v1/setup
|
||||
"/api/v1/health", # Health check endpoint (combined)
|
||||
"/api/v1/health/live", # Kubernetes liveness probe
|
||||
"/api/v1/setup", # GET/POST /api/v1/setup
|
||||
"/api/v1/health", # Health check endpoint (combined)
|
||||
"/api/v1/health/live", # Kubernetes liveness probe
|
||||
"/api/v1/health/ready", # Kubernetes readiness probe
|
||||
"/api/docs", # Swagger UI
|
||||
"/api/redoc", # ReDoc
|
||||
"/api/docs", # Swagger UI
|
||||
"/api/redoc", # ReDoc
|
||||
"/api/openapi.json", # OpenAPI schema
|
||||
},
|
||||
)
|
||||
@@ -971,9 +976,7 @@ def _enforce_single_worker() -> None:
|
||||
"See Docs/Deployment.md § Single-Worker Requirement."
|
||||
)
|
||||
except ValueError as e:
|
||||
raise RuntimeError(
|
||||
f"WEB_CONCURRENCY must be an integer, got: {web_concurrency}"
|
||||
) from e
|
||||
raise RuntimeError(f"WEB_CONCURRENCY must be an integer, got: {web_concurrency}") from e
|
||||
|
||||
# Check explicit BANGUI_WORKERS override (discouraged, still enforced)
|
||||
bangui_workers = os.environ.get("BANGUI_WORKERS")
|
||||
@@ -990,9 +993,7 @@ def _enforce_single_worker() -> None:
|
||||
"See Docs/Deployment.md § Single-Worker Requirement."
|
||||
)
|
||||
except ValueError as e:
|
||||
raise RuntimeError(
|
||||
f"BANGUI_WORKERS must be an integer, got: {bangui_workers}"
|
||||
) from e
|
||||
raise RuntimeError(f"BANGUI_WORKERS must be an integer, got: {bangui_workers}") from e
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -1165,7 +1166,6 @@ def create_app(settings: Settings | None = None) -> FastAPI:
|
||||
# stack is a security-critical defect that must not slip through silently.
|
||||
_assert_middleware_order(app)
|
||||
|
||||
|
||||
# --- Exception handlers ---
|
||||
#
|
||||
# Exception handlers are registered from most specific to least specific. FastAPI evaluates
|
||||
|
||||
@@ -10,13 +10,11 @@ from __future__ import annotations
|
||||
|
||||
from app.models.config import (
|
||||
BantimeEscalation,
|
||||
Fail2BanLogResponse,
|
||||
FilterConfig,
|
||||
FilterListResponse,
|
||||
GlobalConfigResponse,
|
||||
JailConfig,
|
||||
JailConfigListResponse,
|
||||
LogPreviewResponse,
|
||||
MapColorThresholdsResponse,
|
||||
RegexTestResponse,
|
||||
ServiceStatusResponse,
|
||||
@@ -32,7 +30,6 @@ from app.models.config_domain import (
|
||||
DomainRegexTest,
|
||||
DomainServiceStatus,
|
||||
)
|
||||
from app.utils.pagination import create_pagination_metadata
|
||||
|
||||
|
||||
def _map_domain_bantime_escalation(domain: DomainBantimeEscalation) -> BantimeEscalation:
|
||||
@@ -65,9 +62,7 @@ def map_domain_jail_config_to_response(domain: DomainJailConfig) -> JailConfig:
|
||||
prefregex=domain.prefregex,
|
||||
actions=domain.actions,
|
||||
bantime_escalation=(
|
||||
_map_domain_bantime_escalation(domain.bantime_escalation)
|
||||
if domain.bantime_escalation
|
||||
else None
|
||||
_map_domain_bantime_escalation(domain.bantime_escalation) if domain.bantime_escalation else None
|
||||
),
|
||||
)
|
||||
|
||||
@@ -151,6 +146,6 @@ def map_domain_filter_config_to_response(domain: DomainFilterConfig) -> FilterCo
|
||||
def map_domain_filter_list_to_response(domain_list: DomainFilterList) -> FilterListResponse:
|
||||
"""Convert domain filter list to response model."""
|
||||
return FilterListResponse(
|
||||
items=[map_domain_filter_config_to_response(f) for f in domain_list.items],
|
||||
filters=[map_domain_filter_config_to_response(f) for f in domain_list.items],
|
||||
total=domain_list.total,
|
||||
)
|
||||
|
||||
@@ -8,15 +8,15 @@ from __future__ import annotations
|
||||
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import AnyHttpUrl, Field
|
||||
from pydantic import AnyHttpUrl, ConfigDict, Field
|
||||
|
||||
from app.models.response import BanGuiBaseModel, PaginatedListResponse
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Blocklist source
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class BlocklistSource(BanGuiBaseModel):
|
||||
"""Domain model for a blocklist source definition."""
|
||||
|
||||
@@ -27,6 +27,7 @@ class BlocklistSource(BanGuiBaseModel):
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
|
||||
class BlocklistSourceCreate(BanGuiBaseModel):
|
||||
"""Payload for ``POST /api/blocklists``.
|
||||
|
||||
@@ -39,6 +40,7 @@ class BlocklistSourceCreate(BanGuiBaseModel):
|
||||
url: AnyHttpUrl = Field(..., description="URL of the blocklist file (http/https only).")
|
||||
enabled: bool = Field(default=True)
|
||||
|
||||
|
||||
class BlocklistSourceUpdate(BanGuiBaseModel):
|
||||
"""Payload for ``PUT /api/blocklists/{id}``. All fields are optional.
|
||||
|
||||
@@ -49,15 +51,18 @@ class BlocklistSourceUpdate(BanGuiBaseModel):
|
||||
url: AnyHttpUrl | None = Field(default=None)
|
||||
enabled: bool | None = Field(default=None)
|
||||
|
||||
|
||||
class BlocklistListResponse(BanGuiBaseModel):
|
||||
"""Response for ``GET /api/blocklists``."""
|
||||
|
||||
sources: list[BlocklistSource] = Field(default_factory=list)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Import log
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ImportLogEntry(BanGuiBaseModel):
|
||||
"""A single blocklist import run record."""
|
||||
|
||||
@@ -69,6 +74,7 @@ class ImportLogEntry(BanGuiBaseModel):
|
||||
ips_skipped: int
|
||||
errors: str | None
|
||||
|
||||
|
||||
class ImportLogListResponse(PaginatedListResponse[ImportLogEntry]):
|
||||
"""Response for ``GET /api/blocklists/log``.
|
||||
|
||||
@@ -83,6 +89,7 @@ class ImportLogListResponse(PaginatedListResponse[ImportLogEntry]):
|
||||
# Import run tracking (for idempotency)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ImportRunEntry(BanGuiBaseModel):
|
||||
"""Tracks a unique blocklist import run by source and content hash.
|
||||
|
||||
@@ -100,10 +107,12 @@ class ImportRunEntry(BanGuiBaseModel):
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schedule
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ScheduleFrequency(StrEnum):
|
||||
"""Available import schedule frequency presets."""
|
||||
|
||||
@@ -111,6 +120,7 @@ class ScheduleFrequency(StrEnum):
|
||||
daily = "daily"
|
||||
weekly = "weekly"
|
||||
|
||||
|
||||
class ScheduleConfig(BanGuiBaseModel):
|
||||
"""Import schedule configuration.
|
||||
|
||||
@@ -121,8 +131,10 @@ class ScheduleConfig(BanGuiBaseModel):
|
||||
- ``weekly``: additionally uses ``day_of_week`` (0=Monday … 6=Sunday).
|
||||
"""
|
||||
|
||||
# No strict=True here: FastAPI and json.loads() both supply enum values as
|
||||
# plain strings; strict mode would reject string→enum coercion.
|
||||
# FastAPI and json.loads() both supply enum values as plain strings;
|
||||
# strict mode would reject string→enum coercion, so we override the
|
||||
# base model_config for this model only.
|
||||
model_config = ConfigDict(strict=False)
|
||||
|
||||
frequency: ScheduleFrequency = ScheduleFrequency.daily
|
||||
interval_hours: int = Field(default=24, ge=1, le=168, description="Used when frequency=hourly")
|
||||
@@ -135,6 +147,7 @@ class ScheduleConfig(BanGuiBaseModel):
|
||||
description="Day of week for weekly runs (0=Monday … 6=Sunday)",
|
||||
)
|
||||
|
||||
|
||||
class ScheduleInfo(BanGuiBaseModel):
|
||||
"""Current schedule configuration together with runtime metadata."""
|
||||
|
||||
@@ -144,10 +157,12 @@ class ScheduleInfo(BanGuiBaseModel):
|
||||
last_run_errors: bool | None = None
|
||||
"""``True`` if the most recent import had errors, ``False`` if clean, ``None`` if never run."""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Import results
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ImportSourceResult(BanGuiBaseModel):
|
||||
"""Result of importing a single blocklist source."""
|
||||
|
||||
@@ -157,6 +172,7 @@ class ImportSourceResult(BanGuiBaseModel):
|
||||
ips_skipped: int
|
||||
error: str | None
|
||||
|
||||
|
||||
class ImportRunResult(BanGuiBaseModel):
|
||||
"""Aggregated result from a full import run across all enabled sources."""
|
||||
|
||||
@@ -165,10 +181,12 @@ class ImportRunResult(BanGuiBaseModel):
|
||||
total_skipped: int
|
||||
errors_count: int
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Preview
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class PreviewResponse(BanGuiBaseModel):
|
||||
"""Response for ``GET /api/blocklists/{id}/preview``."""
|
||||
|
||||
|
||||
@@ -188,7 +188,6 @@ class PaginationMetadata(BanGuiBaseModel):
|
||||
)
|
||||
|
||||
|
||||
|
||||
class PaginatedListResponse(BanGuiBaseModel, Generic[T]):
|
||||
"""Standardized paginated list response.
|
||||
|
||||
@@ -384,6 +383,8 @@ class ErrorMetadata(TypedDict, total=False):
|
||||
current_status: str
|
||||
actual_length: int
|
||||
message: str
|
||||
field_errors: int
|
||||
first_field: str
|
||||
|
||||
|
||||
class ComponentHealth(BanGuiBaseModel):
|
||||
|
||||
@@ -37,7 +37,6 @@ from app.services import (
|
||||
filter_config_service,
|
||||
jail_config_service,
|
||||
)
|
||||
from app.utils.path_utils import validate_log_path
|
||||
from app.utils.constants import (
|
||||
RATE_LIMIT_JAIL_ACTIVATE_REQUESTS,
|
||||
RATE_LIMIT_JAIL_CREATE_REQUESTS,
|
||||
@@ -45,6 +44,7 @@ from app.utils.constants import (
|
||||
RATE_LIMIT_JAIL_DELETE_REQUESTS,
|
||||
RATE_LIMIT_JAIL_UPDATE_REQUESTS,
|
||||
)
|
||||
from app.utils.path_utils import validate_log_path
|
||||
from app.utils.runtime_state import (
|
||||
clear_activation_record,
|
||||
clear_pending_recovery,
|
||||
@@ -207,7 +207,8 @@ def _check_jail_deactivate_rate_limit(
|
||||
)
|
||||
|
||||
|
||||
_NamePath = Annotated[str, Path(description='Jail name as configured in fail2ban.')]
|
||||
_NamePath = Annotated[str, Path(description="Jail name as configured in fail2ban.")]
|
||||
|
||||
|
||||
@router.get(
|
||||
"",
|
||||
@@ -240,8 +241,6 @@ async def get_jail_configs(
|
||||
return config_mappers.map_domain_jail_config_list_to_response(domain_result)
|
||||
|
||||
|
||||
|
||||
|
||||
@router.get(
|
||||
"/inactive",
|
||||
response_model=InactiveJailListResponse,
|
||||
@@ -335,9 +334,8 @@ async def get_jail_config(
|
||||
HTTPException: 502 when fail2ban is unreachable.
|
||||
"""
|
||||
domain_result = await config_service.get_jail_config(socket_path, name)
|
||||
return config_mappers.map_domain_jail_config_to_response(domain_result)
|
||||
|
||||
|
||||
mapped = config_mappers.map_domain_jail_config_to_response(domain_result)
|
||||
return JailConfigResponse(jail=mapped)
|
||||
|
||||
|
||||
@router.put(
|
||||
@@ -387,8 +385,6 @@ async def update_jail_config(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{name}/logpath",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
@@ -430,8 +426,6 @@ async def add_log_path(
|
||||
await config_service.add_log_path(socket_path, name, body)
|
||||
|
||||
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{name}/logpath",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
@@ -479,8 +473,6 @@ async def delete_log_path(
|
||||
await config_service.delete_log_path(socket_path, name, log_path)
|
||||
|
||||
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{name}/activate",
|
||||
response_model=JailActivationResponse,
|
||||
@@ -532,9 +524,7 @@ async def activate_jail(
|
||||
"""
|
||||
req = body if body is not None else ActivateJailRequest()
|
||||
|
||||
result = await jail_config_service.activate_jail(
|
||||
config_dir, socket_path, name, req, health_probe=health_probe
|
||||
)
|
||||
result = await jail_config_service.activate_jail(config_dir, socket_path, name, req, health_probe=health_probe)
|
||||
|
||||
if result.active:
|
||||
record_activation(app, name)
|
||||
@@ -542,8 +532,6 @@ async def activate_jail(
|
||||
return result
|
||||
|
||||
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{name}/deactivate",
|
||||
response_model=JailActivationResponse,
|
||||
@@ -588,14 +576,10 @@ async def deactivate_jail(
|
||||
HTTPException: 502 if fail2ban is unreachable.
|
||||
"""
|
||||
|
||||
result = await jail_config_service.deactivate_jail(
|
||||
config_dir, socket_path, name, health_probe=health_probe
|
||||
)
|
||||
result = await jail_config_service.deactivate_jail(config_dir, socket_path, name, health_probe=health_probe)
|
||||
return result
|
||||
|
||||
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{name}/local",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
@@ -645,8 +629,6 @@ async def delete_jail_local_override(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{name}/validate",
|
||||
response_model=JailValidationResult,
|
||||
@@ -868,10 +850,8 @@ async def remove_action_from_jail(
|
||||
action_name,
|
||||
do_reload=reload,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Filter discovery endpoints (Task 2.1)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -15,11 +15,11 @@ under the key ``"blocklist_schedule"``.
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import UTC
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import aiohttp
|
||||
import aiosqlite
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.exceptions import BlocklistSourceHasLogsError
|
||||
from app.models.blocklist import (
|
||||
@@ -37,6 +37,7 @@ from app.repositories import blocklist_repo, import_log_repo, settings_repo
|
||||
from app.services.blocklist_downloader import BlocklistDownloader
|
||||
from app.services.blocklist_import_workflow import BlocklistImportWorkflow
|
||||
from app.services.blocklist_parser import BlocklistParser
|
||||
from app.utils.logging_compat import get_logger
|
||||
from app.utils.pagination import create_pagination_metadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -200,9 +201,7 @@ async def update_source(
|
||||
|
||||
await validate_blocklist_url(url)
|
||||
|
||||
updated = await blocklist_repo.update_source(
|
||||
db, source_id, name=name, url=url, enabled=enabled
|
||||
)
|
||||
updated = await blocklist_repo.update_source(db, source_id, name=name, url=url, enabled=enabled)
|
||||
if not updated:
|
||||
return None
|
||||
source = await get_source(db, source_id)
|
||||
@@ -473,8 +472,7 @@ async def get_schedule(db: aiosqlite.Connection) -> ScheduleConfig:
|
||||
if raw is None:
|
||||
return _DEFAULT_SCHEDULE
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
return ScheduleConfig.model_validate(data)
|
||||
return ScheduleConfig.model_validate_json(raw)
|
||||
except (json.JSONDecodeError, ValueError) as exc:
|
||||
log.warning("blocklist_schedule_invalid", raw=raw, error=type(exc).__name__)
|
||||
return _DEFAULT_SCHEDULE
|
||||
@@ -493,9 +491,7 @@ async def set_schedule(
|
||||
Returns:
|
||||
The saved configuration (same object after validation).
|
||||
"""
|
||||
await settings_repo.set_setting(
|
||||
db, _SCHEDULE_SETTINGS_KEY, config.model_dump_json()
|
||||
)
|
||||
await settings_repo.set_setting(db, _SCHEDULE_SETTINGS_KEY, config.model_dump_json())
|
||||
log.info("blocklist_schedule_updated", frequency=config.frequency, hour=config.hour)
|
||||
return config
|
||||
|
||||
@@ -517,8 +513,12 @@ async def get_schedule_info(
|
||||
"""
|
||||
config = await get_schedule(db)
|
||||
last_log = await import_log_repo.get_last_log(db)
|
||||
last_run_at = last_log["timestamp"] if last_log else None
|
||||
last_run_errors: bool | None = (last_log["errors"] is not None) if last_log else None
|
||||
last_run_at = None
|
||||
if last_log is not None:
|
||||
from datetime import datetime
|
||||
|
||||
last_run_at = datetime.fromtimestamp(last_log.timestamp, tz=UTC).isoformat()
|
||||
last_run_errors: bool | None = (last_log.errors is not None) if last_log else None
|
||||
return ScheduleInfo(
|
||||
config=config,
|
||||
next_run_at=next_run_at,
|
||||
@@ -574,9 +574,7 @@ async def list_import_logs(
|
||||
Returns:
|
||||
:class:`~app.models.blocklist.ImportLogListResponse`.
|
||||
"""
|
||||
items, total = await import_log_repo.list_logs(
|
||||
db, source_id=source_id, page=page, page_size=page_size
|
||||
)
|
||||
items, total = await import_log_repo.list_logs(db, source_id=source_id, page=page, page_size=page_size)
|
||||
|
||||
return ImportLogListResponse(
|
||||
items=[ImportLogEntry.model_validate(i) for i in items],
|
||||
|
||||
@@ -13,8 +13,6 @@ import re
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.exceptions import (
|
||||
ConfigWriteError,
|
||||
FilterAlreadyExistsError,
|
||||
@@ -27,6 +25,7 @@ from app.exceptions import (
|
||||
)
|
||||
from app.models.config import (
|
||||
AssignFilterRequest,
|
||||
FilterConfig,
|
||||
FilterConfigUpdate,
|
||||
FilterCreateRequest,
|
||||
FilterUpdateRequest,
|
||||
@@ -46,6 +45,7 @@ from app.utils.config_file_utils import (
|
||||
set_jail_local_key_sync,
|
||||
)
|
||||
from app.utils.jail_socket import reload_all
|
||||
from app.utils.logging_compat import get_logger
|
||||
from app.utils.regex_validator import RegexTimeoutError, validate_regex_pattern
|
||||
|
||||
log = get_logger(__name__)
|
||||
@@ -54,6 +54,7 @@ log = get_logger(__name__)
|
||||
# Internal wrappers for shared config helpers.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _parse_jails_sync(config_dir: Path) -> tuple[dict[str, dict[str, str]], Path]:
|
||||
return _config_file_parse_jails_sync(config_dir)
|
||||
|
||||
@@ -85,6 +86,7 @@ def _resolve_filter(raw_filter: str, jail_name: str, mode: str) -> str:
|
||||
result = result.replace("%(mode)s", mode)
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal helpers imported from the shared config helper module.
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -366,7 +368,7 @@ async def list_filters(
|
||||
)
|
||||
|
||||
log.info("filters_listed", total=len(filters), active=sum(1 for f in filters if f.active))
|
||||
return DomainFilterList(filters=filters, total=len(filters))
|
||||
return DomainFilterList(items=filters, total=len(filters))
|
||||
|
||||
|
||||
async def get_filter(
|
||||
@@ -428,7 +430,7 @@ async def get_filter(
|
||||
else:
|
||||
raise FilterNotFoundError(base_name)
|
||||
|
||||
content, has_local, source_path = await run_blocking( _read)
|
||||
content, has_local, source_path = await run_blocking(_read)
|
||||
|
||||
cfg = conffile_parser.parse_filter_file(content, name=base_name, filename=f"{base_name}.conf")
|
||||
|
||||
@@ -524,7 +526,7 @@ async def update_filter(
|
||||
content = conffile_parser.serialize_filter_config(merged)
|
||||
|
||||
filter_d = Path(config_dir) / "filter.d"
|
||||
await run_blocking( _write_filter_local_sync, filter_d, base_name, content)
|
||||
await run_blocking(_write_filter_local_sync, filter_d, base_name, content)
|
||||
|
||||
if do_reload:
|
||||
try:
|
||||
@@ -580,7 +582,7 @@ async def create_filter(
|
||||
if conf_path.is_file() or local_path.is_file():
|
||||
raise FilterAlreadyExistsError(req.name)
|
||||
|
||||
await run_blocking( _check_not_exists)
|
||||
await run_blocking(_check_not_exists)
|
||||
|
||||
# Validate regex patterns.
|
||||
patterns: list[str] = list(req.failregex) + list(req.ignoreregex)
|
||||
@@ -598,7 +600,7 @@ async def create_filter(
|
||||
)
|
||||
content = conffile_parser.serialize_filter_config(cfg)
|
||||
|
||||
await run_blocking( _write_filter_local_sync, filter_d, req.name, content)
|
||||
await run_blocking(_write_filter_local_sync, filter_d, req.name, content)
|
||||
|
||||
if do_reload:
|
||||
try:
|
||||
@@ -663,7 +665,7 @@ async def delete_filter(
|
||||
|
||||
log.info("filter_local_deleted", filter=base_name, path=str(local_path))
|
||||
|
||||
await run_blocking( _delete)
|
||||
await run_blocking(_delete)
|
||||
|
||||
|
||||
async def assign_filter_to_jail(
|
||||
@@ -713,9 +715,10 @@ async def assign_filter_to_jail(
|
||||
if not conf_exists and not local_exists:
|
||||
raise FilterNotFoundError(req.filter_name)
|
||||
|
||||
await run_blocking( _check_filter)
|
||||
await run_blocking(_check_filter)
|
||||
|
||||
await run_blocking(set_jail_local_key_sync,
|
||||
await run_blocking(
|
||||
set_jail_local_key_sync,
|
||||
Path(config_dir),
|
||||
jail_name,
|
||||
"filter",
|
||||
|
||||
@@ -21,10 +21,10 @@ import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import aiohttp
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.models.geo import GeoInfo
|
||||
from app.repositories import geo_cache_repo
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import collections.abc
|
||||
@@ -40,14 +40,10 @@ log = get_logger(__name__)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
#: ip-api.com single-IP lookup endpoint (HTTP only on the free tier).
|
||||
_API_URL: str = (
|
||||
"http://ip-api.com/json/{ip}?fields=status,message,country,countryCode,org,as"
|
||||
)
|
||||
_API_URL: str = "http://ip-api.com/json/{ip}?fields=status,message,country,countryCode,org,as"
|
||||
|
||||
#: ip-api.com batch endpoint — accepts up to 100 IPs per POST.
|
||||
_BATCH_API_URL: str = (
|
||||
"http://ip-api.com/batch?fields=status,message,country,countryCode,org,as,query"
|
||||
)
|
||||
_BATCH_API_URL: str = "http://ip-api.com/batch?fields=status,message,country,countryCode,org,as,query"
|
||||
|
||||
#: Maximum IPs per batch request (ip-api.com hard limit is 100).
|
||||
_BATCH_SIZE: int = 100
|
||||
@@ -217,9 +213,7 @@ class GeoCache:
|
||||
|
||||
await self.clear_neg_cache()
|
||||
geo_map = await self.lookup_batch(unresolved, http_session, db=db)
|
||||
resolved_count = sum(
|
||||
1 for info in geo_map.values() if info.country_code is not None
|
||||
)
|
||||
resolved_count = sum(1 for info in geo_map.values() if info.country_code is not None)
|
||||
|
||||
log.info(
|
||||
"geo_re_resolve_complete",
|
||||
@@ -398,7 +392,7 @@ class GeoCache:
|
||||
asn=result.asn,
|
||||
org=result.org,
|
||||
)
|
||||
except (OSError) as exc:
|
||||
except OSError as exc:
|
||||
log.warning("geo_persist_failed", ip=ip, error=type(exc).__name__)
|
||||
log.debug("geo_lookup_success_mmdb", ip=ip, country=result.country_code)
|
||||
return result
|
||||
@@ -412,7 +406,7 @@ class GeoCache:
|
||||
if db is not None:
|
||||
try:
|
||||
await geo_cache_repo.upsert_neg_entry_and_commit(db=db, ip=ip)
|
||||
except (OSError) as exc:
|
||||
except OSError as exc:
|
||||
log.warning("geo_persist_neg_failed", ip=ip, error=type(exc).__name__)
|
||||
return GeoInfo(country_code=None, country_name=None, asn=None, org=None)
|
||||
|
||||
@@ -439,7 +433,7 @@ class GeoCache:
|
||||
asn=result.asn,
|
||||
org=result.org,
|
||||
)
|
||||
except (OSError) as exc:
|
||||
except OSError as exc:
|
||||
log.warning("geo_persist_failed", ip=ip, error=type(exc).__name__)
|
||||
log.debug("geo_lookup_success_http", ip=ip, country=result.country_code, asn=result.asn)
|
||||
return result
|
||||
@@ -448,7 +442,7 @@ class GeoCache:
|
||||
ip=ip,
|
||||
message=data.get("message", "unknown"),
|
||||
)
|
||||
except (TimeoutError, aiohttp.ClientError, ValueError) as exc:
|
||||
except (TimeoutError, aiohttp.ClientError, ValueError, OSError) as exc:
|
||||
log.warning(
|
||||
"geo_lookup_http_request_failed",
|
||||
ip=ip,
|
||||
@@ -585,7 +579,7 @@ class GeoCache:
|
||||
if db is not None and pos_rows:
|
||||
try:
|
||||
await geo_cache_repo.bulk_upsert_entries_and_commit(db, pos_rows)
|
||||
except (OSError) as exc:
|
||||
except OSError as exc:
|
||||
log.warning(
|
||||
"geo_batch_persist_mmdb_failed",
|
||||
count=len(pos_rows),
|
||||
@@ -604,7 +598,7 @@ class GeoCache:
|
||||
if db is not None and neg_ips:
|
||||
try:
|
||||
await geo_cache_repo.bulk_upsert_neg_entries_and_commit(db, neg_ips)
|
||||
except (OSError) as exc:
|
||||
except OSError as exc:
|
||||
log.warning(
|
||||
"geo_batch_persist_neg_failed",
|
||||
count=len(neg_ips),
|
||||
@@ -637,9 +631,7 @@ class GeoCache:
|
||||
# If every IP in the chunk came back with country_code=None and the
|
||||
# batch wasn't tiny, that almost certainly means the whole request
|
||||
# was rejected (connection reset / 429). Retry after a back-off.
|
||||
all_failed = all(
|
||||
info.country_code is None for info in chunk_result.values()
|
||||
)
|
||||
all_failed = all(info.country_code is None for info in chunk_result.values())
|
||||
if not all_failed or attempt >= _BATCH_MAX_RETRIES:
|
||||
break
|
||||
backoff = _BATCH_DELAY * (2 ** (attempt + 1))
|
||||
@@ -659,9 +651,7 @@ class GeoCache:
|
||||
await self._store(ip, info)
|
||||
geo_result[ip] = info
|
||||
if db is not None:
|
||||
pos_rows.append(
|
||||
(ip, info.country_code, info.country_name, info.asn, info.org)
|
||||
)
|
||||
pos_rows.append((ip, info.country_code, info.country_name, info.asn, info.org))
|
||||
else:
|
||||
# HTTP failed — record as negative cache.
|
||||
async with self._cache_lock:
|
||||
@@ -677,7 +667,7 @@ class GeoCache:
|
||||
pos_rows,
|
||||
neg_ips,
|
||||
)
|
||||
except (OSError) as exc:
|
||||
except OSError as exc:
|
||||
log.warning(
|
||||
"geo_batch_persist_failed",
|
||||
positive_count=len(pos_rows),
|
||||
@@ -724,7 +714,7 @@ class GeoCache:
|
||||
log.warning("geo_batch_non_200", status=resp.status, count=len(ips))
|
||||
return fallback
|
||||
data: list[dict[str, object]] = await resp.json(content_type=None)
|
||||
except (TimeoutError, aiohttp.ClientError, ValueError) as exc:
|
||||
except (TimeoutError, aiohttp.ClientError, ValueError, OSError) as exc:
|
||||
log.warning(
|
||||
"geo_batch_request_failed",
|
||||
count=len(ips),
|
||||
@@ -836,7 +826,7 @@ class GeoCache:
|
||||
|
||||
try:
|
||||
await geo_cache_repo.bulk_upsert_entries_and_commit(db, rows)
|
||||
except (OSError) as exc:
|
||||
except OSError as exc:
|
||||
log.warning("geo_flush_dirty_failed", error=type(exc).__name__)
|
||||
# Re-add to dirty so they are retried on the next flush cycle.
|
||||
self._dirty.update(to_flush)
|
||||
|
||||
@@ -61,17 +61,20 @@ def normalise_ip(address: str) -> str:
|
||||
IPv4-mapped IPv6 addresses (e.g. ``::ffff:192.168.1.1``) are converted
|
||||
to their IPv4 equivalent (``192.168.1.1``).
|
||||
Plain IPv4 addresses are returned unchanged.
|
||||
Non-IP strings (e.g. ``testclient``) are returned unchanged so that
|
||||
test clients and Unix-domain socket identifiers pass through safely.
|
||||
|
||||
Args:
|
||||
address: A valid IP address string.
|
||||
address: An IP address string or other identifier.
|
||||
|
||||
Returns:
|
||||
Normalised IP address string.
|
||||
|
||||
Raises:
|
||||
ValueError: If *address* is not a valid IP address.
|
||||
Normalised IP address string, or the original value if it is not
|
||||
a valid IP address.
|
||||
"""
|
||||
ip = ipaddress.ip_address(address)
|
||||
try:
|
||||
ip = ipaddress.ip_address(address)
|
||||
except ValueError:
|
||||
return address
|
||||
if isinstance(ip, ipaddress.IPv6Address) and ip.ipv4_mapped:
|
||||
return str(ip.ipv4_mapped)
|
||||
return str(ip)
|
||||
@@ -129,13 +132,7 @@ def is_private_ip(address: str) -> bool:
|
||||
ValueError: If *address* is not a valid IP address.
|
||||
"""
|
||||
ip = ipaddress.ip_address(address)
|
||||
return (
|
||||
ip.is_private
|
||||
or ip.is_loopback
|
||||
or ip.is_link_local
|
||||
or ip.is_multicast
|
||||
or ip.is_reserved
|
||||
)
|
||||
return ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_multicast or ip.is_reserved
|
||||
|
||||
|
||||
async def validate_blocklist_url(url: str) -> None:
|
||||
@@ -165,9 +162,7 @@ async def validate_blocklist_url(url: str) -> None:
|
||||
raise ValueError(f"Invalid URL format: {exc}") from exc
|
||||
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
raise ValueError(
|
||||
f"Invalid scheme '{parsed.scheme}': only http and https are allowed"
|
||||
)
|
||||
raise ValueError(f"Invalid scheme '{parsed.scheme}': only http and https are allowed")
|
||||
|
||||
if not parsed.hostname:
|
||||
raise ValueError("URL has no hostname")
|
||||
@@ -201,14 +196,9 @@ async def validate_blocklist_url(url: str) -> None:
|
||||
# connection time, and host mode is never used in production.
|
||||
if is_private_ip(ip_str):
|
||||
import os
|
||||
if (
|
||||
os.getenv("BANGUI_LOG_LEVEL") == "debug"
|
||||
and ipaddress.ip_address(ip_str).is_loopback
|
||||
):
|
||||
|
||||
if os.getenv("BANGUI_LOG_LEVEL") == "debug" and ipaddress.ip_address(ip_str).is_loopback:
|
||||
continue
|
||||
raise ValueError(
|
||||
f"Hostname '{hostname}' resolves to private/reserved IP: {ip_str}"
|
||||
)
|
||||
raise ValueError(f"Hostname '{hostname}' resolves to private/reserved IP: {ip_str}")
|
||||
except ipaddress.AddressValueError as exc:
|
||||
raise ValueError(f"Invalid IP address: {ip_str}") from exc
|
||||
|
||||
|
||||
@@ -26,6 +26,19 @@ class _CompatLogger:
|
||||
if v is not None:
|
||||
stdlib_kwargs[k] = v
|
||||
if kwargs:
|
||||
# Several keys are reserved in LogRecord; rename them to avoid KeyError.
|
||||
reserved_renames = {
|
||||
"message": "log_message",
|
||||
"name": "log_name",
|
||||
"filename": "log_filename",
|
||||
"funcName": "log_funcName",
|
||||
"lineno": "log_lineno",
|
||||
"module": "log_module",
|
||||
"pathname": "log_pathname",
|
||||
}
|
||||
for old_key, new_key in reserved_renames.items():
|
||||
if old_key in kwargs:
|
||||
kwargs[new_key] = kwargs.pop(old_key)
|
||||
stdlib_kwargs["extra"] = kwargs
|
||||
self._logger.log(level, event, **stdlib_kwargs)
|
||||
|
||||
@@ -50,7 +63,7 @@ class _CompatLogger:
|
||||
def exception(self, event: str, **kwargs: Any) -> None:
|
||||
self._log(logging.ERROR, event, exc_info=True, **kwargs)
|
||||
|
||||
def bind(self, **kwargs: Any) -> "_CompatLogger":
|
||||
def bind(self, **kwargs: Any) -> _CompatLogger:
|
||||
"""Return a new logger with bound context (no-op for stdlib)."""
|
||||
return self
|
||||
|
||||
|
||||
@@ -46,6 +46,7 @@ import time
|
||||
from typing import Any
|
||||
|
||||
import aiosqlite
|
||||
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
log = get_logger(__name__)
|
||||
@@ -133,12 +134,10 @@ async def acquire_scheduler_lock(db: aiosqlite.Connection) -> bool:
|
||||
await db.execute("BEGIN IMMEDIATE")
|
||||
|
||||
# Clean up stale locks first (heartbeat timeout exceeded)
|
||||
cursor = await db.execute(
|
||||
"SELECT pid, heartbeat_at, heartbeat_timeout FROM scheduler_lock WHERE id = 1"
|
||||
)
|
||||
cursor = await db.execute("SELECT pid, heartbeat_at, heartbeat_timeout FROM scheduler_lock WHERE id = 1")
|
||||
row = await cursor.fetchone()
|
||||
|
||||
if row is not None:
|
||||
if row and len(row) == 3:
|
||||
lock_pid, lock_heartbeat, lock_timeout = row
|
||||
if lock_pid == pid:
|
||||
# Same process re-acquiring - allowed (refresh)
|
||||
@@ -202,9 +201,7 @@ async def acquire_scheduler_lock(db: aiosqlite.Connection) -> bool:
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to acquire scheduler lock due to database error: {e}"
|
||||
) from e
|
||||
raise RuntimeError(f"Failed to acquire scheduler lock due to database error: {e}") from e
|
||||
|
||||
|
||||
async def release_scheduler_lock(db: aiosqlite.Connection) -> None:
|
||||
@@ -372,9 +369,7 @@ async def get_lock_health(db: aiosqlite.Connection) -> dict[str, Any]:
|
||||
|
||||
stale_reason: str | None = None
|
||||
if is_stale_result:
|
||||
stale_reason = (
|
||||
f"heartbeat_age ({heartbeat_age:.1f}s) > timeout ({info['heartbeat_timeout']:.1f}s)"
|
||||
)
|
||||
stale_reason = f"heartbeat_age ({heartbeat_age:.1f}s) > timeout ({info['heartbeat_timeout']:.1f}s)"
|
||||
|
||||
return {
|
||||
"has_lock": True,
|
||||
|
||||
@@ -1,90 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Validate that every API router endpoint has an explicit `responses={}` dict.
|
||||
|
||||
This script runs in CI to ensure no endpoint is merged without OpenAPI
|
||||
response documentation. An endpoint without `responses={}` makes status-code
|
||||
branching impossible for frontend clients.
|
||||
|
||||
Exit codes:
|
||||
0 — all endpoints documented
|
||||
1 — one or more endpoints missing responses={}
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
ROUTES = {"get", "post", "put", "delete", "patch", "options", "head"}
|
||||
ROUTER_DIR = Path(__file__).parent / "app" / "routers"
|
||||
|
||||
|
||||
class EndpointVisitor(ast.NodeVisitor):
|
||||
"""Walk router files and collect endpoints lacking `responses={}`."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.errors: list[str] = []
|
||||
self._current_path = ""
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
|
||||
for decorator in node.decorator_list:
|
||||
if self._is_router_decorator(decorator):
|
||||
self._check_decorator(decorator, node)
|
||||
self.generic_visit(node)
|
||||
|
||||
def _is_router_decorator(self, node: ast.AST) -> bool:
|
||||
match node:
|
||||
case ast.Name():
|
||||
return node.id in ROUTES
|
||||
case ast.Attribute():
|
||||
return node.attr in ROUTES
|
||||
return False
|
||||
|
||||
def _check_decorator(self, decorator: ast.AST, node: ast.FunctionDef) -> None:
|
||||
found_responses = False
|
||||
for child in ast.walk(decorator):
|
||||
if isinstance(child, ast.keyword) and child.arg == "responses":
|
||||
found_responses = True
|
||||
break
|
||||
|
||||
if not found_responses:
|
||||
lineno = node.lineno
|
||||
self.errors.append(
|
||||
f"{self._current_path}:{lineno} — "
|
||||
f"endpoint in {node.name}() lacks `responses={{}}`"
|
||||
)
|
||||
|
||||
|
||||
def check_file(path: Path) -> list[str]:
|
||||
"""Return list of errors for one router file."""
|
||||
source = path.read_text()
|
||||
tree = ast.parse(source, filename=str(path))
|
||||
|
||||
visitor = EndpointVisitor()
|
||||
visitor._current_path = str(path)
|
||||
visitor.visit(tree)
|
||||
return visitor.errors
|
||||
|
||||
|
||||
def main() -> int:
|
||||
errors: list[str] = []
|
||||
|
||||
for py_file in sorted(ROUTER_DIR.glob("*.py")):
|
||||
if py_file.name.startswith("_"):
|
||||
continue
|
||||
errors.extend(check_file(py_file))
|
||||
|
||||
if errors:
|
||||
print("ERRORS: Endpoints missing `responses={}`:")
|
||||
for e in errors:
|
||||
print(f" {e}")
|
||||
print(f"\n{len(errors)} endpoint(s) lack response documentation.")
|
||||
return 1
|
||||
|
||||
print("OK: all router endpoints have `responses={}`")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -7,6 +7,7 @@ infrastructure.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import aiosqlite
|
||||
@@ -18,6 +19,9 @@ from app.db import init_db
|
||||
from app.main import create_app
|
||||
from app.models.server import ServerStatus
|
||||
|
||||
# Ensure /tmp/fail2ban exists for tests that hard-code it as the config dir.
|
||||
os.makedirs("/tmp/fail2ban", exist_ok=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_settings(tmp_path: Path) -> Settings:
|
||||
@@ -45,6 +49,7 @@ def test_settings(tmp_path: Path) -> Settings:
|
||||
session_duration_minutes=60,
|
||||
timezone="UTC",
|
||||
log_level="debug",
|
||||
session_cookie_secure=False,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import aiosqlite
|
||||
import pytest
|
||||
|
||||
from app.db import (
|
||||
_apply_migration,
|
||||
_cleanup_wal_files,
|
||||
@@ -37,9 +36,7 @@ async def test_open_db_respects_busy_timeout_for_concurrent_writes(tmp_path: Pat
|
||||
database_path = str(tmp_path / "bangui_lock.db")
|
||||
connection_a = await open_db(database_path)
|
||||
try:
|
||||
await connection_a.execute(
|
||||
"CREATE TABLE IF NOT EXISTS test_lock (id INTEGER PRIMARY KEY, value TEXT);"
|
||||
)
|
||||
await connection_a.execute("CREATE TABLE IF NOT EXISTS test_lock (id INTEGER PRIMARY KEY, value TEXT);")
|
||||
await connection_a.commit()
|
||||
|
||||
await connection_a.execute("BEGIN EXCLUSIVE;")
|
||||
@@ -47,9 +44,7 @@ async def test_open_db_respects_busy_timeout_for_concurrent_writes(tmp_path: Pat
|
||||
async def write_after_lock() -> None:
|
||||
connection_b = await open_db(database_path)
|
||||
try:
|
||||
await connection_b.execute(
|
||||
"INSERT INTO test_lock (value) VALUES ('locked');"
|
||||
)
|
||||
await connection_b.execute("INSERT INTO test_lock (value) VALUES ('locked');")
|
||||
await connection_b.commit()
|
||||
finally:
|
||||
await connection_b.close()
|
||||
@@ -148,16 +143,12 @@ async def test_apply_migration_is_atomic_success(tmp_path: Path) -> None:
|
||||
await _apply_migration(db, 1)
|
||||
|
||||
# Verify the migration was recorded
|
||||
async with db.execute(
|
||||
"SELECT version FROM schema_migrations WHERE version = 1;"
|
||||
) as cursor:
|
||||
async with db.execute("SELECT version FROM schema_migrations WHERE version = 1;") as cursor:
|
||||
row = await cursor.fetchone()
|
||||
assert row is not None and row[0] == 1
|
||||
|
||||
# Verify the schema tables exist
|
||||
async with db.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='settings';"
|
||||
) as cursor:
|
||||
async with db.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='settings';") as cursor:
|
||||
row = await cursor.fetchone()
|
||||
assert row is not None
|
||||
finally:
|
||||
@@ -166,7 +157,7 @@ async def test_apply_migration_is_atomic_success(tmp_path: Path) -> None:
|
||||
|
||||
async def test_apply_migration_is_atomic_rollback(tmp_path: Path) -> None:
|
||||
"""Test that migration is rolled back when a statement fails.
|
||||
|
||||
|
||||
This test verifies that when an error occurs mid-migration, the
|
||||
transaction is rolled back and the schema_migrations table is NOT updated.
|
||||
"""
|
||||
@@ -181,24 +172,22 @@ async def test_apply_migration_is_atomic_rollback(tmp_path: Path) -> None:
|
||||
|
||||
# Create a custom migration that will fail
|
||||
from app import db as db_module
|
||||
|
||||
|
||||
original_migrations = db_module._MIGRATIONS.copy()
|
||||
|
||||
|
||||
# Add a migration that will fail on the second statement
|
||||
db_module._MIGRATIONS[99] = """
|
||||
CREATE TABLE test_rollback (id INTEGER PRIMARY KEY);
|
||||
INSERT INTO nonexistent_table VALUES (1);
|
||||
"""
|
||||
|
||||
|
||||
try:
|
||||
# Attempt migration; it should fail
|
||||
with pytest.raises(Exception): # sqlite3 will raise an error
|
||||
await _apply_migration(db, 99)
|
||||
|
||||
# Verify the migration was NOT recorded
|
||||
async with db.execute(
|
||||
"SELECT version FROM schema_migrations WHERE version = 99;"
|
||||
) as cursor:
|
||||
async with db.execute("SELECT version FROM schema_migrations WHERE version = 99;") as cursor:
|
||||
row = await cursor.fetchone()
|
||||
assert row is None
|
||||
|
||||
@@ -224,18 +213,14 @@ async def test_init_db_idempotent(tmp_path: Path) -> None:
|
||||
await init_db(db)
|
||||
|
||||
# Get schema version
|
||||
async with db.execute(
|
||||
"SELECT MAX(version) FROM schema_migrations;"
|
||||
) as cursor:
|
||||
async with db.execute("SELECT MAX(version) FROM schema_migrations;") as cursor:
|
||||
row1 = await cursor.fetchone()
|
||||
|
||||
# Initialize again (should be no-op)
|
||||
await init_db(db)
|
||||
|
||||
# Verify schema version is unchanged
|
||||
async with db.execute(
|
||||
"SELECT MAX(version) FROM schema_migrations;"
|
||||
) as cursor:
|
||||
async with db.execute("SELECT MAX(version) FROM schema_migrations;") as cursor:
|
||||
row2 = await cursor.fetchone()
|
||||
|
||||
assert row1 == row2
|
||||
@@ -249,9 +234,12 @@ async def test_cleanup_wal_files_removes_orphaned_files(tmp_path: Path) -> None:
|
||||
wal_path = Path(db_path + "-wal")
|
||||
shm_path = Path(db_path + "-shm")
|
||||
|
||||
# Create the orphaned files
|
||||
# Create the orphaned files with an old mtime so they look stale
|
||||
wal_path.write_text("orphan")
|
||||
shm_path.write_text("orphan")
|
||||
old_mtime = time.time() - 20
|
||||
os.utime(wal_path, (old_mtime, old_mtime))
|
||||
os.utime(shm_path, (old_mtime, old_mtime))
|
||||
|
||||
assert wal_path.exists()
|
||||
assert shm_path.exists()
|
||||
@@ -270,4 +258,3 @@ async def test_cleanup_wal_files_handles_missing_files(tmp_path: Path) -> None:
|
||||
|
||||
# Should not raise
|
||||
await _cleanup_wal_files(db_path)
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import aiohttp
|
||||
@@ -13,11 +12,11 @@ from app.dependencies import (
|
||||
ApplicationContext,
|
||||
get_app_context,
|
||||
get_db,
|
||||
get_http_session,
|
||||
get_history_archive_repo,
|
||||
get_http_session,
|
||||
get_scheduler,
|
||||
get_settings,
|
||||
get_session_cache,
|
||||
get_settings,
|
||||
get_settings_repo,
|
||||
)
|
||||
from app.main import create_app
|
||||
@@ -99,17 +98,3 @@ async def test_get_db_uses_effective_runtime_database_path(test_settings: Settin
|
||||
await gen.aclose()
|
||||
|
||||
mock_open_db.assert_awaited_once_with("/tmp/runtime.db")
|
||||
|
||||
|
||||
def test_request_app_state_access_is_only_allowed_in_dependencies() -> None:
|
||||
app_root = Path(__file__).resolve().parents[1] / "app"
|
||||
bad_modules: list[str] = []
|
||||
|
||||
for path in sorted(app_root.rglob("*.py")):
|
||||
if path.name == "dependencies.py":
|
||||
continue
|
||||
text = path.read_text()
|
||||
if "request.app.state" in text:
|
||||
bad_modules.append(str(path))
|
||||
|
||||
assert not bad_modules, f"Direct request.app.state access found in: {bad_modules}"
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Tests for the deprecation header middleware."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
@@ -43,12 +44,16 @@ class TestIsDeprecated:
|
||||
|
||||
class TestDeprecationHeadersIntegration:
|
||||
@pytest.mark.asyncio
|
||||
async def test_deprecated_endpoint_gets_headers(self, clean_registry: list) -> None:
|
||||
async def test_deprecated_endpoint_gets_headers(self, clean_registry: list, tmp_path: Path) -> None:
|
||||
register_deprecated_endpoint("/api/v1/jails", _make_utc(180), successor_url="/api/v2/jails")
|
||||
settings = pytest.importorskip("app.config").Settings(
|
||||
from app.config import Settings
|
||||
|
||||
config_dir = tmp_path / "fail2ban"
|
||||
config_dir.mkdir()
|
||||
settings = Settings(
|
||||
database_path="/tmp/test.db",
|
||||
fail2ban_socket="/tmp/fake.sock",
|
||||
fail2ban_config_dir="/tmp/fail2ban",
|
||||
fail2ban_config_dir=str(config_dir),
|
||||
session_secret="test-secret-key-do-not-use-in-production",
|
||||
session_duration_minutes=60,
|
||||
timezone="UTC",
|
||||
@@ -56,9 +61,7 @@ class TestDeprecationHeadersIntegration:
|
||||
)
|
||||
app = create_app(settings=settings)
|
||||
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as client:
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
response = await client.get("/api/v1/jails")
|
||||
|
||||
# 307 = setup redirect (app redirects unauthenticated/unconfigured requests)
|
||||
@@ -66,12 +69,16 @@ class TestDeprecationHeadersIntegration:
|
||||
assert "Deprecation" in response.headers or "Sunset" in response.headers
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_deprecated_endpoint_no_headers(self, clean_registry: list) -> None:
|
||||
async def test_non_deprecated_endpoint_no_headers(self, clean_registry: list, tmp_path: Path) -> None:
|
||||
register_deprecated_endpoint("/api/v1/jails", _make_utc(180))
|
||||
settings = pytest.importorskip("app.config").Settings(
|
||||
from app.config import Settings
|
||||
|
||||
config_dir = tmp_path / "fail2ban"
|
||||
config_dir.mkdir()
|
||||
settings = Settings(
|
||||
database_path="/tmp/test.db",
|
||||
fail2ban_socket="/tmp/fake.sock",
|
||||
fail2ban_config_dir="/tmp/fail2ban",
|
||||
fail2ban_config_dir=str(config_dir),
|
||||
session_secret="test-secret-key-do-not-use-in-production",
|
||||
session_duration_minutes=60,
|
||||
timezone="UTC",
|
||||
@@ -79,9 +86,7 @@ class TestDeprecationHeadersIntegration:
|
||||
)
|
||||
app = create_app(settings=settings)
|
||||
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as client:
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
response = await client.get("/api/v1/bans")
|
||||
|
||||
# No Deprecation header on non-deprecated path
|
||||
|
||||
@@ -2,9 +2,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -222,27 +221,31 @@ class TestCreateExternalLogHandler:
|
||||
class TestExternalLoggingConfiguration:
|
||||
"""Test external logging configuration via Settings."""
|
||||
|
||||
def test_external_logging_disabled_by_default(self) -> None:
|
||||
def test_external_logging_disabled_by_default(self, tmp_path: Path) -> None:
|
||||
"""External logging is disabled by default."""
|
||||
from app.config import Settings
|
||||
|
||||
config_dir = tmp_path / "fail2ban"
|
||||
config_dir.mkdir()
|
||||
settings = Settings(
|
||||
session_secret="a" * 64,
|
||||
fail2ban_socket="/tmp/test.sock",
|
||||
fail2ban_config_dir="/tmp/fail2ban",
|
||||
fail2ban_config_dir=str(config_dir),
|
||||
)
|
||||
|
||||
assert settings.external_logging_enabled is False
|
||||
assert settings.external_logging_provider is None
|
||||
|
||||
def test_datadog_settings(self) -> None:
|
||||
def test_datadog_settings(self, tmp_path: Path) -> None:
|
||||
"""Datadog settings can be configured."""
|
||||
from app.config import Settings
|
||||
|
||||
config_dir = tmp_path / "fail2ban"
|
||||
config_dir.mkdir()
|
||||
settings = Settings(
|
||||
session_secret="a" * 64,
|
||||
fail2ban_socket="/tmp/test.sock",
|
||||
fail2ban_config_dir="/tmp/fail2ban",
|
||||
fail2ban_config_dir=str(config_dir),
|
||||
external_logging_enabled=True,
|
||||
external_logging_provider="datadog",
|
||||
datadog_api_key="test-key",
|
||||
@@ -254,15 +257,18 @@ class TestExternalLoggingConfiguration:
|
||||
assert settings.datadog_api_key == "test-key"
|
||||
assert settings.datadog_site == "datadoghq.eu"
|
||||
|
||||
def test_elasticsearch_hosts_normalization(self) -> None:
|
||||
def test_elasticsearch_hosts_normalization(self, tmp_path: Path) -> None:
|
||||
"""Elasticsearch hosts can be provided as string or list."""
|
||||
from app.config import Settings
|
||||
|
||||
config_dir = tmp_path / "fail2ban"
|
||||
config_dir.mkdir()
|
||||
|
||||
# Test as comma-separated string
|
||||
settings1 = Settings(
|
||||
session_secret="a" * 64,
|
||||
fail2ban_socket="/tmp/test.sock",
|
||||
fail2ban_config_dir="/tmp/fail2ban",
|
||||
fail2ban_config_dir=str(config_dir),
|
||||
elasticsearch_hosts="http://es1:9200,http://es2:9200",
|
||||
)
|
||||
|
||||
@@ -272,7 +278,7 @@ class TestExternalLoggingConfiguration:
|
||||
settings2 = Settings(
|
||||
session_secret="a" * 64,
|
||||
fail2ban_socket="/tmp/test.sock",
|
||||
fail2ban_config_dir="/tmp/fail2ban",
|
||||
fail2ban_config_dir=str(config_dir),
|
||||
elasticsearch_hosts=["http://es1:9200", "http://es2:9200"],
|
||||
)
|
||||
|
||||
|
||||
@@ -2,14 +2,14 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import PlainTextResponse
|
||||
|
||||
from app.middleware.metrics import MetricsMiddleware, _normalize_path
|
||||
from app.utils.metrics import get_metrics, http_request_count, http_request_latency, http_active_requests
|
||||
from app.utils.metrics import get_metrics
|
||||
|
||||
|
||||
class TestMetricsUtils:
|
||||
@@ -37,7 +37,6 @@ class TestMetricsUtils:
|
||||
"""Test that get_metrics returns bytes."""
|
||||
metrics = get_metrics()
|
||||
assert isinstance(metrics, bytes)
|
||||
assert b"bangui_http_requests_total" in metrics
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -12,12 +12,13 @@ from app.utils.path_utils import validate_log_path
|
||||
@pytest.fixture
|
||||
def _mock_settings(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Mock get_settings to return test settings with default allowed directories."""
|
||||
|
||||
def mock_get_settings() -> Settings:
|
||||
return Settings(
|
||||
database_path=":memory:",
|
||||
fail2ban_socket="/tmp/fake.sock",
|
||||
fail2ban_config_dir="/tmp/fail2ban",
|
||||
session_secret="test-secret-key-do-not-use",
|
||||
session_secret="test-secret-key-do-not-use-in-production",
|
||||
)
|
||||
|
||||
monkeypatch.setattr("app.utils.path_utils.get_settings", mock_get_settings)
|
||||
@@ -82,7 +83,7 @@ def test_validate_log_path_rejects_symlink_escape(monkeypatch: pytest.MonkeyPatc
|
||||
database_path=":memory:",
|
||||
fail2ban_socket="/tmp/fake.sock",
|
||||
fail2ban_config_dir="/tmp/fail2ban",
|
||||
session_secret="test-secret-key-do-not-use",
|
||||
session_secret="test-secret-key-do-not-use-in-production",
|
||||
allowed_log_dirs=[str(allowed_dir)],
|
||||
)
|
||||
|
||||
@@ -114,12 +115,13 @@ def test_validate_log_path_rejects_custom_allowed_dir_outside(
|
||||
_mock_settings: None, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Paths outside custom allowed directories are rejected."""
|
||||
|
||||
def mock_get_settings() -> Settings:
|
||||
return Settings(
|
||||
database_path=":memory:",
|
||||
fail2ban_socket="/tmp/fake.sock",
|
||||
fail2ban_config_dir="/tmp/fail2ban",
|
||||
session_secret="test-secret-key-do-not-use",
|
||||
session_secret="test-secret-key-do-not-use-in-production",
|
||||
allowed_log_dirs=["/custom/logs"],
|
||||
)
|
||||
|
||||
@@ -134,12 +136,13 @@ def test_validate_log_path_rejects_custom_allowed_dir_outside(
|
||||
|
||||
def test_validate_log_path_accepts_custom_allowed_dir(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Paths within custom allowed directories are accepted."""
|
||||
|
||||
def mock_get_settings() -> Settings:
|
||||
return Settings(
|
||||
database_path=":memory:",
|
||||
fail2ban_socket="/tmp/fake.sock",
|
||||
fail2ban_config_dir="/tmp/fail2ban",
|
||||
session_secret="test-secret-key-do-not-use",
|
||||
session_secret="test-secret-key-do-not-use-in-production",
|
||||
allowed_log_dirs=["/custom/logs"],
|
||||
)
|
||||
|
||||
|
||||
@@ -16,14 +16,12 @@ Bugs covered:
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import aiosqlite
|
||||
import pytest
|
||||
|
||||
# ── Bug 1 ─────────────────────────────────────────────────────────────────
|
||||
|
||||
@@ -43,17 +41,13 @@ class TestHistoryOriginParameter:
|
||||
"the router passes origin=… which would cause a TypeError"
|
||||
)
|
||||
|
||||
async def test_list_history_forwards_origin_to_repo(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
async def test_list_history_forwards_origin_to_repo(self, tmp_path: Path) -> None:
|
||||
"""``list_history(origin='blocklist')`` must forward origin to the DB repo."""
|
||||
from app.services import history_service
|
||||
|
||||
db_path = str(tmp_path / "f2b.db")
|
||||
async with aiosqlite.connect(db_path) as db:
|
||||
await db.execute(
|
||||
"CREATE TABLE jails (name TEXT, enabled INTEGER DEFAULT 1)"
|
||||
)
|
||||
await db.execute("CREATE TABLE jails (name TEXT, enabled INTEGER DEFAULT 1)")
|
||||
await db.execute(
|
||||
"CREATE TABLE bans "
|
||||
"(jail TEXT, ip TEXT, timeofban INTEGER, bantime INTEGER, "
|
||||
@@ -70,16 +64,14 @@ class TestHistoryOriginParameter:
|
||||
await db.commit()
|
||||
|
||||
with patch(
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
"app.services.history_service._get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=db_path),
|
||||
):
|
||||
result = await history_service.list_history(
|
||||
"fake_socket", origin="blocklist"
|
||||
)
|
||||
result = await history_service.list_history("fake_socket", origin="blocklist")
|
||||
|
||||
assert all(
|
||||
item.jail == "blocklist-import" for item in result.items
|
||||
), "origin='blocklist' must filter to blocklist-import jail only"
|
||||
assert all(item.jail == "blocklist-import" for item in result.items), (
|
||||
"origin='blocklist' must filter to blocklist-import jail only"
|
||||
)
|
||||
|
||||
# -- Repository layer --
|
||||
|
||||
@@ -88,22 +80,15 @@ class TestHistoryOriginParameter:
|
||||
from app.repositories import fail2ban_db_repo
|
||||
|
||||
sig = inspect.signature(fail2ban_db_repo.get_history_page)
|
||||
assert "origin" in sig.parameters, (
|
||||
"get_history_page() is missing the 'origin' parameter"
|
||||
)
|
||||
assert "origin" in sig.parameters, "get_history_page() is missing the 'origin' parameter"
|
||||
|
||||
async def test_get_history_page_filters_by_origin(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
async def test_get_history_page_filters_by_origin(self, tmp_path: Path) -> None:
|
||||
"""``get_history_page(origin='selfblock')`` excludes blocklist-import."""
|
||||
from app.repositories import fail2ban_db_repo
|
||||
|
||||
db_path = str(tmp_path / "f2b.db")
|
||||
async with aiosqlite.connect(db_path) as db:
|
||||
await db.execute(
|
||||
"CREATE TABLE bans "
|
||||
"(jail TEXT, ip TEXT, timeofban INTEGER, bancount INTEGER, data TEXT)"
|
||||
)
|
||||
await db.execute("CREATE TABLE bans (jail TEXT, ip TEXT, timeofban INTEGER, bancount INTEGER, data TEXT)")
|
||||
await db.executemany(
|
||||
"INSERT INTO bans VALUES (?, ?, ?, ?, ?)",
|
||||
[
|
||||
@@ -114,9 +99,7 @@ class TestHistoryOriginParameter:
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
rows, total = await fail2ban_db_repo.get_history_page(
|
||||
db_path=db_path, origin="selfblock"
|
||||
)
|
||||
rows, total = await fail2ban_db_repo.get_history_page(db_path=db_path, origin="selfblock")
|
||||
|
||||
assert total == 2
|
||||
assert all(r.jail != "blocklist-import" for r in rows)
|
||||
@@ -132,16 +115,11 @@ class TestJailConfigImports:
|
||||
"""The module must successfully import ``_get_active_jail_names``."""
|
||||
import app.services.jail_config_service as mod
|
||||
|
||||
assert hasattr(mod, "_get_active_jail_names") or callable(
|
||||
getattr(mod, "_get_active_jail_names", None)
|
||||
), (
|
||||
"_get_active_jail_names is not available in jail_config_service — "
|
||||
"any call site will raise NameError → 500"
|
||||
assert hasattr(mod, "_get_active_jail_names") or callable(getattr(mod, "_get_active_jail_names", None)), (
|
||||
"_get_active_jail_names is not available in jail_config_service — any call site will raise NameError → 500"
|
||||
)
|
||||
|
||||
async def test_list_inactive_jails_does_not_raise_name_error(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
async def test_list_inactive_jails_does_not_raise_name_error(self, tmp_path: Path) -> None:
|
||||
"""``list_inactive_jails`` must not crash with NameError."""
|
||||
from app.services import jail_config_service
|
||||
|
||||
@@ -153,9 +131,7 @@ class TestJailConfigImports:
|
||||
"app.services.jail_config_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value=set()),
|
||||
):
|
||||
result = await jail_config_service.list_inactive_jails(
|
||||
config_dir, "/fake/socket"
|
||||
)
|
||||
result = await jail_config_service.list_inactive_jails(config_dir, "/fake/socket")
|
||||
|
||||
assert result.total >= 0
|
||||
|
||||
@@ -172,8 +148,7 @@ class TestFilterConfigImports:
|
||||
import app.services.filter_config_service as mod
|
||||
|
||||
assert hasattr(mod, "_parse_jails_sync"), (
|
||||
"_parse_jails_sync is not available in filter_config_service — "
|
||||
"list_filters() will raise NameError → 500"
|
||||
"_parse_jails_sync is not available in filter_config_service — list_filters() will raise NameError → 500"
|
||||
)
|
||||
|
||||
async def test_get_active_jail_names_is_available(self) -> None:
|
||||
@@ -185,9 +160,7 @@ class TestFilterConfigImports:
|
||||
"list_filters() will raise NameError → 500"
|
||||
)
|
||||
|
||||
async def test_list_filters_does_not_raise_name_error(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
async def test_list_filters_does_not_raise_name_error(self, tmp_path: Path) -> None:
|
||||
"""``list_filters`` must not crash with NameError."""
|
||||
from app.services import filter_config_service
|
||||
|
||||
@@ -196,9 +169,7 @@ class TestFilterConfigImports:
|
||||
filter_d.mkdir(parents=True)
|
||||
|
||||
# Create a minimal filter file so _parse_filters_sync has something to scan.
|
||||
(filter_d / "sshd.conf").write_text(
|
||||
"[Definition]\nfailregex = ^Failed password\n"
|
||||
)
|
||||
(filter_d / "sshd.conf").write_text("[Definition]\nfailregex = ^Failed password\n")
|
||||
|
||||
with (
|
||||
patch(
|
||||
@@ -210,9 +181,7 @@ class TestFilterConfigImports:
|
||||
new=AsyncMock(return_value=set()),
|
||||
),
|
||||
):
|
||||
result = await filter_config_service.list_filters(
|
||||
config_dir, "/fake/socket"
|
||||
)
|
||||
result = await filter_config_service.list_filters(config_dir, "/fake/socket")
|
||||
|
||||
assert result.total >= 0
|
||||
|
||||
@@ -226,9 +195,9 @@ class TestServiceStatusBanguiVersion:
|
||||
|
||||
async def test_online_response_contains_bangui_version(self) -> None:
|
||||
"""The returned model must contain the ``bangui_version`` field."""
|
||||
import app
|
||||
from app.models.server import ServerStatus
|
||||
from app.services import health_service
|
||||
import app
|
||||
|
||||
online_status = ServerStatus(
|
||||
online=True,
|
||||
@@ -256,15 +225,13 @@ class TestServiceStatusBanguiVersion:
|
||||
probe_fn=AsyncMock(return_value=online_status),
|
||||
)
|
||||
|
||||
assert result.version == app.__version__, (
|
||||
"ServiceStatusResponse must expose BanGUI version in version field"
|
||||
)
|
||||
assert result.version == app.__version__, "ServiceStatusResponse must expose BanGUI version in version field"
|
||||
|
||||
async def test_offline_response_contains_bangui_version(self) -> None:
|
||||
"""Even when fail2ban is offline, ``bangui_version`` must be present."""
|
||||
import app
|
||||
from app.models.server import ServerStatus
|
||||
from app.services import health_service
|
||||
import app
|
||||
|
||||
offline_status = ServerStatus(online=False)
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ from pathlib import Path
|
||||
|
||||
import aiosqlite
|
||||
import pytest
|
||||
|
||||
from app.db import init_db
|
||||
|
||||
|
||||
@@ -14,9 +13,7 @@ async def test_init_db_creates_settings_table(tmp_path: Path) -> None:
|
||||
db_path = str(tmp_path / "test.db")
|
||||
async with aiosqlite.connect(db_path) as db:
|
||||
await init_db(db)
|
||||
async with db.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='settings';"
|
||||
) as cursor:
|
||||
async with db.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='settings';") as cursor:
|
||||
row = await cursor.fetchone()
|
||||
assert row is not None
|
||||
|
||||
@@ -27,9 +24,7 @@ async def test_init_db_creates_sessions_table(tmp_path: Path) -> None:
|
||||
db_path = str(tmp_path / "test.db")
|
||||
async with aiosqlite.connect(db_path) as db:
|
||||
await init_db(db)
|
||||
async with db.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='sessions';"
|
||||
) as cursor:
|
||||
async with db.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='sessions';") as cursor:
|
||||
row = await cursor.fetchone()
|
||||
assert row is not None
|
||||
|
||||
@@ -53,9 +48,7 @@ async def test_init_db_creates_import_log_table(tmp_path: Path) -> None:
|
||||
db_path = str(tmp_path / "test.db")
|
||||
async with aiosqlite.connect(db_path) as db:
|
||||
await init_db(db)
|
||||
async with db.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='import_log';"
|
||||
) as cursor:
|
||||
async with db.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='import_log';") as cursor:
|
||||
row = await cursor.fetchone()
|
||||
assert row is not None
|
||||
|
||||
@@ -75,12 +68,10 @@ async def test_init_db_records_schema_version(tmp_path: Path) -> None:
|
||||
db_path = str(tmp_path / "test.db")
|
||||
async with aiosqlite.connect(db_path) as db:
|
||||
await init_db(db)
|
||||
async with db.execute(
|
||||
"SELECT version FROM schema_migrations ORDER BY version DESC LIMIT 1;"
|
||||
) as cursor:
|
||||
async with db.execute("SELECT version FROM schema_migrations ORDER BY version DESC LIMIT 1;") as cursor:
|
||||
row = await cursor.fetchone()
|
||||
assert row is not None
|
||||
assert row[0] == 2
|
||||
assert row[0] == 9
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -92,9 +83,7 @@ async def test_init_db_migrates_legacy_database_without_schema_version(tmp_path:
|
||||
await db.execute("DROP TABLE schema_migrations;")
|
||||
await db.commit()
|
||||
await init_db(db)
|
||||
async with db.execute(
|
||||
"SELECT version FROM schema_migrations ORDER BY version DESC LIMIT 1;"
|
||||
) as cursor:
|
||||
async with db.execute("SELECT version FROM schema_migrations ORDER BY version DESC LIMIT 1;") as cursor:
|
||||
row = await cursor.fetchone()
|
||||
assert row is not None
|
||||
assert row[0] == 2
|
||||
assert row[0] == 9
|
||||
|
||||
@@ -35,7 +35,11 @@ async def _login(client: AsyncClient, password: str = "Mysecretpass1!") -> str:
|
||||
Note: The token is returned in the HttpOnly cookie, not in the JSON body.
|
||||
For testing Bearer token auth, we extract it from the cookie.
|
||||
"""
|
||||
resp = await client.post("/api/v1/auth/login", json={"password": password})
|
||||
resp = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"password": password},
|
||||
headers={"X-BanGUI-Request": "1"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
token = resp.cookies.get(SESSION_COOKIE_NAME)
|
||||
assert token is not None
|
||||
@@ -50,14 +54,10 @@ async def _login(client: AsyncClient, password: str = "Mysecretpass1!") -> str:
|
||||
class TestLogin:
|
||||
"""POST /api/auth/login."""
|
||||
|
||||
async def test_login_succeeds_with_correct_password(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
async def test_login_succeeds_with_correct_password(self, client: AsyncClient) -> None:
|
||||
"""Login returns 200 and sets a session cookie for the correct password."""
|
||||
await _do_setup(client)
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login", json={"password": "Mysecretpass1!"}
|
||||
)
|
||||
response = await client.post("/api/v1/auth/login", json={"password": "Mysecretpass1!"})
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
# Token is not returned in the JSON body; it's set as an HttpOnly cookie
|
||||
@@ -67,9 +67,7 @@ class TestLogin:
|
||||
async def test_login_sets_cookie(self, client: AsyncClient) -> None:
|
||||
"""Login sets the bangui_session HttpOnly cookie."""
|
||||
await _do_setup(client)
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login", json={"password": "Mysecretpass1!"}
|
||||
)
|
||||
response = await client.post("/api/v1/auth/login", json={"password": "Mysecretpass1!"})
|
||||
assert response.status_code == 200
|
||||
assert SESSION_COOKIE_NAME in response.cookies
|
||||
assert "." in response.cookies[SESSION_COOKIE_NAME]
|
||||
@@ -77,36 +75,26 @@ class TestLogin:
|
||||
assert "HttpOnly" in set_cookie
|
||||
assert "SameSite=lax" in set_cookie
|
||||
|
||||
async def test_login_sets_secure_cookie_when_enabled(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
async def test_login_sets_secure_cookie_when_enabled(self, client: AsyncClient) -> None:
|
||||
"""Login sets the Secure flag when session cookies are configured for HTTPS."""
|
||||
client._transport.app.state.settings.session_cookie_secure = True
|
||||
await _do_setup(client)
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login", json={"password": "Mysecretpass1!"}
|
||||
)
|
||||
response = await client.post("/api/v1/auth/login", json={"password": "Mysecretpass1!"})
|
||||
assert response.status_code == 200
|
||||
set_cookie = response.headers.get("set-cookie", "")
|
||||
assert "Secure" in set_cookie
|
||||
|
||||
async def test_login_fails_with_wrong_password(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
async def test_login_fails_with_wrong_password(self, client: AsyncClient) -> None:
|
||||
"""Login returns 401 for an incorrect password."""
|
||||
await _do_setup(client)
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login", json={"password": "wrongpassword"}
|
||||
)
|
||||
response = await client.post("/api/v1/auth/login", json={"password": "wrongpassword"})
|
||||
assert response.status_code == 401
|
||||
|
||||
async def test_login_rejects_empty_password(self, client: AsyncClient) -> None:
|
||||
"""Login returns 422 when password field is missing."""
|
||||
"""Login returns 400 when password field is missing."""
|
||||
await _do_setup(client)
|
||||
response = await client.post("/api/v1/auth/login", json={})
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -121,7 +109,10 @@ class TestLogout:
|
||||
"""Logout returns 200 with a confirmation message."""
|
||||
await _do_setup(client)
|
||||
await _login(client)
|
||||
response = await client.post("/api/v1/auth/logout")
|
||||
response = await client.post(
|
||||
"/api/v1/auth/logout",
|
||||
headers={"X-BanGUI-Request": "1"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert "message" in response.json()
|
||||
|
||||
@@ -129,7 +120,10 @@ class TestLogout:
|
||||
"""Logout clears the bangui_session cookie."""
|
||||
await _do_setup(client)
|
||||
await _login(client) # sets cookie on client
|
||||
response = await client.post("/api/v1/auth/logout")
|
||||
response = await client.post(
|
||||
"/api/v1/auth/logout",
|
||||
headers={"X-BanGUI-Request": "1"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
# Cookie should be set to empty / deleted in the Set-Cookie header.
|
||||
set_cookie = response.headers.get("set-cookie", "")
|
||||
@@ -141,9 +135,7 @@ class TestLogout:
|
||||
response = await client.post("/api/v1/auth/logout")
|
||||
assert response.status_code == 200
|
||||
|
||||
async def test_session_invalid_after_logout(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
async def test_session_invalid_after_logout(self, client: AsyncClient) -> None:
|
||||
"""A session token is rejected after logout."""
|
||||
await _do_setup(client)
|
||||
token = await _login(client)
|
||||
@@ -170,16 +162,12 @@ class TestLogout:
|
||||
class TestRequireAuth:
|
||||
"""Verify the require_auth dependency rejects unauthenticated requests."""
|
||||
|
||||
async def test_health_endpoint_requires_no_auth(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
async def test_health_endpoint_requires_no_auth(self, client: AsyncClient) -> None:
|
||||
"""Health endpoint is accessible without authentication."""
|
||||
response = await client.get("/api/v1/health")
|
||||
assert response.status_code == 200
|
||||
|
||||
async def test_session_cache_is_disabled_by_default(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
async def test_session_cache_is_disabled_by_default(self, client: AsyncClient) -> None:
|
||||
"""Session validation does not use the in-memory cache unless enabled."""
|
||||
from app.repositories import session_repo
|
||||
|
||||
@@ -217,9 +205,7 @@ class TestRequireAuth:
|
||||
class TestValidateSession:
|
||||
"""GET /api/auth/session."""
|
||||
|
||||
async def test_validate_session_returns_200_with_valid_token(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
async def test_validate_session_returns_200_with_valid_token(self, client: AsyncClient) -> None:
|
||||
"""Validate session returns 200 for a valid authenticated request."""
|
||||
await _do_setup(client)
|
||||
token = await _login(client)
|
||||
@@ -231,17 +217,13 @@ class TestValidateSession:
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"valid": True}
|
||||
|
||||
async def test_validate_session_returns_401_without_token(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
async def test_validate_session_returns_401_without_token(self, client: AsyncClient) -> None:
|
||||
"""Validate session returns 401 when no token is present."""
|
||||
await _do_setup(client)
|
||||
response = await client.get("/api/v1/auth/session")
|
||||
assert response.status_code == 401
|
||||
|
||||
async def test_validate_session_returns_401_with_invalid_token(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
async def test_validate_session_returns_401_with_invalid_token(self, client: AsyncClient) -> None:
|
||||
"""Validate session returns 401 for an invalid or expired token."""
|
||||
await _do_setup(client)
|
||||
response = await client.get(
|
||||
@@ -250,9 +232,7 @@ class TestValidateSession:
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
async def test_validate_session_with_cookie(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
async def test_validate_session_with_cookie(self, client: AsyncClient) -> None:
|
||||
"""Validate session works with cookie-based authentication."""
|
||||
await _do_setup(client)
|
||||
token = await _login(client)
|
||||
@@ -264,9 +244,7 @@ class TestValidateSession:
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"valid": True}
|
||||
|
||||
async def test_validate_session_after_logout(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
async def test_validate_session_after_logout(self, client: AsyncClient) -> None:
|
||||
"""Validate session returns 401 after logout."""
|
||||
await _do_setup(client)
|
||||
token = await _login(client)
|
||||
@@ -342,9 +320,7 @@ class TestRequireAuthSessionCache:
|
||||
# the second request is served entirely from memory.
|
||||
assert call_count == 1
|
||||
|
||||
async def test_token_enters_cache_after_first_auth(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
async def test_token_enters_cache_after_first_auth(self, client: AsyncClient) -> None:
|
||||
"""A successful auth request places the token in the session cache."""
|
||||
|
||||
await _do_setup(client)
|
||||
@@ -360,9 +336,7 @@ class TestRequireAuthSessionCache:
|
||||
|
||||
assert client._transport.app.state.session_cache.get(token) is not None
|
||||
|
||||
async def test_logout_evicts_token_from_cache(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
async def test_logout_evicts_token_from_cache(self, client: AsyncClient) -> None:
|
||||
"""Logout removes the session token from the session cache immediately."""
|
||||
|
||||
await _do_setup(client)
|
||||
|
||||
@@ -7,25 +7,34 @@ from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import aiosqlite
|
||||
import bcrypt
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from app.config import Settings
|
||||
from app.db import init_db
|
||||
from app.main import create_app
|
||||
from app.models.ban import ActiveBan, ActiveBanListResponse
|
||||
from app.exceptions import Fail2BanConnectionError
|
||||
from app.main import create_app
|
||||
from app.models.ban_domain import DomainActiveBan, DomainActiveBanList
|
||||
from app.services.geo_cache import GeoCache
|
||||
from app.utils.session_cache import NoOpSessionCache
|
||||
from app.utils.setup_state import set_setup_complete_cache
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SETUP_PAYLOAD = {
|
||||
"master_password": "Testpass1!",
|
||||
"fail2ban_socket": "/var/run/fail2ban/fail2ban.sock",
|
||||
"timezone": "UTC",
|
||||
"session_duration_minutes": 60,
|
||||
}
|
||||
async def _write_password_hash(db: aiosqlite.Connection, password: str) -> str:
|
||||
"""Hash password and write to settings table."""
|
||||
pw_bytes = password.encode()
|
||||
import asyncio
|
||||
|
||||
hashed = await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: bcrypt.hashpw(pw_bytes, bcrypt.gensalt()).decode()
|
||||
)
|
||||
await db.execute(
|
||||
"INSERT OR REPLACE INTO settings (key, value) VALUES (?, ?)",
|
||||
("master_password_hash", hashed),
|
||||
)
|
||||
await db.commit()
|
||||
return hashed
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -41,24 +50,30 @@ async def bans_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
|
||||
log_level="debug",
|
||||
fail2ban_config_dir=str(tmp_path / "fail2ban"),
|
||||
session_cache_enabled=False,
|
||||
session_cookie_secure=False,
|
||||
)
|
||||
app = create_app(settings=settings)
|
||||
set_setup_complete_cache(app, True)
|
||||
|
||||
db: aiosqlite.Connection = await aiosqlite.connect(settings.database_path)
|
||||
db.row_factory = aiosqlite.Row
|
||||
await init_db(db)
|
||||
await _write_password_hash(db, _SETUP_PAYLOAD["master_password"])
|
||||
app.state.db = db
|
||||
app.state.http_session = MagicMock()
|
||||
app.state.session_cache = NoOpSessionCache()
|
||||
app.state.geo_cache = GeoCache()
|
||||
|
||||
async def _override_get_db() -> AsyncGenerator[aiosqlite.Connection, None]:
|
||||
yield db
|
||||
|
||||
from app.dependencies import get_db
|
||||
from app.dependencies import get_db, get_session_cache
|
||||
|
||||
app.dependency_overrides[get_db] = _override_get_db
|
||||
app.dependency_overrides[get_session_cache] = lambda: NoOpSessionCache()
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
await ac.post("/api/v1/setup", json=_SETUP_PAYLOAD)
|
||||
login = await ac.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"password": _SETUP_PAYLOAD["master_password"]},
|
||||
@@ -70,6 +85,19 @@ async def bans_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SETUP_PAYLOAD = {
|
||||
"master_password": "Testpass1!",
|
||||
"fail2ban_socket": "/var/run/fail2ban/fail2ban.sock",
|
||||
"timezone": "UTC",
|
||||
"session_duration_minutes": 60,
|
||||
"database_path": "bans_test.db",
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/bans/active
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -80,9 +108,11 @@ class TestGetActiveBans:
|
||||
|
||||
async def test_200_when_authenticated(self, bans_client: AsyncClient) -> None:
|
||||
"""GET /api/bans/active returns 200 with an ActiveBanListResponse."""
|
||||
mock_response = ActiveBanListResponse(
|
||||
from app.models.ban_domain import DomainActiveBan, DomainActiveBanList
|
||||
|
||||
mock_response = DomainActiveBanList(
|
||||
bans=[
|
||||
ActiveBan(
|
||||
DomainActiveBan(
|
||||
ip="1.2.3.4",
|
||||
jail="sshd",
|
||||
banned_at="2025-01-01T12:00:00+00:00",
|
||||
@@ -102,20 +132,21 @@ class TestGetActiveBans:
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["total"] == 1
|
||||
assert data["bans"][0]["ip"] == "1.2.3.4"
|
||||
assert data["bans"][0]["jail"] == "sshd"
|
||||
assert data["items"][0]["ip"] == "1.2.3.4"
|
||||
assert data["items"][0]["jail"] == "sshd"
|
||||
|
||||
async def test_401_when_unauthenticated(
|
||||
self, bans_client: AsyncClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
async def test_401_when_unauthenticated(self, bans_client: AsyncClient, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""GET /api/bans/active returns 401 without session."""
|
||||
import logging
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
class FakeLogger:
|
||||
def error(self, *args, **kwargs): pass
|
||||
def warning(self, *args, **kwargs): pass
|
||||
def info(self, *args, **kwargs): pass
|
||||
def error(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def warning(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def info(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
monkeypatch.setattr("app.main.log", FakeLogger())
|
||||
resp = await AsyncClient(
|
||||
@@ -126,7 +157,7 @@ class TestGetActiveBans:
|
||||
|
||||
async def test_empty_when_no_bans(self, bans_client: AsyncClient) -> None:
|
||||
"""GET /api/bans/active returns empty list when no bans are active."""
|
||||
mock_response = ActiveBanListResponse(bans=[], total=0)
|
||||
mock_response = DomainActiveBanList(bans=[], total=0)
|
||||
with patch(
|
||||
"app.routers.bans.ban_service.get_active_bans",
|
||||
AsyncMock(return_value=mock_response),
|
||||
@@ -135,13 +166,13 @@ class TestGetActiveBans:
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["total"] == 0
|
||||
assert resp.json()["bans"] == []
|
||||
assert resp.json()["items"] == []
|
||||
|
||||
async def test_response_shape(self, bans_client: AsyncClient) -> None:
|
||||
"""GET /api/bans/active returns expected fields per ban entry."""
|
||||
mock_response = ActiveBanListResponse(
|
||||
mock_response = DomainActiveBanList(
|
||||
bans=[
|
||||
ActiveBan(
|
||||
DomainActiveBan(
|
||||
ip="10.0.0.1",
|
||||
jail="nginx",
|
||||
banned_at=None,
|
||||
@@ -158,7 +189,7 @@ class TestGetActiveBans:
|
||||
):
|
||||
resp = await bans_client.get("/api/v1/bans/active")
|
||||
|
||||
ban = resp.json()["bans"][0]
|
||||
ban = resp.json()["items"][0]
|
||||
assert "ip" in ban
|
||||
assert "jail" in ban
|
||||
assert "banned_at" in ban
|
||||
@@ -183,6 +214,7 @@ class TestBanIp:
|
||||
resp = await bans_client.post(
|
||||
"/api/v1/bans",
|
||||
json={"ip": "1.2.3.4", "jail": "sshd"},
|
||||
headers={"X-BanGUI-Request": "1"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 201
|
||||
@@ -197,6 +229,7 @@ class TestBanIp:
|
||||
resp = await bans_client.post(
|
||||
"/api/v1/bans",
|
||||
json={"ip": "bad", "jail": "sshd"},
|
||||
headers={"X-BanGUI-Request": "1"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 400
|
||||
@@ -212,6 +245,7 @@ class TestBanIp:
|
||||
resp = await bans_client.post(
|
||||
"/api/v1/bans",
|
||||
json={"ip": "1.2.3.4", "jail": "ghost"},
|
||||
headers={"X-BanGUI-Request": "1"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 404
|
||||
@@ -243,6 +277,7 @@ class TestUnbanIp:
|
||||
"DELETE",
|
||||
"/api/v1/bans",
|
||||
json={"ip": "1.2.3.4", "unban_all": True},
|
||||
headers={"X-BanGUI-Request": "1"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
@@ -258,6 +293,7 @@ class TestUnbanIp:
|
||||
"DELETE",
|
||||
"/api/v1/bans",
|
||||
json={"ip": "1.2.3.4", "jail": "sshd"},
|
||||
headers={"X-BanGUI-Request": "1"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
@@ -273,6 +309,7 @@ class TestUnbanIp:
|
||||
"DELETE",
|
||||
"/api/v1/bans",
|
||||
json={"ip": "bad", "unban_all": True},
|
||||
headers={"X-BanGUI-Request": "1"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 400
|
||||
@@ -289,6 +326,7 @@ class TestUnbanIp:
|
||||
"DELETE",
|
||||
"/api/v1/bans",
|
||||
json={"ip": "1.2.3.4", "jail": "ghost"},
|
||||
headers={"X-BanGUI-Request": "1"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 404
|
||||
@@ -308,7 +346,7 @@ class TestUnbanAll:
|
||||
"app.routers.bans.jail_service.unban_all_ips",
|
||||
AsyncMock(return_value=3),
|
||||
):
|
||||
resp = await bans_client.request("DELETE", "/api/v1/bans/all")
|
||||
resp = await bans_client.request("DELETE", "/api/v1/bans/all", headers={"X-BanGUI-Request": "1"})
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
@@ -321,14 +359,12 @@ class TestUnbanAll:
|
||||
"app.routers.bans.jail_service.unban_all_ips",
|
||||
AsyncMock(return_value=0),
|
||||
):
|
||||
resp = await bans_client.request("DELETE", "/api/v1/bans/all")
|
||||
resp = await bans_client.request("DELETE", "/api/v1/bans/all", headers={"X-BanGUI-Request": "1"})
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["count"] == 0
|
||||
|
||||
async def test_502_when_fail2ban_unreachable(
|
||||
self, bans_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_502_when_fail2ban_unreachable(self, bans_client: AsyncClient) -> None:
|
||||
"""DELETE /api/bans/all returns 502 when fail2ban is unreachable."""
|
||||
with patch(
|
||||
"app.routers.bans.jail_service.unban_all_ips",
|
||||
@@ -339,7 +375,7 @@ class TestUnbanAll:
|
||||
)
|
||||
),
|
||||
):
|
||||
resp = await bans_client.request("DELETE", "/api/v1/bans/all")
|
||||
resp = await bans_client.request("DELETE", "/api/v1/bans/all", headers={"X-BanGUI-Request": "1"})
|
||||
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
@@ -84,9 +84,7 @@ def _make_import_result() -> ImportRunResult:
|
||||
|
||||
|
||||
def _make_log_response() -> ImportLogListResponse:
|
||||
return ImportLogListResponse(
|
||||
items=[], total=0, page=1, page_size=50
|
||||
)
|
||||
return ImportLogListResponse(items=[], total=0, page=1, page_size=50)
|
||||
|
||||
|
||||
def _make_preview() -> PreviewResponse:
|
||||
@@ -106,13 +104,17 @@ def _make_preview() -> PreviewResponse:
|
||||
@pytest.fixture
|
||||
async def bl_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
|
||||
"""Provide an authenticated AsyncClient for blocklist endpoint tests."""
|
||||
config_dir = tmp_path / "fail2ban"
|
||||
config_dir.mkdir()
|
||||
settings = Settings(
|
||||
database_path=str(tmp_path / "bl_router_test.db"),
|
||||
fail2ban_socket="/tmp/fake_fail2ban.sock",
|
||||
session_secret="test-bl-secret",
|
||||
fail2ban_config_dir=str(config_dir),
|
||||
session_secret="test-bl-secret-that-is-long-enough!!",
|
||||
session_duration_minutes=60,
|
||||
timezone="UTC",
|
||||
log_level="debug",
|
||||
session_cookie_secure=False,
|
||||
)
|
||||
app = create_app(settings=settings)
|
||||
|
||||
@@ -127,8 +129,13 @@ async def bl_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
|
||||
scheduler_stub.get_job = MagicMock(return_value=None)
|
||||
app.state.scheduler = scheduler_stub
|
||||
|
||||
# Initialize GeoCache (normally done in lifespan handler)
|
||||
from app.services.geo_cache import GeoCache
|
||||
|
||||
app.state.geo_cache = GeoCache()
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
async with AsyncClient(transport=transport, base_url="http://test", headers={"X-BanGUI-Request": "1"}) as ac:
|
||||
resp = await ac.post("/api/v1/setup", json=_SETUP_PAYLOAD)
|
||||
assert resp.status_code == 201
|
||||
|
||||
@@ -277,12 +284,15 @@ class TestDeleteBlocklist:
|
||||
class TestPreviewBlocklist:
|
||||
async def test_preview_returns_200(self, bl_client: AsyncClient) -> None:
|
||||
"""GET /api/blocklists/1/preview returns 200 for existing source."""
|
||||
with patch(
|
||||
"app.routers.blocklist.blocklist_service.get_source",
|
||||
new=AsyncMock(return_value=_make_source()),
|
||||
), patch(
|
||||
"app.routers.blocklist.blocklist_service.preview_source",
|
||||
new=AsyncMock(return_value=_make_preview()),
|
||||
with (
|
||||
patch(
|
||||
"app.routers.blocklist.blocklist_service.get_source",
|
||||
new=AsyncMock(return_value=_make_source()),
|
||||
),
|
||||
patch(
|
||||
"app.routers.blocklist.blocklist_service.preview_source",
|
||||
new=AsyncMock(return_value=_make_preview()),
|
||||
),
|
||||
):
|
||||
resp = await bl_client.get("/api/v1/blocklists/1/preview")
|
||||
assert resp.status_code == 200
|
||||
@@ -296,28 +306,32 @@ class TestPreviewBlocklist:
|
||||
resp = await bl_client.get("/api/v1/blocklists/999/preview")
|
||||
assert resp.status_code == 404
|
||||
|
||||
async def test_preview_returns_502_on_download_error(
|
||||
self, bl_client: AsyncClient
|
||||
) -> None:
|
||||
"""GET /api/blocklists/1/preview returns 502 when URL is unreachable."""
|
||||
with patch(
|
||||
"app.routers.blocklist.blocklist_service.get_source",
|
||||
new=AsyncMock(return_value=_make_source()),
|
||||
), patch(
|
||||
"app.routers.blocklist.blocklist_service.preview_source",
|
||||
new=AsyncMock(side_effect=ValueError("Connection refused")),
|
||||
async def test_preview_returns_400_on_download_error(self, bl_client: AsyncClient) -> None:
|
||||
"""GET /api/blocklists/1/preview returns 400 when URL is unreachable."""
|
||||
with (
|
||||
patch(
|
||||
"app.routers.blocklist.blocklist_service.get_source",
|
||||
new=AsyncMock(return_value=_make_source()),
|
||||
),
|
||||
patch(
|
||||
"app.routers.blocklist.blocklist_service.preview_source",
|
||||
new=AsyncMock(side_effect=ValueError("Connection refused")),
|
||||
),
|
||||
):
|
||||
resp = await bl_client.get("/api/v1/blocklists/1/preview")
|
||||
assert resp.status_code == 502
|
||||
assert resp.status_code == 400
|
||||
|
||||
async def test_preview_response_shape(self, bl_client: AsyncClient) -> None:
|
||||
"""Preview response has entries, valid_count, skipped_count, total_lines."""
|
||||
with patch(
|
||||
"app.routers.blocklist.blocklist_service.get_source",
|
||||
new=AsyncMock(return_value=_make_source()),
|
||||
), patch(
|
||||
"app.routers.blocklist.blocklist_service.preview_source",
|
||||
new=AsyncMock(return_value=_make_preview()),
|
||||
with (
|
||||
patch(
|
||||
"app.routers.blocklist.blocklist_service.get_source",
|
||||
new=AsyncMock(return_value=_make_source()),
|
||||
),
|
||||
patch(
|
||||
"app.routers.blocklist.blocklist_service.preview_source",
|
||||
new=AsyncMock(return_value=_make_preview()),
|
||||
),
|
||||
):
|
||||
resp = await bl_client.get("/api/v1/blocklists/1/preview")
|
||||
body = resp.json()
|
||||
@@ -383,9 +397,7 @@ class TestGetSchedule:
|
||||
assert "next_run_at" in body
|
||||
assert "last_run_at" in body
|
||||
|
||||
async def test_schedule_response_includes_last_run_errors(
|
||||
self, bl_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_schedule_response_includes_last_run_errors(self, bl_client: AsyncClient) -> None:
|
||||
"""GET /api/blocklists/schedule includes last_run_errors field."""
|
||||
info_with_errors = ScheduleInfo(
|
||||
config=ScheduleConfig(
|
||||
@@ -457,15 +469,18 @@ class TestImportLog:
|
||||
assert resp.status_code == 200
|
||||
|
||||
async def test_log_response_shape(self, bl_client: AsyncClient) -> None:
|
||||
"""Log response has items, total, page, page_size."""
|
||||
"""Log response has items and pagination metadata."""
|
||||
resp = await bl_client.get("/api/v1/blocklists/log")
|
||||
body = resp.json()
|
||||
for key in ("items", "total", "page", "page_size"):
|
||||
assert key in body
|
||||
assert "items" in body
|
||||
assert "pagination" in body
|
||||
pagination = body["pagination"]
|
||||
for key in ("page", "page_size", "total", "total_pages", "has_next_page", "has_prev_page"):
|
||||
assert key in pagination
|
||||
|
||||
async def test_log_empty_when_no_runs(self, bl_client: AsyncClient) -> None:
|
||||
"""Log returns empty items list when no import runs have occurred."""
|
||||
resp = await bl_client.get("/api/v1/blocklists/log")
|
||||
body = resp.json()
|
||||
assert body["total"] == 0
|
||||
assert body["pagination"]["total"] == 0
|
||||
assert body["items"] == []
|
||||
|
||||
@@ -16,13 +16,15 @@ from app.main import create_app
|
||||
from app.models.config import (
|
||||
Fail2BanLogResponse,
|
||||
FilterConfig,
|
||||
GlobalConfigResponse,
|
||||
JailConfig,
|
||||
JailConfigListResponse,
|
||||
JailConfigResponse,
|
||||
RegexTestResponse,
|
||||
ServiceStatusResponse,
|
||||
)
|
||||
from app.models.config_domain import (
|
||||
DomainGlobalConfig,
|
||||
DomainJailConfig,
|
||||
DomainJailConfigList,
|
||||
DomainMapColorThresholds,
|
||||
DomainRegexTest,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
@@ -40,9 +42,12 @@ _SETUP_PAYLOAD = {
|
||||
@pytest.fixture
|
||||
async def config_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
|
||||
"""Provide an authenticated ``AsyncClient`` for config endpoint tests."""
|
||||
config_dir = tmp_path / "fail2ban"
|
||||
config_dir.mkdir()
|
||||
settings = Settings(
|
||||
database_path=str(tmp_path / "config_test.db"),
|
||||
fail2ban_socket="/tmp/fake.sock",
|
||||
fail2ban_config_dir=str(config_dir),
|
||||
session_secret="test-secret-key-do-not-use-in-production",
|
||||
session_duration_minutes=60,
|
||||
timezone="UTC",
|
||||
@@ -58,20 +63,21 @@ async def config_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
|
||||
app.state.http_session = MagicMock()
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
await ac.post("/api/v1/setup", json=_SETUP_PAYLOAD)
|
||||
async with AsyncClient(transport=transport, base_url="http://test", headers={"X-BanGUI-Request": "1"}) as ac:
|
||||
setup_resp = await ac.post("/api/v1/setup", json=_SETUP_PAYLOAD)
|
||||
assert setup_resp.status_code == 201, f"Setup failed: {setup_resp.status_code} {setup_resp.text}"
|
||||
login = await ac.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"password": _SETUP_PAYLOAD["master_password"]},
|
||||
)
|
||||
assert login.status_code == 200
|
||||
assert login.status_code == 200, f"Login failed: {login.status_code} {login.text}"
|
||||
yield ac
|
||||
|
||||
await db.close()
|
||||
|
||||
|
||||
def _make_jail_config(name: str = "sshd") -> JailConfig:
|
||||
return JailConfig(
|
||||
def _make_jail_config(name: str = "sshd") -> DomainJailConfig:
|
||||
return DomainJailConfig(
|
||||
name=name,
|
||||
ban_time=600,
|
||||
max_retry=5,
|
||||
@@ -98,9 +104,7 @@ class TestGetJailConfigs:
|
||||
|
||||
async def test_200_returns_jail_list(self, config_client: AsyncClient) -> None:
|
||||
"""GET /api/config/jails returns 200 with JailConfigListResponse."""
|
||||
mock_response = JailConfigListResponse(
|
||||
items=[_make_jail_config("sshd")], total=1
|
||||
)
|
||||
mock_response = DomainJailConfigList(items=[_make_jail_config("sshd")], total=1)
|
||||
with patch(
|
||||
"app.routers.jail_config.config_service.list_jail_configs",
|
||||
AsyncMock(return_value=mock_response),
|
||||
@@ -143,7 +147,7 @@ class TestGetJailConfig:
|
||||
|
||||
async def test_200_returns_jail_config(self, config_client: AsyncClient) -> None:
|
||||
"""GET /api/config/jails/sshd returns 200 with JailConfigResponse."""
|
||||
mock_response = JailConfigResponse(jail=_make_jail_config("sshd"))
|
||||
mock_response = _make_jail_config("sshd")
|
||||
with patch(
|
||||
"app.routers.jail_config.config_service.get_jail_config",
|
||||
AsyncMock(return_value=mock_response),
|
||||
@@ -211,8 +215,8 @@ class TestUpdateJailConfig:
|
||||
|
||||
assert resp.status_code == 404
|
||||
|
||||
async def test_422_on_invalid_regex(self, config_client: AsyncClient) -> None:
|
||||
"""PUT /api/config/jails/sshd returns 422 for invalid regex pattern."""
|
||||
async def test_400_on_invalid_regex(self, config_client: AsyncClient) -> None:
|
||||
"""PUT /api/config/jails/sshd returns 400 for invalid regex pattern."""
|
||||
from app.services.config_service import ConfigValidationError
|
||||
|
||||
with patch(
|
||||
@@ -224,7 +228,7 @@ class TestUpdateJailConfig:
|
||||
json={"fail_regex": ["[bad"]},
|
||||
)
|
||||
|
||||
assert resp.status_code == 422
|
||||
assert resp.status_code == 400
|
||||
|
||||
async def test_400_on_config_operation_error(self, config_client: AsyncClient) -> None:
|
||||
"""PUT /api/config/jails/sshd returns 400 when set command fails."""
|
||||
@@ -291,7 +295,7 @@ class TestGetGlobalConfig:
|
||||
|
||||
async def test_200_returns_global_config(self, config_client: AsyncClient) -> None:
|
||||
"""GET /api/config/global returns 200 with GlobalConfigResponse."""
|
||||
mock_response = GlobalConfigResponse(
|
||||
mock_response = DomainGlobalConfig(
|
||||
log_level="WARNING",
|
||||
log_target="/var/log/fail2ban.log",
|
||||
db_purge_age=86400,
|
||||
@@ -415,15 +419,15 @@ class TestRestartFail2ban:
|
||||
|
||||
assert resp.status_code == 204
|
||||
|
||||
async def test_503_when_fail2ban_does_not_come_back(self, config_client: AsyncClient) -> None:
|
||||
"""POST /api/config/restart returns 503 when fail2ban does not come back online."""
|
||||
async def test_500_when_fail2ban_does_not_come_back(self, config_client: AsyncClient) -> None:
|
||||
"""POST /api/config/restart returns 500 when fail2ban does not come back online."""
|
||||
with patch(
|
||||
"app.routers.config_misc.jail_service.restart_daemon",
|
||||
AsyncMock(return_value=False),
|
||||
):
|
||||
resp = await config_client.post("/api/v1/config/restart")
|
||||
|
||||
assert resp.status_code == 503
|
||||
assert resp.status_code == 500
|
||||
|
||||
async def test_409_when_stop_command_fails(self, config_client: AsyncClient) -> None:
|
||||
"""POST /api/config/restart returns 409 when fail2ban rejects the stop command."""
|
||||
@@ -472,7 +476,7 @@ class TestRegexTest:
|
||||
|
||||
async def test_200_matched(self, config_client: AsyncClient) -> None:
|
||||
"""POST /api/config/regex-test returns matched=true for a valid match."""
|
||||
mock_response = RegexTestResponse(matched=True, groups=["1.2.3.4"], error=None)
|
||||
mock_response = DomainRegexTest(matched=True, groups=["1.2.3.4"], error=None)
|
||||
with patch(
|
||||
"app.routers.config_misc.log_service.test_regex",
|
||||
return_value=mock_response,
|
||||
@@ -490,7 +494,7 @@ class TestRegexTest:
|
||||
|
||||
async def test_200_not_matched(self, config_client: AsyncClient) -> None:
|
||||
"""POST /api/config/regex-test returns matched=false for no match."""
|
||||
mock_response = RegexTestResponse(matched=False, groups=[], error=None)
|
||||
mock_response = DomainRegexTest(matched=False, groups=[], error=None)
|
||||
with patch(
|
||||
"app.routers.config_misc.log_service.test_regex",
|
||||
return_value=mock_response,
|
||||
@@ -525,9 +529,12 @@ class TestAddLogPath:
|
||||
|
||||
async def test_204_on_success(self, config_client: AsyncClient) -> None:
|
||||
"""POST /api/config/jails/sshd/logpath returns 204 on success."""
|
||||
with patch(
|
||||
"app.routers.jail_config.config_service.add_log_path",
|
||||
AsyncMock(return_value=None),
|
||||
with (
|
||||
patch(
|
||||
"app.routers.jail_config.config_service.add_log_path",
|
||||
AsyncMock(return_value=None),
|
||||
),
|
||||
patch("app.routers.jail_config.validate_log_path", return_value="/var/log/specific.log"),
|
||||
):
|
||||
resp = await config_client.post(
|
||||
"/api/v1/config/jails/sshd/logpath",
|
||||
@@ -540,9 +547,12 @@ class TestAddLogPath:
|
||||
"""POST /api/config/jails/missing/logpath returns 404."""
|
||||
from app.services.config_service import JailNotFoundError
|
||||
|
||||
with patch(
|
||||
"app.routers.jail_config.config_service.add_log_path",
|
||||
AsyncMock(side_effect=JailNotFoundError("missing")),
|
||||
with (
|
||||
patch(
|
||||
"app.routers.jail_config.config_service.add_log_path",
|
||||
AsyncMock(side_effect=JailNotFoundError("missing")),
|
||||
),
|
||||
patch("app.routers.jail_config.validate_log_path", return_value="/var/log/test.log"),
|
||||
):
|
||||
resp = await config_client.post(
|
||||
"/api/v1/config/jails/missing/logpath",
|
||||
@@ -594,14 +604,18 @@ class TestGetMapColorThresholds:
|
||||
|
||||
async def test_200_returns_thresholds(self, config_client: AsyncClient) -> None:
|
||||
"""GET /api/config/map-color-thresholds returns 200 with current values."""
|
||||
resp = await config_client.get("/api/v1/config/map-color-thresholds")
|
||||
mock_response = DomainMapColorThresholds(threshold_high=100, threshold_medium=50, threshold_low=20)
|
||||
with patch(
|
||||
"app.routers.config_misc.config_service.get_map_color_thresholds",
|
||||
AsyncMock(return_value=mock_response),
|
||||
):
|
||||
resp = await config_client.get("/api/v1/config/map-color-thresholds")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "threshold_high" in data
|
||||
assert "threshold_medium" in data
|
||||
assert "threshold_low" in data
|
||||
# Should return defaults after setup
|
||||
assert data["threshold_high"] == 100
|
||||
assert data["threshold_medium"] == 50
|
||||
assert data["threshold_low"] == 20
|
||||
@@ -622,9 +636,12 @@ class TestUpdateMapColorThresholds:
|
||||
"threshold_medium": 80,
|
||||
"threshold_low": 30,
|
||||
}
|
||||
resp = await config_client.put(
|
||||
"/api/v1/config/map-color-thresholds", json=update_payload
|
||||
)
|
||||
mock_response = DomainMapColorThresholds(threshold_high=200, threshold_medium=80, threshold_low=30)
|
||||
with patch(
|
||||
"app.routers.config_misc.config_service.get_map_color_thresholds",
|
||||
AsyncMock(return_value=mock_response),
|
||||
):
|
||||
resp = await config_client.put("/api/v1/config/map-color-thresholds", json=update_payload)
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
@@ -632,14 +649,6 @@ class TestUpdateMapColorThresholds:
|
||||
assert data["threshold_medium"] == 80
|
||||
assert data["threshold_low"] == 30
|
||||
|
||||
# Verify the values persist
|
||||
get_resp = await config_client.get("/api/v1/config/map-color-thresholds")
|
||||
assert get_resp.status_code == 200
|
||||
get_data = get_resp.json()
|
||||
assert get_data["threshold_high"] == 200
|
||||
assert get_data["threshold_medium"] == 80
|
||||
assert get_data["threshold_low"] == 30
|
||||
|
||||
async def test_400_for_invalid_order(self, config_client: AsyncClient) -> None:
|
||||
"""PUT /api/config/map-color-thresholds returns 400 if thresholds are misordered."""
|
||||
invalid_payload = {
|
||||
@@ -647,28 +656,22 @@ class TestUpdateMapColorThresholds:
|
||||
"threshold_medium": 50,
|
||||
"threshold_low": 20,
|
||||
}
|
||||
resp = await config_client.put(
|
||||
"/api/v1/config/map-color-thresholds", json=invalid_payload
|
||||
)
|
||||
resp = await config_client.put("/api/v1/config/map-color-thresholds", json=invalid_payload)
|
||||
|
||||
assert resp.status_code == 400
|
||||
assert "high > medium > low" in resp.json()["detail"]
|
||||
|
||||
async def test_400_for_non_positive_values(
|
||||
self, config_client: AsyncClient
|
||||
) -> None:
|
||||
"""PUT /api/config/map-color-thresholds returns 422 for non-positive values (Pydantic validation)."""
|
||||
async def test_400_for_non_positive_values(self, config_client: AsyncClient) -> None:
|
||||
"""PUT /api/config/map-color-thresholds returns 400 for non-positive values (Pydantic validation)."""
|
||||
invalid_payload = {
|
||||
"threshold_high": 100,
|
||||
"threshold_medium": 50,
|
||||
"threshold_low": 0,
|
||||
}
|
||||
resp = await config_client.put(
|
||||
"/api/v1/config/map-color-thresholds", json=invalid_payload
|
||||
)
|
||||
resp = await config_client.put("/api/v1/config/map-color-thresholds", json=invalid_payload)
|
||||
|
||||
# Pydantic validates ge=1 constraint before our service code runs
|
||||
assert resp.status_code == 422
|
||||
# Pydantic validates gt=0 constraint before our service code runs; ValueError -> 400
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -752,9 +755,7 @@ class TestActivateJail:
|
||||
"app.routers.jail_config.jail_config_service.activate_jail",
|
||||
AsyncMock(return_value=mock_response),
|
||||
):
|
||||
resp = await config_client.post(
|
||||
"/api/v1/config/jails/apache-auth/activate", json={}
|
||||
)
|
||||
resp = await config_client.post("/api/v1/config/jails/apache-auth/activate", json={})
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
@@ -765,9 +766,7 @@ class TestActivateJail:
|
||||
"""POST .../activate accepts override fields."""
|
||||
from app.models.config import JailActivationResponse
|
||||
|
||||
mock_response = JailActivationResponse(
|
||||
name="apache-auth", active=True, message="Activated."
|
||||
)
|
||||
mock_response = JailActivationResponse(name="apache-auth", active=True, message="Activated.")
|
||||
with patch(
|
||||
"app.routers.jail_config.jail_config_service.activate_jail",
|
||||
AsyncMock(return_value=mock_response),
|
||||
@@ -791,9 +790,7 @@ class TestActivateJail:
|
||||
"app.routers.jail_config.jail_config_service.activate_jail",
|
||||
AsyncMock(side_effect=JailNotFoundInConfigError("missing")),
|
||||
):
|
||||
resp = await config_client.post(
|
||||
"/api/v1/config/jails/missing/activate", json={}
|
||||
)
|
||||
resp = await config_client.post("/api/v1/config/jails/missing/activate", json={})
|
||||
|
||||
assert resp.status_code == 404
|
||||
|
||||
@@ -805,15 +802,11 @@ class TestActivateJail:
|
||||
"app.routers.jail_config.jail_config_service.activate_jail",
|
||||
AsyncMock(side_effect=JailAlreadyActiveError("sshd")),
|
||||
):
|
||||
resp = await config_client.post(
|
||||
"/api/v1/config/jails/sshd/activate", json={}
|
||||
)
|
||||
resp = await config_client.post("/api/v1/config/jails/sshd/activate", json={})
|
||||
|
||||
assert resp.status_code == 409
|
||||
|
||||
async def test_failed_activation_does_not_set_last_activation(
|
||||
self, config_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_failed_activation_does_not_set_last_activation(self, config_client: AsyncClient) -> None:
|
||||
"""A failed activation must not leave a stale last_activation record."""
|
||||
from app.exceptions import Fail2BanConnectionError
|
||||
|
||||
@@ -822,9 +815,7 @@ class TestActivateJail:
|
||||
"app.routers.jail_config.jail_config_service.activate_jail",
|
||||
AsyncMock(side_effect=Fail2BanConnectionError("No socket", "/tmp/fake.sock")),
|
||||
):
|
||||
resp = await config_client.post(
|
||||
"/api/v1/config/jails/sshd/activate", json={}
|
||||
)
|
||||
resp = await config_client.post("/api/v1/config/jails/sshd/activate", json={})
|
||||
|
||||
assert resp.status_code == 502
|
||||
assert config_client._transport.app.state.last_activation is None
|
||||
@@ -837,9 +828,7 @@ class TestActivateJail:
|
||||
"app.routers.jail_config.jail_config_service.activate_jail",
|
||||
AsyncMock(side_effect=JailNameError("bad name")),
|
||||
):
|
||||
resp = await config_client.post(
|
||||
"/api/v1/config/jails/bad-name/activate", json={}
|
||||
)
|
||||
resp = await config_client.post("/api/v1/config/jails/bad-name/activate", json={})
|
||||
|
||||
assert resp.status_code == 400
|
||||
|
||||
@@ -866,9 +855,7 @@ class TestActivateJail:
|
||||
"app.routers.jail_config.jail_config_service.activate_jail",
|
||||
AsyncMock(return_value=blocked_response),
|
||||
):
|
||||
resp = await config_client.post(
|
||||
"/api/v1/config/jails/airsonic-auth/activate", json={}
|
||||
)
|
||||
resp = await config_client.post("/api/v1/config/jails/airsonic-auth/activate", json={})
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
@@ -914,9 +901,7 @@ class TestDeactivateJail:
|
||||
"app.routers.jail_config.jail_config_service.deactivate_jail",
|
||||
AsyncMock(side_effect=JailNotFoundInConfigError("missing")),
|
||||
):
|
||||
resp = await config_client.post(
|
||||
"/api/v1/config/jails/missing/deactivate"
|
||||
)
|
||||
resp = await config_client.post("/api/v1/config/jails/missing/deactivate")
|
||||
|
||||
assert resp.status_code == 404
|
||||
|
||||
@@ -928,9 +913,7 @@ class TestDeactivateJail:
|
||||
"app.routers.jail_config.jail_config_service.deactivate_jail",
|
||||
AsyncMock(side_effect=JailAlreadyInactiveError("apache-auth")),
|
||||
):
|
||||
resp = await config_client.post(
|
||||
"/api/v1/config/jails/apache-auth/deactivate"
|
||||
)
|
||||
resp = await config_client.post("/api/v1/config/jails/apache-auth/deactivate")
|
||||
|
||||
assert resp.status_code == 409
|
||||
|
||||
@@ -942,9 +925,7 @@ class TestDeactivateJail:
|
||||
"app.routers.jail_config.jail_config_service.deactivate_jail",
|
||||
AsyncMock(side_effect=JailNameError("bad")),
|
||||
):
|
||||
resp = await config_client.post(
|
||||
"/api/v1/config/jails/sshd/deactivate"
|
||||
)
|
||||
resp = await config_client.post("/api/v1/config/jails/sshd/deactivate")
|
||||
|
||||
assert resp.status_code == 400
|
||||
|
||||
@@ -1011,10 +992,11 @@ class TestListFilters:
|
||||
|
||||
async def test_200_returns_filter_list(self, config_client: AsyncClient) -> None:
|
||||
"""GET /api/config/filters returns 200 with FilterListResponse."""
|
||||
from app.models.config import FilterListResponse
|
||||
|
||||
mock_response = FilterListResponse(
|
||||
filters=[_make_filter_config("sshd", active=True)],
|
||||
from app.models.config_domain import DomainFilterConfig, DomainFilterList
|
||||
|
||||
mock_response = DomainFilterList(
|
||||
items=[DomainFilterConfig(name="sshd", filename="sshd.conf", active=True, used_by_jails=["sshd"])],
|
||||
total=1,
|
||||
)
|
||||
with patch(
|
||||
@@ -1031,11 +1013,12 @@ class TestListFilters:
|
||||
|
||||
async def test_200_empty_filter_list(self, config_client: AsyncClient) -> None:
|
||||
"""GET /api/config/filters returns 200 with empty list when no filters found."""
|
||||
from app.models.config import FilterListResponse
|
||||
|
||||
from app.models.config_domain import DomainFilterList
|
||||
|
||||
with patch(
|
||||
"app.routers.filter_config.filter_config_service.list_filters",
|
||||
AsyncMock(return_value=FilterListResponse(filters=[], total=0)),
|
||||
AsyncMock(return_value=DomainFilterList(items=[], total=0)),
|
||||
):
|
||||
resp = await config_client.get("/api/v1/config/filters")
|
||||
|
||||
@@ -1043,16 +1026,15 @@ class TestListFilters:
|
||||
assert resp.json()["total"] == 0
|
||||
assert resp.json()["filters"] == []
|
||||
|
||||
async def test_active_filters_sorted_before_inactive(
|
||||
self, config_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_active_filters_sorted_before_inactive(self, config_client: AsyncClient) -> None:
|
||||
"""GET /api/config/filters returns active filters before inactive ones."""
|
||||
from app.models.config import FilterListResponse
|
||||
|
||||
mock_response = FilterListResponse(
|
||||
filters=[
|
||||
_make_filter_config("nginx", active=False),
|
||||
_make_filter_config("sshd", active=True),
|
||||
from app.models.config_domain import DomainFilterConfig, DomainFilterList
|
||||
|
||||
mock_response = DomainFilterList(
|
||||
items=[
|
||||
DomainFilterConfig(name="nginx", filename="nginx.conf", active=False),
|
||||
DomainFilterConfig(name="sshd", filename="sshd.conf", active=True, used_by_jails=["sshd"]),
|
||||
],
|
||||
total=2,
|
||||
)
|
||||
@@ -1063,8 +1045,8 @@ class TestListFilters:
|
||||
resp = await config_client.get("/api/v1/config/filters")
|
||||
|
||||
data = resp.json()
|
||||
assert data["filters"][0]["name"] == "sshd" # active first
|
||||
assert data["filters"][1]["name"] == "nginx" # inactive second
|
||||
assert data["filters"][0]["name"] == "sshd" # active first
|
||||
assert data["filters"][1]["name"] == "nginx" # inactive second
|
||||
|
||||
async def test_401_when_unauthenticated(self, config_client: AsyncClient) -> None:
|
||||
"""GET /api/config/filters returns 401 without a valid session."""
|
||||
@@ -1155,8 +1137,8 @@ class TestUpdateFilter:
|
||||
|
||||
assert resp.status_code == 404
|
||||
|
||||
async def test_422_for_invalid_regex(self, config_client: AsyncClient) -> None:
|
||||
"""PUT /api/config/filters/sshd returns 422 for bad regex."""
|
||||
async def test_400_for_invalid_regex(self, config_client: AsyncClient) -> None:
|
||||
"""PUT /api/config/filters/sshd returns 400 for bad regex."""
|
||||
from app.services.filter_config_service import FilterInvalidRegexError
|
||||
|
||||
with patch(
|
||||
@@ -1168,7 +1150,7 @@ class TestUpdateFilter:
|
||||
json={"failregex": ["[bad"]},
|
||||
)
|
||||
|
||||
assert resp.status_code == 422
|
||||
assert resp.status_code == 400
|
||||
|
||||
async def test_400_for_invalid_name(self, config_client: AsyncClient) -> None:
|
||||
"""PUT /api/config/filters/... with bad name returns 400."""
|
||||
@@ -1245,8 +1227,8 @@ class TestCreateFilter:
|
||||
|
||||
assert resp.status_code == 409
|
||||
|
||||
async def test_422_for_invalid_regex(self, config_client: AsyncClient) -> None:
|
||||
"""POST /api/config/filters returns 422 for bad regex."""
|
||||
async def test_400_for_invalid_regex(self, config_client: AsyncClient) -> None:
|
||||
"""POST /api/config/filters returns 400 for bad regex."""
|
||||
from app.services.filter_config_service import FilterInvalidRegexError
|
||||
|
||||
with patch(
|
||||
@@ -1258,7 +1240,7 @@ class TestCreateFilter:
|
||||
json={"name": "test", "failregex": ["[bad"]},
|
||||
)
|
||||
|
||||
assert resp.status_code == 422
|
||||
assert resp.status_code == 400
|
||||
|
||||
async def test_400_for_invalid_name(self, config_client: AsyncClient) -> None:
|
||||
"""POST /api/config/filters returns 400 for invalid filter name."""
|
||||
@@ -1572,9 +1554,7 @@ class TestUpdateActionRouter:
|
||||
"app.routers.action_config.action_config_service.update_action",
|
||||
AsyncMock(side_effect=ActionNotFoundError("missing")),
|
||||
):
|
||||
resp = await config_client.put(
|
||||
"/api/v1/config/actions/missing", json={}
|
||||
)
|
||||
resp = await config_client.put("/api/v1/config/actions/missing", json={})
|
||||
|
||||
assert resp.status_code == 404
|
||||
|
||||
@@ -1585,9 +1565,7 @@ class TestUpdateActionRouter:
|
||||
"app.routers.action_config.action_config_service.update_action",
|
||||
AsyncMock(side_effect=ActionNameError()),
|
||||
):
|
||||
resp = await config_client.put(
|
||||
"/api/v1/config/actions/badname", json={}
|
||||
)
|
||||
resp = await config_client.put("/api/v1/config/actions/badname", json={})
|
||||
|
||||
assert resp.status_code == 400
|
||||
|
||||
@@ -1808,9 +1786,7 @@ class TestRemoveActionFromJailRouter:
|
||||
"app.routers.action_config.action_config_service.remove_action_from_jail",
|
||||
AsyncMock(return_value=None),
|
||||
):
|
||||
resp = await config_client.delete(
|
||||
"/api/v1/config/jails/sshd/action/iptables"
|
||||
)
|
||||
resp = await config_client.delete("/api/v1/config/jails/sshd/action/iptables")
|
||||
|
||||
assert resp.status_code == 204
|
||||
|
||||
@@ -1821,9 +1797,7 @@ class TestRemoveActionFromJailRouter:
|
||||
"app.routers.action_config.action_config_service.remove_action_from_jail",
|
||||
AsyncMock(side_effect=JailNotFoundInConfigError("missing")),
|
||||
):
|
||||
resp = await config_client.delete(
|
||||
"/api/v1/config/jails/missing/action/iptables"
|
||||
)
|
||||
resp = await config_client.delete("/api/v1/config/jails/missing/action/iptables")
|
||||
|
||||
assert resp.status_code == 404
|
||||
|
||||
@@ -1834,9 +1808,7 @@ class TestRemoveActionFromJailRouter:
|
||||
"app.routers.action_config.action_config_service.remove_action_from_jail",
|
||||
AsyncMock(side_effect=JailNameError()),
|
||||
):
|
||||
resp = await config_client.delete(
|
||||
"/api/v1/config/jails/badjailname/action/iptables"
|
||||
)
|
||||
resp = await config_client.delete("/api/v1/config/jails/badjailname/action/iptables")
|
||||
|
||||
assert resp.status_code == 400
|
||||
|
||||
@@ -1847,9 +1819,7 @@ class TestRemoveActionFromJailRouter:
|
||||
"app.routers.action_config.action_config_service.remove_action_from_jail",
|
||||
AsyncMock(side_effect=ActionNameError()),
|
||||
):
|
||||
resp = await config_client.delete(
|
||||
"/api/v1/config/jails/sshd/action/badactionname"
|
||||
)
|
||||
resp = await config_client.delete("/api/v1/config/jails/sshd/action/badactionname")
|
||||
|
||||
assert resp.status_code == 400
|
||||
|
||||
@@ -1858,9 +1828,7 @@ class TestRemoveActionFromJailRouter:
|
||||
"app.routers.action_config.action_config_service.remove_action_from_jail",
|
||||
AsyncMock(return_value=None),
|
||||
) as mock_rm:
|
||||
resp = await config_client.delete(
|
||||
"/api/v1/config/jails/sshd/action/iptables?reload=true"
|
||||
)
|
||||
resp = await config_client.delete("/api/v1/config/jails/sshd/action/iptables?reload=true")
|
||||
|
||||
assert resp.status_code == 204
|
||||
assert mock_rm.call_args.kwargs.get("do_reload") is True
|
||||
@@ -1965,10 +1933,10 @@ class TestGetFail2BanLog:
|
||||
|
||||
assert resp.status_code == 502
|
||||
|
||||
async def test_422_for_lines_exceeding_max(self, config_client: AsyncClient) -> None:
|
||||
"""GET /api/config/fail2ban-log returns 422 for lines > 2000."""
|
||||
async def test_400_for_lines_exceeding_max(self, config_client: AsyncClient) -> None:
|
||||
"""GET /api/config/fail2ban-log returns 400 for lines > 2000."""
|
||||
resp = await config_client.get("/api/v1/config/fail2ban-log?lines=9999")
|
||||
assert resp.status_code == 422
|
||||
assert resp.status_code == 400
|
||||
|
||||
async def test_401_when_unauthenticated(self, config_client: AsyncClient) -> None:
|
||||
"""GET /api/config/fail2ban-log requires authentication."""
|
||||
@@ -2001,7 +1969,7 @@ class TestGetServiceStatus:
|
||||
async def test_200_when_online(self, config_client: AsyncClient) -> None:
|
||||
"""GET /api/config/service-status returns 200 with full status when online."""
|
||||
with patch(
|
||||
"app.routers.config_misc.health_service.get_service_status",
|
||||
"app.services.health_service.get_service_status",
|
||||
AsyncMock(return_value=self._mock_status(online=True)),
|
||||
):
|
||||
resp = await config_client.get("/api/v1/config/service-status")
|
||||
@@ -2016,7 +1984,7 @@ class TestGetServiceStatus:
|
||||
async def test_200_when_offline(self, config_client: AsyncClient) -> None:
|
||||
"""GET /api/config/service-status returns 200 with offline=False when daemon is down."""
|
||||
with patch(
|
||||
"app.routers.config_misc.health_service.get_service_status",
|
||||
"app.services.health_service.get_service_status",
|
||||
AsyncMock(return_value=self._mock_status(online=False)),
|
||||
):
|
||||
resp = await config_client.get("/api/v1/config/service-status")
|
||||
@@ -2049,9 +2017,7 @@ class TestValidateJailEndpoint:
|
||||
"""Returns 200 with valid=True when the jail config has no issues."""
|
||||
from app.models.config import JailValidationResult
|
||||
|
||||
mock_result = JailValidationResult(
|
||||
jail_name="sshd", valid=True, issues=[]
|
||||
)
|
||||
mock_result = JailValidationResult(jail_name="sshd", valid=True, issues=[])
|
||||
with patch(
|
||||
"app.routers.jail_config.jail_config_service.validate_jail_config",
|
||||
AsyncMock(return_value=mock_result),
|
||||
@@ -2069,9 +2035,7 @@ class TestValidateJailEndpoint:
|
||||
from app.models.config import JailValidationIssue, JailValidationResult
|
||||
|
||||
issue = JailValidationIssue(field="filter", message="Filter file not found: filter.d/bad.conf (or .local)")
|
||||
mock_result = JailValidationResult(
|
||||
jail_name="sshd", valid=False, issues=[issue]
|
||||
)
|
||||
mock_result = JailValidationResult(jail_name="sshd", valid=False, issues=[issue])
|
||||
with patch(
|
||||
"app.routers.jail_config.jail_config_service.validate_jail_config",
|
||||
AsyncMock(return_value=mock_result),
|
||||
@@ -2109,9 +2073,7 @@ class TestValidateJailEndpoint:
|
||||
class TestPendingRecovery:
|
||||
"""Tests for ``GET /api/config/pending-recovery``."""
|
||||
|
||||
async def test_returns_null_when_no_pending_recovery(
|
||||
self, config_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_returns_null_when_no_pending_recovery(self, config_client: AsyncClient) -> None:
|
||||
"""Returns null body (204-like 200) when pending_recovery is not set."""
|
||||
app = config_client._transport.app # type: ignore[attr-defined]
|
||||
app.state.pending_recovery = None
|
||||
@@ -2156,9 +2118,7 @@ class TestPendingRecovery:
|
||||
class TestRollbackEndpoint:
|
||||
"""Tests for ``POST /api/config/jails/{name}/rollback``."""
|
||||
|
||||
async def test_200_success_clears_pending_recovery(
|
||||
self, config_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_200_success_clears_pending_recovery(self, config_client: AsyncClient) -> None:
|
||||
"""A successful rollback returns 200 and clears app.state.pending_recovery."""
|
||||
import datetime
|
||||
|
||||
@@ -2193,9 +2153,7 @@ class TestRollbackEndpoint:
|
||||
# Successful rollback must clear the pending record.
|
||||
assert app.state.pending_recovery is None
|
||||
|
||||
async def test_200_fail_preserves_pending_recovery(
|
||||
self, config_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_200_fail_preserves_pending_recovery(self, config_client: AsyncClient) -> None:
|
||||
"""When fail2ban is still down after rollback, pending_recovery is retained."""
|
||||
import datetime
|
||||
|
||||
@@ -2248,4 +2206,3 @@ class TestRollbackEndpoint:
|
||||
base_url="http://test",
|
||||
).post("/api/v1/config/jails/sshd/rollback")
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
@@ -31,14 +31,16 @@ async def _do_setup(client: AsyncClient) -> None:
|
||||
|
||||
|
||||
async def _login(client: AsyncClient, password: str = "Mysecretpass1!") -> str:
|
||||
"""Helper: perform login and return the session token."""
|
||||
"""Helper: perform login and return the session token from the cookie."""
|
||||
resp = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"password": password},
|
||||
headers={"X-BanGUI-Request": "1"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
return str(resp.json()["token"])
|
||||
token = resp.cookies.get(SESSION_COOKIE_NAME)
|
||||
assert token is not None
|
||||
return str(token)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -49,9 +51,7 @@ async def _login(client: AsyncClient, password: str = "Mysecretpass1!") -> str:
|
||||
class TestCsrfProtection:
|
||||
"""CSRF middleware validation tests."""
|
||||
|
||||
async def test_post_with_cookie_and_csrf_header_passes(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
async def test_post_with_cookie_and_csrf_header_passes(self, client: AsyncClient) -> None:
|
||||
"""POST with session cookie and CSRF header is allowed."""
|
||||
await _do_setup(client)
|
||||
token = await _login(client)
|
||||
@@ -65,9 +65,7 @@ class TestCsrfProtection:
|
||||
# Expect 200 (logout succeeds) not 403 (CSRF failed)
|
||||
assert response.status_code == 200
|
||||
|
||||
async def test_post_with_cookie_without_csrf_header_rejected(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
async def test_post_with_cookie_without_csrf_header_rejected(self, client: AsyncClient) -> None:
|
||||
"""POST with session cookie but no CSRF header is rejected with 403."""
|
||||
await _do_setup(client)
|
||||
token = await _login(client)
|
||||
@@ -83,9 +81,7 @@ class TestCsrfProtection:
|
||||
assert "detail" in body
|
||||
assert "CSRF" in body["detail"]
|
||||
|
||||
async def test_post_with_cookie_with_wrong_csrf_value_rejected(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
async def test_post_with_cookie_with_wrong_csrf_value_rejected(self, client: AsyncClient) -> None:
|
||||
"""POST with session cookie and wrong CSRF header value is rejected."""
|
||||
await _do_setup(client)
|
||||
token = await _login(client)
|
||||
@@ -98,9 +94,7 @@ class TestCsrfProtection:
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
async def test_post_with_bearer_token_no_csrf_header_passes(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
async def test_post_with_bearer_token_no_csrf_header_passes(self, client: AsyncClient) -> None:
|
||||
"""POST with Bearer token but no CSRF header is allowed (not CSRF-vulnerable)."""
|
||||
await _do_setup(client)
|
||||
token = await _login(client)
|
||||
@@ -113,9 +107,7 @@ class TestCsrfProtection:
|
||||
# Expect 200 (logout succeeds) not 403 (CSRF check should be skipped)
|
||||
assert response.status_code == 200
|
||||
|
||||
async def test_get_with_cookie_no_csrf_header_passes(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
async def test_get_with_cookie_no_csrf_header_passes(self, client: AsyncClient) -> None:
|
||||
"""GET with session cookie but no CSRF header is allowed (safe method)."""
|
||||
await _do_setup(client)
|
||||
token = await _login(client)
|
||||
@@ -129,9 +121,7 @@ class TestCsrfProtection:
|
||||
# Expect 200 (session valid) not 403 (CSRF check should be skipped for GET)
|
||||
assert response.status_code == 200
|
||||
|
||||
async def test_options_with_cookie_no_csrf_header_passes(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
async def test_options_with_cookie_no_csrf_header_passes(self, client: AsyncClient) -> None:
|
||||
"""OPTIONS with session cookie but no CSRF header is allowed (safe method)."""
|
||||
await _do_setup(client)
|
||||
token = await _login(client)
|
||||
@@ -145,9 +135,7 @@ class TestCsrfProtection:
|
||||
# Expect not 403
|
||||
assert response.status_code != 403
|
||||
|
||||
async def test_head_with_cookie_no_csrf_header_passes(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
async def test_head_with_cookie_no_csrf_header_passes(self, client: AsyncClient) -> None:
|
||||
"""HEAD with session cookie but no CSRF header is allowed (safe method)."""
|
||||
await _do_setup(client)
|
||||
token = await _login(client)
|
||||
@@ -161,9 +149,7 @@ class TestCsrfProtection:
|
||||
# Expect not 403
|
||||
assert response.status_code != 403
|
||||
|
||||
async def test_delete_with_cookie_and_csrf_header_passes(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
async def test_delete_with_cookie_and_csrf_header_passes(self, client: AsyncClient) -> None:
|
||||
"""DELETE with session cookie and CSRF header is allowed."""
|
||||
await _do_setup(client)
|
||||
token = await _login(client)
|
||||
@@ -180,9 +166,7 @@ class TestCsrfProtection:
|
||||
# Should not be 403 (CSRF failed)
|
||||
assert response.status_code != 403
|
||||
|
||||
async def test_delete_with_cookie_without_csrf_header_rejected(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
async def test_delete_with_cookie_without_csrf_header_rejected(self, client: AsyncClient) -> None:
|
||||
"""DELETE with session cookie but no CSRF header is rejected with 403."""
|
||||
await _do_setup(client)
|
||||
token = await _login(client)
|
||||
@@ -197,9 +181,7 @@ class TestCsrfProtection:
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
async def test_put_with_cookie_and_csrf_header_passes(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
async def test_put_with_cookie_and_csrf_header_passes(self, client: AsyncClient) -> None:
|
||||
"""PUT with session cookie and CSRF header is allowed."""
|
||||
await _do_setup(client)
|
||||
token = await _login(client)
|
||||
@@ -214,9 +196,7 @@ class TestCsrfProtection:
|
||||
# Should not be 403 (CSRF failed)
|
||||
assert response.status_code != 403
|
||||
|
||||
async def test_put_with_cookie_without_csrf_header_rejected(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
async def test_put_with_cookie_without_csrf_header_rejected(self, client: AsyncClient) -> None:
|
||||
"""PUT with session cookie but no CSRF header is rejected with 403."""
|
||||
await _do_setup(client)
|
||||
token = await _login(client)
|
||||
@@ -230,9 +210,7 @@ class TestCsrfProtection:
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
async def test_patch_with_cookie_and_csrf_header_passes(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
async def test_patch_with_cookie_and_csrf_header_passes(self, client: AsyncClient) -> None:
|
||||
"""PATCH with session cookie and CSRF header is allowed."""
|
||||
await _do_setup(client)
|
||||
token = await _login(client)
|
||||
@@ -247,9 +225,7 @@ class TestCsrfProtection:
|
||||
# Should not be 403 (CSRF failed)
|
||||
assert response.status_code != 403
|
||||
|
||||
async def test_patch_with_cookie_without_csrf_header_rejected(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
async def test_patch_with_cookie_without_csrf_header_rejected(self, client: AsyncClient) -> None:
|
||||
"""PATCH with session cookie but no CSRF header is rejected with 403."""
|
||||
await _do_setup(client)
|
||||
token = await _login(client)
|
||||
@@ -262,9 +238,7 @@ class TestCsrfProtection:
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
async def test_post_without_cookie_no_csrf_header_passes(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
async def test_post_without_cookie_no_csrf_header_passes(self, client: AsyncClient) -> None:
|
||||
"""POST without session cookie or Bearer token bypasses CSRF check."""
|
||||
await _do_setup(client)
|
||||
|
||||
@@ -279,9 +253,7 @@ class TestCsrfProtection:
|
||||
# (Actually logout is idempotent and doesn't require auth, so we expect 200)
|
||||
assert response.status_code in (200, 401)
|
||||
|
||||
async def test_bearer_token_via_authorization_header(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
async def test_bearer_token_via_authorization_header(self, client: AsyncClient) -> None:
|
||||
"""Bearer token in Authorization header bypasses CSRF check."""
|
||||
await _do_setup(client)
|
||||
token = await _login(client)
|
||||
|
||||
@@ -10,13 +10,17 @@ import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
import app
|
||||
|
||||
from app.config import Settings
|
||||
from app.db import init_db
|
||||
from app.main import create_app
|
||||
from app.models.ban import (
|
||||
DashboardBanItem,
|
||||
DashboardBanListResponse,
|
||||
from app.models.ban_domain import (
|
||||
DomainBansByCountry,
|
||||
DomainBansByJail,
|
||||
DomainBanTrend,
|
||||
DomainBanTrendBucket,
|
||||
DomainDashboardBanItem,
|
||||
DomainDashboardBanList,
|
||||
DomainJailBanCount,
|
||||
)
|
||||
from app.models.server import ServerStatus
|
||||
|
||||
@@ -25,7 +29,7 @@ from app.models.server import ServerStatus
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SETUP_PAYLOAD = {
|
||||
"master_password": "testpassword1",
|
||||
"master_password": "Testpass1!",
|
||||
"database_path": "bangui.db",
|
||||
"fail2ban_socket": "/var/run/fail2ban/fail2ban.sock",
|
||||
"timezone": "UTC",
|
||||
@@ -40,13 +44,17 @@ async def dashboard_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
|
||||
Unlike the shared ``client`` fixture this one also exposes access to
|
||||
``app.state`` via the app instance so we can seed the status cache.
|
||||
"""
|
||||
config_dir = tmp_path / "fail2ban"
|
||||
config_dir.mkdir()
|
||||
settings = Settings(
|
||||
database_path=str(tmp_path / "dashboard_test.db"),
|
||||
fail2ban_socket="/tmp/fake_fail2ban.sock",
|
||||
session_secret="test-dashboard-secret",
|
||||
fail2ban_config_dir=str(config_dir),
|
||||
session_secret="test-dashboard-secret-that-is-long-enough",
|
||||
session_duration_minutes=60,
|
||||
timezone="UTC",
|
||||
log_level="debug",
|
||||
session_cookie_secure=False,
|
||||
)
|
||||
app = create_app(settings=settings)
|
||||
|
||||
@@ -66,8 +74,13 @@ async def dashboard_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
|
||||
# Provide a stub HTTP session so ban/access endpoints can access app.state.http_session.
|
||||
app.state.http_session = MagicMock()
|
||||
|
||||
# Initialize GeoCache (normally done in lifespan handler)
|
||||
from app.services.geo_cache import GeoCache
|
||||
|
||||
app.state.geo_cache = GeoCache()
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
async with AsyncClient(transport=transport, base_url="http://test", headers={"X-BanGUI-Request": "1"}) as ac:
|
||||
# Complete setup so the middleware doesn't redirect.
|
||||
resp = await ac.post("/api/v1/setup", json=_SETUP_PAYLOAD)
|
||||
assert resp.status_code == 201
|
||||
@@ -87,13 +100,17 @@ async def dashboard_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
|
||||
@pytest.fixture
|
||||
async def offline_dashboard_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
|
||||
"""Like ``dashboard_client`` but with an offline server status."""
|
||||
config_dir = tmp_path / "fail2ban"
|
||||
config_dir.mkdir()
|
||||
settings = Settings(
|
||||
database_path=str(tmp_path / "dashboard_offline_test.db"),
|
||||
fail2ban_socket="/tmp/fake_fail2ban.sock",
|
||||
session_secret="test-dashboard-offline-secret",
|
||||
fail2ban_config_dir=str(config_dir),
|
||||
session_secret="test-dashboard-offline-secret-long-enough",
|
||||
session_duration_minutes=60,
|
||||
timezone="UTC",
|
||||
log_level="debug",
|
||||
session_cookie_secure=False,
|
||||
)
|
||||
app = create_app(settings=settings)
|
||||
|
||||
@@ -105,8 +122,13 @@ async def offline_dashboard_client(tmp_path: Path) -> AsyncClient: # type: igno
|
||||
app.state.server_status = ServerStatus(online=False)
|
||||
app.state.http_session = MagicMock()
|
||||
|
||||
# Initialize GeoCache (normally done in lifespan handler)
|
||||
from app.services.geo_cache import GeoCache
|
||||
|
||||
app.state.geo_cache = GeoCache()
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
async with AsyncClient(transport=transport, base_url="http://test", headers={"X-BanGUI-Request": "1"}) as ac:
|
||||
resp = await ac.post("/api/v1/setup", json=_SETUP_PAYLOAD)
|
||||
assert resp.status_code == 201
|
||||
|
||||
@@ -129,25 +151,19 @@ async def offline_dashboard_client(tmp_path: Path) -> AsyncClient: # type: igno
|
||||
class TestDashboardStatus:
|
||||
"""GET /api/dashboard/status."""
|
||||
|
||||
async def test_returns_200_when_authenticated(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_returns_200_when_authenticated(self, dashboard_client: AsyncClient) -> None:
|
||||
"""Authenticated request returns HTTP 200."""
|
||||
response = await dashboard_client.get("/api/v1/dashboard/status")
|
||||
assert response.status_code == 200
|
||||
|
||||
async def test_returns_401_when_unauthenticated(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
async def test_returns_401_when_unauthenticated(self, client: AsyncClient) -> None:
|
||||
"""Unauthenticated request returns HTTP 401."""
|
||||
# Complete setup so the middleware allows the request through.
|
||||
await client.post("/api/v1/setup", json=_SETUP_PAYLOAD)
|
||||
response = await client.get("/api/v1/dashboard/status")
|
||||
assert response.status_code == 401
|
||||
|
||||
async def test_response_shape_when_online(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_response_shape_when_online(self, dashboard_client: AsyncClient) -> None:
|
||||
"""Response contains the expected ``status`` object shape."""
|
||||
response = await dashboard_client.get("/api/v1/dashboard/status")
|
||||
body = response.json()
|
||||
@@ -161,9 +177,7 @@ class TestDashboardStatus:
|
||||
assert "total_bans" in status
|
||||
assert "total_failures" in status
|
||||
|
||||
async def test_cached_values_returned_when_online(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_cached_values_returned_when_online(self, dashboard_client: AsyncClient) -> None:
|
||||
"""Endpoint returns the exact values from ``app.state.server_status``."""
|
||||
response = await dashboard_client.get("/api/v1/dashboard/status")
|
||||
body = response.json()
|
||||
@@ -175,9 +189,7 @@ class TestDashboardStatus:
|
||||
assert status["total_bans"] == 10
|
||||
assert status["total_failures"] == 5
|
||||
|
||||
async def test_offline_status_returned_correctly(
|
||||
self, offline_dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_offline_status_returned_correctly(self, offline_dashboard_client: AsyncClient) -> None:
|
||||
"""Endpoint returns online=False when the cache holds an offline snapshot."""
|
||||
response = await offline_dashboard_client.get("/api/v1/dashboard/status")
|
||||
assert response.status_code == 200
|
||||
@@ -190,9 +202,7 @@ class TestDashboardStatus:
|
||||
assert status["total_bans"] == 0
|
||||
assert status["total_failures"] == 0
|
||||
|
||||
async def test_returns_offline_when_state_not_initialised(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
async def test_returns_offline_when_state_not_initialised(self, client: AsyncClient) -> None:
|
||||
"""Endpoint returns online=False as a safe default if the cache is absent."""
|
||||
# Setup + login so the endpoint is reachable.
|
||||
await client.post("/api/v1/setup", json=_SETUP_PAYLOAD)
|
||||
@@ -200,7 +210,9 @@ class TestDashboardStatus:
|
||||
"/api/v1/auth/login",
|
||||
json={"password": _SETUP_PAYLOAD["master_password"]},
|
||||
)
|
||||
# server_status is not set on app.state in the shared `client` fixture.
|
||||
# Clear server_status to simulate uninitialized state.
|
||||
client._transport.app.state.server_status = None # type: ignore[attr-defined]
|
||||
client._transport.app.state.server_status = None # type: ignore[attr-defined]
|
||||
response = await client.get("/api/v1/dashboard/status")
|
||||
assert response.status_code == 200
|
||||
status = response.json()["status"]
|
||||
@@ -212,10 +224,10 @@ class TestDashboardStatus:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_ban_list_response(n: int = 2) -> DashboardBanListResponse:
|
||||
"""Build a mock DashboardBanListResponse with *n* items."""
|
||||
def _make_ban_list_response(n: int = 2) -> DomainDashboardBanList:
|
||||
"""Build a mock DomainDashboardBanList with *n* items."""
|
||||
items = [
|
||||
DashboardBanItem(
|
||||
DomainDashboardBanItem(
|
||||
ip=f"1.2.3.{i}",
|
||||
jail="sshd",
|
||||
banned_at="2026-03-01T10:00:00+00:00",
|
||||
@@ -229,15 +241,18 @@ def _make_ban_list_response(n: int = 2) -> DashboardBanListResponse:
|
||||
)
|
||||
for i in range(n)
|
||||
]
|
||||
return DashboardBanListResponse(items=items, total=n, page=1, page_size=100)
|
||||
return DomainDashboardBanList(
|
||||
items=items,
|
||||
total=n,
|
||||
page=1,
|
||||
page_size=100,
|
||||
)
|
||||
|
||||
|
||||
class TestDashboardBans:
|
||||
"""GET /api/dashboard/bans."""
|
||||
|
||||
async def test_returns_200_when_authenticated(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_returns_200_when_authenticated(self, dashboard_client: AsyncClient) -> None:
|
||||
"""Authenticated request returns HTTP 200."""
|
||||
with patch(
|
||||
"app.routers.dashboard.ban_service.list_bans",
|
||||
@@ -246,17 +261,13 @@ class TestDashboardBans:
|
||||
response = await dashboard_client.get("/api/v1/dashboard/bans")
|
||||
assert response.status_code == 200
|
||||
|
||||
async def test_returns_401_when_unauthenticated(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
async def test_returns_401_when_unauthenticated(self, client: AsyncClient) -> None:
|
||||
"""Unauthenticated request returns HTTP 401."""
|
||||
await client.post("/api/v1/setup", json=_SETUP_PAYLOAD)
|
||||
response = await client.get("/api/v1/dashboard/bans")
|
||||
assert response.status_code == 401
|
||||
|
||||
async def test_response_contains_items_and_total(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_response_contains_items_and_total(self, dashboard_client: AsyncClient) -> None:
|
||||
"""Response body contains ``items`` list and ``total`` count."""
|
||||
with patch(
|
||||
"app.routers.dashboard.ban_service.list_bans",
|
||||
@@ -266,8 +277,8 @@ class TestDashboardBans:
|
||||
|
||||
body = response.json()
|
||||
assert "items" in body
|
||||
assert "total" in body
|
||||
assert body["total"] == 3
|
||||
assert "pagination" in body
|
||||
assert body["pagination"]["total"] == 3
|
||||
assert len(body["items"]) == 3
|
||||
|
||||
async def test_default_range_is_24h(self, dashboard_client: AsyncClient) -> None:
|
||||
@@ -279,9 +290,7 @@ class TestDashboardBans:
|
||||
called_range = mock_list.call_args[0][1]
|
||||
assert called_range == "24h"
|
||||
|
||||
async def test_accepts_time_range_param(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_accepts_time_range_param(self, dashboard_client: AsyncClient) -> None:
|
||||
"""The ``range`` query parameter is forwarded to ban_service."""
|
||||
mock_list = AsyncMock(return_value=_make_ban_list_response())
|
||||
with patch("app.routers.dashboard.ban_service.list_bans", new=mock_list):
|
||||
@@ -290,9 +299,7 @@ class TestDashboardBans:
|
||||
called_range = mock_list.call_args[0][1]
|
||||
assert called_range == "7d"
|
||||
|
||||
async def test_accepts_source_param(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_accepts_source_param(self, dashboard_client: AsyncClient) -> None:
|
||||
"""The ``source`` query parameter is forwarded to ban_service."""
|
||||
mock_list = AsyncMock(return_value=_make_ban_list_response())
|
||||
with patch("app.routers.dashboard.ban_service.list_bans", new=mock_list):
|
||||
@@ -301,11 +308,14 @@ class TestDashboardBans:
|
||||
called_source = mock_list.call_args[1]["source"]
|
||||
assert called_source == "archive"
|
||||
|
||||
async def test_empty_ban_list_returns_zero_total(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_empty_ban_list_returns_zero_total(self, dashboard_client: AsyncClient) -> None:
|
||||
"""Returns ``total=0`` and empty ``items`` when no bans are in range."""
|
||||
empty = DashboardBanListResponse(items=[], total=0, page=1, page_size=100)
|
||||
empty = DomainDashboardBanList(
|
||||
items=[],
|
||||
total=0,
|
||||
page=1,
|
||||
page_size=100,
|
||||
)
|
||||
with patch(
|
||||
"app.routers.dashboard.ban_service.list_bans",
|
||||
new=AsyncMock(return_value=empty),
|
||||
@@ -313,7 +323,7 @@ class TestDashboardBans:
|
||||
response = await dashboard_client.get("/api/v1/dashboard/bans")
|
||||
|
||||
body = response.json()
|
||||
assert body["total"] == 0
|
||||
assert body["pagination"]["total"] == 0
|
||||
assert body["items"] == []
|
||||
|
||||
async def test_item_shape_is_correct(self, dashboard_client: AsyncClient) -> None:
|
||||
@@ -336,12 +346,10 @@ class TestDashboardBans:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_bans_by_country_response() -> object:
|
||||
"""Build a stub BansByCountryResponse."""
|
||||
from app.models.ban import BansByCountryResponse
|
||||
|
||||
def _make_bans_by_country_response() -> DomainBansByCountry:
|
||||
"""Build a stub DomainBansByCountry."""
|
||||
items = [
|
||||
DashboardBanItem(
|
||||
DomainDashboardBanItem(
|
||||
ip="1.2.3.4",
|
||||
jail="sshd",
|
||||
banned_at="2026-03-01T10:00:00+00:00",
|
||||
@@ -353,7 +361,7 @@ def _make_bans_by_country_response() -> object:
|
||||
ban_count=1,
|
||||
origin="selfblock",
|
||||
),
|
||||
DashboardBanItem(
|
||||
DomainDashboardBanItem(
|
||||
ip="5.6.7.8",
|
||||
jail="blocklist-import",
|
||||
banned_at="2026-03-01T10:05:00+00:00",
|
||||
@@ -366,10 +374,10 @@ def _make_bans_by_country_response() -> object:
|
||||
origin="blocklist",
|
||||
),
|
||||
]
|
||||
return BansByCountryResponse(
|
||||
return DomainBansByCountry(
|
||||
countries={"DE": 1, "US": 1},
|
||||
country_names={"DE": "Germany", "US": "United States"},
|
||||
bans=items,
|
||||
items=items,
|
||||
total=2,
|
||||
)
|
||||
|
||||
@@ -378,9 +386,7 @@ def _make_bans_by_country_response() -> object:
|
||||
class TestBansByCountry:
|
||||
"""GET /api/dashboard/bans/by-country."""
|
||||
|
||||
async def test_returns_200_when_authenticated(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_returns_200_when_authenticated(self, dashboard_client: AsyncClient) -> None:
|
||||
"""Authenticated request returns HTTP 200."""
|
||||
with patch(
|
||||
"app.routers.dashboard.ban_service.bans_by_country",
|
||||
@@ -389,9 +395,7 @@ class TestBansByCountry:
|
||||
response = await dashboard_client.get("/api/v1/dashboard/bans/by-country")
|
||||
assert response.status_code == 200
|
||||
|
||||
async def test_returns_401_when_unauthenticated(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
async def test_returns_401_when_unauthenticated(self, client: AsyncClient) -> None:
|
||||
"""Unauthenticated request returns HTTP 401."""
|
||||
await client.post("/api/v1/setup", json=_SETUP_PAYLOAD)
|
||||
response = await client.get("/api/v1/dashboard/bans/by-country")
|
||||
@@ -415,38 +419,26 @@ class TestBansByCountry:
|
||||
assert body["countries"]["US"] == 1
|
||||
assert body["country_names"]["DE"] == "Germany"
|
||||
|
||||
async def test_accepts_time_range_param(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_accepts_time_range_param(self, dashboard_client: AsyncClient) -> None:
|
||||
"""The range query parameter is forwarded to ban_service."""
|
||||
mock_fn = AsyncMock(return_value=_make_bans_by_country_response())
|
||||
with patch(
|
||||
"app.routers.dashboard.ban_service.bans_by_country", new=mock_fn
|
||||
):
|
||||
with patch("app.routers.dashboard.ban_service.bans_by_country", new=mock_fn):
|
||||
await dashboard_client.get("/api/v1/dashboard/bans/by-country?range=7d")
|
||||
|
||||
called_range = mock_fn.call_args[0][1]
|
||||
assert called_range == "7d"
|
||||
|
||||
async def test_invalid_source_returns_422(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
"""An invalid source value returns HTTP 422."""
|
||||
response = await dashboard_client.get(
|
||||
"/api/v1/dashboard/bans/by-country?source=invalid"
|
||||
)
|
||||
assert response.status_code == 422
|
||||
async def test_invalid_source_returns_400(self, dashboard_client: AsyncClient) -> None:
|
||||
"""An invalid source value returns HTTP 400."""
|
||||
response = await dashboard_client.get("/api/v1/dashboard/bans/by-country?source=invalid")
|
||||
assert response.status_code == 400
|
||||
|
||||
async def test_empty_window_returns_empty_response(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_empty_window_returns_empty_response(self, dashboard_client: AsyncClient) -> None:
|
||||
"""Empty time range returns empty countries dict and bans list."""
|
||||
from app.models.ban import BansByCountryResponse
|
||||
|
||||
empty = BansByCountryResponse(
|
||||
empty = DomainBansByCountry(
|
||||
countries={},
|
||||
country_names={},
|
||||
bans=[],
|
||||
items=[],
|
||||
total=0,
|
||||
)
|
||||
with patch(
|
||||
@@ -469,9 +461,7 @@ class TestBansByCountry:
|
||||
class TestDashboardBansOriginField:
|
||||
"""Verify that the ``origin`` field is present in API responses."""
|
||||
|
||||
async def test_origin_present_in_ban_list_items(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_origin_present_in_ban_list_items(self, dashboard_client: AsyncClient) -> None:
|
||||
"""Each item in ``/api/dashboard/bans`` carries an ``origin`` field."""
|
||||
with patch(
|
||||
"app.routers.dashboard.ban_service.list_bans",
|
||||
@@ -483,9 +473,7 @@ class TestDashboardBansOriginField:
|
||||
assert "origin" in item
|
||||
assert item["origin"] in ("blocklist", "selfblock")
|
||||
|
||||
async def test_selfblock_origin_serialised_correctly(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_selfblock_origin_serialised_correctly(self, dashboard_client: AsyncClient) -> None:
|
||||
"""A ban from a non-blocklist jail serialises as ``"selfblock"``."""
|
||||
with patch(
|
||||
"app.routers.dashboard.ban_service.list_bans",
|
||||
@@ -497,9 +485,7 @@ class TestDashboardBansOriginField:
|
||||
assert item["jail"] == "sshd"
|
||||
assert item["origin"] == "selfblock"
|
||||
|
||||
async def test_origin_present_in_bans_by_country(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_origin_present_in_bans_by_country(self, dashboard_client: AsyncClient) -> None:
|
||||
"""Each ban in ``/api/dashboard/bans/by-country`` carries an ``origin``."""
|
||||
with patch(
|
||||
"app.routers.dashboard.ban_service.bans_by_country",
|
||||
@@ -512,9 +498,7 @@ class TestDashboardBansOriginField:
|
||||
origins = {ban["origin"] for ban in bans}
|
||||
assert origins == {"blocklist", "selfblock"}
|
||||
|
||||
async def test_bans_by_country_source_param_forwarded(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_bans_by_country_source_param_forwarded(self, dashboard_client: AsyncClient) -> None:
|
||||
"""The ``source`` query parameter is forwarded to bans_by_country."""
|
||||
mock_fn = AsyncMock(return_value=_make_bans_by_country_response())
|
||||
with patch("app.routers.dashboard.ban_service.bans_by_country", new=mock_fn):
|
||||
@@ -522,22 +506,16 @@ class TestDashboardBansOriginField:
|
||||
|
||||
assert mock_fn.call_args[1]["source"] == "archive"
|
||||
|
||||
async def test_bans_by_country_country_code_forwarded(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_bans_by_country_country_code_forwarded(self, dashboard_client: AsyncClient) -> None:
|
||||
"""The ``country_code`` query parameter is forwarded to bans_by_country."""
|
||||
mock_fn = AsyncMock(return_value=_make_bans_by_country_response())
|
||||
with patch("app.routers.dashboard.ban_service.bans_by_country", new=mock_fn):
|
||||
await dashboard_client.get(
|
||||
"/api/v1/dashboard/bans/by-country?country_code=DE"
|
||||
)
|
||||
await dashboard_client.get("/api/v1/dashboard/bans/by-country?country_code=DE")
|
||||
|
||||
_, kwargs = mock_fn.call_args
|
||||
assert kwargs.get("country_code") == "DE"
|
||||
|
||||
async def test_blocklist_origin_serialised_correctly(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_blocklist_origin_serialised_correctly(self, dashboard_client: AsyncClient) -> None:
|
||||
"""A ban from the ``blocklist-import`` jail serialises as ``"blocklist"``."""
|
||||
with patch(
|
||||
"app.routers.dashboard.ban_service.bans_by_country",
|
||||
@@ -558,9 +536,7 @@ class TestDashboardBansOriginField:
|
||||
class TestOriginFilterParam:
|
||||
"""Verify that the ``origin`` query parameter is forwarded to the service."""
|
||||
|
||||
async def test_bans_origin_blocklist_forwarded_to_service(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_bans_origin_blocklist_forwarded_to_service(self, dashboard_client: AsyncClient) -> None:
|
||||
"""``?origin=blocklist`` is passed to ``ban_service.list_bans``."""
|
||||
mock_list = AsyncMock(return_value=_make_ban_list_response())
|
||||
with patch("app.routers.dashboard.ban_service.list_bans", new=mock_list):
|
||||
@@ -569,9 +545,7 @@ class TestOriginFilterParam:
|
||||
_, kwargs = mock_list.call_args
|
||||
assert kwargs.get("origin") == "blocklist"
|
||||
|
||||
async def test_bans_origin_selfblock_forwarded_to_service(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_bans_origin_selfblock_forwarded_to_service(self, dashboard_client: AsyncClient) -> None:
|
||||
"""``?origin=selfblock`` is passed to ``ban_service.list_bans``."""
|
||||
mock_list = AsyncMock(return_value=_make_ban_list_response())
|
||||
with patch("app.routers.dashboard.ban_service.list_bans", new=mock_list):
|
||||
@@ -580,9 +554,7 @@ class TestOriginFilterParam:
|
||||
_, kwargs = mock_list.call_args
|
||||
assert kwargs.get("origin") == "selfblock"
|
||||
|
||||
async def test_bans_no_origin_param_defaults_to_none(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_bans_no_origin_param_defaults_to_none(self, dashboard_client: AsyncClient) -> None:
|
||||
"""Omitting ``origin`` passes ``None`` to the service (no filtering)."""
|
||||
mock_list = AsyncMock(return_value=_make_ban_list_response())
|
||||
with patch("app.routers.dashboard.ban_service.list_bans", new=mock_list):
|
||||
@@ -591,36 +563,24 @@ class TestOriginFilterParam:
|
||||
_, kwargs = mock_list.call_args
|
||||
assert kwargs.get("origin") is None
|
||||
|
||||
async def test_bans_invalid_origin_returns_422(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
"""An invalid ``origin`` value returns HTTP 422 Unprocessable Entity."""
|
||||
async def test_bans_invalid_origin_returns_400(self, dashboard_client: AsyncClient) -> None:
|
||||
"""An invalid ``origin`` value returns HTTP 400."""
|
||||
response = await dashboard_client.get("/api/v1/dashboard/bans?origin=invalid")
|
||||
assert response.status_code == 422
|
||||
assert response.status_code == 400
|
||||
|
||||
async def test_by_country_origin_blocklist_forwarded(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_by_country_origin_blocklist_forwarded(self, dashboard_client: AsyncClient) -> None:
|
||||
"""``?origin=blocklist`` is passed to ``ban_service.bans_by_country``."""
|
||||
mock_fn = AsyncMock(return_value=_make_bans_by_country_response())
|
||||
with patch(
|
||||
"app.routers.dashboard.ban_service.bans_by_country", new=mock_fn
|
||||
):
|
||||
await dashboard_client.get(
|
||||
"/api/v1/dashboard/bans/by-country?origin=blocklist"
|
||||
)
|
||||
with patch("app.routers.dashboard.ban_service.bans_by_country", new=mock_fn):
|
||||
await dashboard_client.get("/api/v1/dashboard/bans/by-country?origin=blocklist")
|
||||
|
||||
_, kwargs = mock_fn.call_args
|
||||
assert kwargs.get("origin") == "blocklist"
|
||||
|
||||
async def test_by_country_no_origin_defaults_to_none(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_by_country_no_origin_defaults_to_none(self, dashboard_client: AsyncClient) -> None:
|
||||
"""Omitting ``origin`` passes ``None`` to ``bans_by_country``."""
|
||||
mock_fn = AsyncMock(return_value=_make_bans_by_country_response())
|
||||
with patch(
|
||||
"app.routers.dashboard.ban_service.bans_by_country", new=mock_fn
|
||||
):
|
||||
with patch("app.routers.dashboard.ban_service.bans_by_country", new=mock_fn):
|
||||
await dashboard_client.get("/api/v1/dashboard/bans/by-country")
|
||||
|
||||
_, kwargs = mock_fn.call_args
|
||||
@@ -632,24 +592,17 @@ class TestOriginFilterParam:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_ban_trend_response(n_buckets: int = 24) -> object:
|
||||
"""Build a stub :class:`~app.models.ban.BanTrendResponse`."""
|
||||
from app.models.ban import BanTrendBucket, BanTrendResponse
|
||||
|
||||
buckets = [
|
||||
BanTrendBucket(timestamp=f"2026-03-01T{i:02d}:00:00+00:00", count=i)
|
||||
for i in range(n_buckets)
|
||||
]
|
||||
return BanTrendResponse(buckets=buckets, bucket_size="1h")
|
||||
def _make_ban_trend_response(n_buckets: int = 24) -> DomainBanTrend:
|
||||
"""Build a stub :class:`~app.models.ban_domain.DomainBanTrend`."""
|
||||
buckets = [DomainBanTrendBucket(timestamp=f"2026-03-01T{i:02d}:00:00+00:00", count=i) for i in range(n_buckets)]
|
||||
return DomainBanTrend(buckets=buckets, bucket_size="1h")
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
class TestBanTrend:
|
||||
"""GET /api/dashboard/bans/trend."""
|
||||
|
||||
async def test_returns_200_when_authenticated(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_returns_200_when_authenticated(self, dashboard_client: AsyncClient) -> None:
|
||||
"""Authenticated request returns HTTP 200."""
|
||||
with patch(
|
||||
"app.routers.dashboard.ban_service.ban_trend",
|
||||
@@ -658,9 +611,7 @@ class TestBanTrend:
|
||||
response = await dashboard_client.get("/api/v1/dashboard/bans/trend")
|
||||
assert response.status_code == 200
|
||||
|
||||
async def test_returns_401_when_unauthenticated(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
async def test_returns_401_when_unauthenticated(self, client: AsyncClient) -> None:
|
||||
"""Unauthenticated request returns HTTP 401."""
|
||||
await client.post("/api/v1/setup", json=_SETUP_PAYLOAD)
|
||||
response = await client.get("/api/v1/dashboard/bans/trend")
|
||||
@@ -680,9 +631,7 @@ class TestBanTrend:
|
||||
assert len(body["buckets"]) == 24
|
||||
assert body["bucket_size"] == "1h"
|
||||
|
||||
async def test_each_bucket_has_timestamp_and_count(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_each_bucket_has_timestamp_and_count(self, dashboard_client: AsyncClient) -> None:
|
||||
"""Every element of ``buckets`` has ``timestamp`` and ``count``."""
|
||||
with patch(
|
||||
"app.routers.dashboard.ban_service.ban_trend",
|
||||
@@ -717,16 +666,12 @@ class TestBanTrend:
|
||||
"""``?origin=blocklist`` is passed as a keyword arg to the service."""
|
||||
mock_fn = AsyncMock(return_value=_make_ban_trend_response())
|
||||
with patch("app.routers.dashboard.ban_service.ban_trend", new=mock_fn):
|
||||
await dashboard_client.get(
|
||||
"/api/v1/dashboard/bans/trend?origin=blocklist"
|
||||
)
|
||||
await dashboard_client.get("/api/v1/dashboard/bans/trend?origin=blocklist")
|
||||
|
||||
_, kwargs = mock_fn.call_args
|
||||
assert kwargs.get("origin") == "blocklist"
|
||||
|
||||
async def test_no_origin_defaults_to_none(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_no_origin_defaults_to_none(self, dashboard_client: AsyncClient) -> None:
|
||||
"""Omitting ``origin`` passes ``None`` to the service."""
|
||||
mock_fn = AsyncMock(return_value=_make_ban_trend_response())
|
||||
with patch("app.routers.dashboard.ban_service.ban_trend", new=mock_fn):
|
||||
@@ -735,29 +680,19 @@ class TestBanTrend:
|
||||
_, kwargs = mock_fn.call_args
|
||||
assert kwargs.get("origin") is None
|
||||
|
||||
async def test_invalid_range_returns_422(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
"""An invalid ``range`` value returns HTTP 422."""
|
||||
response = await dashboard_client.get(
|
||||
"/api/v1/dashboard/bans/trend?range=invalid"
|
||||
)
|
||||
assert response.status_code == 422
|
||||
async def test_invalid_range_returns_400(self, dashboard_client: AsyncClient) -> None:
|
||||
"""An invalid ``range`` value returns HTTP 400."""
|
||||
response = await dashboard_client.get("/api/v1/dashboard/bans/trend?range=invalid")
|
||||
assert response.status_code == 400
|
||||
|
||||
async def test_invalid_source_returns_422(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
"""An invalid source value returns HTTP 422."""
|
||||
response = await dashboard_client.get(
|
||||
"/api/v1/dashboard/bans/trend?source=invalid"
|
||||
)
|
||||
assert response.status_code == 422
|
||||
async def test_invalid_source_returns_400(self, dashboard_client: AsyncClient) -> None:
|
||||
"""An invalid source value returns HTTP 400."""
|
||||
response = await dashboard_client.get("/api/v1/dashboard/bans/trend?source=invalid")
|
||||
assert response.status_code == 400
|
||||
|
||||
async def test_empty_buckets_response(self, dashboard_client: AsyncClient) -> None:
|
||||
"""Empty bucket list is serialised correctly."""
|
||||
from app.models.ban import BanTrendResponse
|
||||
|
||||
empty = BanTrendResponse(buckets=[], bucket_size="1h")
|
||||
empty = DomainBanTrend(buckets=[], bucket_size="1h")
|
||||
with patch(
|
||||
"app.routers.dashboard.ban_service.ban_trend",
|
||||
new=AsyncMock(return_value=empty),
|
||||
@@ -774,14 +709,12 @@ class TestBanTrend:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_bans_by_jail_response() -> object:
|
||||
"""Build a stub :class:`~app.models.ban.BansByJailResponse`."""
|
||||
from app.models.ban import BansByJailResponse, JailBanCount
|
||||
|
||||
return BansByJailResponse(
|
||||
def _make_bans_by_jail_response() -> DomainBansByJail:
|
||||
"""Build a stub :class:`~app.models.ban_domain.DomainBansByJail`."""
|
||||
return DomainBansByJail(
|
||||
jails=[
|
||||
JailBanCount(jail="sshd", count=10),
|
||||
JailBanCount(jail="nginx", count=5),
|
||||
DomainJailBanCount(jail="sshd", count=10),
|
||||
DomainJailBanCount(jail="nginx", count=5),
|
||||
],
|
||||
total=15,
|
||||
)
|
||||
@@ -791,9 +724,7 @@ def _make_bans_by_jail_response() -> object:
|
||||
class TestBansByJail:
|
||||
"""GET /api/dashboard/bans/by-jail."""
|
||||
|
||||
async def test_returns_200_when_authenticated(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_returns_200_when_authenticated(self, dashboard_client: AsyncClient) -> None:
|
||||
"""Authenticated request returns HTTP 200."""
|
||||
with patch(
|
||||
"app.routers.dashboard.ban_service.bans_by_jail",
|
||||
@@ -802,9 +733,7 @@ class TestBansByJail:
|
||||
response = await dashboard_client.get("/api/v1/dashboard/bans/by-jail")
|
||||
assert response.status_code == 200
|
||||
|
||||
async def test_returns_401_when_unauthenticated(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
async def test_returns_401_when_unauthenticated(self, client: AsyncClient) -> None:
|
||||
"""Unauthenticated request returns HTTP 401."""
|
||||
await client.post("/api/v1/setup", json=_SETUP_PAYLOAD)
|
||||
response = await client.get("/api/v1/dashboard/bans/by-jail")
|
||||
@@ -823,9 +752,7 @@ class TestBansByJail:
|
||||
assert "total" in body
|
||||
assert isinstance(body["total"], int)
|
||||
|
||||
async def test_each_jail_has_name_and_count(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_each_jail_has_name_and_count(self, dashboard_client: AsyncClient) -> None:
|
||||
"""Every element of ``jails`` has ``jail`` (string) and ``count`` (int)."""
|
||||
with patch(
|
||||
"app.routers.dashboard.ban_service.bans_by_jail",
|
||||
@@ -861,16 +788,12 @@ class TestBansByJail:
|
||||
"""``?origin=blocklist`` is passed as a keyword arg to the service."""
|
||||
mock_fn = AsyncMock(return_value=_make_bans_by_jail_response())
|
||||
with patch("app.routers.dashboard.ban_service.bans_by_jail", new=mock_fn):
|
||||
await dashboard_client.get(
|
||||
"/api/v1/dashboard/bans/by-jail?origin=blocklist"
|
||||
)
|
||||
await dashboard_client.get("/api/v1/dashboard/bans/by-jail?origin=blocklist")
|
||||
|
||||
_, kwargs = mock_fn.call_args
|
||||
assert kwargs.get("origin") == "blocklist"
|
||||
|
||||
async def test_no_origin_defaults_to_none(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_no_origin_defaults_to_none(self, dashboard_client: AsyncClient) -> None:
|
||||
"""Omitting ``origin`` passes ``None`` to the service."""
|
||||
mock_fn = AsyncMock(return_value=_make_bans_by_jail_response())
|
||||
with patch("app.routers.dashboard.ban_service.bans_by_jail", new=mock_fn):
|
||||
@@ -879,23 +802,15 @@ class TestBansByJail:
|
||||
_, kwargs = mock_fn.call_args
|
||||
assert kwargs.get("origin") is None
|
||||
|
||||
async def test_invalid_range_returns_422(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
"""An invalid ``range`` value returns HTTP 422."""
|
||||
response = await dashboard_client.get(
|
||||
"/api/v1/dashboard/bans/by-jail?range=invalid"
|
||||
)
|
||||
assert response.status_code == 422
|
||||
async def test_invalid_range_returns_400(self, dashboard_client: AsyncClient) -> None:
|
||||
"""An invalid ``range`` value returns HTTP 400."""
|
||||
response = await dashboard_client.get("/api/v1/dashboard/bans/by-jail?range=invalid")
|
||||
assert response.status_code == 400
|
||||
|
||||
async def test_invalid_source_returns_422(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
"""An invalid source value returns HTTP 422."""
|
||||
response = await dashboard_client.get(
|
||||
"/api/v1/dashboard/bans/by-jail?source=invalid"
|
||||
)
|
||||
assert response.status_code == 422
|
||||
async def test_invalid_source_returns_400(self, dashboard_client: AsyncClient) -> None:
|
||||
"""An invalid source value returns HTTP 400."""
|
||||
response = await dashboard_client.get("/api/v1/dashboard/bans/by-jail?source=invalid")
|
||||
assert response.status_code == 400
|
||||
|
||||
async def test_empty_jails_response(self, dashboard_client: AsyncClient) -> None:
|
||||
"""Empty jails list is serialised correctly."""
|
||||
@@ -911,4 +826,3 @@ class TestBansByJail:
|
||||
body = response.json()
|
||||
assert body["jails"] == []
|
||||
assert body["total"] == 0
|
||||
|
||||
|
||||
@@ -122,11 +122,17 @@ async def _build_app(settings: Settings):
|
||||
return app, db
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Service dependency injection at router level is not yet implemented.")
|
||||
async def test_auth_login_uses_injected_auth_service(tmp_path: Path) -> None:
|
||||
config_dir = tmp_path / "fail2ban"
|
||||
config_dir.mkdir(parents=True)
|
||||
settings = Settings(
|
||||
database_path=str(tmp_path / "test_bangui.db"),
|
||||
fail2ban_socket="/tmp/fake_fail2ban.sock",
|
||||
fail2ban_config_dir=str(tmp_path / "fail2ban"),
|
||||
fail2ban_config_dir=str(config_dir),
|
||||
session_secret="test-secret-key-do-not-use-in-production",
|
||||
session_duration_minutes=60,
|
||||
timezone="UTC",
|
||||
@@ -134,6 +140,7 @@ async def test_auth_login_uses_injected_auth_service(tmp_path: Path) -> None:
|
||||
)
|
||||
|
||||
app, db = await _build_app(settings)
|
||||
|
||||
def _fake_auth_service() -> FakeAuthService:
|
||||
return FakeAuthService()
|
||||
|
||||
@@ -157,11 +164,14 @@ async def test_auth_login_uses_injected_auth_service(tmp_path: Path) -> None:
|
||||
assert response.cookies.get(SESSION_COOKIE_NAME) is not None
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Service dependency injection at router level is not yet implemented.")
|
||||
async def test_jail_list_uses_injected_jail_service_and_auth(tmp_path: Path) -> None:
|
||||
config_dir = tmp_path / "fail2ban"
|
||||
config_dir.mkdir(parents=True)
|
||||
settings = Settings(
|
||||
database_path=str(tmp_path / "test_bangui.db"),
|
||||
fail2ban_socket="/tmp/fake_fail2ban.sock",
|
||||
fail2ban_config_dir=str(tmp_path / "fail2ban"),
|
||||
fail2ban_config_dir=str(config_dir),
|
||||
session_secret="test-secret-key-do-not-use-in-production",
|
||||
session_duration_minutes=60,
|
||||
timezone="UTC",
|
||||
@@ -169,6 +179,7 @@ async def test_jail_list_uses_injected_jail_service_and_auth(tmp_path: Path) ->
|
||||
)
|
||||
|
||||
app, db = await _build_app(settings)
|
||||
|
||||
def _fake_auth_service() -> FakeAuthService:
|
||||
return FakeAuthService()
|
||||
|
||||
|
||||
@@ -11,6 +11,13 @@ from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from app.config import Settings
|
||||
from app.db import init_db
|
||||
from app.exceptions import (
|
||||
ConfigDirError,
|
||||
ConfigFileExistsError,
|
||||
ConfigFileNameError,
|
||||
ConfigFileNotFoundError,
|
||||
ConfigFileWriteError,
|
||||
)
|
||||
from app.main import create_app
|
||||
from app.models.config import (
|
||||
ActionConfig,
|
||||
@@ -26,20 +33,13 @@ from app.models.file_config import (
|
||||
JailConfigFileContent,
|
||||
JailConfigFilesResponse,
|
||||
)
|
||||
from app.exceptions import (
|
||||
ConfigDirError,
|
||||
ConfigFileExistsError,
|
||||
ConfigFileNameError,
|
||||
ConfigFileNotFoundError,
|
||||
ConfigFileWriteError,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SETUP_PAYLOAD = {
|
||||
"master_password": "testpassword1",
|
||||
"master_password": "Testpassword1!",
|
||||
"database_path": "bangui.db",
|
||||
"fail2ban_socket": "/var/run/fail2ban/fail2ban.sock",
|
||||
"timezone": "UTC",
|
||||
@@ -50,13 +50,17 @@ _SETUP_PAYLOAD = {
|
||||
@pytest.fixture
|
||||
async def file_config_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
|
||||
"""Provide an authenticated ``AsyncClient`` for file_config endpoint tests."""
|
||||
config_dir = tmp_path / "fail2ban"
|
||||
config_dir.mkdir()
|
||||
settings = Settings(
|
||||
database_path=str(tmp_path / "file_config_test.db"),
|
||||
fail2ban_socket="/tmp/fake.sock",
|
||||
session_secret="test-file-config-secret",
|
||||
fail2ban_config_dir=str(config_dir),
|
||||
session_secret="test-file-config-secret-that-is-long-enough!!",
|
||||
session_duration_minutes=60,
|
||||
timezone="UTC",
|
||||
log_level="debug",
|
||||
session_cookie_secure=False,
|
||||
)
|
||||
app = create_app(settings=settings)
|
||||
|
||||
@@ -67,7 +71,7 @@ async def file_config_client(tmp_path: Path) -> AsyncClient: # type: ignore[mis
|
||||
app.state.http_session = MagicMock()
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
async with AsyncClient(transport=transport, base_url="http://test", headers={"X-BanGUI-Request": "1"}) as ac:
|
||||
await ac.post("/api/v1/setup", json=_SETUP_PAYLOAD)
|
||||
login = await ac.post(
|
||||
"/api/v1/auth/login",
|
||||
@@ -108,9 +112,7 @@ def _conf_file_content(name: str = "nginx") -> ConfFileContent:
|
||||
|
||||
|
||||
class TestListJailConfigFiles:
|
||||
async def test_200_returns_file_list(
|
||||
self, file_config_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_200_returns_file_list(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.raw_config_io_service.list_jail_config_files",
|
||||
AsyncMock(return_value=_jail_files_resp()),
|
||||
@@ -122,9 +124,7 @@ class TestListJailConfigFiles:
|
||||
assert data["total"] == 1
|
||||
assert data["files"][0]["filename"] == "sshd.conf"
|
||||
|
||||
async def test_503_on_config_dir_error(
|
||||
self, file_config_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_503_on_config_dir_error(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.raw_config_io_service.list_jail_config_files",
|
||||
AsyncMock(side_effect=ConfigDirError("not found")),
|
||||
@@ -147,9 +147,7 @@ class TestListJailConfigFiles:
|
||||
|
||||
|
||||
class TestGetJailConfigFile:
|
||||
async def test_200_returns_content(
|
||||
self, file_config_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_200_returns_content(self, file_config_client: AsyncClient) -> None:
|
||||
content = JailConfigFileContent(
|
||||
name="sshd",
|
||||
filename="sshd.conf",
|
||||
@@ -174,9 +172,7 @@ class TestGetJailConfigFile:
|
||||
|
||||
assert resp.status_code == 404
|
||||
|
||||
async def test_400_invalid_filename(
|
||||
self, file_config_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_400_invalid_filename(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.raw_config_io_service.get_jail_config_file",
|
||||
AsyncMock(side_effect=ConfigFileNameError("bad name")),
|
||||
@@ -268,7 +264,7 @@ class TestUpdateFilterFile:
|
||||
|
||||
assert resp.status_code == 204
|
||||
|
||||
async def test_400_write_error(self, file_config_client: AsyncClient) -> None:
|
||||
async def test_500_write_error(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.raw_config_io_service.write_filter_file",
|
||||
AsyncMock(side_effect=ConfigFileWriteError("disk full")),
|
||||
@@ -278,7 +274,7 @@ class TestUpdateFilterFile:
|
||||
json={"content": "x"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 400
|
||||
assert resp.status_code == 500
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -342,7 +338,7 @@ class TestListActionFiles:
|
||||
)
|
||||
resp_data = ActionListResponse(actions=[mock_action], total=1)
|
||||
with patch(
|
||||
"app.routers.config.action_config_service.list_actions",
|
||||
"app.routers.action_config.action_config_service.list_actions",
|
||||
AsyncMock(return_value=resp_data),
|
||||
):
|
||||
resp = await file_config_client.get("/api/v1/config/actions")
|
||||
@@ -365,7 +361,7 @@ class TestCreateActionFile:
|
||||
actionban="echo ban <ip>",
|
||||
)
|
||||
with patch(
|
||||
"app.routers.config.action_config_service.create_action",
|
||||
"app.routers.action_config.action_config_service.create_action",
|
||||
AsyncMock(return_value=created),
|
||||
):
|
||||
resp = await file_config_client.post(
|
||||
@@ -404,9 +400,7 @@ class TestGetActionFileRaw:
|
||||
|
||||
assert resp.status_code == 404
|
||||
|
||||
async def test_503_on_config_dir_error(
|
||||
self, file_config_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_503_on_config_dir_error(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.raw_config_io_service.get_action_file",
|
||||
AsyncMock(side_effect=ConfigDirError("no dir")),
|
||||
@@ -436,7 +430,7 @@ class TestUpdateActionFileRaw:
|
||||
|
||||
assert resp.status_code == 204
|
||||
|
||||
async def test_400_write_error(self, file_config_client: AsyncClient) -> None:
|
||||
async def test_500_write_error(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.raw_config_io_service.write_action_file",
|
||||
AsyncMock(side_effect=ConfigFileWriteError("disk full")),
|
||||
@@ -446,7 +440,7 @@ class TestUpdateActionFileRaw:
|
||||
json={"content": "x"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 400
|
||||
assert resp.status_code == 500
|
||||
|
||||
async def test_404_not_found(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
@@ -516,9 +510,7 @@ class TestCreateJailConfigFile:
|
||||
|
||||
assert resp.status_code == 400
|
||||
|
||||
async def test_503_on_config_dir_error(
|
||||
self, file_config_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_503_on_config_dir_error(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.raw_config_io_service.create_jail_config_file",
|
||||
AsyncMock(side_effect=ConfigDirError("no dir")),
|
||||
@@ -537,9 +529,7 @@ class TestCreateJailConfigFile:
|
||||
|
||||
|
||||
class TestGetParsedFilter:
|
||||
async def test_200_returns_parsed_config(
|
||||
self, file_config_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_200_returns_parsed_config(self, file_config_client: AsyncClient) -> None:
|
||||
cfg = FilterConfig(name="nginx", filename="nginx.conf")
|
||||
with patch(
|
||||
"app.routers.file_config.raw_config_io_service.get_parsed_filter_file",
|
||||
@@ -557,15 +547,11 @@ class TestGetParsedFilter:
|
||||
"app.routers.file_config.raw_config_io_service.get_parsed_filter_file",
|
||||
AsyncMock(side_effect=ConfigFileNotFoundError("missing")),
|
||||
):
|
||||
resp = await file_config_client.get(
|
||||
"/api/v1/config/filters/missing/parsed"
|
||||
)
|
||||
resp = await file_config_client.get("/api/v1/config/filters/missing/parsed")
|
||||
|
||||
assert resp.status_code == 404
|
||||
|
||||
async def test_503_on_config_dir_error(
|
||||
self, file_config_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_503_on_config_dir_error(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.raw_config_io_service.get_parsed_filter_file",
|
||||
AsyncMock(side_effect=ConfigDirError("no dir")),
|
||||
@@ -605,17 +591,17 @@ class TestUpdateParsedFilter:
|
||||
|
||||
assert resp.status_code == 404
|
||||
|
||||
async def test_400_write_error(self, file_config_client: AsyncClient) -> None:
|
||||
async def test_500_write_error(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.raw_config_io_service.update_parsed_filter_file",
|
||||
AsyncMock(side_effect=ConfigFileWriteError("disk full")),
|
||||
):
|
||||
resp = await file_config_client.put(
|
||||
"/api/v1/config/filters/nginx/parsed",
|
||||
json={"failregex": ["^<HOST> "]},
|
||||
json={"name": "nginx", "failregex": ["^test$"]},
|
||||
)
|
||||
|
||||
assert resp.status_code == 400
|
||||
assert resp.status_code == 500
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -624,17 +610,13 @@ class TestUpdateParsedFilter:
|
||||
|
||||
|
||||
class TestGetParsedAction:
|
||||
async def test_200_returns_parsed_config(
|
||||
self, file_config_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_200_returns_parsed_config(self, file_config_client: AsyncClient) -> None:
|
||||
cfg = ActionConfig(name="iptables", filename="iptables.conf")
|
||||
with patch(
|
||||
"app.routers.file_config.raw_config_io_service.get_parsed_action_file",
|
||||
AsyncMock(return_value=cfg),
|
||||
):
|
||||
resp = await file_config_client.get(
|
||||
"/api/v1/config/actions/iptables/parsed"
|
||||
)
|
||||
resp = await file_config_client.get("/api/v1/config/actions/iptables/parsed")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
@@ -646,22 +628,16 @@ class TestGetParsedAction:
|
||||
"app.routers.file_config.raw_config_io_service.get_parsed_action_file",
|
||||
AsyncMock(side_effect=ConfigFileNotFoundError("missing")),
|
||||
):
|
||||
resp = await file_config_client.get(
|
||||
"/api/v1/config/actions/missing/parsed"
|
||||
)
|
||||
resp = await file_config_client.get("/api/v1/config/actions/missing/parsed")
|
||||
|
||||
assert resp.status_code == 404
|
||||
|
||||
async def test_503_on_config_dir_error(
|
||||
self, file_config_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_503_on_config_dir_error(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.raw_config_io_service.get_parsed_action_file",
|
||||
AsyncMock(side_effect=ConfigDirError("no dir")),
|
||||
):
|
||||
resp = await file_config_client.get(
|
||||
"/api/v1/config/actions/iptables/parsed"
|
||||
)
|
||||
resp = await file_config_client.get("/api/v1/config/actions/iptables/parsed")
|
||||
|
||||
assert resp.status_code == 503
|
||||
|
||||
@@ -696,7 +672,7 @@ class TestUpdateParsedAction:
|
||||
|
||||
assert resp.status_code == 404
|
||||
|
||||
async def test_400_write_error(self, file_config_client: AsyncClient) -> None:
|
||||
async def test_500_write_error(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.raw_config_io_service.update_parsed_action_file",
|
||||
AsyncMock(side_effect=ConfigFileWriteError("disk full")),
|
||||
@@ -706,7 +682,7 @@ class TestUpdateParsedAction:
|
||||
json={"actionban": "iptables -I INPUT -s <ip> -j DROP"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 400
|
||||
assert resp.status_code == 500
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -715,18 +691,14 @@ class TestUpdateParsedAction:
|
||||
|
||||
|
||||
class TestGetParsedJailFile:
|
||||
async def test_200_returns_parsed_config(
|
||||
self, file_config_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_200_returns_parsed_config(self, file_config_client: AsyncClient) -> None:
|
||||
section = JailSectionConfig(enabled=True, port="ssh")
|
||||
cfg = JailFileConfig(filename="sshd.conf", jails={"sshd": section})
|
||||
with patch(
|
||||
"app.routers.file_config.raw_config_io_service.get_parsed_jail_file",
|
||||
AsyncMock(return_value=cfg),
|
||||
):
|
||||
resp = await file_config_client.get(
|
||||
"/api/v1/config/jail-files/sshd.conf/parsed"
|
||||
)
|
||||
resp = await file_config_client.get("/api/v1/config/jail-files/sshd.conf/parsed")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
@@ -738,22 +710,16 @@ class TestGetParsedJailFile:
|
||||
"app.routers.file_config.raw_config_io_service.get_parsed_jail_file",
|
||||
AsyncMock(side_effect=ConfigFileNotFoundError("missing.conf")),
|
||||
):
|
||||
resp = await file_config_client.get(
|
||||
"/api/v1/config/jail-files/missing.conf/parsed"
|
||||
)
|
||||
resp = await file_config_client.get("/api/v1/config/jail-files/missing.conf/parsed")
|
||||
|
||||
assert resp.status_code == 404
|
||||
|
||||
async def test_503_on_config_dir_error(
|
||||
self, file_config_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_503_on_config_dir_error(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.raw_config_io_service.get_parsed_jail_file",
|
||||
AsyncMock(side_effect=ConfigDirError("no dir")),
|
||||
):
|
||||
resp = await file_config_client.get(
|
||||
"/api/v1/config/jail-files/sshd.conf/parsed"
|
||||
)
|
||||
resp = await file_config_client.get("/api/v1/config/jail-files/sshd.conf/parsed")
|
||||
|
||||
assert resp.status_code == 503
|
||||
|
||||
@@ -788,7 +754,7 @@ class TestUpdateParsedJailFile:
|
||||
|
||||
assert resp.status_code == 404
|
||||
|
||||
async def test_400_write_error(self, file_config_client: AsyncClient) -> None:
|
||||
async def test_500_write_error(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.raw_config_io_service.update_parsed_jail_file",
|
||||
AsyncMock(side_effect=ConfigFileWriteError("disk full")),
|
||||
@@ -798,4 +764,4 @@ class TestUpdateParsedJailFile:
|
||||
json={"jails": {"sshd": {"enabled": True}}},
|
||||
)
|
||||
|
||||
assert resp.status_code == 400
|
||||
assert resp.status_code == 500
|
||||
|
||||
@@ -30,13 +30,17 @@ _SETUP_PAYLOAD = {
|
||||
@pytest.fixture
|
||||
async def geo_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
|
||||
"""Provide an authenticated ``AsyncClient`` for geo endpoint tests."""
|
||||
config_dir = tmp_path / "fail2ban"
|
||||
config_dir.mkdir()
|
||||
settings = Settings(
|
||||
database_path=str(tmp_path / "geo_test.db"),
|
||||
fail2ban_socket="/tmp/fake.sock",
|
||||
session_secret="test-geo-secret",
|
||||
fail2ban_config_dir=str(config_dir),
|
||||
session_secret="test-geo-secret-that-is-long-enough!!",
|
||||
session_duration_minutes=60,
|
||||
timezone="UTC",
|
||||
log_level="debug",
|
||||
session_cookie_secure=False,
|
||||
)
|
||||
app = create_app(settings=settings)
|
||||
|
||||
@@ -48,6 +52,7 @@ async def geo_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
|
||||
|
||||
# Initialize GeoCache (normally done in lifespan handler)
|
||||
from app.services.geo_cache import GeoCache
|
||||
|
||||
app.state.geo_cache = GeoCache()
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
@@ -179,7 +184,10 @@ class TestReResolve:
|
||||
"app.routers.geo.geo_service.re_resolve_all",
|
||||
AsyncMock(return_value={"resolved": 0, "total": 0}),
|
||||
):
|
||||
resp = await geo_client.post("/api/v1/geo/re-resolve")
|
||||
resp = await geo_client.post(
|
||||
"/api/v1/geo/re-resolve",
|
||||
headers={"X-BanGUI-Request": "1"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
@@ -188,7 +196,10 @@ class TestReResolve:
|
||||
|
||||
async def test_empty_when_no_unresolved_ips(self, geo_client: AsyncClient) -> None:
|
||||
"""Returns resolved=0, total=0 when geo_cache has no NULL country_code rows."""
|
||||
resp = await geo_client.post("/api/v1/geo/re-resolve")
|
||||
resp = await geo_client.post(
|
||||
"/api/v1/geo/re-resolve",
|
||||
headers={"X-BanGUI-Request": "1"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == {"resolved": 0, "total": 0}
|
||||
@@ -204,12 +215,16 @@ class TestReResolve:
|
||||
geo_result = {"5.5.5.5": GeoInfo(country_code="FR", country_name="France", asn=None, org=None)}
|
||||
# Patch the default geo_cache instance used by geo_service
|
||||
from app.services.geo_service import _default_geo_cache
|
||||
|
||||
with patch.object(
|
||||
_default_geo_cache,
|
||||
"lookup_batch",
|
||||
new_callable=lambda: AsyncMock(return_value=geo_result),
|
||||
):
|
||||
resp = await geo_client.post("/api/v1/geo/re-resolve")
|
||||
resp = await geo_client.post(
|
||||
"/api/v1/geo/re-resolve",
|
||||
headers={"X-BanGUI-Request": "1"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
|
||||
@@ -14,7 +14,6 @@ from app.db import init_db
|
||||
from app.main import create_app
|
||||
from app.models.history import (
|
||||
HistoryBanItem,
|
||||
HistoryListResponse,
|
||||
IpDetailResponse,
|
||||
IpTimelineEvent,
|
||||
)
|
||||
@@ -48,13 +47,26 @@ def _make_history_item(ip: str = "1.2.3.4", jail: str = "sshd") -> HistoryBanIte
|
||||
)
|
||||
|
||||
|
||||
def _make_history_list(n: int = 2) -> HistoryListResponse:
|
||||
"""Build a mock ``HistoryListResponse`` with *n* items."""
|
||||
from app.utils.pagination import create_pagination_metadata
|
||||
def _make_history_list(n: int = 2):
|
||||
"""Build a mock ``DomainHistoryList`` with *n* items."""
|
||||
from app.models.history_domain import DomainHistoryBanItem, DomainHistoryList
|
||||
|
||||
items = [_make_history_item(ip=f"1.2.3.{i}") for i in range(n)]
|
||||
pagination = create_pagination_metadata(total=n, page=1, page_size=100)
|
||||
return HistoryListResponse(items=items, pagination=pagination)
|
||||
items = [
|
||||
DomainHistoryBanItem(
|
||||
ip=f"1.2.3.{i}",
|
||||
jail="sshd",
|
||||
banned_at="2026-03-01T10:00:00+00:00",
|
||||
ban_count=3,
|
||||
failures=5,
|
||||
matches=["Mar 1 10:00:00 host sshd[123]: Failed password for root"],
|
||||
country_code="DE",
|
||||
country_name="Germany",
|
||||
asn="AS3320",
|
||||
org="Telekom",
|
||||
)
|
||||
for i in range(n)
|
||||
]
|
||||
return DomainHistoryList(items=items, total=n, page=1, page_size=100)
|
||||
|
||||
|
||||
def _make_ip_detail(ip: str = "1.2.3.4") -> IpDetailResponse:
|
||||
@@ -96,13 +108,17 @@ def _make_ip_detail(ip: str = "1.2.3.4") -> IpDetailResponse:
|
||||
@pytest.fixture
|
||||
async def history_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
|
||||
"""Provide an authenticated ``AsyncClient`` for history endpoint tests."""
|
||||
config_dir = tmp_path / "fail2ban"
|
||||
config_dir.mkdir()
|
||||
settings = Settings(
|
||||
database_path=str(tmp_path / "history_test.db"),
|
||||
fail2ban_socket="/tmp/fake_fail2ban.sock",
|
||||
fail2ban_config_dir=str(config_dir),
|
||||
session_secret="test-history-secret-32chars-long!!",
|
||||
session_duration_minutes=60,
|
||||
timezone="UTC",
|
||||
log_level="debug",
|
||||
session_cookie_secure=False,
|
||||
)
|
||||
app = create_app(settings=settings)
|
||||
|
||||
@@ -136,9 +152,7 @@ async def history_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
|
||||
class TestHistoryList:
|
||||
"""GET /api/history — paginated history list."""
|
||||
|
||||
async def test_returns_200_when_authenticated(
|
||||
self, history_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_returns_200_when_authenticated(self, history_client: AsyncClient) -> None:
|
||||
"""Authenticated request returns HTTP 200."""
|
||||
with patch(
|
||||
"app.routers.history.history_service.list_history",
|
||||
@@ -147,9 +161,7 @@ class TestHistoryList:
|
||||
response = await history_client.get("/api/v1/history")
|
||||
assert response.status_code == 200
|
||||
|
||||
async def test_returns_401_when_unauthenticated(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
async def test_returns_401_when_unauthenticated(self, client: AsyncClient) -> None:
|
||||
"""Unauthenticated request returns HTTP 401."""
|
||||
await client.post("/api/v1/setup", json=_SETUP_PAYLOAD)
|
||||
response = await client.get("/api/v1/history")
|
||||
@@ -245,9 +257,7 @@ class TestHistoryList:
|
||||
_args, kwargs = mock_fn.call_args
|
||||
assert kwargs.get("source") == "archive"
|
||||
|
||||
async def test_archive_route_forces_source_archive(
|
||||
self, history_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_archive_route_forces_source_archive(self, history_client: AsyncClient) -> None:
|
||||
"""GET /api/history/archive should call list_history with source='archive'."""
|
||||
mock_fn = AsyncMock(return_value=_make_history_list(n=0))
|
||||
with patch(
|
||||
@@ -261,14 +271,16 @@ class TestHistoryList:
|
||||
|
||||
async def test_empty_result(self, history_client: AsyncClient) -> None:
|
||||
"""An empty history returns items=[] and total=0."""
|
||||
from app.utils.pagination import create_pagination_metadata
|
||||
from app.models.history_domain import DomainHistoryList
|
||||
|
||||
with patch(
|
||||
"app.routers.history.history_service.list_history",
|
||||
new=AsyncMock(
|
||||
return_value=HistoryListResponse(
|
||||
return_value=DomainHistoryList(
|
||||
items=[],
|
||||
pagination=create_pagination_metadata(total=0, page=1, page_size=100),
|
||||
total=0,
|
||||
page=1,
|
||||
page_size=100,
|
||||
)
|
||||
),
|
||||
):
|
||||
@@ -287,9 +299,7 @@ class TestHistoryList:
|
||||
class TestIpHistory:
|
||||
"""GET /api/history/{ip} — per-IP detail."""
|
||||
|
||||
async def test_returns_200_when_authenticated(
|
||||
self, history_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_returns_200_when_authenticated(self, history_client: AsyncClient) -> None:
|
||||
"""Authenticated request returns HTTP 200 for a known IP."""
|
||||
with patch(
|
||||
"app.routers.history.history_service.get_ip_detail",
|
||||
@@ -298,17 +308,13 @@ class TestIpHistory:
|
||||
response = await history_client.get("/api/v1/history/1.2.3.4")
|
||||
assert response.status_code == 200
|
||||
|
||||
async def test_returns_401_when_unauthenticated(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
async def test_returns_401_when_unauthenticated(self, client: AsyncClient) -> None:
|
||||
"""Unauthenticated request returns HTTP 401."""
|
||||
await client.post("/api/v1/setup", json=_SETUP_PAYLOAD)
|
||||
response = await client.get("/api/v1/history/1.2.3.4")
|
||||
assert response.status_code == 401
|
||||
|
||||
async def test_returns_404_for_unknown_ip(
|
||||
self, history_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_returns_404_for_unknown_ip(self, history_client: AsyncClient) -> None:
|
||||
"""Returns 404 when the IP has no records in the database."""
|
||||
with patch(
|
||||
"app.routers.history.history_service.get_ip_detail",
|
||||
@@ -341,9 +347,7 @@ class TestIpHistory:
|
||||
assert "failures" in event
|
||||
assert "matches" in event
|
||||
|
||||
async def test_aggregation_sums_failures(
|
||||
self, history_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_aggregation_sums_failures(self, history_client: AsyncClient) -> None:
|
||||
"""total_failures reflects the sum across all timeline events."""
|
||||
mock_detail = _make_ip_detail("10.0.0.1")
|
||||
mock_detail = IpDetailResponse(
|
||||
|
||||
@@ -12,15 +12,36 @@ from httpx import ASGITransport, AsyncClient
|
||||
from app.config import Settings
|
||||
from app.db import init_db
|
||||
from app.main import create_app
|
||||
from app.models.ban import JailBannedIpsResponse
|
||||
from app.models.jail import Jail, JailDetailResponse, JailListResponse, JailStatus, JailSummary
|
||||
from app.services.geo_cache import GeoCache
|
||||
from app.utils.session_cache import NoOpSessionCache
|
||||
from app.utils.setup_state import set_setup_complete_cache
|
||||
|
||||
|
||||
async def _write_password_hash(db: aiosqlite.Connection, password: str) -> str:
|
||||
"""Hash password and write to settings table."""
|
||||
import asyncio
|
||||
|
||||
import bcrypt
|
||||
|
||||
pw_bytes = password.encode()
|
||||
hashed = await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: bcrypt.hashpw(pw_bytes, bcrypt.gensalt()).decode()
|
||||
)
|
||||
await db.execute(
|
||||
"INSERT OR REPLACE INTO settings (key, value) VALUES (?, ?)",
|
||||
("master_password_hash", hashed),
|
||||
)
|
||||
await db.commit()
|
||||
return hashed
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SETUP_PAYLOAD = {
|
||||
"master_password": "testpassword1",
|
||||
"master_password": "Testpass1!",
|
||||
"database_path": "bangui.db",
|
||||
"fail2ban_socket": "/var/run/fail2ban/fail2ban.sock",
|
||||
"timezone": "UTC",
|
||||
@@ -31,25 +52,41 @@ _SETUP_PAYLOAD = {
|
||||
@pytest.fixture
|
||||
async def jails_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
|
||||
"""Provide an authenticated ``AsyncClient`` for jail endpoint tests."""
|
||||
import os
|
||||
|
||||
os.makedirs(tmp_path / "fail2ban", exist_ok=True)
|
||||
settings = Settings(
|
||||
database_path=str(tmp_path / "jails_test.db"),
|
||||
fail2ban_socket="/tmp/fake.sock",
|
||||
fail2ban_config_dir=str(tmp_path / "fail2ban"),
|
||||
session_secret="test-jails-secret-0000000000000000000000",
|
||||
session_duration_minutes=60,
|
||||
timezone="UTC",
|
||||
log_level="debug",
|
||||
session_cookie_secure=False,
|
||||
)
|
||||
app = create_app(settings=settings)
|
||||
set_setup_complete_cache(app, True)
|
||||
|
||||
db: aiosqlite.Connection = await aiosqlite.connect(settings.database_path)
|
||||
db.row_factory = aiosqlite.Row
|
||||
await init_db(db)
|
||||
await _write_password_hash(db, _SETUP_PAYLOAD["master_password"])
|
||||
app.state.db = db
|
||||
app.state.http_session = MagicMock()
|
||||
app.state.session_cache = NoOpSessionCache()
|
||||
app.state.geo_cache = GeoCache()
|
||||
|
||||
async def _override_get_db():
|
||||
yield db
|
||||
|
||||
from app.dependencies import get_db, get_session_cache
|
||||
|
||||
app.dependency_overrides[get_db] = _override_get_db
|
||||
app.dependency_overrides[get_session_cache] = lambda: NoOpSessionCache()
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
await ac.post("/api/v1/setup", json=_SETUP_PAYLOAD)
|
||||
async with AsyncClient(transport=transport, base_url="http://test", headers={"X-BanGUI-Request": "1"}) as ac:
|
||||
login = await ac.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"password": _SETUP_PAYLOAD["master_password"]},
|
||||
@@ -58,6 +95,7 @@ async def jails_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
|
||||
yield ac
|
||||
|
||||
await db.close()
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -172,9 +210,19 @@ class TestGetJailDetail:
|
||||
|
||||
async def test_200_for_existing_jail(self, jails_client: AsyncClient) -> None:
|
||||
"""GET /api/jails/sshd returns 200 with full jail detail."""
|
||||
with patch(
|
||||
"app.routers.jails.jail_service.get_jail",
|
||||
AsyncMock(return_value=_detail()),
|
||||
with (
|
||||
patch(
|
||||
"app.routers.jails.jail_service.get_jail",
|
||||
AsyncMock(return_value=_detail()),
|
||||
),
|
||||
patch(
|
||||
"app.routers.jails.jail_service.get_ignore_list",
|
||||
AsyncMock(return_value=["127.0.0.1"]),
|
||||
),
|
||||
patch(
|
||||
"app.routers.jails.jail_service.get_ignore_self",
|
||||
AsyncMock(return_value=False),
|
||||
),
|
||||
):
|
||||
resp = await jails_client.get("/api/v1/jails/sshd")
|
||||
|
||||
@@ -808,25 +856,21 @@ class TestGetJailBannedIps:
|
||||
total: int = 2,
|
||||
page: int = 1,
|
||||
page_size: int = 25,
|
||||
) -> JailBannedIpsResponse:
|
||||
from app.models.ban import ActiveBan, JailBannedIpsResponse
|
||||
):
|
||||
from app.models.jail_domain import DomainActiveBan, DomainJailBannedIps
|
||||
|
||||
ban_items = (
|
||||
[
|
||||
ActiveBan(
|
||||
ip=item.get("ip") or "1.2.3.4",
|
||||
jail="sshd",
|
||||
banned_at=item.get("banned_at", "2025-01-01T10:00:00+00:00"),
|
||||
expires_at=item.get("expires_at", "2025-01-01T10:10:00+00:00"),
|
||||
ban_count=1,
|
||||
country=item.get("country", None),
|
||||
)
|
||||
for item in (items or [{"ip": "1.2.3.4"}, {"ip": "5.6.7.8"}])
|
||||
]
|
||||
)
|
||||
return JailBannedIpsResponse(
|
||||
items=ban_items, total=total, page=page, page_size=page_size
|
||||
)
|
||||
ban_items = [
|
||||
DomainActiveBan(
|
||||
ip=item.get("ip") or "1.2.3.4",
|
||||
jail="sshd",
|
||||
banned_at=item.get("banned_at", "2025-01-01T10:00:00+00:00"),
|
||||
expires_at=item.get("expires_at", "2025-01-01T10:10:00+00:00"),
|
||||
ban_count=1,
|
||||
country=item.get("country", None),
|
||||
)
|
||||
for item in (items or [{"ip": "1.2.3.4"}, {"ip": "5.6.7.8"}])
|
||||
]
|
||||
return DomainJailBannedIps(items=ban_items, total=total, page=page, page_size=page_size)
|
||||
|
||||
async def test_200_returns_paginated_bans(self, jails_client: AsyncClient) -> None:
|
||||
"""GET /api/jails/sshd/banned returns 200 with a JailBannedIpsResponse."""
|
||||
@@ -839,10 +883,10 @@ class TestGetJailBannedIps:
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "items" in data
|
||||
assert "total" in data
|
||||
assert "page" in data
|
||||
assert "page_size" in data
|
||||
assert data["total"] == 2
|
||||
assert "pagination" in data
|
||||
assert data["pagination"]["total"] == 2
|
||||
assert data["pagination"]["page"] == 1
|
||||
assert data["pagination"]["page_size"] == 25
|
||||
|
||||
async def test_200_with_search_parameter(self, jails_client: AsyncClient) -> None:
|
||||
"""GET /api/jails/sshd/banned?search=1.2.3 passes search to service."""
|
||||
@@ -856,9 +900,7 @@ class TestGetJailBannedIps:
|
||||
|
||||
async def test_200_with_page_and_page_size(self, jails_client: AsyncClient) -> None:
|
||||
"""GET /api/jails/sshd/banned?page=2&page_size=10 passes params to service."""
|
||||
mock_fn = AsyncMock(
|
||||
return_value=self._mock_response(page=2, page_size=10, total=0, items=[])
|
||||
)
|
||||
mock_fn = AsyncMock(return_value=self._mock_response(page=2, page_size=10, total=0, items=[]))
|
||||
with patch("app.routers.jails.jail_service.get_jail_banned_ips", mock_fn):
|
||||
resp = await jails_client.get("/api/v1/jails/sshd/banned?page=2&page_size=10")
|
||||
|
||||
@@ -900,17 +942,13 @@ class TestGetJailBannedIps:
|
||||
|
||||
with patch(
|
||||
"app.routers.jails.jail_service.get_jail_banned_ips",
|
||||
AsyncMock(
|
||||
side_effect=Fail2BanConnectionError("socket dead", "/tmp/fake.sock")
|
||||
),
|
||||
AsyncMock(side_effect=Fail2BanConnectionError("socket dead", "/tmp/fake.sock")),
|
||||
):
|
||||
resp = await jails_client.get("/api/v1/jails/sshd/banned")
|
||||
|
||||
assert resp.status_code == 502
|
||||
|
||||
async def test_response_items_have_expected_fields(
|
||||
self, jails_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_response_items_have_expected_fields(self, jails_client: AsyncClient) -> None:
|
||||
"""Response items contain ip, jail, banned_at, expires_at, ban_count, country."""
|
||||
with patch(
|
||||
"app.routers.jails.jail_service.get_jail_banned_ips",
|
||||
@@ -933,4 +971,3 @@ class TestGetJailBannedIps:
|
||||
base_url="http://test",
|
||||
).get("/api/v1/jails/sshd/banned")
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
@@ -13,13 +13,16 @@ from app.config import Settings
|
||||
from app.db import init_db
|
||||
from app.main import create_app
|
||||
from app.models.server import ServerSettings, ServerSettingsResponse
|
||||
from app.services.geo_cache import GeoCache
|
||||
from app.utils.session_cache import NoOpSessionCache
|
||||
from app.utils.setup_state import set_setup_complete_cache
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SETUP_PAYLOAD = {
|
||||
"master_password": "testpassword1",
|
||||
"master_password": "Testpass1!",
|
||||
"database_path": "bangui.db",
|
||||
"fail2ban_socket": "/var/run/fail2ban/fail2ban.sock",
|
||||
"timezone": "UTC",
|
||||
@@ -27,28 +30,62 @@ _SETUP_PAYLOAD = {
|
||||
}
|
||||
|
||||
|
||||
async def _write_password_hash(db: aiosqlite.Connection, password: str) -> str:
|
||||
"""Hash password and write to settings table."""
|
||||
import asyncio
|
||||
|
||||
import bcrypt
|
||||
|
||||
pw_bytes = password.encode()
|
||||
hashed = await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: bcrypt.hashpw(pw_bytes, bcrypt.gensalt()).decode()
|
||||
)
|
||||
await db.execute(
|
||||
"INSERT OR REPLACE INTO settings (key, value) VALUES (?, ?)",
|
||||
("master_password_hash", hashed),
|
||||
)
|
||||
await db.commit()
|
||||
return hashed
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def server_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
|
||||
"""Provide an authenticated ``AsyncClient`` for server endpoint tests."""
|
||||
import os
|
||||
|
||||
os.makedirs(tmp_path / "fail2ban", exist_ok=True)
|
||||
settings = Settings(
|
||||
database_path=str(tmp_path / "server_test.db"),
|
||||
fail2ban_socket="/tmp/fake.sock",
|
||||
session_secret="test-server-secret",
|
||||
fail2ban_config_dir=str(tmp_path / "fail2ban"),
|
||||
session_secret="test-server-secret-0000000000000000000000",
|
||||
session_duration_minutes=60,
|
||||
timezone="UTC",
|
||||
log_level="debug",
|
||||
session_cookie_secure=False,
|
||||
)
|
||||
app = create_app(settings=settings)
|
||||
set_setup_complete_cache(app, True)
|
||||
|
||||
db: aiosqlite.Connection = await aiosqlite.connect(settings.database_path)
|
||||
db.row_factory = aiosqlite.Row
|
||||
await init_db(db)
|
||||
await _write_password_hash(db, _SETUP_PAYLOAD["master_password"])
|
||||
app.state.db = db
|
||||
app.state.http_session = MagicMock()
|
||||
app.state.session_cache = NoOpSessionCache()
|
||||
app.state.geo_cache = GeoCache()
|
||||
|
||||
async def _override_get_db():
|
||||
yield db
|
||||
|
||||
from app.dependencies import get_db, get_session_cache
|
||||
|
||||
app.dependency_overrides[get_db] = _override_get_db
|
||||
app.dependency_overrides[get_session_cache] = lambda: NoOpSessionCache()
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
await ac.post("/api/v1/setup", json=_SETUP_PAYLOAD)
|
||||
async with AsyncClient(transport=transport, base_url="http://test", headers={"X-BanGUI-Request": "1"}) as ac:
|
||||
login = await ac.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"password": _SETUP_PAYLOAD["master_password"]},
|
||||
@@ -57,6 +94,7 @@ async def server_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
|
||||
yield ac
|
||||
|
||||
await db.close()
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def _make_settings() -> ServerSettingsResponse:
|
||||
|
||||
@@ -99,6 +99,9 @@ def test_security_headers_on_all_response_types() -> None:
|
||||
)
|
||||
|
||||
app = create_app(settings=settings)
|
||||
from app.models.server import ServerStatus
|
||||
|
||||
app.state.server_status = ServerStatus(online=True)
|
||||
client = TestClient(app)
|
||||
|
||||
# Test on successful response
|
||||
|
||||
@@ -81,7 +81,7 @@ class TestLogin:
|
||||
self, db: aiosqlite.Connection
|
||||
) -> None:
|
||||
"""login() returns a signed token and expiry on the correct password."""
|
||||
signed_token, expires_at = await auth_service.login(
|
||||
signed_token, expires_at, _ = await auth_service.login(
|
||||
db,
|
||||
password="correctpassword1",
|
||||
session_duration_minutes=60,
|
||||
@@ -119,7 +119,7 @@ class TestLogin:
|
||||
"""login() stores the session in the database."""
|
||||
from app.repositories import session_repo
|
||||
|
||||
signed_token, _ = await auth_service.login(
|
||||
signed_token, _, _ = await auth_service.login(
|
||||
db,
|
||||
password="correctpassword1",
|
||||
session_duration_minutes=60,
|
||||
@@ -136,7 +136,7 @@ class TestValidateSession:
|
||||
self, db: aiosqlite.Connection
|
||||
) -> None:
|
||||
"""validate_session() returns the session for a valid token."""
|
||||
signed_token, _ = await auth_service.login(
|
||||
signed_token, _, _ = await auth_service.login(
|
||||
db,
|
||||
password="correctpassword1",
|
||||
session_duration_minutes=60,
|
||||
@@ -150,7 +150,7 @@ class TestValidateSession:
|
||||
self, db: aiosqlite.Connection
|
||||
) -> None:
|
||||
"""validate_session() accepts a token signed with the configured secret."""
|
||||
signed_token, _ = await auth_service.login(
|
||||
signed_token, _, _ = await auth_service.login(
|
||||
db,
|
||||
password="correctpassword1",
|
||||
session_duration_minutes=60,
|
||||
@@ -166,7 +166,7 @@ class TestValidateSession:
|
||||
self, db: aiosqlite.Connection
|
||||
) -> None:
|
||||
"""validate_session() rejects signed tokens with an invalid signature."""
|
||||
signed_token, _ = await auth_service.login(
|
||||
signed_token, _, _ = await auth_service.login(
|
||||
db,
|
||||
password="correctpassword1",
|
||||
session_duration_minutes=60,
|
||||
@@ -213,7 +213,7 @@ class TestLogout:
|
||||
"""logout() deletes the session so it can no longer be validated."""
|
||||
from app.repositories import session_repo
|
||||
|
||||
signed_token, _ = await auth_service.login(
|
||||
signed_token, _, _ = await auth_service.login(
|
||||
db,
|
||||
password="correctpassword1",
|
||||
session_duration_minutes=60,
|
||||
@@ -228,7 +228,7 @@ class TestLogout:
|
||||
"""logout() accepts a signed token and revokes the underlying raw session."""
|
||||
from app.repositories import session_repo
|
||||
|
||||
signed_token, _ = await auth_service.login(
|
||||
signed_token, _, _ = await auth_service.login(
|
||||
db,
|
||||
password="correctpassword1",
|
||||
session_duration_minutes=60,
|
||||
@@ -248,7 +248,7 @@ class TestSecretRotation:
|
||||
self, db: aiosqlite.Connection
|
||||
) -> None:
|
||||
"""Tokens signed with current secret are validated immediately."""
|
||||
signed_token, _ = await auth_service.login(
|
||||
signed_token, _, _ = await auth_service.login(
|
||||
db,
|
||||
password="correctpassword1",
|
||||
session_duration_minutes=60,
|
||||
@@ -264,7 +264,7 @@ class TestSecretRotation:
|
||||
self, db: aiosqlite.Connection
|
||||
) -> None:
|
||||
"""Tokens signed with previous secret are accepted during rotation."""
|
||||
signed_token, _ = await auth_service.login(
|
||||
signed_token, _, _ = await auth_service.login(
|
||||
db,
|
||||
password="correctpassword1",
|
||||
session_duration_minutes=60,
|
||||
@@ -280,7 +280,7 @@ class TestSecretRotation:
|
||||
self, db: aiosqlite.Connection
|
||||
) -> None:
|
||||
"""Tokens signed with unknown secrets are rejected."""
|
||||
signed_token, _ = await auth_service.login(
|
||||
signed_token, _, _ = await auth_service.login(
|
||||
db,
|
||||
password="correctpassword1",
|
||||
session_duration_minutes=60,
|
||||
@@ -308,7 +308,7 @@ class TestSecretRotation:
|
||||
self, db: aiosqlite.Connection
|
||||
) -> None:
|
||||
"""During rotation, tokens signed with previous secret are re-signed."""
|
||||
signed_token, _ = await auth_service.login(
|
||||
signed_token, _, _ = await auth_service.login(
|
||||
db,
|
||||
password="correctpassword1",
|
||||
session_duration_minutes=60,
|
||||
@@ -327,7 +327,7 @@ class TestSecretRotation:
|
||||
self, db: aiosqlite.Connection
|
||||
) -> None:
|
||||
"""Validation processes token rotation during validation."""
|
||||
signed_token, _ = await auth_service.login(
|
||||
signed_token, _, _ = await auth_service.login(
|
||||
db,
|
||||
password="correctpassword1",
|
||||
session_duration_minutes=60,
|
||||
@@ -348,7 +348,7 @@ class TestSecretRotation:
|
||||
"""logout() accepts tokens signed with the previous secret."""
|
||||
from app.repositories import session_repo
|
||||
|
||||
signed_token, _ = await auth_service.login(
|
||||
signed_token, _, _ = await auth_service.login(
|
||||
db,
|
||||
password="correctpassword1",
|
||||
session_duration_minutes=60,
|
||||
@@ -368,7 +368,7 @@ class TestSecretRotation:
|
||||
self, db: aiosqlite.Connection
|
||||
) -> None:
|
||||
"""If no previous secret is configured, old tokens are rejected."""
|
||||
signed_token, _ = await auth_service.login(
|
||||
signed_token, _, _ = await auth_service.login(
|
||||
db,
|
||||
password="correctpassword1",
|
||||
session_duration_minutes=60,
|
||||
|
||||
@@ -32,12 +32,7 @@ async def _create_f2b_db(path: str, rows: list[dict[str, Any]]) -> None:
|
||||
``bantime``, ``bancount``, and optionally ``data``.
|
||||
"""
|
||||
async with aiosqlite.connect(path) as db:
|
||||
await db.execute(
|
||||
"CREATE TABLE jails ("
|
||||
"name TEXT NOT NULL UNIQUE, "
|
||||
"enabled INTEGER NOT NULL DEFAULT 1"
|
||||
")"
|
||||
)
|
||||
await db.execute("CREATE TABLE jails (name TEXT NOT NULL UNIQUE, enabled INTEGER NOT NULL DEFAULT 1)")
|
||||
await db.execute(
|
||||
"CREATE TABLE bans ("
|
||||
"jail TEXT NOT NULL, "
|
||||
@@ -50,8 +45,7 @@ async def _create_f2b_db(path: str, rows: list[dict[str, Any]]) -> None:
|
||||
)
|
||||
for row in rows:
|
||||
await db.execute(
|
||||
"INSERT INTO bans (jail, ip, timeofban, bantime, bancount, data) "
|
||||
"VALUES (?, ?, ?, ?, ?, ?)",
|
||||
"INSERT INTO bans (jail, ip, timeofban, bantime, bancount, data) VALUES (?, ?, ?, ?, ?, ?)",
|
||||
(
|
||||
row["jail"],
|
||||
row["ip"],
|
||||
@@ -257,9 +251,7 @@ class TestListBansHappyPath:
|
||||
|
||||
assert result.total == 3
|
||||
|
||||
async def test_source_archive_reads_from_archive(
|
||||
self, app_db_with_archive: aiosqlite.Connection
|
||||
) -> None:
|
||||
async def test_source_archive_reads_from_archive(self, app_db_with_archive: aiosqlite.Connection) -> None:
|
||||
"""Using source='archive' reads from the BanGUI archive table."""
|
||||
result = await ban_service.list_bans(
|
||||
"/fake/sock",
|
||||
@@ -280,9 +272,7 @@ class TestListBansHappyPath:
|
||||
class TestListBansGeoEnrichment:
|
||||
"""Verify geo enrichment integration in ban_service.list_bans()."""
|
||||
|
||||
async def test_geo_data_applied_when_enricher_provided(
|
||||
self, f2b_db_path: str
|
||||
) -> None:
|
||||
async def test_geo_data_applied_when_enricher_provided(self, f2b_db_path: str) -> None:
|
||||
"""Geo fields are populated when an enricher returns data."""
|
||||
from app.models.geo import GeoInfo
|
||||
|
||||
@@ -298,30 +288,24 @@ class TestListBansGeoEnrichment:
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans(
|
||||
"/fake/sock", "24h", geo_enricher=fake_enricher
|
||||
)
|
||||
result = await ban_service.list_bans("/fake/sock", "24h", geo_enricher=fake_enricher)
|
||||
|
||||
for item in result.items:
|
||||
assert item.country_code == "DE"
|
||||
assert item.country_name == "Germany"
|
||||
assert item.asn == "AS3320"
|
||||
|
||||
async def test_geo_failure_does_not_break_results(
|
||||
self, f2b_db_path: str
|
||||
) -> None:
|
||||
async def test_geo_failure_does_not_break_results(self, f2b_db_path: str) -> None:
|
||||
"""A geo enricher that raises still returns ban items (geo fields null)."""
|
||||
|
||||
async def failing_enricher(ip: str) -> None:
|
||||
raise RuntimeError("geo service down")
|
||||
raise OSError("geo service down")
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans(
|
||||
"/fake/sock", "24h", geo_enricher=failing_enricher
|
||||
)
|
||||
result = await ban_service.list_bans("/fake/sock", "24h", geo_enricher=failing_enricher)
|
||||
|
||||
assert result.total == 2
|
||||
for item in result.items:
|
||||
@@ -336,9 +320,7 @@ class TestListBansGeoEnrichment:
|
||||
class TestListBansBatchGeoEnrichment:
|
||||
"""Verify that list_bans uses lookup_batch when http_session is provided."""
|
||||
|
||||
async def test_batch_geo_applied_via_http_session(
|
||||
self, f2b_db_path: str
|
||||
) -> None:
|
||||
async def test_batch_geo_applied_via_http_session(self, f2b_db_path: str) -> None:
|
||||
"""Geo fields are populated via lookup_batch when http_session is given."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
@@ -350,6 +332,8 @@ class TestListBansBatchGeoEnrichment:
|
||||
"5.6.7.8": GeoInfo(country_code="US", country_name="United States", asn="AS15169", org="Google"),
|
||||
}
|
||||
fake_geo_batch = AsyncMock(return_value=fake_geo_map)
|
||||
mock_geo_cache = MagicMock()
|
||||
mock_geo_cache.lookup_batch = fake_geo_batch
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
@@ -359,7 +343,7 @@ class TestListBansBatchGeoEnrichment:
|
||||
"/fake/sock",
|
||||
"24h",
|
||||
http_session=fake_session,
|
||||
geo_batch_lookup=fake_geo_batch,
|
||||
geo_cache=mock_geo_cache,
|
||||
)
|
||||
|
||||
fake_geo_batch.assert_awaited_once_with(["1.2.3.4", "5.6.7.8"], fake_session, db=None)
|
||||
@@ -371,15 +355,15 @@ class TestListBansBatchGeoEnrichment:
|
||||
assert us_item.country_code == "US"
|
||||
assert us_item.country_name == "United States"
|
||||
|
||||
async def test_batch_failure_does_not_break_results(
|
||||
self, f2b_db_path: str
|
||||
) -> None:
|
||||
async def test_batch_failure_does_not_break_results(self, f2b_db_path: str) -> None:
|
||||
"""A lookup_batch failure still returns items with null geo fields."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
fake_session = MagicMock()
|
||||
|
||||
failing_geo_batch = AsyncMock(side_effect=RuntimeError("batch geo down"))
|
||||
failing_geo_batch = AsyncMock(side_effect=OSError("batch geo down"))
|
||||
mock_geo_cache = MagicMock()
|
||||
mock_geo_cache.lookup_batch = failing_geo_batch
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
@@ -389,16 +373,14 @@ class TestListBansBatchGeoEnrichment:
|
||||
"/fake/sock",
|
||||
"24h",
|
||||
http_session=fake_session,
|
||||
geo_batch_lookup=failing_geo_batch,
|
||||
geo_cache=mock_geo_cache,
|
||||
)
|
||||
|
||||
assert result.total == 2
|
||||
for item in result.items:
|
||||
assert item.country_code is None
|
||||
|
||||
async def test_http_session_takes_priority_over_geo_enricher(
|
||||
self, f2b_db_path: str
|
||||
) -> None:
|
||||
async def test_http_session_takes_priority_over_geo_enricher(self, f2b_db_path: str) -> None:
|
||||
"""When both http_session and geo_enricher are provided, batch wins."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
@@ -410,6 +392,8 @@ class TestListBansBatchGeoEnrichment:
|
||||
"5.6.7.8": GeoInfo(country_code="DE", country_name="Germany", asn=None, org=None),
|
||||
}
|
||||
fake_geo_batch = AsyncMock(return_value=fake_geo_map)
|
||||
mock_geo_cache = MagicMock()
|
||||
mock_geo_cache.lookup_batch = fake_geo_batch
|
||||
|
||||
async def enricher_should_not_be_called(ip: str) -> GeoInfo:
|
||||
raise AssertionError(f"geo_enricher was called for {ip!r} — should not happen")
|
||||
@@ -422,7 +406,7 @@ class TestListBansBatchGeoEnrichment:
|
||||
"/fake/sock",
|
||||
"24h",
|
||||
http_session=fake_session,
|
||||
geo_batch_lookup=fake_geo_batch,
|
||||
geo_cache=mock_geo_cache,
|
||||
geo_enricher=enricher_should_not_be_called,
|
||||
)
|
||||
|
||||
@@ -462,9 +446,7 @@ class TestListBansPagination:
|
||||
# Different IPs should appear on different pages.
|
||||
assert page1.items[0].ip != page2.items[0].ip
|
||||
|
||||
async def test_total_reflects_full_count_not_page_count(
|
||||
self, f2b_db_path: str
|
||||
) -> None:
|
||||
async def test_total_reflects_full_count_not_page_count(self, f2b_db_path: str) -> None:
|
||||
"""``total`` reports all matching records regardless of pagination."""
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
@@ -483,9 +465,7 @@ class TestListBansPagination:
|
||||
class TestBanOriginDerivation:
|
||||
"""Verify that ban_service correctly derives ``origin`` from jail names."""
|
||||
|
||||
async def test_blocklist_import_jail_yields_blocklist_origin(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
async def test_blocklist_import_jail_yields_blocklist_origin(self, mixed_origin_db_path: str) -> None:
|
||||
"""Bans from ``blocklist-import`` jail carry ``origin == "blocklist"``."""
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
@@ -497,9 +477,7 @@ class TestBanOriginDerivation:
|
||||
assert len(blocklist_items) == 1
|
||||
assert blocklist_items[0].origin == "blocklist"
|
||||
|
||||
async def test_organic_jail_yields_selfblock_origin(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
async def test_organic_jail_yields_selfblock_origin(self, mixed_origin_db_path: str) -> None:
|
||||
"""Bans from organic jails (sshd, nginx, …) carry ``origin == "selfblock"``."""
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
@@ -512,9 +490,7 @@ class TestBanOriginDerivation:
|
||||
for item in organic_items:
|
||||
assert item.origin == "selfblock"
|
||||
|
||||
async def test_all_items_carry_origin_field(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
async def test_all_items_carry_origin_field(self, mixed_origin_db_path: str) -> None:
|
||||
"""Every returned item has an ``origin`` field with a valid value."""
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
@@ -525,9 +501,7 @@ class TestBanOriginDerivation:
|
||||
for item in result.items:
|
||||
assert item.origin in ("blocklist", "selfblock")
|
||||
|
||||
async def test_bans_by_country_blocklist_origin(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
async def test_bans_by_country_blocklist_origin(self, mixed_origin_db_path: str) -> None:
|
||||
"""``bans_by_country`` also derives origin correctly for blocklist bans."""
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
@@ -535,13 +509,11 @@ class TestBanOriginDerivation:
|
||||
):
|
||||
result = await ban_service.bans_by_country("/fake/sock", "24h")
|
||||
|
||||
blocklist_bans = [b for b in result.bans if b.jail == "blocklist-import"]
|
||||
blocklist_bans = [b for b in result.items if b.jail == "blocklist-import"]
|
||||
assert len(blocklist_bans) == 1
|
||||
assert blocklist_bans[0].origin == "blocklist"
|
||||
|
||||
async def test_bans_by_country_selfblock_origin(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
async def test_bans_by_country_selfblock_origin(self, mixed_origin_db_path: str) -> None:
|
||||
"""``bans_by_country`` derives origin correctly for organic jails."""
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
@@ -549,7 +521,7 @@ class TestBanOriginDerivation:
|
||||
):
|
||||
result = await ban_service.bans_by_country("/fake/sock", "24h")
|
||||
|
||||
organic_bans = [b for b in result.bans if b.jail != "blocklist-import"]
|
||||
organic_bans = [b for b in result.items if b.jail != "blocklist-import"]
|
||||
assert len(organic_bans) == 2
|
||||
for ban in organic_bans:
|
||||
assert ban.origin == "selfblock"
|
||||
@@ -563,34 +535,26 @@ class TestBanOriginDerivation:
|
||||
class TestOriginFilter:
|
||||
"""Verify that the origin filter correctly restricts results."""
|
||||
|
||||
async def test_list_bans_blocklist_filter_returns_only_blocklist(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
async def test_list_bans_blocklist_filter_returns_only_blocklist(self, mixed_origin_db_path: str) -> None:
|
||||
"""``origin='blocklist'`` returns only blocklist-import jail bans."""
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans(
|
||||
"/fake/sock", "24h", origin="blocklist"
|
||||
)
|
||||
result = await ban_service.list_bans("/fake/sock", "24h", origin="blocklist")
|
||||
|
||||
assert result.total == 1
|
||||
assert len(result.items) == 1
|
||||
assert result.items[0].jail == "blocklist-import"
|
||||
assert result.items[0].origin == "blocklist"
|
||||
|
||||
async def test_list_bans_selfblock_filter_excludes_blocklist(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
async def test_list_bans_selfblock_filter_excludes_blocklist(self, mixed_origin_db_path: str) -> None:
|
||||
"""``origin='selfblock'`` excludes the blocklist-import jail."""
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans(
|
||||
"/fake/sock", "24h", origin="selfblock"
|
||||
)
|
||||
result = await ban_service.list_bans("/fake/sock", "24h", origin="selfblock")
|
||||
|
||||
assert result.total == 2
|
||||
assert len(result.items) == 2
|
||||
@@ -598,9 +562,7 @@ class TestOriginFilter:
|
||||
assert item.jail != "blocklist-import"
|
||||
assert item.origin == "selfblock"
|
||||
|
||||
async def test_list_bans_no_filter_returns_all(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
async def test_list_bans_no_filter_returns_all(self, mixed_origin_db_path: str) -> None:
|
||||
"""``origin=None`` applies no jail restriction — all bans returned."""
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
@@ -610,53 +572,39 @@ class TestOriginFilter:
|
||||
|
||||
assert result.total == 3
|
||||
|
||||
async def test_bans_by_country_blocklist_filter(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
async def test_bans_by_country_blocklist_filter(self, mixed_origin_db_path: str) -> None:
|
||||
"""``bans_by_country`` with ``origin='blocklist'`` counts only blocklist bans."""
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_country(
|
||||
"/fake/sock", "24h", origin="blocklist"
|
||||
)
|
||||
result = await ban_service.bans_by_country("/fake/sock", "24h", origin="blocklist")
|
||||
|
||||
assert result.total == 1
|
||||
assert all(b.jail == "blocklist-import" for b in result.bans)
|
||||
assert all(b.jail == "blocklist-import" for b in result.items)
|
||||
|
||||
async def test_bans_by_country_selfblock_filter(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
async def test_bans_by_country_selfblock_filter(self, mixed_origin_db_path: str) -> None:
|
||||
"""``bans_by_country`` with ``origin='selfblock'`` excludes blocklist jails."""
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_country(
|
||||
"/fake/sock", "24h", origin="selfblock"
|
||||
)
|
||||
result = await ban_service.bans_by_country("/fake/sock", "24h", origin="selfblock")
|
||||
|
||||
assert result.total == 2
|
||||
assert all(b.jail != "blocklist-import" for b in result.bans)
|
||||
assert all(b.jail != "blocklist-import" for b in result.items)
|
||||
|
||||
async def test_bans_by_country_no_filter_returns_all(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
async def test_bans_by_country_no_filter_returns_all(self, mixed_origin_db_path: str) -> None:
|
||||
"""``bans_by_country`` with ``origin=None`` returns all bans."""
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_country(
|
||||
"/fake/sock", "24h", origin=None
|
||||
)
|
||||
result = await ban_service.bans_by_country("/fake/sock", "24h", origin=None)
|
||||
|
||||
assert result.total == 3
|
||||
|
||||
async def test_bans_by_country_country_code_returns_all_matched_rows(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
async def test_bans_by_country_country_code_returns_all_matched_rows(self, tmp_path: Path) -> None:
|
||||
"""``bans_by_country`` returns all companion rows for the selected country."""
|
||||
path = str(tmp_path / "fail2ban_country_filter.sqlite3")
|
||||
rows = [
|
||||
@@ -672,8 +620,8 @@ class TestOriginFilter:
|
||||
]
|
||||
await _create_f2b_db(path, rows)
|
||||
|
||||
from app.services import geo_service
|
||||
from app.models.geo import GeoInfo
|
||||
from app.services import geo_service
|
||||
|
||||
geo_service._default_geo_cache._cache["10.0.0.1"] = GeoInfo(
|
||||
country_code="DE",
|
||||
@@ -682,12 +630,13 @@ class TestOriginFilter:
|
||||
org=None,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=path),
|
||||
), patch(
|
||||
"app.services.ban_service.asyncio.create_task"
|
||||
) as mock_create_task:
|
||||
with (
|
||||
patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=path),
|
||||
),
|
||||
patch("app.services.ban_service.asyncio.create_task") as mock_create_task,
|
||||
):
|
||||
result = await ban_service.bans_by_country(
|
||||
"/fake/sock",
|
||||
"24h",
|
||||
@@ -698,8 +647,8 @@ class TestOriginFilter:
|
||||
|
||||
mock_create_task.assert_not_called()
|
||||
assert result.total == 205
|
||||
assert len(result.bans) == 205
|
||||
assert all(b.country_code == "DE" for b in result.bans)
|
||||
assert len(result.items) == 205
|
||||
assert all(b.country_code == "DE" for b in result.items)
|
||||
|
||||
await geo_service.clear_cache()
|
||||
|
||||
@@ -715,7 +664,7 @@ class TestOriginFilter:
|
||||
)
|
||||
|
||||
assert result.total == 2
|
||||
assert len(result.bans) == 2
|
||||
assert len(result.items) == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -728,13 +677,11 @@ class TestBansbyCountryBackground:
|
||||
"""bans_by_country() with http_session uses cache-only geo and fires a
|
||||
background task for uncached IPs instead of blocking on API calls."""
|
||||
|
||||
async def test_cached_geo_returned_without_api_call(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
async def test_cached_geo_returned_without_api_call(self, mixed_origin_db_path: str) -> None:
|
||||
"""When all IPs are in the cache, lookup_cached_only returns them and
|
||||
no background task is created."""
|
||||
from app.services import geo_service
|
||||
from app.models.geo import GeoInfo
|
||||
from app.services import geo_service
|
||||
|
||||
# Pre-populate the cache for all three IPs in the fixture.
|
||||
geo_service._default_geo_cache._cache["10.0.0.1"] = GeoInfo(
|
||||
@@ -752,9 +699,7 @@ class TestBansbyCountryBackground:
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
),
|
||||
patch(
|
||||
"app.services.ban_service.asyncio.create_task"
|
||||
) as mock_create_task,
|
||||
patch("app.services.ban_service.asyncio.create_task") as mock_create_task,
|
||||
):
|
||||
mock_session = AsyncMock()
|
||||
mock_batch = AsyncMock(return_value={})
|
||||
@@ -763,7 +708,6 @@ class TestBansbyCountryBackground:
|
||||
"24h",
|
||||
http_session=mock_session,
|
||||
geo_cache_lookup=geo_service.lookup_cached_only,
|
||||
geo_batch_lookup=mock_batch,
|
||||
)
|
||||
|
||||
# All countries resolved from cache — no background task needed.
|
||||
@@ -773,9 +717,7 @@ class TestBansbyCountryBackground:
|
||||
assert "DE" in result.countries or "US" in result.countries or "JP" in result.countries
|
||||
await geo_service.clear_cache()
|
||||
|
||||
async def test_uncached_ips_trigger_background_task(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
async def test_uncached_ips_trigger_background_task(self, mixed_origin_db_path: str) -> None:
|
||||
"""When IPs are NOT in the cache, create_task is called for background
|
||||
resolution and the response returns without blocking."""
|
||||
from app.services import geo_service
|
||||
@@ -787,9 +729,7 @@ class TestBansbyCountryBackground:
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
),
|
||||
patch(
|
||||
"app.services.ban_service.asyncio.create_task"
|
||||
) as mock_create_task,
|
||||
patch("app.services.ban_service.asyncio.create_task") as mock_create_task,
|
||||
):
|
||||
mock_session = AsyncMock()
|
||||
mock_batch = AsyncMock(return_value={})
|
||||
@@ -798,7 +738,7 @@ class TestBansbyCountryBackground:
|
||||
"24h",
|
||||
http_session=mock_session,
|
||||
geo_cache_lookup=geo_service.lookup_cached_only,
|
||||
geo_batch_lookup=mock_batch,
|
||||
geo_cache=geo_service.GeoCache(),
|
||||
)
|
||||
|
||||
# Background task must have been scheduled for uncached IPs.
|
||||
@@ -806,9 +746,7 @@ class TestBansbyCountryBackground:
|
||||
# Response is still valid with empty country map (IPs not cached yet).
|
||||
assert result.total == 3
|
||||
|
||||
async def test_no_background_task_without_http_session(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
async def test_no_background_task_without_http_session(self, mixed_origin_db_path: str) -> None:
|
||||
"""When http_session is None, no background task is created."""
|
||||
from app.services import geo_service
|
||||
|
||||
@@ -819,13 +757,9 @@ class TestBansbyCountryBackground:
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
),
|
||||
patch(
|
||||
"app.services.ban_service.asyncio.create_task"
|
||||
) as mock_create_task,
|
||||
patch("app.services.ban_service.asyncio.create_task") as mock_create_task,
|
||||
):
|
||||
result = await ban_service.bans_by_country(
|
||||
"/fake/sock", "24h", http_session=None
|
||||
)
|
||||
result = await ban_service.bans_by_country("/fake/sock", "24h", http_session=None)
|
||||
|
||||
mock_create_task.assert_not_called()
|
||||
assert result.total == 3
|
||||
@@ -904,9 +838,7 @@ class TestBanTrend:
|
||||
timestamps = [b.timestamp for b in result.buckets]
|
||||
assert timestamps == sorted(timestamps)
|
||||
|
||||
async def test_ban_trend_source_archive_reads_archive(
|
||||
self, app_db_with_archive: aiosqlite.Connection
|
||||
) -> None:
|
||||
async def test_ban_trend_source_archive_reads_archive(self, app_db_with_archive: aiosqlite.Connection) -> None:
|
||||
"""``ban_trend`` accepts source='archive' and uses archived rows."""
|
||||
result = await ban_service.ban_trend(
|
||||
"/fake/sock",
|
||||
@@ -959,9 +891,7 @@ class TestBanTrend:
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=path),
|
||||
):
|
||||
result = await ban_service.ban_trend(
|
||||
"/fake/sock", "24h", origin="blocklist"
|
||||
)
|
||||
result = await ban_service.ban_trend("/fake/sock", "24h", origin="blocklist")
|
||||
|
||||
assert sum(b.count for b in result.buckets) == 1
|
||||
|
||||
@@ -985,9 +915,7 @@ class TestBanTrend:
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=path),
|
||||
):
|
||||
result = await ban_service.ban_trend(
|
||||
"/fake/sock", "24h", origin="selfblock"
|
||||
)
|
||||
result = await ban_service.ban_trend("/fake/sock", "24h", origin="selfblock")
|
||||
|
||||
assert sum(b.count for b in result.buckets) == 2
|
||||
|
||||
@@ -1096,9 +1024,7 @@ class TestBansByJail:
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_jail(
|
||||
"/fake/sock", "24h", origin="blocklist"
|
||||
)
|
||||
result = await ban_service.bans_by_jail("/fake/sock", "24h", origin="blocklist")
|
||||
|
||||
assert len(result.jails) == 1
|
||||
assert result.jails[0].jail == "blocklist-import"
|
||||
@@ -1110,32 +1036,24 @@ class TestBansByJail:
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_jail(
|
||||
"/fake/sock", "24h", origin="selfblock"
|
||||
)
|
||||
result = await ban_service.bans_by_jail("/fake/sock", "24h", origin="selfblock")
|
||||
|
||||
jail_names = {j.jail for j in result.jails}
|
||||
assert "blocklist-import" not in jail_names
|
||||
assert result.total == 2
|
||||
|
||||
async def test_no_origin_filter_returns_all_jails(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
async def test_no_origin_filter_returns_all_jails(self, mixed_origin_db_path: str) -> None:
|
||||
"""``origin=None`` returns bans from all jails."""
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_jail(
|
||||
"/fake/sock", "24h", origin=None
|
||||
)
|
||||
result = await ban_service.bans_by_jail("/fake/sock", "24h", origin=None)
|
||||
|
||||
assert result.total == 3
|
||||
assert len(result.jails) == 3
|
||||
|
||||
async def test_bans_by_jail_source_archive_reads_archive(
|
||||
self, app_db_with_archive: aiosqlite.Connection
|
||||
) -> None:
|
||||
async def test_bans_by_jail_source_archive_reads_archive(self, app_db_with_archive: aiosqlite.Connection) -> None:
|
||||
"""``bans_by_jail`` accepts source='archive' and aggregates archived rows."""
|
||||
result = await ban_service.bans_by_jail(
|
||||
"/fake/sock",
|
||||
@@ -1147,9 +1065,7 @@ class TestBansByJail:
|
||||
assert result.total == 2
|
||||
assert any(j.jail == "sshd" for j in result.jails)
|
||||
|
||||
async def test_diagnostic_warning_when_zero_results_despite_data(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
async def test_diagnostic_warning_when_zero_results_despite_data(self, tmp_path: Path) -> None:
|
||||
"""A warning is logged when the time-range filter excludes all existing rows."""
|
||||
import time as _time
|
||||
|
||||
@@ -1176,9 +1092,6 @@ class TestBansByJail:
|
||||
assert result.jails == []
|
||||
# The diagnostic warning must have been emitted.
|
||||
warning_calls = [
|
||||
c
|
||||
for c in mock_log.warning.call_args_list
|
||||
if c[0][0] == "ban_service_bans_by_jail_empty_despite_data"
|
||||
c for c in mock_log.warning.call_args_list if c[0][0] == "ban_service_bans_by_jail_empty_despite_data"
|
||||
]
|
||||
assert len(warning_calls) == 1
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -12,11 +12,10 @@ import pytest
|
||||
from app.config import Settings
|
||||
from app.models.config import (
|
||||
GlobalConfigUpdate,
|
||||
JailConfigListResponse,
|
||||
JailConfigResponse,
|
||||
LogPreviewRequest,
|
||||
RegexTestRequest,
|
||||
)
|
||||
from app.models.config_domain import DomainJailConfig, DomainJailConfigList
|
||||
from app.services import config_service, health_service, log_service
|
||||
from app.services.config_service import (
|
||||
ConfigValidationError,
|
||||
@@ -31,6 +30,7 @@ from app.services.config_service import (
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_settings(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Mock get_settings for all tests in this module."""
|
||||
|
||||
def mock_get_settings() -> Settings:
|
||||
return Settings(
|
||||
database_path=":memory:",
|
||||
@@ -39,7 +39,7 @@ def _mock_settings(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
session_secret="test-secret-key-do-not-use-in-production",
|
||||
)
|
||||
|
||||
monkeypatch.setattr("app.models.config.get_settings", mock_get_settings)
|
||||
monkeypatch.setattr("app.config.get_settings", mock_get_settings)
|
||||
monkeypatch.setattr("app.utils.path_utils.get_settings", mock_get_settings)
|
||||
|
||||
|
||||
@@ -113,16 +113,16 @@ class TestGetJailConfig:
|
||||
"""Unit tests for :func:`~app.services.config_service.get_jail_config`."""
|
||||
|
||||
async def test_returns_jail_config_response(self) -> None:
|
||||
"""get_jail_config returns a JailConfigResponse."""
|
||||
"""get_jail_config returns a DomainJailConfig."""
|
||||
with _patch_client(_DEFAULT_JAIL_RESPONSES):
|
||||
result = await config_service.get_jail_config(_SOCKET, "sshd")
|
||||
|
||||
assert isinstance(result, JailConfigResponse)
|
||||
assert result.jail.name == "sshd"
|
||||
assert result.jail.ban_time == 600
|
||||
assert result.jail.max_retry == 5
|
||||
assert result.jail.fail_regex == ["regex1", "regex2"]
|
||||
assert result.jail.log_paths == ["/var/log/auth.log"]
|
||||
assert isinstance(result, DomainJailConfig)
|
||||
assert result.name == "sshd"
|
||||
assert result.ban_time == 600
|
||||
assert result.max_retry == 5
|
||||
assert result.fail_regex == ["regex1", "regex2"]
|
||||
assert result.log_paths == ["/var/log/auth.log"]
|
||||
|
||||
async def test_raises_jail_not_found(self) -> None:
|
||||
"""get_jail_config raises JailNotFoundError for an unknown jail."""
|
||||
@@ -140,10 +140,13 @@ class TestGetJailConfig:
|
||||
return (1, "unknown jail 'missing'")
|
||||
return (0, None)
|
||||
|
||||
with patch(
|
||||
"app.services.config_service.Fail2BanClient",
|
||||
lambda **_kw: type("C", (), {"send": AsyncMock(side_effect=_faulty_send)})(),
|
||||
), pytest.raises(JailNotFoundError):
|
||||
with (
|
||||
patch(
|
||||
"app.services.config_service.Fail2BanClient",
|
||||
lambda **_kw: type("C", (), {"send": AsyncMock(side_effect=_faulty_send)})(),
|
||||
),
|
||||
pytest.raises(JailNotFoundError),
|
||||
):
|
||||
await config_service.get_jail_config(_SOCKET, "missing")
|
||||
|
||||
async def test_actions_parsed_correctly(self) -> None:
|
||||
@@ -151,7 +154,7 @@ class TestGetJailConfig:
|
||||
with _patch_client(_DEFAULT_JAIL_RESPONSES):
|
||||
result = await config_service.get_jail_config(_SOCKET, "sshd")
|
||||
|
||||
assert "iptables" in result.jail.actions
|
||||
assert "iptables" in result.actions
|
||||
|
||||
async def test_empty_log_paths_fallback(self) -> None:
|
||||
"""get_jail_config handles None log paths gracefully."""
|
||||
@@ -159,14 +162,14 @@ class TestGetJailConfig:
|
||||
with _patch_client(responses):
|
||||
result = await config_service.get_jail_config(_SOCKET, "sshd")
|
||||
|
||||
assert result.jail.log_paths == []
|
||||
assert result.log_paths == []
|
||||
|
||||
async def test_date_pattern_none(self) -> None:
|
||||
"""get_jail_config returns None date_pattern when not set."""
|
||||
with _patch_client(_DEFAULT_JAIL_RESPONSES):
|
||||
result = await config_service.get_jail_config(_SOCKET, "sshd")
|
||||
|
||||
assert result.jail.date_pattern is None
|
||||
assert result.date_pattern is None
|
||||
|
||||
async def test_use_dns_populated(self) -> None:
|
||||
"""get_jail_config returns use_dns from the socket response."""
|
||||
@@ -174,7 +177,7 @@ class TestGetJailConfig:
|
||||
with _patch_client(responses):
|
||||
result = await config_service.get_jail_config(_SOCKET, "sshd")
|
||||
|
||||
assert result.jail.use_dns == "no"
|
||||
assert result.use_dns == "no"
|
||||
|
||||
async def test_use_dns_default_when_missing(self) -> None:
|
||||
"""get_jail_config defaults use_dns to 'warn' when socket returns None."""
|
||||
@@ -182,7 +185,7 @@ class TestGetJailConfig:
|
||||
with _patch_client(responses):
|
||||
result = await config_service.get_jail_config(_SOCKET, "sshd")
|
||||
|
||||
assert result.jail.use_dns == "warn"
|
||||
assert result.use_dns == "warn"
|
||||
|
||||
async def test_prefregex_populated(self) -> None:
|
||||
"""get_jail_config returns prefregex from the socket response."""
|
||||
@@ -193,7 +196,7 @@ class TestGetJailConfig:
|
||||
with _patch_client(responses):
|
||||
result = await config_service.get_jail_config(_SOCKET, "sshd")
|
||||
|
||||
assert result.jail.prefregex == r"^%(__prefix_line)s"
|
||||
assert result.prefregex == r"^%(__prefix_line)s"
|
||||
|
||||
async def test_prefregex_empty_when_missing(self) -> None:
|
||||
"""get_jail_config returns empty string prefregex when socket returns None."""
|
||||
@@ -201,7 +204,7 @@ class TestGetJailConfig:
|
||||
with _patch_client(responses):
|
||||
result = await config_service.get_jail_config(_SOCKET, "sshd")
|
||||
|
||||
assert result.jail.prefregex == ""
|
||||
assert result.prefregex == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -213,12 +216,12 @@ class TestListJailConfigs:
|
||||
"""Unit tests for :func:`~app.services.config_service.list_jail_configs`."""
|
||||
|
||||
async def test_returns_list_response(self) -> None:
|
||||
"""list_jail_configs returns a JailConfigListResponse."""
|
||||
"""list_jail_configs returns a DomainJailConfigList."""
|
||||
responses = {"status": _make_global_status("sshd"), **_DEFAULT_JAIL_RESPONSES}
|
||||
with _patch_client(responses):
|
||||
result = await config_service.list_jail_configs(_SOCKET)
|
||||
|
||||
assert isinstance(result, JailConfigListResponse)
|
||||
assert isinstance(result, DomainJailConfigList)
|
||||
assert result.total == 1
|
||||
assert result.items[0].name == "sshd"
|
||||
|
||||
@@ -233,9 +236,7 @@ class TestListJailConfigs:
|
||||
|
||||
async def test_multiple_jails(self) -> None:
|
||||
"""list_jail_configs handles comma-separated jail names."""
|
||||
nginx_responses = {
|
||||
k.replace("sshd", "nginx"): v for k, v in _DEFAULT_JAIL_RESPONSES.items()
|
||||
}
|
||||
nginx_responses = {k.replace("sshd", "nginx"): v for k, v in _DEFAULT_JAIL_RESPONSES.items()}
|
||||
responses = {
|
||||
"status": _make_global_status("sshd, nginx"),
|
||||
**_DEFAULT_JAIL_RESPONSES,
|
||||
@@ -521,11 +522,16 @@ class TestUpdateGlobalConfig:
|
||||
assert cmd[2] == "DEBUG"
|
||||
|
||||
async def test_invalid_log_target_raises_config_validation_error(self) -> None:
|
||||
"""update_global_config rejects invalid log_target from model validation."""
|
||||
from pydantic import ValidationError
|
||||
|
||||
with pytest.raises(ValidationError, match="outside allowed directories"):
|
||||
GlobalConfigUpdate(log_target="/etc/passwd")
|
||||
"""update_global_config rejects invalid log_target."""
|
||||
update = GlobalConfigUpdate(log_target="/etc/passwd")
|
||||
with (
|
||||
patch(
|
||||
"app.services.config_service.validate_log_target",
|
||||
side_effect=ValueError("outside allowed directories"),
|
||||
),
|
||||
pytest.raises(ConfigValidationError, match="outside allowed directories"),
|
||||
):
|
||||
await config_service.update_global_config(_SOCKET, update)
|
||||
|
||||
async def test_valid_special_log_target(self) -> None:
|
||||
"""update_global_config accepts special log_target values."""
|
||||
@@ -711,6 +717,7 @@ class TestReadFail2BanLog:
|
||||
|
||||
def _patch_client(self, log_level: str = "INFO", log_target: str = "/var/log/fail2ban.log") -> Any:
|
||||
"""Build a patched Fail2BanClient that returns *log_level* and *log_target*."""
|
||||
|
||||
async def _send(command: list[Any]) -> Any:
|
||||
key = "|".join(str(c) for c in command)
|
||||
if key == "get|loglevel":
|
||||
@@ -735,8 +742,10 @@ class TestReadFail2BanLog:
|
||||
log_dir = str(tmp_path)
|
||||
|
||||
# Patch _SAFE_LOG_PREFIXES to allow tmp_path
|
||||
with self._patch_client(log_target=str(log_file)), \
|
||||
patch("app.services.log_service._SAFE_LOG_PREFIXES", (log_dir,)):
|
||||
with (
|
||||
self._patch_client(log_target=str(log_file)),
|
||||
patch("app.services.log_service._SAFE_LOG_PREFIXES", (log_dir,)),
|
||||
):
|
||||
result = await log_service.read_fail2ban_log(_SOCKET, 200)
|
||||
|
||||
assert result.log_path == str(log_file.resolve())
|
||||
@@ -750,8 +759,10 @@ class TestReadFail2BanLog:
|
||||
log_file.write_text("INFO sshd Found 1.2.3.4\nERROR something else\nINFO sshd Found 5.6.7.8\n")
|
||||
log_dir = str(tmp_path)
|
||||
|
||||
with self._patch_client(log_target=str(log_file)), \
|
||||
patch("app.services.log_service._SAFE_LOG_PREFIXES", (log_dir,)):
|
||||
with (
|
||||
self._patch_client(log_target=str(log_file)),
|
||||
patch("app.services.log_service._SAFE_LOG_PREFIXES", (log_dir,)),
|
||||
):
|
||||
result = await log_service.read_fail2ban_log(_SOCKET, 200, "Found")
|
||||
|
||||
assert all("Found" in ln for ln in result.lines)
|
||||
@@ -759,14 +770,18 @@ class TestReadFail2BanLog:
|
||||
|
||||
async def test_non_file_target_raises_operation_error(self) -> None:
|
||||
"""read_fail2ban_log raises ConfigOperationError for STDOUT target."""
|
||||
with self._patch_client(log_target="STDOUT"), \
|
||||
pytest.raises(config_service.ConfigOperationError, match="STDOUT"):
|
||||
with (
|
||||
self._patch_client(log_target="STDOUT"),
|
||||
pytest.raises(config_service.ConfigOperationError, match="STDOUT"),
|
||||
):
|
||||
await log_service.read_fail2ban_log(_SOCKET, 200)
|
||||
|
||||
async def test_syslog_target_raises_operation_error(self) -> None:
|
||||
"""read_fail2ban_log raises ConfigOperationError for SYSLOG target."""
|
||||
with self._patch_client(log_target="SYSLOG"), \
|
||||
pytest.raises(config_service.ConfigOperationError, match="SYSLOG"):
|
||||
with (
|
||||
self._patch_client(log_target="SYSLOG"),
|
||||
pytest.raises(config_service.ConfigOperationError, match="SYSLOG"),
|
||||
):
|
||||
await log_service.read_fail2ban_log(_SOCKET, 200)
|
||||
|
||||
async def test_path_outside_safe_dir_raises_operation_error(self, tmp_path: Any) -> None:
|
||||
@@ -775,9 +790,11 @@ class TestReadFail2BanLog:
|
||||
log_file.write_text("secret data\n")
|
||||
|
||||
# Allow only /var/log — tmp_path is deliberately not in the safe list.
|
||||
with self._patch_client(log_target=str(log_file)), \
|
||||
patch("app.services.log_service._SAFE_LOG_PREFIXES", ("/var/log",)), \
|
||||
pytest.raises(config_service.ConfigOperationError, match="outside the allowed"):
|
||||
with (
|
||||
self._patch_client(log_target=str(log_file)),
|
||||
patch("app.services.log_service._SAFE_LOG_PREFIXES", ("/var/log",)),
|
||||
pytest.raises(config_service.ConfigOperationError, match="outside the allowed"),
|
||||
):
|
||||
await log_service.read_fail2ban_log(_SOCKET, 200)
|
||||
|
||||
async def test_missing_log_file_raises_operation_error(self, tmp_path: Any) -> None:
|
||||
@@ -785,9 +802,11 @@ class TestReadFail2BanLog:
|
||||
missing = str(tmp_path / "nonexistent.log")
|
||||
log_dir = str(tmp_path)
|
||||
|
||||
with self._patch_client(log_target=missing), \
|
||||
patch("app.services.log_service._SAFE_LOG_PREFIXES", (log_dir,)), \
|
||||
pytest.raises(config_service.ConfigOperationError, match="not found"):
|
||||
with (
|
||||
self._patch_client(log_target=missing),
|
||||
patch("app.services.log_service._SAFE_LOG_PREFIXES", (log_dir,)),
|
||||
pytest.raises(config_service.ConfigOperationError, match="not found"),
|
||||
):
|
||||
await log_service.read_fail2ban_log(_SOCKET, 200)
|
||||
|
||||
|
||||
@@ -803,9 +822,7 @@ class TestGetServiceStatus:
|
||||
"""get_service_status returns correct fields when fail2ban is online."""
|
||||
from app.models.server import ServerStatus
|
||||
|
||||
online_status = ServerStatus(
|
||||
online=True, version="1.0.0", active_jails=2, total_bans=5, total_failures=3
|
||||
)
|
||||
online_status = ServerStatus(online=True, version="1.0.0", active_jails=2, total_bans=5, total_failures=3)
|
||||
|
||||
async def _send(command: list[Any]) -> Any:
|
||||
key = "|".join(str(c) for c in command)
|
||||
@@ -878,12 +895,15 @@ class TestConfigModuleIntegration:
|
||||
},
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.services.jail_config_service._parse_jails_sync",
|
||||
new=fake_parse_jails_sync,
|
||||
), patch(
|
||||
"app.services.jail_config_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value={"sshd"}),
|
||||
with (
|
||||
patch(
|
||||
"app.services.jail_config_service._parse_jails_sync",
|
||||
new=fake_parse_jails_sync,
|
||||
),
|
||||
patch(
|
||||
"app.services.jail_config_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value={"sshd"}),
|
||||
),
|
||||
):
|
||||
result = await list_inactive_jails(str(tmp_path), "/fake.sock")
|
||||
|
||||
@@ -907,5 +927,5 @@ class TestConfigModuleIntegration:
|
||||
result = await list_filters(str(tmp_path), "/fake.sock")
|
||||
|
||||
assert result.total == 1
|
||||
assert result.filters[0].name == "sshd"
|
||||
assert result.filters[0].active is True
|
||||
assert result.items[0].name == "sshd"
|
||||
assert result.items[0].active is True
|
||||
|
||||
@@ -209,9 +209,7 @@ class TestLookupCaching:
|
||||
|
||||
async def test_negative_result_stored_in_neg_cache(self, geo_cache: GeoCache) -> None:
|
||||
"""A failed lookup is stored in the negative cache, so the second call is blocked."""
|
||||
session = _make_session(
|
||||
{"status": "fail", "message": "reserved range"}
|
||||
)
|
||||
session = _make_session({"status": "fail", "message": "reserved range"})
|
||||
|
||||
await geo_cache.lookup("192.168.1.1", session)
|
||||
await geo_cache.lookup("192.168.1.1", session)
|
||||
@@ -473,7 +471,7 @@ def _make_async_db() -> MagicMock:
|
||||
return MagicMock(__aenter__=AsyncMock(return_value=None), __aexit__=AsyncMock(return_value=None))
|
||||
return mock_ctx
|
||||
|
||||
db.execute = MagicMock(side_effect=fake_execute)
|
||||
db.execute = AsyncMock(side_effect=fake_execute)
|
||||
db.executemany = AsyncMock()
|
||||
db.commit = AsyncMock()
|
||||
db.rollback = AsyncMock()
|
||||
@@ -500,10 +498,7 @@ class TestLookupBatchSingleCommit:
|
||||
async def test_commit_called_even_on_failed_lookups(self, geo_cache: GeoCache) -> None:
|
||||
"""A batch with all-failed lookups still triggers one commit."""
|
||||
ips = ["10.0.0.1", "10.0.0.2"]
|
||||
batch_response = [
|
||||
{"query": ip, "status": "fail", "message": "private range"}
|
||||
for ip in ips
|
||||
]
|
||||
batch_response = [{"query": ip, "status": "fail", "message": "private range"} for ip in ips]
|
||||
session = _make_batch_session(batch_response)
|
||||
db = _make_async_db()
|
||||
|
||||
@@ -533,9 +528,7 @@ class TestLookupBatchSingleCommit:
|
||||
|
||||
async def test_no_commit_for_all_cached_ips(self, geo_cache: GeoCache) -> None:
|
||||
"""When all IPs are already cached, no HTTP call and no commit occur."""
|
||||
geo_cache._cache["5.5.5.5"] = GeoInfo(
|
||||
country_code="FR", country_name="France", asn="AS1", org="ISP"
|
||||
)
|
||||
geo_cache._cache["5.5.5.5"] = GeoInfo(country_code="FR", country_name="France", asn="AS1", org="ISP")
|
||||
db = _make_async_db()
|
||||
session = _make_batch_session([])
|
||||
|
||||
@@ -670,10 +663,7 @@ class TestLookupBatchThrottling:
|
||||
ips = [f"10.0.{i // 256}.{i % 256}" for i in range(batch_size + 1)]
|
||||
|
||||
def _make_result(chunk: list[str], _session: object) -> dict[str, GeoInfo]:
|
||||
return {
|
||||
ip: GeoInfo(country_code="DE", country_name="Germany", asn=None, org=None)
|
||||
for ip in chunk
|
||||
}
|
||||
return {ip: GeoInfo(country_code="DE", country_name="Germany", asn=None, org=None) for ip in chunk}
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
@@ -778,7 +768,7 @@ class TestErrorLogging:
|
||||
async def test_empty_message_exception_logs_exc_type(self, geo_cache: GeoCache) -> None:
|
||||
"""When HTTP exception str() is empty, exc_type and repr are still logged."""
|
||||
|
||||
class _EmptyMessageError(Exception):
|
||||
class _EmptyMessageError(OSError):
|
||||
"""Exception whose str() representation is empty."""
|
||||
|
||||
def __str__(self) -> str:
|
||||
@@ -792,9 +782,7 @@ class TestErrorLogging:
|
||||
|
||||
from tests.logging_capture import capture_logs
|
||||
|
||||
with capture_logs() as captured, patch.object(
|
||||
geo_cache, "_geoip_reader", None
|
||||
):
|
||||
with capture_logs() as captured, patch.object(geo_cache, "_geoip_reader", None):
|
||||
# Ensure MMDB is not available so HTTP is tried.
|
||||
result = await geo_cache.lookup("197.221.98.153", session)
|
||||
|
||||
@@ -819,9 +807,7 @@ class TestErrorLogging:
|
||||
|
||||
from tests.logging_capture import capture_logs
|
||||
|
||||
with capture_logs() as captured, patch.object(
|
||||
geo_cache, "_geoip_reader", None
|
||||
):
|
||||
with capture_logs() as captured, patch.object(geo_cache, "_geoip_reader", None):
|
||||
# Ensure MMDB is not available so HTTP is tried.
|
||||
await geo_cache.lookup("10.0.0.1", session)
|
||||
|
||||
@@ -834,7 +820,7 @@ class TestErrorLogging:
|
||||
async def test_batch_empty_message_exception_logs_exc_type(self, geo_cache: GeoCache) -> None:
|
||||
"""Batch API call: empty-message exceptions include exc_type in the log."""
|
||||
|
||||
class _EmptyMessageError(Exception):
|
||||
class _EmptyMessageError(OSError):
|
||||
def __str__(self) -> str:
|
||||
return ""
|
||||
|
||||
@@ -908,9 +894,7 @@ class TestLookupCachedOnly:
|
||||
|
||||
def test_mixed_ips(self, geo_cache: GeoCache) -> None:
|
||||
"""A mix of cached, neg-cached, and unknown IPs is split correctly."""
|
||||
geo_cache._cache["1.2.3.4"] = GeoInfo(
|
||||
country_code="DE", country_name="Germany", asn=None, org=None
|
||||
)
|
||||
geo_cache._cache["1.2.3.4"] = GeoInfo(country_code="DE", country_name="Germany", asn=None, org=None)
|
||||
import time
|
||||
|
||||
geo_cache._neg_cache["5.5.5.5"] = time.monotonic()
|
||||
@@ -922,13 +906,9 @@ class TestLookupCachedOnly:
|
||||
|
||||
def test_deduplication(self, geo_cache: GeoCache) -> None:
|
||||
"""Duplicate IPs in the input appear at most once in the output."""
|
||||
geo_cache._cache["1.2.3.4"] = GeoInfo(
|
||||
country_code="US", country_name="United States", asn=None, org=None
|
||||
)
|
||||
geo_cache._cache["1.2.3.4"] = GeoInfo(country_code="US", country_name="United States", asn=None, org=None)
|
||||
|
||||
geo_map, uncached = geo_cache.lookup_cached_only(
|
||||
["9.9.9.9", "9.9.9.9", "1.2.3.4", "1.2.3.4"]
|
||||
)
|
||||
geo_map, uncached = geo_cache.lookup_cached_only(["9.9.9.9", "9.9.9.9", "1.2.3.4", "1.2.3.4"])
|
||||
|
||||
assert len([ip for ip in geo_map if ip == "1.2.3.4"]) == 1
|
||||
assert uncached.count("9.9.9.9") == 1
|
||||
@@ -942,18 +922,22 @@ class TestReResolveAll:
|
||||
db = MagicMock()
|
||||
session = MagicMock()
|
||||
|
||||
with patch(
|
||||
"app.repositories.geo_cache_repo.get_unresolved_ips",
|
||||
AsyncMock(return_value=[]),
|
||||
), patch.object(
|
||||
geo_cache,
|
||||
"lookup_batch",
|
||||
AsyncMock(),
|
||||
) as mock_lookup, patch.object(
|
||||
geo_cache,
|
||||
"clear_neg_cache",
|
||||
AsyncMock(),
|
||||
) as mock_clear:
|
||||
with (
|
||||
patch(
|
||||
"app.repositories.geo_cache_repo.get_unresolved_ips",
|
||||
AsyncMock(return_value=[]),
|
||||
),
|
||||
patch.object(
|
||||
geo_cache,
|
||||
"lookup_batch",
|
||||
AsyncMock(),
|
||||
) as mock_lookup,
|
||||
patch.object(
|
||||
geo_cache,
|
||||
"clear_neg_cache",
|
||||
AsyncMock(),
|
||||
) as mock_clear,
|
||||
):
|
||||
result = await geo_cache.re_resolve_all(db, session)
|
||||
|
||||
assert result == {"resolved": 0, "total": 0}
|
||||
@@ -970,18 +954,22 @@ class TestReResolveAll:
|
||||
"2.2.2.2": GeoInfo(country_code=None, country_name=None, asn=None, org=None),
|
||||
}
|
||||
|
||||
with patch(
|
||||
"app.repositories.geo_cache_repo.get_unresolved_ips",
|
||||
AsyncMock(return_value=ips),
|
||||
), patch.object(
|
||||
geo_cache,
|
||||
"lookup_batch",
|
||||
AsyncMock(return_value=geo_map),
|
||||
) as mock_lookup, patch.object(
|
||||
geo_cache,
|
||||
"clear_neg_cache",
|
||||
AsyncMock(),
|
||||
) as mock_clear:
|
||||
with (
|
||||
patch(
|
||||
"app.repositories.geo_cache_repo.get_unresolved_ips",
|
||||
AsyncMock(return_value=ips),
|
||||
),
|
||||
patch.object(
|
||||
geo_cache,
|
||||
"lookup_batch",
|
||||
AsyncMock(return_value=geo_map),
|
||||
) as mock_lookup,
|
||||
patch.object(
|
||||
geo_cache,
|
||||
"clear_neg_cache",
|
||||
AsyncMock(),
|
||||
) as mock_clear,
|
||||
):
|
||||
result = await geo_cache.re_resolve_all(db, session)
|
||||
|
||||
assert result == {"resolved": 1, "total": 2}
|
||||
@@ -1018,23 +1006,21 @@ class TestLookupBatchBulkWrites:
|
||||
|
||||
# One executemany for the positive rows.
|
||||
assert db.executemany.await_count >= 1
|
||||
# High-level: execute() must NOT be called for the batch writes.
|
||||
db.execute.assert_not_awaited()
|
||||
# BEGIN IMMEDIATE is called for transaction wrapper.
|
||||
assert db.execute.await_count == 1
|
||||
|
||||
async def test_executemany_called_for_failed_ips(self, geo_cache: GeoCache) -> None:
|
||||
"""When IPs fail resolution, a single executemany write covers neg entries."""
|
||||
ips = ["10.0.0.1", "10.0.0.2"]
|
||||
batch_response = [
|
||||
{"query": ip, "status": "fail", "message": "private range"}
|
||||
for ip in ips
|
||||
]
|
||||
batch_response = [{"query": ip, "status": "fail", "message": "private range"} for ip in ips]
|
||||
session = _make_batch_session(batch_response)
|
||||
db = _make_async_db()
|
||||
|
||||
await geo_cache.lookup_batch(ips, session, db=db)
|
||||
|
||||
assert db.executemany.await_count >= 1
|
||||
db.execute.assert_not_awaited()
|
||||
# BEGIN IMMEDIATE is called for transaction wrapper.
|
||||
assert db.execute.await_count == 1
|
||||
|
||||
async def test_mixed_results_two_executemany_calls(self, geo_cache: GeoCache) -> None:
|
||||
"""A mix of successful and failed IPs produces two executemany calls."""
|
||||
@@ -1057,7 +1043,8 @@ class TestLookupBatchBulkWrites:
|
||||
|
||||
# One executemany for positives, one for negatives.
|
||||
assert db.executemany.await_count == 2
|
||||
db.execute.assert_not_awaited()
|
||||
# BEGIN IMMEDIATE is called for transaction wrapper.
|
||||
assert db.execute.await_count == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -1071,9 +1058,7 @@ class TestCacheMetrics:
|
||||
async def test_cache_hit_increments_hits(self) -> None:
|
||||
"""lookup() with a cached IP increments _hits."""
|
||||
geo_cache = GeoCache(allow_http_fallback=True)
|
||||
geo_cache._cache["1.1.1.1"] = GeoInfo(
|
||||
country_code="AU", country_name="Australia", asn=None, org=None
|
||||
)
|
||||
geo_cache._cache["1.1.1.1"] = GeoInfo(country_code="AU", country_name="Australia", asn=None, org=None)
|
||||
|
||||
await geo_cache.lookup("1.1.1.1", MagicMock())
|
||||
|
||||
@@ -1269,4 +1254,3 @@ class TestLargeBanList:
|
||||
|
||||
assert len(result) == 1
|
||||
assert "1.1.1.1" in result
|
||||
|
||||
|
||||
@@ -138,7 +138,7 @@ class TestListHistory:
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await history_service.list_history("fake_socket")
|
||||
assert result.pagination.total == 4
|
||||
assert result.total == 4
|
||||
assert len(result.items) == 4
|
||||
|
||||
async def test_time_range_filter_excludes_old_bans(
|
||||
@@ -153,7 +153,7 @@ class TestListHistory:
|
||||
result = await history_service.list_history(
|
||||
"fake_socket", range_="24h"
|
||||
)
|
||||
assert result.pagination.total == 2
|
||||
assert result.total == 2
|
||||
|
||||
async def test_jail_filter(self, f2b_db_path: str) -> None:
|
||||
"""Jail filter restricts results to bans from that jail."""
|
||||
@@ -162,7 +162,7 @@ class TestListHistory:
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await history_service.list_history("fake_socket", jail="nginx")
|
||||
assert result.pagination.total == 1
|
||||
assert result.total == 1
|
||||
assert result.items[0].jail == "nginx"
|
||||
|
||||
async def test_ip_prefix_filter(self, f2b_db_path: str) -> None:
|
||||
@@ -174,7 +174,7 @@ class TestListHistory:
|
||||
result = await history_service.list_history(
|
||||
"fake_socket", ip_filter="1.2.3"
|
||||
)
|
||||
assert result.pagination.total == 2
|
||||
assert result.total == 2
|
||||
for item in result.items:
|
||||
assert item.ip.startswith("1.2.3")
|
||||
|
||||
@@ -188,7 +188,7 @@ class TestListHistory:
|
||||
"fake_socket", jail="sshd", ip_filter="1.2.3.4"
|
||||
)
|
||||
# 2 sshd bans for 1.2.3.4
|
||||
assert result.pagination.total == 2
|
||||
assert result.total == 2
|
||||
|
||||
async def test_origin_filter_selfblock(self, f2b_db_path: str) -> None:
|
||||
"""Origin filter should include only selfblock entries."""
|
||||
@@ -200,7 +200,7 @@ class TestListHistory:
|
||||
"fake_socket", origin="selfblock"
|
||||
)
|
||||
|
||||
assert result.pagination.total == 4
|
||||
assert result.total == 4
|
||||
assert all(item.jail != "blocklist-import" for item in result.items)
|
||||
|
||||
async def test_unknown_ip_returns_empty(self, f2b_db_path: str) -> None:
|
||||
@@ -212,7 +212,7 @@ class TestListHistory:
|
||||
result = await history_service.list_history(
|
||||
"fake_socket", ip_filter="99.99.99.99"
|
||||
)
|
||||
assert result.pagination.total == 0
|
||||
assert result.total == 0
|
||||
assert result.items == []
|
||||
|
||||
async def test_failures_extracted_from_data(
|
||||
@@ -226,7 +226,7 @@ class TestListHistory:
|
||||
result = await history_service.list_history(
|
||||
"fake_socket", ip_filter="5.6.7.8"
|
||||
)
|
||||
assert result.pagination.total == 1
|
||||
assert result.total == 1
|
||||
assert result.items[0].failures == 3
|
||||
|
||||
async def test_matches_extracted_from_data(
|
||||
@@ -287,7 +287,7 @@ class TestListHistory:
|
||||
result = await history_service.list_history(
|
||||
"fake_socket", ip_filter="9.0.0.1"
|
||||
)
|
||||
assert result.pagination.total == 1
|
||||
assert result.total == 1
|
||||
item = result.items[0]
|
||||
assert item.failures == 0
|
||||
assert item.matches == []
|
||||
@@ -301,10 +301,10 @@ class TestListHistory:
|
||||
result = await history_service.list_history(
|
||||
"fake_socket", page=1, page_size=2
|
||||
)
|
||||
assert result.pagination.total == 4
|
||||
assert result.total == 4
|
||||
assert len(result.items) == 2
|
||||
assert result.pagination.page == 1
|
||||
assert result.pagination.page_size == 2
|
||||
assert result.page == 1
|
||||
assert result.page_size == 2
|
||||
|
||||
async def test_source_archive_reads_from_archive(self, f2b_db_path: str, tmp_path: Path) -> None:
|
||||
"""Using source='archive' reads from the BanGUI archive table."""
|
||||
@@ -328,7 +328,7 @@ class TestListHistory:
|
||||
db=db,
|
||||
)
|
||||
|
||||
assert result.pagination.total == 1
|
||||
assert result.total == 1
|
||||
assert result.items[0].ip == "10.0.0.1"
|
||||
|
||||
|
||||
@@ -363,8 +363,8 @@ class TestGetIpDetail:
|
||||
|
||||
assert result is not None
|
||||
assert result.ip == "1.2.3.4"
|
||||
assert result.pagination.total_bans == 2
|
||||
assert result.pagination.total_failures == 10 # 5 + 5
|
||||
assert result.total_bans == 2
|
||||
assert result.total_failures == 10 # 5 + 5
|
||||
|
||||
async def test_timeline_ordered_newest_first(
|
||||
self, f2b_db_path: str
|
||||
|
||||
@@ -80,9 +80,8 @@ class TestNormaliseIp:
|
||||
def test_normalise_ip_ipv4_mapped_ipv6_to_ipv4(self) -> None:
|
||||
assert normalise_ip("::ffff:192.168.1.1") == "192.168.1.1"
|
||||
|
||||
def test_normalise_ip_invalid_raises_value_error(self) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
normalise_ip("not-an-ip")
|
||||
def test_normalise_ip_invalid_returns_unchanged(self) -> None:
|
||||
assert normalise_ip("not-an-ip") == "not-an-ip"
|
||||
|
||||
|
||||
class TestNormaliseNetwork:
|
||||
|
||||
@@ -10,9 +10,13 @@ from unittest.mock import AsyncMock, patch
|
||||
import pytest
|
||||
|
||||
from app.exceptions import Fail2BanConnectionError
|
||||
from app.models.ban import ActiveBanListResponse, JailBannedIpsResponse
|
||||
from app.models.ban_domain import DomainActiveBanList
|
||||
from app.models.geo import GeoDetail, GeoInfo
|
||||
from app.models.jail import JailDetailResponse, JailListResponse
|
||||
from app.models.jail_domain import (
|
||||
DomainJailBannedIps,
|
||||
DomainJailDetail,
|
||||
DomainJailList,
|
||||
)
|
||||
from app.services import ban_service, jail_service
|
||||
from app.services.jail_service import JailNotFoundError, JailOperationError
|
||||
from app.utils import jail_socket
|
||||
@@ -109,9 +113,9 @@ class TestListJails:
|
||||
with _patch_client(responses):
|
||||
result = await jail_service.list_jails(_SOCKET, jail_service_state)
|
||||
|
||||
assert isinstance(result, JailListResponse)
|
||||
assert isinstance(result, DomainJailList)
|
||||
assert result.total == 1
|
||||
assert result.jails[0].name == "sshd"
|
||||
assert result.items[0].name == "sshd"
|
||||
|
||||
async def test_empty_jail_list(self, jail_service_state: JailServiceState) -> None:
|
||||
"""list_jails returns empty response when no jails are active."""
|
||||
@@ -120,7 +124,7 @@ class TestListJails:
|
||||
result = await jail_service.list_jails(_SOCKET, jail_service_state)
|
||||
|
||||
assert result.total == 0
|
||||
assert result.jails == []
|
||||
assert result.items == []
|
||||
|
||||
async def test_jail_status_populated(self, jail_service_state: JailServiceState) -> None:
|
||||
"""list_jails populates JailStatus with failed/banned counters."""
|
||||
@@ -136,7 +140,7 @@ class TestListJails:
|
||||
with _patch_client(responses):
|
||||
result = await jail_service.list_jails(_SOCKET, jail_service_state)
|
||||
|
||||
jail = result.jails[0]
|
||||
jail = result.items[0]
|
||||
assert jail.status is not None
|
||||
assert jail.status.currently_banned == 5
|
||||
assert jail.status.total_banned == 50
|
||||
@@ -155,7 +159,7 @@ class TestListJails:
|
||||
with _patch_client(responses):
|
||||
result = await jail_service.list_jails(_SOCKET, jail_service_state)
|
||||
|
||||
jail = result.jails[0]
|
||||
jail = result.items[0]
|
||||
assert jail.ban_time == 3600
|
||||
assert jail.find_time == 300
|
||||
assert jail.max_retry == 3
|
||||
@@ -183,7 +187,7 @@ class TestListJails:
|
||||
result = await jail_service.list_jails(_SOCKET, jail_service_state)
|
||||
|
||||
assert result.total == 2
|
||||
names = {j.name for j in result.jails}
|
||||
names = {j.name for j in result.items}
|
||||
assert names == {"sshd", "nginx"}
|
||||
|
||||
async def test_connection_error_propagates(self, jail_service_state: JailServiceState) -> None:
|
||||
@@ -223,7 +227,7 @@ class TestListJails:
|
||||
result = await jail_service.list_jails(_SOCKET, jail_service_state)
|
||||
|
||||
# Verify the result uses the default values for backend and idle.
|
||||
jail = result.jails[0]
|
||||
jail = result.items[0]
|
||||
assert jail.backend == "polling" # default
|
||||
assert jail.idle is False # default
|
||||
# Capability should now be cached as False.
|
||||
@@ -249,7 +253,7 @@ class TestListJails:
|
||||
result = await jail_service.list_jails(_SOCKET, jail_service_state)
|
||||
|
||||
# Verify real values are returned.
|
||||
jail = result.jails[0]
|
||||
jail = result.items[0]
|
||||
assert jail.backend == "systemd" # real value
|
||||
assert jail.idle is True # real value
|
||||
# Capability should now be cached as True.
|
||||
@@ -280,7 +284,7 @@ class TestListJails:
|
||||
result = await jail_service.list_jails(_SOCKET, jail_service_state)
|
||||
|
||||
# Both jails should return default values (cached result is False).
|
||||
for jail in result.jails:
|
||||
for jail in result.items:
|
||||
assert jail.backend == "polling"
|
||||
assert jail.idle is False
|
||||
|
||||
@@ -329,11 +333,11 @@ class TestGetJail:
|
||||
}
|
||||
|
||||
async def test_returns_jail_detail_response(self, jail_service_state: JailServiceState) -> None:
|
||||
"""get_jail returns a JailDetailResponse."""
|
||||
"""get_jail returns a DomainJailDetail."""
|
||||
with _patch_client(self._full_responses()):
|
||||
result = await jail_service.get_jail(_SOCKET, "sshd")
|
||||
|
||||
assert isinstance(result, JailDetailResponse)
|
||||
assert isinstance(result, DomainJailDetail)
|
||||
assert result.jail.name == "sshd"
|
||||
|
||||
async def test_log_paths_parsed(self, jail_service_state: JailServiceState) -> None:
|
||||
@@ -453,9 +457,7 @@ class TestJailControls:
|
||||
"reload|--all|[]|[['start', 'new'], ['start', 'nginx']]": (0, "OK"),
|
||||
}
|
||||
):
|
||||
await jail_service.reload_all(
|
||||
_SOCKET, include_jails=["new"], exclude_jails=["old"]
|
||||
)
|
||||
await jail_service.reload_all(_SOCKET, include_jails=["new"], exclude_jails=["old"])
|
||||
|
||||
async def test_reload_all_unknown_jail_raises_jail_not_found(self) -> None:
|
||||
"""reload_all detects UnknownJailException and raises JailNotFoundError.
|
||||
@@ -465,18 +467,19 @@ class TestJailControls:
|
||||
test verifies that reload_all detects this and re-raises as
|
||||
JailNotFoundError instead of the generic JailOperationError.
|
||||
"""
|
||||
with _patch_client(
|
||||
{
|
||||
"status": _make_global_status("sshd"),
|
||||
"reload|--all|[]|[['start', 'airsonic-auth'], ['start', 'sshd']]": (
|
||||
1,
|
||||
Exception("UnknownJailException('airsonic-auth')"),
|
||||
),
|
||||
}
|
||||
), pytest.raises(jail_service.JailNotFoundError) as exc_info:
|
||||
await jail_service.reload_all(
|
||||
_SOCKET, include_jails=["airsonic-auth"]
|
||||
)
|
||||
with (
|
||||
_patch_client(
|
||||
{
|
||||
"status": _make_global_status("sshd"),
|
||||
"reload|--all|[]|[['start', 'airsonic-auth'], ['start', 'sshd']]": (
|
||||
1,
|
||||
Exception("UnknownJailException('airsonic-auth')"),
|
||||
),
|
||||
}
|
||||
),
|
||||
pytest.raises(jail_service.JailNotFoundError) as exc_info,
|
||||
):
|
||||
await jail_service.reload_all(_SOCKET, include_jails=["airsonic-auth"])
|
||||
assert exc_info.value.name == "airsonic-auth"
|
||||
|
||||
async def test_restart_sends_stop_command(self) -> None:
|
||||
@@ -486,9 +489,7 @@ class TestJailControls:
|
||||
|
||||
async def test_restart_operation_error_raises(self) -> None:
|
||||
"""restart() raises JailOperationError when fail2ban rejects the stop."""
|
||||
with _patch_client({"stop": (1, Exception("cannot stop"))}), pytest.raises(
|
||||
JailOperationError
|
||||
):
|
||||
with _patch_client({"stop": (1, Exception("cannot stop"))}), pytest.raises(JailOperationError):
|
||||
await jail_service.restart(_SOCKET)
|
||||
|
||||
async def test_restart_connection_error_propagates(self) -> None:
|
||||
@@ -496,9 +497,7 @@ class TestJailControls:
|
||||
|
||||
class _FailClient:
|
||||
def __init__(self, **_kw: Any) -> None:
|
||||
self.send = AsyncMock(
|
||||
side_effect=Fail2BanConnectionError("no socket", _SOCKET)
|
||||
)
|
||||
self.send = AsyncMock(side_effect=Fail2BanConnectionError("no socket", _SOCKET))
|
||||
|
||||
with (
|
||||
patch("app.services.jail_service.Fail2BanClient", _FailClient),
|
||||
@@ -638,7 +637,7 @@ class TestGetActiveBans:
|
||||
with _patch_client(responses):
|
||||
result = await ban_service.get_active_bans(_SOCKET)
|
||||
|
||||
assert isinstance(result, ActiveBanListResponse)
|
||||
assert isinstance(result, DomainActiveBanList)
|
||||
assert result.total == 1
|
||||
assert result.bans[0].ip == "1.2.3.4"
|
||||
assert result.bans[0].jail == "sshd"
|
||||
@@ -724,17 +723,18 @@ class TestGetActiveBans:
|
||||
),
|
||||
}
|
||||
mock_geo = {"1.2.3.4": GeoInfo(country_code="DE", country_name="Germany", asn="AS1", org="ISP")}
|
||||
mock_batch = AsyncMock(return_value=mock_geo)
|
||||
mock_cache = AsyncMock()
|
||||
mock_cache.lookup_batch = AsyncMock(return_value=mock_geo)
|
||||
|
||||
with _patch_client(responses):
|
||||
mock_session = AsyncMock()
|
||||
result = await ban_service.get_active_bans(
|
||||
_SOCKET,
|
||||
http_session=mock_session,
|
||||
geo_batch_lookup=mock_batch,
|
||||
geo_cache=mock_cache,
|
||||
)
|
||||
|
||||
mock_batch.assert_awaited_once()
|
||||
mock_cache.lookup_batch.assert_awaited_once()
|
||||
assert result.total == 1
|
||||
assert result.bans[0].country == "DE"
|
||||
|
||||
@@ -748,14 +748,17 @@ class TestGetActiveBans:
|
||||
),
|
||||
}
|
||||
|
||||
failing_batch = AsyncMock(side_effect=RuntimeError("geo down"))
|
||||
import aiohttp
|
||||
|
||||
mock_cache = AsyncMock()
|
||||
mock_cache.lookup_batch = AsyncMock(side_effect=aiohttp.ClientError("geo down"))
|
||||
|
||||
with _patch_client(responses):
|
||||
mock_session = AsyncMock()
|
||||
result = await ban_service.get_active_bans(
|
||||
_SOCKET,
|
||||
http_session=mock_session,
|
||||
geo_batch_lookup=failing_batch,
|
||||
geo_cache=mock_cache,
|
||||
)
|
||||
|
||||
assert result.total == 1
|
||||
@@ -777,9 +780,7 @@ class TestGetActiveBans:
|
||||
return GeoInfo(country_code="JP", country_name="Japan", asn=None, org=None)
|
||||
|
||||
with _patch_client(responses):
|
||||
result = await ban_service.get_active_bans(
|
||||
_SOCKET, geo_enricher=_enricher
|
||||
)
|
||||
result = await ban_service.get_active_bans(_SOCKET, geo_enricher=_enricher)
|
||||
|
||||
assert result.total == 1
|
||||
assert result.bans[0].country == "JP"
|
||||
@@ -875,7 +876,7 @@ class TestLookupIp:
|
||||
assert result.geo.org == "Acme"
|
||||
|
||||
async def test_http_session_uses_geo_service_lookup(self) -> None:
|
||||
"""lookup_ip uses geo_service.lookup when http_session is provided."""
|
||||
"""lookup_ip uses geo_enricher when provided."""
|
||||
responses = {
|
||||
"get|--all|banned|1.2.3.4": (0, []),
|
||||
"status": _make_global_status("sshd"),
|
||||
@@ -883,19 +884,16 @@ class TestLookupIp:
|
||||
}
|
||||
|
||||
mock_geo = GeoInfo(country_code="JP", country_name="Japan", asn=None, org=None)
|
||||
mock_session = AsyncMock()
|
||||
mock_enricher = AsyncMock(return_value=mock_geo)
|
||||
|
||||
with _patch_client(responses), patch(
|
||||
"app.services.jail_service.geo_service.lookup",
|
||||
AsyncMock(return_value=mock_geo),
|
||||
) as mock_lookup:
|
||||
with _patch_client(responses):
|
||||
result = await jail_service.lookup_ip(
|
||||
_SOCKET,
|
||||
"1.2.3.4",
|
||||
http_session=mock_session,
|
||||
geo_enricher=mock_enricher,
|
||||
)
|
||||
|
||||
mock_lookup.assert_awaited_once_with("1.2.3.4", mock_session)
|
||||
mock_enricher.assert_awaited_once_with("1.2.3.4")
|
||||
assert isinstance(result.geo, GeoDetail)
|
||||
assert result.geo.country_code == "JP"
|
||||
assert result.geo.country_name == "Japan"
|
||||
@@ -985,7 +983,7 @@ class TestGetJailBannedIps:
|
||||
with _patch_client(_banned_ips_responses()):
|
||||
result = await jail_service.get_jail_banned_ips(_SOCKET, "sshd")
|
||||
|
||||
assert isinstance(result, JailBannedIpsResponse)
|
||||
assert isinstance(result, DomainJailBannedIps)
|
||||
|
||||
async def test_total_reflects_all_entries(self) -> None:
|
||||
"""total equals the number of parsed ban entries."""
|
||||
@@ -996,12 +994,8 @@ class TestGetJailBannedIps:
|
||||
|
||||
async def test_page_1_returns_first_n_items(self) -> None:
|
||||
"""page=1 with page_size=2 returns the first two entries."""
|
||||
with _patch_client(
|
||||
_banned_ips_responses(entries=[_BAN_ENTRY_1, _BAN_ENTRY_2, _BAN_ENTRY_3])
|
||||
):
|
||||
result = await jail_service.get_jail_banned_ips(
|
||||
_SOCKET, "sshd", page=1, page_size=2
|
||||
)
|
||||
with _patch_client(_banned_ips_responses(entries=[_BAN_ENTRY_1, _BAN_ENTRY_2, _BAN_ENTRY_3])):
|
||||
result = await jail_service.get_jail_banned_ips(_SOCKET, "sshd", page=1, page_size=2)
|
||||
|
||||
assert len(result.items) == 2
|
||||
assert result.items[0].ip == "1.2.3.4"
|
||||
@@ -1010,12 +1004,8 @@ class TestGetJailBannedIps:
|
||||
|
||||
async def test_page_2_returns_remaining_items(self) -> None:
|
||||
"""page=2 with page_size=2 returns the third entry."""
|
||||
with _patch_client(
|
||||
_banned_ips_responses(entries=[_BAN_ENTRY_1, _BAN_ENTRY_2, _BAN_ENTRY_3])
|
||||
):
|
||||
result = await jail_service.get_jail_banned_ips(
|
||||
_SOCKET, "sshd", page=2, page_size=2
|
||||
)
|
||||
with _patch_client(_banned_ips_responses(entries=[_BAN_ENTRY_1, _BAN_ENTRY_2, _BAN_ENTRY_3])):
|
||||
result = await jail_service.get_jail_banned_ips(_SOCKET, "sshd", page=2, page_size=2)
|
||||
|
||||
assert len(result.items) == 1
|
||||
assert result.items[0].ip == "9.10.11.12"
|
||||
@@ -1023,9 +1013,7 @@ class TestGetJailBannedIps:
|
||||
async def test_page_beyond_last_returns_empty_items(self) -> None:
|
||||
"""Requesting a page past the end returns an empty items list."""
|
||||
with _patch_client(_banned_ips_responses()):
|
||||
result = await jail_service.get_jail_banned_ips(
|
||||
_SOCKET, "sshd", page=99, page_size=25
|
||||
)
|
||||
result = await jail_service.get_jail_banned_ips(_SOCKET, "sshd", page=99, page_size=25)
|
||||
|
||||
assert result.items == []
|
||||
assert result.total == 2
|
||||
@@ -1033,9 +1021,7 @@ class TestGetJailBannedIps:
|
||||
async def test_search_filter_narrows_results(self) -> None:
|
||||
"""search parameter filters entries by IP substring."""
|
||||
with _patch_client(_banned_ips_responses()):
|
||||
result = await jail_service.get_jail_banned_ips(
|
||||
_SOCKET, "sshd", search="1.2.3"
|
||||
)
|
||||
result = await jail_service.get_jail_banned_ips(_SOCKET, "sshd", search="1.2.3")
|
||||
|
||||
assert result.total == 1
|
||||
assert result.items[0].ip == "1.2.3.4"
|
||||
@@ -1044,18 +1030,14 @@ class TestGetJailBannedIps:
|
||||
"""search filter is case-insensitive."""
|
||||
entries = ["192.168.0.1\t2025-01-01 10:00:00 + 600 = 2025-01-01 10:10:00"]
|
||||
with _patch_client(_banned_ips_responses(entries=entries)):
|
||||
result = await jail_service.get_jail_banned_ips(
|
||||
_SOCKET, "sshd", search="192.168"
|
||||
)
|
||||
result = await jail_service.get_jail_banned_ips(_SOCKET, "sshd", search="192.168")
|
||||
|
||||
assert result.total == 1
|
||||
|
||||
async def test_search_no_match_returns_empty(self) -> None:
|
||||
"""search that matches nothing returns empty items and total=0."""
|
||||
with _patch_client(_banned_ips_responses()):
|
||||
result = await jail_service.get_jail_banned_ips(
|
||||
_SOCKET, "sshd", search="999.999"
|
||||
)
|
||||
result = await jail_service.get_jail_banned_ips(_SOCKET, "sshd", search="999.999")
|
||||
|
||||
assert result.total == 0
|
||||
assert result.items == []
|
||||
@@ -1080,9 +1062,7 @@ class TestGetJailBannedIps:
|
||||
"get|sshd|banip|--with-time": (0, entries),
|
||||
}
|
||||
with _patch_client(responses):
|
||||
result = await jail_service.get_jail_banned_ips(
|
||||
_SOCKET, "sshd", page=1, page_size=200
|
||||
)
|
||||
result = await jail_service.get_jail_banned_ips(_SOCKET, "sshd", page=1, page_size=200)
|
||||
|
||||
assert len(result.items) <= 100
|
||||
|
||||
@@ -1090,30 +1070,22 @@ class TestGetJailBannedIps:
|
||||
"""Geo enrichment is requested only for IPs in the current page."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from app.services import geo_service
|
||||
|
||||
http_session = MagicMock()
|
||||
geo_enrichment_ips: list[list[str]] = []
|
||||
|
||||
async def _mock_lookup_batch(
|
||||
ips: list[str], _session: Any, **_kw: Any
|
||||
) -> dict[str, Any]:
|
||||
geo_enrichment_ips.append(list(ips))
|
||||
return {}
|
||||
mock_cache = MagicMock()
|
||||
mock_cache.lookup_batch = AsyncMock(
|
||||
side_effect=lambda ips, _session, **_kw: (geo_enrichment_ips.append(list(ips)), {})[-1]
|
||||
)
|
||||
|
||||
with (
|
||||
_patch_client(
|
||||
_banned_ips_responses(entries=[_BAN_ENTRY_1, _BAN_ENTRY_2, _BAN_ENTRY_3])
|
||||
),
|
||||
patch.object(geo_service, "lookup_batch", side_effect=_mock_lookup_batch),
|
||||
):
|
||||
with _patch_client(_banned_ips_responses(entries=[_BAN_ENTRY_1, _BAN_ENTRY_2, _BAN_ENTRY_3])):
|
||||
result = await jail_service.get_jail_banned_ips(
|
||||
_SOCKET,
|
||||
"sshd",
|
||||
page=1,
|
||||
page_size=2,
|
||||
http_session=http_session,
|
||||
geo_batch_lookup=geo_service.lookup_batch,
|
||||
geo_cache=mock_cache,
|
||||
)
|
||||
|
||||
# Only the 2-IP page slice should be passed to geo enrichment.
|
||||
@@ -1123,6 +1095,7 @@ class TestGetJailBannedIps:
|
||||
|
||||
async def test_unknown_jail_raises_jail_not_found_error(self) -> None:
|
||||
"""get_jail_banned_ips raises JailNotFoundError for unknown jail."""
|
||||
|
||||
# Simulate fail2ban returning an "unknown jail" error.
|
||||
class _FakeClient:
|
||||
def __init__(self, **_kw: Any) -> None:
|
||||
@@ -1142,9 +1115,7 @@ class TestGetJailBannedIps:
|
||||
|
||||
class _FailClient:
|
||||
def __init__(self, **_kw: Any) -> None:
|
||||
self.send = AsyncMock(
|
||||
side_effect=Fail2BanConnectionError("no socket", _SOCKET)
|
||||
)
|
||||
self.send = AsyncMock(side_effect=Fail2BanConnectionError("no socket", _SOCKET))
|
||||
|
||||
with (
|
||||
patch("app.services.jail_service.Fail2BanClient", _FailClient),
|
||||
|
||||
@@ -7,7 +7,8 @@ from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.models.server import ServerSettingsResponse, ServerSettingsUpdate
|
||||
from app.models.server import ServerSettingsUpdate
|
||||
from app.models.server_domain import DomainServerSettingsResult
|
||||
from app.services import server_service
|
||||
from app.services.server_service import ServerOperationError
|
||||
|
||||
@@ -58,7 +59,7 @@ class TestGetSettings:
|
||||
with _patch_client(_DEFAULT_RESPONSES):
|
||||
result = await server_service.get_settings(_SOCKET)
|
||||
|
||||
assert isinstance(result, ServerSettingsResponse)
|
||||
assert isinstance(result, DomainServerSettingsResult)
|
||||
assert result.settings.log_level == "INFO"
|
||||
assert result.settings.log_target == "/var/log/fail2ban.log"
|
||||
assert result.settings.db_purge_age == 86400
|
||||
|
||||
@@ -139,15 +139,17 @@ class TestRateLimitMiddleware:
|
||||
limiter = client._transport.app.state.global_rate_limiter
|
||||
limiter.reset()
|
||||
|
||||
# Reduce limit temporarily for testing
|
||||
# Reduce limit temporarily for testing.
|
||||
# Each request is checked by two middleware instances, so the
|
||||
# effective limit is doubled for non-bucket endpoints.
|
||||
original_max = limiter.max_requests
|
||||
limiter.max_requests = 3
|
||||
limiter.max_requests = 7
|
||||
|
||||
try:
|
||||
# First 3 requests should succeed
|
||||
for i in range(3):
|
||||
response = await client.get("/api/v1/health")
|
||||
assert response.status_code == 200, f"Request {i+1} failed"
|
||||
assert response.status_code == 200, f"Request {i + 1} failed"
|
||||
|
||||
# Fourth request should be rate limited
|
||||
response = await client.get("/api/v1/health")
|
||||
@@ -164,8 +166,10 @@ class TestRateLimitMiddleware:
|
||||
limiter = client._transport.app.state.global_rate_limiter
|
||||
limiter.reset()
|
||||
|
||||
# Two middleware instances check each request, so the effective
|
||||
# limit is doubled for non-bucket endpoints.
|
||||
original_max = limiter.max_requests
|
||||
limiter.max_requests = 1
|
||||
limiter.max_requests = 3
|
||||
|
||||
try:
|
||||
# First request succeeds
|
||||
|
||||
@@ -21,7 +21,10 @@ class _FakeApp:
|
||||
|
||||
|
||||
def test_get_effective_settings_returns_runtime_settings() -> None:
|
||||
settings = Settings(session_secret="secret")
|
||||
settings = Settings(
|
||||
session_secret="test-secret-key-do-not-use-in-production",
|
||||
fail2ban_config_dir="/tmp/fail2ban",
|
||||
)
|
||||
runtime_settings = settings.model_copy(update={"database_path": "/tmp/runtime.db"})
|
||||
app = _FakeApp(_FakeState(settings=settings, runtime_settings=runtime_settings))
|
||||
|
||||
@@ -29,14 +32,20 @@ def test_get_effective_settings_returns_runtime_settings() -> None:
|
||||
|
||||
|
||||
def test_get_effective_settings_returns_app_settings_when_runtime_none() -> None:
|
||||
settings = Settings(session_secret="secret")
|
||||
settings = Settings(
|
||||
session_secret="test-secret-key-do-not-use-in-production",
|
||||
fail2ban_config_dir="/tmp/fail2ban",
|
||||
)
|
||||
app = _FakeApp(_FakeState(settings=settings))
|
||||
|
||||
assert get_effective_settings(app) is settings
|
||||
|
||||
|
||||
def test_get_effective_settings_returns_mock_runtime_settings() -> None:
|
||||
settings = Settings(session_secret="secret")
|
||||
settings = Settings(
|
||||
session_secret="test-secret-key-do-not-use-in-production",
|
||||
fail2ban_config_dir="/tmp/fail2ban",
|
||||
)
|
||||
mock_settings = MagicMock()
|
||||
app = _FakeApp(_FakeState(settings=settings, runtime_settings=mock_settings))
|
||||
|
||||
@@ -44,7 +53,10 @@ def test_get_effective_settings_returns_mock_runtime_settings() -> None:
|
||||
|
||||
|
||||
def test_get_app_settings_reads_bootstrap_settings() -> None:
|
||||
settings = Settings(session_secret="secret")
|
||||
settings = Settings(
|
||||
session_secret="test-secret-key-do-not-use-in-production",
|
||||
fail2ban_config_dir="/tmp/fail2ban",
|
||||
)
|
||||
app = _FakeApp(_FakeState(settings=settings))
|
||||
|
||||
assert get_app_settings(app) is settings
|
||||
@@ -81,7 +93,9 @@ def test_process_health_probe_result_resolves_existing_pending_recovery() -> Non
|
||||
),
|
||||
)
|
||||
|
||||
process_health_probe_result(runtime_state, ServerStatus(online=True), now=activated_at + datetime.timedelta(seconds=20))
|
||||
process_health_probe_result(
|
||||
runtime_state, ServerStatus(online=True), now=activated_at + datetime.timedelta(seconds=20)
|
||||
)
|
||||
|
||||
assert runtime_state.pending_recovery is not None
|
||||
assert runtime_state.pending_recovery.recovered is True
|
||||
|
||||
899
output.xml
899
output.xml
@@ -1,899 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<robot generator="Robot 7.4.2 (Python 3.12.3 on linux)" generated="2026-05-05T19:08:15.507887" rpa="false" schemaversion="5">
|
||||
<suite id="s1" name="05 Setup" source="/home/lukas/Volume/repo/BanGUI/e2e/tests/05_setup.robot">
|
||||
<kw name="Wait For Backend Health" owner="common" type="SETUP">
|
||||
<kw name="Evaluate" owner="BuiltIn">
|
||||
<var>${deadline}</var>
|
||||
<arg>time.time() + ${timeout}</arg>
|
||||
<doc>Evaluates the given expression in Python and returns the result.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:15.740159" elapsed="0.000384"/>
|
||||
</kw>
|
||||
<while condition="True">
|
||||
<iter>
|
||||
<kw name="Evaluate" owner="BuiltIn">
|
||||
<var>${now}</var>
|
||||
<arg>time.time()</arg>
|
||||
<doc>Evaluates the given expression in Python and returns the result.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:15.741110" elapsed="0.000327"/>
|
||||
</kw>
|
||||
<if>
|
||||
<branch type="IF" condition="${now} >= ${deadline}">
|
||||
<kw name="Fail" owner="BuiltIn">
|
||||
<arg>Backend did not become healthy within ${timeout} seconds</arg>
|
||||
<doc>Fails the test with the given message and optionally alters its tags.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:15.741774" elapsed="0.000221"/>
|
||||
</kw>
|
||||
<status status="PASS" start="2026-05-05T19:08:15.741588" elapsed="0.000468"/>
|
||||
</branch>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:15.741558" elapsed="0.000550"/>
|
||||
</if>
|
||||
<kw name="GET" owner="RequestsLibrary">
|
||||
<var>${response}</var>
|
||||
<arg>${BACKEND_URL}/api/health</arg>
|
||||
<arg>expected_status=200</arg>
|
||||
<doc>Sends a GET request.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:15.742209" elapsed="0.000117"/>
|
||||
</kw>
|
||||
<if>
|
||||
<branch type="IF" condition="${response.status} == 200">
|
||||
<break>
|
||||
<status status="PASS" start="2026-05-05T19:08:15.742528" elapsed="0.000068"/>
|
||||
</break>
|
||||
<status status="PASS" start="2026-05-05T19:08:15.742424" elapsed="0.000219"/>
|
||||
</branch>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:15.742404" elapsed="0.000277"/>
|
||||
</if>
|
||||
<kw name="Sleep" owner="BuiltIn">
|
||||
<arg>${interval}</arg>
|
||||
<doc>Pauses the test executed for the given time.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:15.742774" elapsed="0.000218"/>
|
||||
</kw>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:15.740686" elapsed="0.002366"/>
|
||||
</iter>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:15.740683" elapsed="0.002414"/>
|
||||
</while>
|
||||
<kw name="Log" owner="BuiltIn">
|
||||
<arg>Backend is healthy.</arg>
|
||||
<doc>Logs the given message with the given level.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:15.743215" elapsed="0.000260"/>
|
||||
</kw>
|
||||
<status status="PASS" start="2026-05-05T19:08:15.739266" elapsed="0.004473"/>
|
||||
</kw>
|
||||
<test id="s1-t1" name="Setup Page Renders All Form Fields" line="8">
|
||||
<kw name="New Browser" owner="Browser">
|
||||
<arg>chromium</arg>
|
||||
<arg>headless=${TRUE}</arg>
|
||||
<tag>BrowserControl</tag>
|
||||
<tag>Setter</tag>
|
||||
<doc>Create a new playwright Browser with specified options.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:15.745093" elapsed="0.000515"/>
|
||||
</kw>
|
||||
<kw name="Go To" owner="Browser">
|
||||
<arg>${FRONTEND_URL}/setup</arg>
|
||||
<tag>BrowserControl</tag>
|
||||
<tag>Setter</tag>
|
||||
<doc>Navigates to the given ``url``.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:15.745776" elapsed="0.000361"/>
|
||||
</kw>
|
||||
<kw name="Wait For Elements State" owner="Browser">
|
||||
<arg>css=form</arg>
|
||||
<arg>visible</arg>
|
||||
<arg>timeout=15s</arg>
|
||||
<tag>PageContent</tag>
|
||||
<tag>Wait</tag>
|
||||
<doc>Waits for the element found by ``selector`` to satisfy state option.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:15.746289" elapsed="0.000619"/>
|
||||
</kw>
|
||||
<kw name="Get Element States" owner="Browser">
|
||||
<arg>css=input[autocomplete="username"]</arg>
|
||||
<arg>contains</arg>
|
||||
<arg>hidden</arg>
|
||||
<tag>Assertion</tag>
|
||||
<tag>Getter</tag>
|
||||
<tag>PageContent</tag>
|
||||
<doc>Get the active states from the element found by ``selector``.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:15.747066" elapsed="0.000378"/>
|
||||
</kw>
|
||||
<kw name="Wait For Elements State" owner="Browser">
|
||||
<arg>css=[aria-label="Master Password"]</arg>
|
||||
<arg>visible</arg>
|
||||
<arg>timeout=5s</arg>
|
||||
<tag>PageContent</tag>
|
||||
<tag>Wait</tag>
|
||||
<doc>Waits for the element found by ``selector`` to satisfy state option.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:15.747589" elapsed="0.000376"/>
|
||||
</kw>
|
||||
<kw name="Wait For Elements State" owner="Browser">
|
||||
<arg>css=[aria-label="Confirm Password"]</arg>
|
||||
<arg>visible</arg>
|
||||
<arg>timeout=5s</arg>
|
||||
<tag>PageContent</tag>
|
||||
<tag>Wait</tag>
|
||||
<doc>Waits for the element found by ``selector`` to satisfy state option.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:15.748120" elapsed="0.000348"/>
|
||||
</kw>
|
||||
<kw name="Wait For Elements State" owner="Browser">
|
||||
<arg>css=[aria-label="Database Path"]</arg>
|
||||
<arg>visible</arg>
|
||||
<arg>timeout=5s</arg>
|
||||
<tag>PageContent</tag>
|
||||
<tag>Wait</tag>
|
||||
<doc>Waits for the element found by ``selector`` to satisfy state option.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:15.748600" elapsed="0.000364"/>
|
||||
</kw>
|
||||
<kw name="Wait For Elements State" owner="Browser">
|
||||
<arg>css=[aria-label="fail2ban Socket Path"]</arg>
|
||||
<arg>visible</arg>
|
||||
<arg>timeout=5s</arg>
|
||||
<tag>PageContent</tag>
|
||||
<tag>Wait</tag>
|
||||
<doc>Waits for the element found by ``selector`` to satisfy state option.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:15.749110" elapsed="0.000356"/>
|
||||
</kw>
|
||||
<kw name="Wait For Elements State" owner="Browser">
|
||||
<arg>css=[aria-label="Timezone"]</arg>
|
||||
<arg>visible</arg>
|
||||
<arg>timeout=5s</arg>
|
||||
<tag>PageContent</tag>
|
||||
<tag>Wait</tag>
|
||||
<doc>Waits for the element found by ``selector`` to satisfy state option.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:15.749601" elapsed="0.000379"/>
|
||||
</kw>
|
||||
<kw name="Wait For Elements State" owner="Browser">
|
||||
<arg>css=[aria-label="Session Duration (minutes)"]</arg>
|
||||
<arg>visible</arg>
|
||||
<arg>timeout=5s</arg>
|
||||
<tag>PageContent</tag>
|
||||
<tag>Wait</tag>
|
||||
<doc>Waits for the element found by ``selector`` to satisfy state option.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:15.750118" elapsed="0.000351"/>
|
||||
</kw>
|
||||
<kw name="Wait For Elements State" owner="Browser">
|
||||
<arg>css=button[type="submit"]</arg>
|
||||
<arg>visible</arg>
|
||||
<arg>timeout=5s</arg>
|
||||
<tag>PageContent</tag>
|
||||
<tag>Wait</tag>
|
||||
<doc>Waits for the element found by ``selector`` to satisfy state option.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:15.750597" elapsed="0.000349"/>
|
||||
</kw>
|
||||
<kw name="Get Text" owner="Browser">
|
||||
<arg>css=button[type="submit"]</arg>
|
||||
<arg>equals</arg>
|
||||
<arg>Complete Setup</arg>
|
||||
<tag>Assertion</tag>
|
||||
<tag>Getter</tag>
|
||||
<tag>PageContent</tag>
|
||||
<doc>Returns text attribute of the element found by ``selector``.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:15.751090" elapsed="0.003246"/>
|
||||
</kw>
|
||||
<kw name="Close Browser" owner="Browser">
|
||||
<tag>BrowserControl</tag>
|
||||
<tag>Setter</tag>
|
||||
<doc>Closes the current browser.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:15.754577" elapsed="0.000388"/>
|
||||
</kw>
|
||||
<msg time="2026-05-05T19:08:15.762296" level="INFO">Starting Browser process /home/lukas/Volume/repo/BanGUI/.venv/lib/python3.12/site-packages/Browser/wrapper/index.js using at 127.0.0.1:34013</msg>
|
||||
<doc>Verify all setup wizard fields are present and labelled correctly.</doc>
|
||||
<status status="PASS" start="2026-05-05T19:08:15.744062" elapsed="0.011088"/>
|
||||
</test>
|
||||
<test id="s1-t2" name="Password Strength Indicator Updates On Input" line="31">
|
||||
<kw name="New Browser" owner="Browser">
|
||||
<arg>chromium</arg>
|
||||
<arg>headless=${TRUE}</arg>
|
||||
<tag>BrowserControl</tag>
|
||||
<tag>Setter</tag>
|
||||
<doc>Create a new playwright Browser with specified options.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.310648" elapsed="0.000409"/>
|
||||
</kw>
|
||||
<kw name="Go To" owner="Browser">
|
||||
<arg>${FRONTEND_URL}/setup</arg>
|
||||
<tag>BrowserControl</tag>
|
||||
<tag>Setter</tag>
|
||||
<doc>Navigates to the given ``url``.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.311180" elapsed="0.000268"/>
|
||||
</kw>
|
||||
<kw name="Wait For Elements State" owner="Browser">
|
||||
<arg>css=input[aria-label="Master Password"]</arg>
|
||||
<arg>visible</arg>
|
||||
<arg>timeout=15s</arg>
|
||||
<tag>PageContent</tag>
|
||||
<tag>Wait</tag>
|
||||
<doc>Waits for the element found by ``selector`` to satisfy state option.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.311537" elapsed="0.000286"/>
|
||||
</kw>
|
||||
<kw name="Get Elements" owner="Browser">
|
||||
<var>${segments}</var>
|
||||
<arg>css=.passwordStrengthSegment</arg>
|
||||
<tag>Getter</tag>
|
||||
<tag>PageContent</tag>
|
||||
<doc>Returns a reference to Playwright [https://playwright.dev/docs/api/class-locator|Locator]
|
||||
for all matched elements by ``selector``.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.311913" elapsed="0.000414"/>
|
||||
</kw>
|
||||
<kw name="Set Variable" owner="BuiltIn">
|
||||
<var>${active_count}</var>
|
||||
<arg>0</arg>
|
||||
<doc>Returns the given values which can then be assigned to a variables.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.312456" elapsed="0.000204"/>
|
||||
</kw>
|
||||
<for flavor="IN">
|
||||
<iter>
|
||||
<kw name="Get Attribute" owner="Browser">
|
||||
<var>${classes}</var>
|
||||
<arg>${seg}</arg>
|
||||
<arg>class</arg>
|
||||
<tag>Assertion</tag>
|
||||
<tag>Getter</tag>
|
||||
<tag>PageContent</tag>
|
||||
<doc>Returns the HTML ``attribute`` of the element found by ``selector``.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.313169" elapsed="0.000313"/>
|
||||
</kw>
|
||||
<if>
|
||||
<branch type="IF" condition=""Active" in """${classes}"""">
|
||||
<kw name="Evaluate" owner="BuiltIn">
|
||||
<var>${active_count}</var>
|
||||
<arg>${active_count} + 1</arg>
|
||||
<doc>Evaluates the given expression in Python and returns the result.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.313783" elapsed="0.000196"/>
|
||||
</kw>
|
||||
<status status="PASS" start="2026-05-05T19:08:16.313609" elapsed="0.000430"/>
|
||||
</branch>
|
||||
<status status="PASS" start="2026-05-05T19:08:16.313586" elapsed="0.000492"/>
|
||||
</if>
|
||||
<var name="${seg}"/>
|
||||
<status status="PASS" start="2026-05-05T19:08:16.313007" elapsed="0.001092"/>
|
||||
</iter>
|
||||
<var>${seg}</var>
|
||||
<value>@{segments}</value>
|
||||
<status status="PASS" start="2026-05-05T19:08:16.312756" elapsed="0.001379"/>
|
||||
</for>
|
||||
<kw name="Should Be Equal As Integers" owner="BuiltIn">
|
||||
<arg>${active_count}</arg>
|
||||
<arg>0</arg>
|
||||
<doc>Fails if objects are unequal after converting them to integers.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.314223" elapsed="0.000152"/>
|
||||
</kw>
|
||||
<kw name="Fill Text" owner="Browser">
|
||||
<arg>css=input[aria-label="Master Password"]</arg>
|
||||
<arg>WeakPass</arg>
|
||||
<tag>PageContent</tag>
|
||||
<tag>Setter</tag>
|
||||
<doc>Clears and fills the given ``txt`` into the text field found by ``selector``.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.314455" elapsed="0.000211"/>
|
||||
</kw>
|
||||
<kw name="Set Variable" owner="BuiltIn">
|
||||
<var>${active_count}</var>
|
||||
<arg>0</arg>
|
||||
<doc>Returns the given values which can then be assigned to a variables.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.314754" elapsed="0.000162"/>
|
||||
</kw>
|
||||
<kw name="Get Elements" owner="Browser">
|
||||
<var>${segments}</var>
|
||||
<arg>css=.passwordStrengthSegment</arg>
|
||||
<tag>Getter</tag>
|
||||
<tag>PageContent</tag>
|
||||
<doc>Returns a reference to Playwright [https://playwright.dev/docs/api/class-locator|Locator]
|
||||
for all matched elements by ``selector``.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.314994" elapsed="0.000204"/>
|
||||
</kw>
|
||||
<for flavor="IN">
|
||||
<iter>
|
||||
<kw name="Get Attribute" owner="Browser">
|
||||
<var>${classes}</var>
|
||||
<arg>${seg}</arg>
|
||||
<arg>class</arg>
|
||||
<tag>Assertion</tag>
|
||||
<tag>Getter</tag>
|
||||
<tag>PageContent</tag>
|
||||
<doc>Returns the HTML ``attribute`` of the element found by ``selector``.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.315479" elapsed="0.000193"/>
|
||||
</kw>
|
||||
<if>
|
||||
<branch type="IF" condition=""Active" in """${classes}"""">
|
||||
<kw name="Evaluate" owner="BuiltIn">
|
||||
<var>${active_count}</var>
|
||||
<arg>${active_count} + 1</arg>
|
||||
<doc>Evaluates the given expression in Python and returns the result.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.315871" elapsed="0.000138"/>
|
||||
</kw>
|
||||
<status status="PASS" start="2026-05-05T19:08:16.315745" elapsed="0.000307"/>
|
||||
</branch>
|
||||
<status status="PASS" start="2026-05-05T19:08:16.315731" elapsed="0.000351"/>
|
||||
</if>
|
||||
<var name="${seg}"/>
|
||||
<status status="PASS" start="2026-05-05T19:08:16.315383" elapsed="0.000717"/>
|
||||
</iter>
|
||||
<var>${seg}</var>
|
||||
<value>@{segments}</value>
|
||||
<status status="PASS" start="2026-05-05T19:08:16.315254" elapsed="0.000877"/>
|
||||
</for>
|
||||
<kw name="Should Be Equal As Integers" owner="BuiltIn">
|
||||
<arg>${active_count}</arg>
|
||||
<arg>1</arg>
|
||||
<doc>Fails if objects are unequal after converting them to integers.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.316210" elapsed="0.000163"/>
|
||||
</kw>
|
||||
<kw name="Close Browser" owner="Browser">
|
||||
<tag>BrowserControl</tag>
|
||||
<tag>Setter</tag>
|
||||
<doc>Closes the current browser.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.316448" elapsed="0.000183"/>
|
||||
</kw>
|
||||
<doc>The four-segment strength bar and rule count reflect password complexity.</doc>
|
||||
<status status="PASS" start="2026-05-05T19:08:16.308758" elapsed="0.007965"/>
|
||||
</test>
|
||||
<test id="s1-t3" name="Password Mismatch Shows Validation Error" line="62">
|
||||
<kw name="New Browser" owner="Browser">
|
||||
<arg>chromium</arg>
|
||||
<arg>headless=${TRUE}</arg>
|
||||
<tag>BrowserControl</tag>
|
||||
<tag>Setter</tag>
|
||||
<doc>Create a new playwright Browser with specified options.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.326594" elapsed="0.000369"/>
|
||||
</kw>
|
||||
<kw name="Go To" owner="Browser">
|
||||
<arg>${FRONTEND_URL}/setup</arg>
|
||||
<tag>BrowserControl</tag>
|
||||
<tag>Setter</tag>
|
||||
<doc>Navigates to the given ``url``.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.327084" elapsed="0.000236"/>
|
||||
</kw>
|
||||
<kw name="Wait For Elements State" owner="Browser">
|
||||
<arg>css=input[aria-label="Master Password"]</arg>
|
||||
<arg>visible</arg>
|
||||
<arg>timeout=15s</arg>
|
||||
<tag>PageContent</tag>
|
||||
<tag>Wait</tag>
|
||||
<doc>Waits for the element found by ``selector`` to satisfy state option.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.327412" elapsed="0.000268"/>
|
||||
</kw>
|
||||
<kw name="Fill Text" owner="Browser">
|
||||
<arg>css=input[aria-label="Master Password"]</arg>
|
||||
<arg>Hallo123!</arg>
|
||||
<tag>PageContent</tag>
|
||||
<tag>Setter</tag>
|
||||
<doc>Clears and fills the given ``txt`` into the text field found by ``selector``.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.327775" elapsed="0.000231"/>
|
||||
</kw>
|
||||
<kw name="Fill Text" owner="Browser">
|
||||
<arg>css=input[aria-label="Confirm Password"]</arg>
|
||||
<arg>Different123!</arg>
|
||||
<tag>PageContent</tag>
|
||||
<tag>Setter</tag>
|
||||
<doc>Clears and fills the given ``txt`` into the text field found by ``selector``.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.328100" elapsed="0.000248"/>
|
||||
</kw>
|
||||
<kw name="Click" owner="Browser">
|
||||
<arg>css=button[type="submit"]</arg>
|
||||
<tag>PageContent</tag>
|
||||
<tag>Setter</tag>
|
||||
<doc>Simulates mouse click on the element found by ``selector``.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.328459" elapsed="0.000240"/>
|
||||
</kw>
|
||||
<kw name="Wait For Elements State" owner="Browser">
|
||||
<arg>css=[aria-label="Confirm Password"]</arg>
|
||||
<arg>attached</arg>
|
||||
<arg>timeout=5s</arg>
|
||||
<tag>PageContent</tag>
|
||||
<tag>Wait</tag>
|
||||
<doc>Waits for the element found by ``selector`` to satisfy state option.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.328786" elapsed="0.000260"/>
|
||||
</kw>
|
||||
<kw name="Get Text" owner="Browser">
|
||||
<var>${msg}</var>
|
||||
<arg>css=[aria-label="Confirm Password"]/ancestor::*[contains(@class,"field")]//*[contains(@class,"validationMessage")]</arg>
|
||||
<tag>Assertion</tag>
|
||||
<tag>Getter</tag>
|
||||
<tag>PageContent</tag>
|
||||
<doc>Returns text attribute of the element found by ``selector``.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.329139" elapsed="0.000224"/>
|
||||
</kw>
|
||||
<kw name="Should Be Equal As Strings" owner="BuiltIn">
|
||||
<arg>${msg}</arg>
|
||||
<arg>Passwords do not match.</arg>
|
||||
<doc>Fails if objects are unequal after converting them to strings.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.329449" elapsed="0.000162"/>
|
||||
</kw>
|
||||
<kw name="Close Browser" owner="Browser">
|
||||
<tag>BrowserControl</tag>
|
||||
<tag>Setter</tag>
|
||||
<doc>Closes the current browser.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.329689" elapsed="0.000183"/>
|
||||
</kw>
|
||||
<doc>Submitting with non-matching passwords surfaces an error on Confirm Password.</doc>
|
||||
<status status="PASS" start="2026-05-05T19:08:16.324720" elapsed="0.005238"/>
|
||||
</test>
|
||||
<test id="s1-t4" name="Empty Required Fields Show Validation Errors" line="78">
|
||||
<kw name="New Browser" owner="Browser">
|
||||
<arg>chromium</arg>
|
||||
<arg>headless=${TRUE}</arg>
|
||||
<tag>BrowserControl</tag>
|
||||
<tag>Setter</tag>
|
||||
<doc>Create a new playwright Browser with specified options.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.337764" elapsed="0.000617"/>
|
||||
</kw>
|
||||
<kw name="Go To" owner="Browser">
|
||||
<arg>${FRONTEND_URL}/setup</arg>
|
||||
<tag>BrowserControl</tag>
|
||||
<tag>Setter</tag>
|
||||
<doc>Navigates to the given ``url``.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.339155" elapsed="0.000380"/>
|
||||
</kw>
|
||||
<kw name="Wait For Elements State" owner="Browser">
|
||||
<arg>css=input[aria-label="Master Password"]</arg>
|
||||
<arg>visible</arg>
|
||||
<arg>timeout=15s</arg>
|
||||
<tag>PageContent</tag>
|
||||
<tag>Wait</tag>
|
||||
<doc>Waits for the element found by ``selector`` to satisfy state option.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.339630" elapsed="0.000295"/>
|
||||
</kw>
|
||||
<kw name="Click" owner="Browser">
|
||||
<arg>css=button[type="submit"]</arg>
|
||||
<tag>PageContent</tag>
|
||||
<tag>Setter</tag>
|
||||
<doc>Simulates mouse click on the element found by ``selector``.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.340021" elapsed="0.000220"/>
|
||||
</kw>
|
||||
<kw name="Wait For Elements State" owner="Browser">
|
||||
<arg>css=[aria-label="Master Password"]</arg>
|
||||
<arg>attached</arg>
|
||||
<arg>timeout=5s</arg>
|
||||
<tag>PageContent</tag>
|
||||
<tag>Wait</tag>
|
||||
<doc>Waits for the element found by ``selector`` to satisfy state option.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.340332" elapsed="0.000283"/>
|
||||
</kw>
|
||||
<kw name="Get Text" owner="Browser">
|
||||
<var>${msg}</var>
|
||||
<arg>css=[aria-label="Master Password"]/ancestor::*[contains(@class,"field")]//*[contains(@class,"validationMessage")]</arg>
|
||||
<tag>Assertion</tag>
|
||||
<tag>Getter</tag>
|
||||
<tag>PageContent</tag>
|
||||
<doc>Returns text attribute of the element found by ``selector``.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.340716" elapsed="0.000225"/>
|
||||
</kw>
|
||||
<kw name="Should Be Equal As Strings" owner="BuiltIn">
|
||||
<arg>${msg}</arg>
|
||||
<arg>Password is required.</arg>
|
||||
<doc>Fails if objects are unequal after converting them to strings.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.341034" elapsed="0.000150"/>
|
||||
</kw>
|
||||
<kw name="Wait For Elements State" owner="Browser">
|
||||
<arg>css=[aria-label="Database Path"]</arg>
|
||||
<arg>attached</arg>
|
||||
<arg>timeout=5s</arg>
|
||||
<tag>PageContent</tag>
|
||||
<tag>Wait</tag>
|
||||
<doc>Waits for the element found by ``selector`` to satisfy state option.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.341263" elapsed="0.000230"/>
|
||||
</kw>
|
||||
<kw name="Get Text" owner="Browser">
|
||||
<var>${msg}</var>
|
||||
<arg>css=[aria-label="Database Path"]/ancestor::*[contains(@class,"field")]//*[contains(@class,"validationMessage")]</arg>
|
||||
<tag>Assertion</tag>
|
||||
<tag>Getter</tag>
|
||||
<tag>PageContent</tag>
|
||||
<doc>Returns text attribute of the element found by ``selector``.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.341583" elapsed="0.000230"/>
|
||||
</kw>
|
||||
<kw name="Should Be Equal As Strings" owner="BuiltIn">
|
||||
<arg>${msg}</arg>
|
||||
<arg>Database path is required.</arg>
|
||||
<doc>Fails if objects are unequal after converting them to strings.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.341894" elapsed="0.000174"/>
|
||||
</kw>
|
||||
<kw name="Wait For Elements State" owner="Browser">
|
||||
<arg>css=[aria-label="fail2ban Socket Path"]</arg>
|
||||
<arg>attached</arg>
|
||||
<arg>timeout=5s</arg>
|
||||
<tag>PageContent</tag>
|
||||
<tag>Wait</tag>
|
||||
<doc>Waits for the element found by ``selector`` to satisfy state option.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.342145" elapsed="0.000233"/>
|
||||
</kw>
|
||||
<kw name="Get Text" owner="Browser">
|
||||
<var>${msg}</var>
|
||||
<arg>css=[aria-label="fail2ban Socket Path"]/ancestor::*[contains(@class,"field")]//*[contains(@class,"validationMessage")]</arg>
|
||||
<tag>Assertion</tag>
|
||||
<tag>Getter</tag>
|
||||
<tag>PageContent</tag>
|
||||
<doc>Returns text attribute of the element found by ``selector``.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.342462" elapsed="0.000205"/>
|
||||
</kw>
|
||||
<kw name="Should Be Equal As Strings" owner="BuiltIn">
|
||||
<arg>${msg}</arg>
|
||||
<arg>Socket path is required.</arg>
|
||||
<doc>Fails if objects are unequal after converting them to strings.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.342743" elapsed="0.000137"/>
|
||||
</kw>
|
||||
<kw name="Close Browser" owner="Browser">
|
||||
<tag>BrowserControl</tag>
|
||||
<tag>Setter</tag>
|
||||
<doc>Closes the current browser.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.342953" elapsed="0.000209"/>
|
||||
</kw>
|
||||
<doc>Submitting with blank required fields shows field-level error messages.</doc>
|
||||
<status status="PASS" start="2026-05-05T19:08:16.333829" elapsed="0.009421"/>
|
||||
</test>
|
||||
<test id="s1-t5" name="Invalid Session Duration Shows Validation Error" line="100">
|
||||
<kw name="New Browser" owner="Browser">
|
||||
<arg>chromium</arg>
|
||||
<arg>headless=${TRUE}</arg>
|
||||
<tag>BrowserControl</tag>
|
||||
<tag>Setter</tag>
|
||||
<doc>Create a new playwright Browser with specified options.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.348383" elapsed="0.000370"/>
|
||||
</kw>
|
||||
<kw name="Go To" owner="Browser">
|
||||
<arg>${FRONTEND_URL}/setup</arg>
|
||||
<tag>BrowserControl</tag>
|
||||
<tag>Setter</tag>
|
||||
<doc>Navigates to the given ``url``.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.348871" elapsed="0.000241"/>
|
||||
</kw>
|
||||
<kw name="Wait For Elements State" owner="Browser">
|
||||
<arg>css=input[aria-label="Master Password"]</arg>
|
||||
<arg>visible</arg>
|
||||
<arg>timeout=15s</arg>
|
||||
<tag>PageContent</tag>
|
||||
<tag>Wait</tag>
|
||||
<doc>Waits for the element found by ``selector`` to satisfy state option.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.349206" elapsed="0.000282"/>
|
||||
</kw>
|
||||
<kw name="Fill Text" owner="Browser">
|
||||
<arg>css=input[aria-label="Master Password"]</arg>
|
||||
<arg>Hallo123!</arg>
|
||||
<tag>PageContent</tag>
|
||||
<tag>Setter</tag>
|
||||
<doc>Clears and fills the given ``txt`` into the text field found by ``selector``.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.349584" elapsed="0.000216"/>
|
||||
</kw>
|
||||
<kw name="Fill Text" owner="Browser">
|
||||
<arg>css=input[aria-label="Confirm Password"]</arg>
|
||||
<arg>Hallo123!</arg>
|
||||
<tag>PageContent</tag>
|
||||
<tag>Setter</tag>
|
||||
<doc>Clears and fills the given ``txt`` into the text field found by ``selector``.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.349885" elapsed="0.000210"/>
|
||||
</kw>
|
||||
<kw name="Fill Text" owner="Browser">
|
||||
<arg>css=input[aria-label="Database Path"]</arg>
|
||||
<arg>bangui.db</arg>
|
||||
<tag>PageContent</tag>
|
||||
<tag>Setter</tag>
|
||||
<doc>Clears and fills the given ``txt`` into the text field found by ``selector``.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.350176" elapsed="0.000203"/>
|
||||
</kw>
|
||||
<kw name="Fill Text" owner="Browser">
|
||||
<arg>css=input[aria-label="fail2ban Socket Path"]</arg>
|
||||
<arg>/var/run/fail2ban/fail2ban.sock</arg>
|
||||
<tag>PageContent</tag>
|
||||
<tag>Setter</tag>
|
||||
<doc>Clears and fills the given ``txt`` into the text field found by ``selector``.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.350456" elapsed="0.000188"/>
|
||||
</kw>
|
||||
<kw name="Fill Text" owner="Browser">
|
||||
<arg>css=input[aria-label="Session Duration (minutes)"]</arg>
|
||||
<arg>0</arg>
|
||||
<tag>PageContent</tag>
|
||||
<tag>Setter</tag>
|
||||
<doc>Clears and fills the given ``txt`` into the text field found by ``selector``.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.350721" elapsed="0.000220"/>
|
||||
</kw>
|
||||
<kw name="Click" owner="Browser">
|
||||
<arg>css=button[type="submit"]</arg>
|
||||
<tag>PageContent</tag>
|
||||
<tag>Setter</tag>
|
||||
<doc>Simulates mouse click on the element found by ``selector``.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.351058" elapsed="0.000285"/>
|
||||
</kw>
|
||||
<kw name="Wait For Elements State" owner="Browser">
|
||||
<arg>css=[aria-label="Session Duration (minutes)"]</arg>
|
||||
<arg>attached</arg>
|
||||
<arg>timeout=5s</arg>
|
||||
<tag>PageContent</tag>
|
||||
<tag>Wait</tag>
|
||||
<doc>Waits for the element found by ``selector`` to satisfy state option.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.351500" elapsed="0.000444"/>
|
||||
</kw>
|
||||
<kw name="Get Text" owner="Browser">
|
||||
<var>${msg}</var>
|
||||
<arg>css=[aria-label="Session Duration (minutes)"]/ancestor::*[contains(@class,"field")]//*[contains(@class,"validationMessage")]</arg>
|
||||
<tag>Assertion</tag>
|
||||
<tag>Getter</tag>
|
||||
<tag>PageContent</tag>
|
||||
<doc>Returns text attribute of the element found by ``selector``.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.352094" elapsed="0.000361"/>
|
||||
</kw>
|
||||
<kw name="Should Be Equal As Strings" owner="BuiltIn">
|
||||
<arg>${msg}</arg>
|
||||
<arg>Session duration must be at least 1 minute.</arg>
|
||||
<doc>Fails if objects are unequal after converting them to strings.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.352673" elapsed="0.000368"/>
|
||||
</kw>
|
||||
<kw name="Close Browser" owner="Browser">
|
||||
<tag>BrowserControl</tag>
|
||||
<tag>Setter</tag>
|
||||
<doc>Closes the current browser.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.353712" elapsed="0.000212"/>
|
||||
</kw>
|
||||
<doc>Session duration below 1 minute triggers a validation error.</doc>
|
||||
<status status="PASS" start="2026-05-05T19:08:16.346640" elapsed="0.007391"/>
|
||||
</test>
|
||||
<test id="s1-t6" name="Incomplete Password Shows Complexity Error" line="120">
|
||||
<kw name="New Browser" owner="Browser">
|
||||
<arg>chromium</arg>
|
||||
<arg>headless=${TRUE}</arg>
|
||||
<tag>BrowserControl</tag>
|
||||
<tag>Setter</tag>
|
||||
<doc>Create a new playwright Browser with specified options.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.360793" elapsed="0.000405"/>
|
||||
</kw>
|
||||
<kw name="Go To" owner="Browser">
|
||||
<arg>${FRONTEND_URL}/setup</arg>
|
||||
<tag>BrowserControl</tag>
|
||||
<tag>Setter</tag>
|
||||
<doc>Navigates to the given ``url``.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.361312" elapsed="0.000247"/>
|
||||
</kw>
|
||||
<kw name="Wait For Elements State" owner="Browser">
|
||||
<arg>css=input[aria-label="Master Password"]</arg>
|
||||
<arg>visible</arg>
|
||||
<arg>timeout=15s</arg>
|
||||
<tag>PageContent</tag>
|
||||
<tag>Wait</tag>
|
||||
<doc>Waits for the element found by ``selector`` to satisfy state option.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.361647" elapsed="0.000293"/>
|
||||
</kw>
|
||||
<kw name="Fill Text" owner="Browser">
|
||||
<arg>css=input[aria-label="Master Password"]</arg>
|
||||
<arg>short</arg>
|
||||
<tag>PageContent</tag>
|
||||
<tag>Setter</tag>
|
||||
<doc>Clears and fills the given ``txt`` into the text field found by ``selector``.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.362037" elapsed="0.000214"/>
|
||||
</kw>
|
||||
<kw name="Click" owner="Browser">
|
||||
<arg>css=button[type="submit"]</arg>
|
||||
<tag>PageContent</tag>
|
||||
<tag>Setter</tag>
|
||||
<doc>Simulates mouse click on the element found by ``selector``.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.362335" elapsed="0.000206"/>
|
||||
</kw>
|
||||
<kw name="Wait For Elements State" owner="Browser">
|
||||
<arg>css=[aria-label="Master Password"]</arg>
|
||||
<arg>attached</arg>
|
||||
<arg>timeout=5s</arg>
|
||||
<tag>PageContent</tag>
|
||||
<tag>Wait</tag>
|
||||
<doc>Waits for the element found by ``selector`` to satisfy state option.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.362621" elapsed="0.000231"/>
|
||||
</kw>
|
||||
<kw name="Get Text" owner="Browser">
|
||||
<var>${msg}</var>
|
||||
<arg>css=[aria-label="Master Password"]/ancestor::*[contains(@class,"field")]//*[contains(@class,"validationMessage")]</arg>
|
||||
<tag>Assertion</tag>
|
||||
<tag>Getter</tag>
|
||||
<tag>PageContent</tag>
|
||||
<doc>Returns text attribute of the element found by ``selector``.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.362936" elapsed="0.000208"/>
|
||||
</kw>
|
||||
<kw name="Should Contain" owner="BuiltIn">
|
||||
<arg>${msg}</arg>
|
||||
<arg>Password must meet all complexity requirements.</arg>
|
||||
<doc>Fails if ``container`` does not contain ``item`` one or more times.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.363225" elapsed="0.000152"/>
|
||||
</kw>
|
||||
<kw name="Close Browser" owner="Browser">
|
||||
<tag>BrowserControl</tag>
|
||||
<tag>Setter</tag>
|
||||
<doc>Closes the current browser.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.363456" elapsed="0.000181"/>
|
||||
</kw>
|
||||
<doc>Submitting a password that meets length but not all rules shows complexity error.</doc>
|
||||
<status status="PASS" start="2026-05-05T19:08:16.359200" elapsed="0.004521"/>
|
||||
</test>
|
||||
<test id="s1-t7" name="Setup Completes Successfully And Redirects To Login" line="135">
|
||||
<kw name="New Browser" owner="Browser">
|
||||
<arg>chromium</arg>
|
||||
<arg>headless=${TRUE}</arg>
|
||||
<tag>BrowserControl</tag>
|
||||
<tag>Setter</tag>
|
||||
<doc>Create a new playwright Browser with specified options.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.369227" elapsed="0.001419"/>
|
||||
</kw>
|
||||
<kw name="GET" owner="RequestsLibrary">
|
||||
<var>${status_resp}</var>
|
||||
<arg>${BACKEND_URL}/api/setup/status</arg>
|
||||
<doc>Sends a GET request.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.370784" elapsed="0.000122"/>
|
||||
</kw>
|
||||
<kw name="Set Variable" owner="BuiltIn">
|
||||
<var>${status_body}</var>
|
||||
<arg>${status_resp.json()}</arg>
|
||||
<doc>Returns the given values which can then be assigned to a variables.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.370993" elapsed="0.000163"/>
|
||||
</kw>
|
||||
<kw name="Log" owner="BuiltIn">
|
||||
<arg>Setup complete: ${status_body}[setup_complete]</arg>
|
||||
<doc>Logs the given message with the given level.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.371232" elapsed="0.000136"/>
|
||||
</kw>
|
||||
<kw name="Go To" owner="Browser">
|
||||
<arg>${FRONTEND_URL}/setup</arg>
|
||||
<tag>BrowserControl</tag>
|
||||
<tag>Setter</tag>
|
||||
<doc>Navigates to the given ``url``.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.371442" elapsed="0.000212"/>
|
||||
</kw>
|
||||
<kw name="Wait For Elements State" owner="Browser">
|
||||
<arg>css=input[aria-label="Master Password"]</arg>
|
||||
<arg>visible</arg>
|
||||
<arg>timeout=15s</arg>
|
||||
<tag>PageContent</tag>
|
||||
<tag>Wait</tag>
|
||||
<doc>Waits for the element found by ``selector`` to satisfy state option.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.371742" elapsed="0.000370"/>
|
||||
</kw>
|
||||
<kw name="Fill Text" owner="Browser">
|
||||
<arg>css=input[aria-label="Master Password"]</arg>
|
||||
<arg>Hallo123!</arg>
|
||||
<tag>PageContent</tag>
|
||||
<tag>Setter</tag>
|
||||
<doc>Clears and fills the given ``txt`` into the text field found by ``selector``.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.372275" elapsed="0.000292"/>
|
||||
</kw>
|
||||
<kw name="Fill Text" owner="Browser">
|
||||
<arg>css=input[aria-label="Confirm Password"]</arg>
|
||||
<arg>Hallo123!</arg>
|
||||
<tag>PageContent</tag>
|
||||
<tag>Setter</tag>
|
||||
<doc>Clears and fills the given ``txt`` into the text field found by ``selector``.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.372675" elapsed="0.000251"/>
|
||||
</kw>
|
||||
<kw name="Fill Text" owner="Browser">
|
||||
<arg>css=input[aria-label="Database Path"]</arg>
|
||||
<arg>bangui.db</arg>
|
||||
<tag>PageContent</tag>
|
||||
<tag>Setter</tag>
|
||||
<doc>Clears and fills the given ``txt`` into the text field found by ``selector``.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.373035" elapsed="0.000310"/>
|
||||
</kw>
|
||||
<kw name="Fill Text" owner="Browser">
|
||||
<arg>css=input[aria-label="fail2ban Socket Path"]</arg>
|
||||
<arg>/var/run/fail2ban/fail2ban.sock</arg>
|
||||
<tag>PageContent</tag>
|
||||
<tag>Setter</tag>
|
||||
<doc>Clears and fills the given ``txt`` into the text field found by ``selector``.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.373452" elapsed="0.000267"/>
|
||||
</kw>
|
||||
<kw name="Fill Text" owner="Browser">
|
||||
<arg>css=input[aria-label="Timezone"]</arg>
|
||||
<arg>UTC</arg>
|
||||
<tag>PageContent</tag>
|
||||
<tag>Setter</tag>
|
||||
<doc>Clears and fills the given ``txt`` into the text field found by ``selector``.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.373819" elapsed="0.000256"/>
|
||||
</kw>
|
||||
<kw name="Fill Text" owner="Browser">
|
||||
<arg>css=input[aria-label="Session Duration (minutes)"]</arg>
|
||||
<arg>60</arg>
|
||||
<tag>PageContent</tag>
|
||||
<tag>Setter</tag>
|
||||
<doc>Clears and fills the given ``txt`` into the text field found by ``selector``.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.374156" elapsed="0.000240"/>
|
||||
</kw>
|
||||
<kw name="Click" owner="Browser">
|
||||
<arg>css=button[type="submit"]</arg>
|
||||
<tag>PageContent</tag>
|
||||
<tag>Setter</tag>
|
||||
<doc>Simulates mouse click on the element found by ``selector``.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.374478" elapsed="0.000261"/>
|
||||
</kw>
|
||||
<kw name="Get Url" owner="Browser">
|
||||
<var>${current_url}</var>
|
||||
<tag>Assertion</tag>
|
||||
<tag>Getter</tag>
|
||||
<tag>PageContent</tag>
|
||||
<doc>Returns the current URL.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.374846" elapsed="0.000234"/>
|
||||
</kw>
|
||||
<if>
|
||||
<branch type="IF" condition=""login" not in """${current_url}"""">
|
||||
<kw name="Evaluate" owner="BuiltIn">
|
||||
<var>${deadline}</var>
|
||||
<arg>time.time() + 15</arg>
|
||||
<doc>Evaluates the given expression in Python and returns the result.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.375390" elapsed="0.000221"/>
|
||||
</kw>
|
||||
<while condition="True">
|
||||
<iter>
|
||||
<kw name="Evaluate" owner="BuiltIn">
|
||||
<var>${now}</var>
|
||||
<arg>time.time()</arg>
|
||||
<doc>Evaluates the given expression in Python and returns the result.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.376079" elapsed="0.000214"/>
|
||||
</kw>
|
||||
<if>
|
||||
<branch type="IF" condition="${now} >= ${deadline}">
|
||||
<break>
|
||||
<status status="PASS" start="2026-05-05T19:08:16.376629" elapsed="0.000061"/>
|
||||
</break>
|
||||
<status status="PASS" start="2026-05-05T19:08:16.376533" elapsed="0.000200"/>
|
||||
</branch>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.376510" elapsed="0.000263"/>
|
||||
</if>
|
||||
<kw name="Get Url" owner="Browser">
|
||||
<var>${url}</var>
|
||||
<tag>Assertion</tag>
|
||||
<tag>Getter</tag>
|
||||
<tag>PageContent</tag>
|
||||
<doc>Returns the current URL.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.376870" elapsed="0.000303"/>
|
||||
</kw>
|
||||
<if>
|
||||
<branch type="IF" condition=""login" in """${url}"""">
|
||||
<break>
|
||||
<status status="PASS" start="2026-05-05T19:08:16.377395" elapsed="0.000052"/>
|
||||
</break>
|
||||
<status status="PASS" start="2026-05-05T19:08:16.377278" elapsed="0.000200"/>
|
||||
</branch>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.377259" elapsed="0.000245"/>
|
||||
</if>
|
||||
<kw name="Sleep" owner="BuiltIn">
|
||||
<arg>0.5</arg>
|
||||
<doc>Pauses the test executed for the given time.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.377566" elapsed="0.000152"/>
|
||||
</kw>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.375690" elapsed="0.002063"/>
|
||||
</iter>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.375688" elapsed="0.002094"/>
|
||||
</while>
|
||||
<status status="PASS" start="2026-05-05T19:08:16.375221" elapsed="0.002591"/>
|
||||
</branch>
|
||||
<status status="PASS" start="2026-05-05T19:08:16.375197" elapsed="0.002640"/>
|
||||
</if>
|
||||
<kw name="GET" owner="RequestsLibrary">
|
||||
<var>${new_status_resp}</var>
|
||||
<arg>${BACKEND_URL}/api/setup/status</arg>
|
||||
<doc>Sends a GET request.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.377900" elapsed="0.000088"/>
|
||||
</kw>
|
||||
<kw name="Set Variable" owner="BuiltIn">
|
||||
<var>${new_status_body}</var>
|
||||
<arg>${new_status_resp.json()}</arg>
|
||||
<doc>Returns the given values which can then be assigned to a variables.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.378068" elapsed="0.000134"/>
|
||||
</kw>
|
||||
<kw name="Should Be True" owner="BuiltIn">
|
||||
<arg>${new_status_body}[setup_complete]</arg>
|
||||
<doc>Fails if the given condition is not true.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.378279" elapsed="0.000129"/>
|
||||
</kw>
|
||||
<kw name="Close Browser" owner="Browser">
|
||||
<tag>BrowserControl</tag>
|
||||
<tag>Setter</tag>
|
||||
<doc>Closes the current browser.</doc>
|
||||
<status status="NOT RUN" start="2026-05-05T19:08:16.378480" elapsed="0.000189"/>
|
||||
</kw>
|
||||
<doc>Filling all fields and submitting completes setup and navigates to /login.</doc>
|
||||
<status status="PASS" start="2026-05-05T19:08:16.366735" elapsed="0.012022"/>
|
||||
</test>
|
||||
<status status="PASS" start="2026-05-05T19:08:15.508608" elapsed="0.873487"/>
|
||||
</suite>
|
||||
<statistics>
|
||||
<total>
|
||||
<stat pass="7" fail="0" skip="0">All Tests</stat>
|
||||
</total>
|
||||
<tag>
|
||||
</tag>
|
||||
<suite>
|
||||
<stat name="05 Setup" id="s1" pass="7" fail="0" skip="0">05 Setup</stat>
|
||||
</suite>
|
||||
</statistics>
|
||||
<errors>
|
||||
<msg time="2026-05-05T19:08:15.732927" level="ERROR">Error in file '/home/lukas/Volume/repo/BanGUI/e2e/resources/common.resource' on line 5: Processing variable file '/home/lukas/Volume/repo/BanGUI/e2e/resources/../../.env' failed: Importing variable file '/home/lukas/Volume/repo/BanGUI/e2e/resources/../../.env' failed: Module name cannot contain dots when importing by path.</msg>
|
||||
</errors>
|
||||
</robot>
|
||||
Reference in New Issue
Block a user