fixed tests

This commit is contained in:
2026-05-15 20:41:05 +02:00
parent 96ce516ecf
commit 77df5d5d65
50 changed files with 1482 additions and 5089 deletions

View File

@@ -14,6 +14,7 @@ from __future__ import annotations
from pathlib import Path
import aiosqlite
from app.utils.logging_compat import get_logger
log = get_logger(__name__)
@@ -246,7 +247,6 @@ CREATE INDEX IF NOT EXISTS idx_import_log_source_id_desc
}
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
@@ -254,6 +254,7 @@ CREATE INDEX IF NOT EXISTS idx_import_log_source_id_desc
async def _configure_connection(db: aiosqlite.Connection) -> None:
"""Apply hardening pragmas to a newly-opened SQLite connection."""
await db.execute("PRAGMA journal_mode=WAL;")
await db.execute("PRAGMA foreign_keys=ON;")
await db.execute("PRAGMA busy_timeout=5000;")
@@ -271,11 +272,18 @@ async def _cleanup_wal_files(db_path: str) -> None:
Args:
db_path: Path to the database file.
"""
import time
wal_path = Path(db_path + "-wal")
shm_path = Path(db_path + "-shm")
for path in (wal_path, shm_path):
if path.exists():
# Skip files that were modified recently — they likely belong to an
# active connection. Only remove stale files left by crashes.
mtime = path.stat().st_mtime
if time.time() - mtime < 10:
continue
try:
path.unlink()
log.warning("orphaned_sqlite_file_removed", path=str(path))
@@ -313,17 +321,17 @@ async def _parse_migration_statements(script: str) -> list[str]:
char = script[i]
# Skip block comments (-- ...)
if i < len(script) - 1 and script[i:i+2] == "--":
if i < len(script) - 1 and script[i : i + 2] == "--":
while i < len(script) and script[i] != "\n":
i += 1
i += 1
continue
# Skip line comments (/* ... */)
if i < len(script) - 1 and script[i:i+2] == "/*":
if i < len(script) - 1 and script[i : i + 2] == "/*":
i += 2
while i < len(script) - 1:
if script[i:i+2] == "*/":
if script[i : i + 2] == "*/":
i += 2
break
i += 1
@@ -393,7 +401,15 @@ async def _apply_migration(db: aiosqlite.Connection, version: int) -> None:
await db.execute("BEGIN IMMEDIATE;")
for statement in statements:
await db.execute(statement)
try:
await db.execute(statement)
except aiosqlite.OperationalError as exc:
# Ignore duplicate column / table errors so migrations remain
# idempotent when a legacy database already has the object.
msg = str(exc).lower()
if "duplicate column name" in msg or "table" in msg and "already exists" in msg:
continue
raise
await db.execute("INSERT INTO schema_migrations (version) VALUES (?);", (version,))
@@ -411,8 +427,7 @@ async def _migrate_schema(db: aiosqlite.Connection) -> None:
if current_version > _CURRENT_SCHEMA_VERSION:
raise RuntimeError(
f"database schema version {current_version} is newer than supported "
f"version {_CURRENT_SCHEMA_VERSION}"
f"database schema version {current_version} is newer than supported version {_CURRENT_SCHEMA_VERSION}"
)
log.info("migrating_database_schema", from_version=current_version, to_version=_CURRENT_SCHEMA_VERSION)

View File

@@ -36,7 +36,6 @@ from typing import Annotated, cast
import aiohttp
import aiosqlite
from app.utils.logging_compat import get_logger
from apscheduler.schedulers.asyncio import AsyncIOScheduler # type: ignore[import-untyped]
from fastapi import Depends, FastAPI, HTTPException, Request, status
@@ -45,22 +44,6 @@ from app.exceptions import RateLimitError
from app.models.auth import Session
from app.models.config import PendingRecovery
from app.models.server import ServerStatus
from app.repositories.protocols import (
BlocklistRepository,
Fail2BanDbRepository,
GeoCacheRepository,
HistoryArchiveRepository,
ImportLogRepository,
ImportRunRepository,
SessionRepository,
SettingsRepository,
)
from app.services.geo_cache import GeoCache
from app.services.protocols import Fail2BanMetadataService
from app.utils.constants import SESSION_COOKIE_NAME
from app.utils.rate_limiter import GlobalRateLimiter
from app.utils.runtime_state import ApplicationState, JailServiceState, RuntimeState
from app.utils.session_cache import NoOpSessionCache, SessionCache
# Module-level imports for repositories and services
# These are safe at module level since no circular dependencies exist
@@ -74,8 +57,25 @@ from app.repositories import (
session_repo,
settings_repo,
)
from app.repositories.protocols import (
BlocklistRepository,
Fail2BanDbRepository,
GeoCacheRepository,
HistoryArchiveRepository,
ImportLogRepository,
ImportRunRepository,
SessionRepository,
SettingsRepository,
)
from app.services import auth_service, health_service
from app.services.fail2ban_metadata_service import default_fail2ban_metadata_service
from app.services.geo_cache import GeoCache
from app.services.protocols import Fail2BanMetadataService
from app.utils.constants import SESSION_COOKIE_NAME
from app.utils.logging_compat import get_logger
from app.utils.rate_limiter import GlobalRateLimiter
from app.utils.runtime_state import ApplicationState, JailServiceState, RuntimeState
from app.utils.session_cache import NoOpSessionCache, SessionCache
log = get_logger(__name__)
@@ -108,6 +108,7 @@ class ApplicationContext:
#: or distributed deployments, the configured cache backend should provide
#: invalidation semantics appropriate for the deployment.
def _session_cache_enabled(settings: Settings) -> bool:
"""Return whether the session validation cache should be used."""
return settings.session_cache_enabled and settings.session_cache_ttl_seconds > 0.0
@@ -284,6 +285,7 @@ def rate_limit_dependency(
Returns:
A callable that can be used as a FastAPI Depends() dependency.
"""
async def check_rate_limit(
request: Request,
rate_limiter: GlobalRateLimiterDep,
@@ -293,9 +295,7 @@ def rate_limit_dependency(
settings: Settings = request.app.state.settings
client_ip = get_client_ip(request, trusted_proxies=settings.trusted_proxies)
is_allowed, retry_after = rate_limiter.check_allowed_for_bucket(
bucket, client_ip, max_requests, window_seconds
)
is_allowed, retry_after = rate_limiter.check_allowed_for_bucket(bucket, client_ip, max_requests, window_seconds)
if not is_allowed:
log.warning(
@@ -407,6 +407,8 @@ async def get_app(request: Request) -> FastAPI:
async def get_server_status(app_context: Annotated[ApplicationContext, Depends(get_app_context)]) -> ServerStatus:
"""Return the cached fail2ban server status snapshot from application context."""
if app_context.server_status is None:
return ServerStatus(online=False)
return app_context.server_status
@@ -654,7 +656,7 @@ async def require_auth(
if not token:
auth_header: str = request.headers.get("Authorization", "")
if auth_header.startswith("Bearer "):
token = auth_header[len("Bearer "):]
token = auth_header[len("Bearer ") :]
if not token:
raise HTTPException(

View File

@@ -72,13 +72,13 @@ from app.utils.external_logging import (
ExternalLogHandler,
create_external_log_handler,
)
from app.utils.json_formatter import JSONFormatter
from app.utils.logging_compat import get_logger
from app.utils.rate_limiter import GlobalRateLimiter
from app.utils.runtime_state import ApplicationState, RuntimeState
from app.utils.scheduler_lock import release_scheduler_lock
from app.utils.session_cache import InMemorySessionCache, NoOpSessionCache
from app.utils.setup_state import is_setup_complete_cached, set_setup_complete_cache
from app.utils.json_formatter import JSONFormatter
from app.utils.logging_compat import get_logger
log = get_logger("bangui")
@@ -125,8 +125,15 @@ def _configure_logging(log_level: str, log_file: str | None, settings: Settings
level: int = logging.getLevelName(log_level.upper())
handlers: list[logging.Handler] = [logging.StreamHandler(sys.stdout)]
if log_file:
os.makedirs(os.path.dirname(log_file), exist_ok=True)
handlers.append(logging.FileHandler(log_file))
try:
os.makedirs(os.path.dirname(log_file), exist_ok=True)
handlers.append(logging.FileHandler(log_file))
except (PermissionError, OSError) as exc:
log.warning(
"log_file_directory_not_created",
log_file=log_file,
error=str(exc),
)
# Suppress verbose third-party library logs that emit plain text
# through the standard library logging module.
@@ -163,9 +170,7 @@ def _update_session_cache(app: FastAPI, settings: Settings) -> None:
settings: The effective application settings.
"""
cache_enabled = settings.session_cache_enabled and settings.session_cache_ttl_seconds > 0.0
app.state.session_cache = (
InMemorySessionCache() if cache_enabled else NoOpSessionCache()
)
app.state.session_cache = InMemorySessionCache() if cache_enabled else NoOpSessionCache()
@asynccontextmanager
@@ -811,12 +816,12 @@ async def _request_validation_error_handler(
# the guard without being explicitly allowed.
_EXACT_ALLOWED: frozenset[str] = frozenset(
{
"/api/v1/setup", # GET/POST /api/v1/setup
"/api/v1/health", # Health check endpoint (combined)
"/api/v1/health/live", # Kubernetes liveness probe
"/api/v1/setup", # GET/POST /api/v1/setup
"/api/v1/health", # Health check endpoint (combined)
"/api/v1/health/live", # Kubernetes liveness probe
"/api/v1/health/ready", # Kubernetes readiness probe
"/api/docs", # Swagger UI
"/api/redoc", # ReDoc
"/api/docs", # Swagger UI
"/api/redoc", # ReDoc
"/api/openapi.json", # OpenAPI schema
},
)
@@ -971,9 +976,7 @@ def _enforce_single_worker() -> None:
"See Docs/Deployment.md § Single-Worker Requirement."
)
except ValueError as e:
raise RuntimeError(
f"WEB_CONCURRENCY must be an integer, got: {web_concurrency}"
) from e
raise RuntimeError(f"WEB_CONCURRENCY must be an integer, got: {web_concurrency}") from e
# Check explicit BANGUI_WORKERS override (discouraged, still enforced)
bangui_workers = os.environ.get("BANGUI_WORKERS")
@@ -990,9 +993,7 @@ def _enforce_single_worker() -> None:
"See Docs/Deployment.md § Single-Worker Requirement."
)
except ValueError as e:
raise RuntimeError(
f"BANGUI_WORKERS must be an integer, got: {bangui_workers}"
) from e
raise RuntimeError(f"BANGUI_WORKERS must be an integer, got: {bangui_workers}") from e
# ---------------------------------------------------------------------------
@@ -1165,7 +1166,6 @@ def create_app(settings: Settings | None = None) -> FastAPI:
# stack is a security-critical defect that must not slip through silently.
_assert_middleware_order(app)
# --- Exception handlers ---
#
# Exception handlers are registered from most specific to least specific. FastAPI evaluates

View File

@@ -10,13 +10,11 @@ from __future__ import annotations
from app.models.config import (
BantimeEscalation,
Fail2BanLogResponse,
FilterConfig,
FilterListResponse,
GlobalConfigResponse,
JailConfig,
JailConfigListResponse,
LogPreviewResponse,
MapColorThresholdsResponse,
RegexTestResponse,
ServiceStatusResponse,
@@ -32,7 +30,6 @@ from app.models.config_domain import (
DomainRegexTest,
DomainServiceStatus,
)
from app.utils.pagination import create_pagination_metadata
def _map_domain_bantime_escalation(domain: DomainBantimeEscalation) -> BantimeEscalation:
@@ -65,9 +62,7 @@ def map_domain_jail_config_to_response(domain: DomainJailConfig) -> JailConfig:
prefregex=domain.prefregex,
actions=domain.actions,
bantime_escalation=(
_map_domain_bantime_escalation(domain.bantime_escalation)
if domain.bantime_escalation
else None
_map_domain_bantime_escalation(domain.bantime_escalation) if domain.bantime_escalation else None
),
)
@@ -151,6 +146,6 @@ def map_domain_filter_config_to_response(domain: DomainFilterConfig) -> FilterCo
def map_domain_filter_list_to_response(domain_list: DomainFilterList) -> FilterListResponse:
"""Convert domain filter list to response model."""
return FilterListResponse(
items=[map_domain_filter_config_to_response(f) for f in domain_list.items],
filters=[map_domain_filter_config_to_response(f) for f in domain_list.items],
total=domain_list.total,
)

View File

@@ -8,15 +8,15 @@ from __future__ import annotations
from enum import StrEnum
from pydantic import AnyHttpUrl, Field
from pydantic import AnyHttpUrl, ConfigDict, Field
from app.models.response import BanGuiBaseModel, PaginatedListResponse
# ---------------------------------------------------------------------------
# Blocklist source
# ---------------------------------------------------------------------------
class BlocklistSource(BanGuiBaseModel):
"""Domain model for a blocklist source definition."""
@@ -27,6 +27,7 @@ class BlocklistSource(BanGuiBaseModel):
created_at: str
updated_at: str
class BlocklistSourceCreate(BanGuiBaseModel):
"""Payload for ``POST /api/blocklists``.
@@ -39,6 +40,7 @@ class BlocklistSourceCreate(BanGuiBaseModel):
url: AnyHttpUrl = Field(..., description="URL of the blocklist file (http/https only).")
enabled: bool = Field(default=True)
class BlocklistSourceUpdate(BanGuiBaseModel):
"""Payload for ``PUT /api/blocklists/{id}``. All fields are optional.
@@ -49,15 +51,18 @@ class BlocklistSourceUpdate(BanGuiBaseModel):
url: AnyHttpUrl | None = Field(default=None)
enabled: bool | None = Field(default=None)
class BlocklistListResponse(BanGuiBaseModel):
"""Response for ``GET /api/blocklists``."""
sources: list[BlocklistSource] = Field(default_factory=list)
# ---------------------------------------------------------------------------
# Import log
# ---------------------------------------------------------------------------
class ImportLogEntry(BanGuiBaseModel):
"""A single blocklist import run record."""
@@ -69,6 +74,7 @@ class ImportLogEntry(BanGuiBaseModel):
ips_skipped: int
errors: str | None
class ImportLogListResponse(PaginatedListResponse[ImportLogEntry]):
"""Response for ``GET /api/blocklists/log``.
@@ -83,6 +89,7 @@ class ImportLogListResponse(PaginatedListResponse[ImportLogEntry]):
# Import run tracking (for idempotency)
# ---------------------------------------------------------------------------
class ImportRunEntry(BanGuiBaseModel):
"""Tracks a unique blocklist import run by source and content hash.
@@ -100,10 +107,12 @@ class ImportRunEntry(BanGuiBaseModel):
created_at: str
updated_at: str
# ---------------------------------------------------------------------------
# Schedule
# ---------------------------------------------------------------------------
class ScheduleFrequency(StrEnum):
"""Available import schedule frequency presets."""
@@ -111,6 +120,7 @@ class ScheduleFrequency(StrEnum):
daily = "daily"
weekly = "weekly"
class ScheduleConfig(BanGuiBaseModel):
"""Import schedule configuration.
@@ -121,8 +131,10 @@ class ScheduleConfig(BanGuiBaseModel):
- ``weekly``: additionally uses ``day_of_week`` (0=Monday … 6=Sunday).
"""
# No strict=True here: FastAPI and json.loads() both supply enum values as
# plain strings; strict mode would reject string→enum coercion.
# FastAPI and json.loads() both supply enum values as plain strings;
# strict mode would reject string→enum coercion, so we override the
# base model_config for this model only.
model_config = ConfigDict(strict=False)
frequency: ScheduleFrequency = ScheduleFrequency.daily
interval_hours: int = Field(default=24, ge=1, le=168, description="Used when frequency=hourly")
@@ -135,6 +147,7 @@ class ScheduleConfig(BanGuiBaseModel):
description="Day of week for weekly runs (0=Monday … 6=Sunday)",
)
class ScheduleInfo(BanGuiBaseModel):
"""Current schedule configuration together with runtime metadata."""
@@ -144,10 +157,12 @@ class ScheduleInfo(BanGuiBaseModel):
last_run_errors: bool | None = None
"""``True`` if the most recent import had errors, ``False`` if clean, ``None`` if never run."""
# ---------------------------------------------------------------------------
# Import results
# ---------------------------------------------------------------------------
class ImportSourceResult(BanGuiBaseModel):
"""Result of importing a single blocklist source."""
@@ -157,6 +172,7 @@ class ImportSourceResult(BanGuiBaseModel):
ips_skipped: int
error: str | None
class ImportRunResult(BanGuiBaseModel):
"""Aggregated result from a full import run across all enabled sources."""
@@ -165,10 +181,12 @@ class ImportRunResult(BanGuiBaseModel):
total_skipped: int
errors_count: int
# ---------------------------------------------------------------------------
# Preview
# ---------------------------------------------------------------------------
class PreviewResponse(BanGuiBaseModel):
"""Response for ``GET /api/blocklists/{id}/preview``."""

View File

@@ -188,7 +188,6 @@ class PaginationMetadata(BanGuiBaseModel):
)
class PaginatedListResponse(BanGuiBaseModel, Generic[T]):
"""Standardized paginated list response.
@@ -384,6 +383,8 @@ class ErrorMetadata(TypedDict, total=False):
current_status: str
actual_length: int
message: str
field_errors: int
first_field: str
class ComponentHealth(BanGuiBaseModel):

View File

@@ -37,7 +37,6 @@ from app.services import (
filter_config_service,
jail_config_service,
)
from app.utils.path_utils import validate_log_path
from app.utils.constants import (
RATE_LIMIT_JAIL_ACTIVATE_REQUESTS,
RATE_LIMIT_JAIL_CREATE_REQUESTS,
@@ -45,6 +44,7 @@ from app.utils.constants import (
RATE_LIMIT_JAIL_DELETE_REQUESTS,
RATE_LIMIT_JAIL_UPDATE_REQUESTS,
)
from app.utils.path_utils import validate_log_path
from app.utils.runtime_state import (
clear_activation_record,
clear_pending_recovery,
@@ -207,7 +207,8 @@ def _check_jail_deactivate_rate_limit(
)
_NamePath = Annotated[str, Path(description='Jail name as configured in fail2ban.')]
_NamePath = Annotated[str, Path(description="Jail name as configured in fail2ban.")]
@router.get(
"",
@@ -240,8 +241,6 @@ async def get_jail_configs(
return config_mappers.map_domain_jail_config_list_to_response(domain_result)
@router.get(
"/inactive",
response_model=InactiveJailListResponse,
@@ -335,9 +334,8 @@ async def get_jail_config(
HTTPException: 502 when fail2ban is unreachable.
"""
domain_result = await config_service.get_jail_config(socket_path, name)
return config_mappers.map_domain_jail_config_to_response(domain_result)
mapped = config_mappers.map_domain_jail_config_to_response(domain_result)
return JailConfigResponse(jail=mapped)
@router.put(
@@ -387,8 +385,6 @@ async def update_jail_config(
# ---------------------------------------------------------------------------
@router.post(
"/{name}/logpath",
status_code=status.HTTP_204_NO_CONTENT,
@@ -430,8 +426,6 @@ async def add_log_path(
await config_service.add_log_path(socket_path, name, body)
@router.delete(
"/{name}/logpath",
status_code=status.HTTP_204_NO_CONTENT,
@@ -479,8 +473,6 @@ async def delete_log_path(
await config_service.delete_log_path(socket_path, name, log_path)
@router.post(
"/{name}/activate",
response_model=JailActivationResponse,
@@ -532,9 +524,7 @@ async def activate_jail(
"""
req = body if body is not None else ActivateJailRequest()
result = await jail_config_service.activate_jail(
config_dir, socket_path, name, req, health_probe=health_probe
)
result = await jail_config_service.activate_jail(config_dir, socket_path, name, req, health_probe=health_probe)
if result.active:
record_activation(app, name)
@@ -542,8 +532,6 @@ async def activate_jail(
return result
@router.post(
"/{name}/deactivate",
response_model=JailActivationResponse,
@@ -588,14 +576,10 @@ async def deactivate_jail(
HTTPException: 502 if fail2ban is unreachable.
"""
result = await jail_config_service.deactivate_jail(
config_dir, socket_path, name, health_probe=health_probe
)
result = await jail_config_service.deactivate_jail(config_dir, socket_path, name, health_probe=health_probe)
return result
@router.delete(
"/{name}/local",
status_code=status.HTTP_204_NO_CONTENT,
@@ -645,8 +629,6 @@ async def delete_jail_local_override(
# ---------------------------------------------------------------------------
@router.post(
"/{name}/validate",
response_model=JailValidationResult,
@@ -868,10 +850,8 @@ async def remove_action_from_jail(
action_name,
do_reload=reload,
)
# ---------------------------------------------------------------------------
# Filter discovery endpoints (Task 2.1)
# ---------------------------------------------------------------------------

View File

@@ -15,11 +15,11 @@ under the key ``"blocklist_schedule"``.
from __future__ import annotations
import json
from datetime import UTC
from typing import TYPE_CHECKING
import aiohttp
import aiosqlite
from app.utils.logging_compat import get_logger
from app.exceptions import BlocklistSourceHasLogsError
from app.models.blocklist import (
@@ -37,6 +37,7 @@ from app.repositories import blocklist_repo, import_log_repo, settings_repo
from app.services.blocklist_downloader import BlocklistDownloader
from app.services.blocklist_import_workflow import BlocklistImportWorkflow
from app.services.blocklist_parser import BlocklistParser
from app.utils.logging_compat import get_logger
from app.utils.pagination import create_pagination_metadata
if TYPE_CHECKING:
@@ -200,9 +201,7 @@ async def update_source(
await validate_blocklist_url(url)
updated = await blocklist_repo.update_source(
db, source_id, name=name, url=url, enabled=enabled
)
updated = await blocklist_repo.update_source(db, source_id, name=name, url=url, enabled=enabled)
if not updated:
return None
source = await get_source(db, source_id)
@@ -473,8 +472,7 @@ async def get_schedule(db: aiosqlite.Connection) -> ScheduleConfig:
if raw is None:
return _DEFAULT_SCHEDULE
try:
data = json.loads(raw)
return ScheduleConfig.model_validate(data)
return ScheduleConfig.model_validate_json(raw)
except (json.JSONDecodeError, ValueError) as exc:
log.warning("blocklist_schedule_invalid", raw=raw, error=type(exc).__name__)
return _DEFAULT_SCHEDULE
@@ -493,9 +491,7 @@ async def set_schedule(
Returns:
The saved configuration (same object after validation).
"""
await settings_repo.set_setting(
db, _SCHEDULE_SETTINGS_KEY, config.model_dump_json()
)
await settings_repo.set_setting(db, _SCHEDULE_SETTINGS_KEY, config.model_dump_json())
log.info("blocklist_schedule_updated", frequency=config.frequency, hour=config.hour)
return config
@@ -517,8 +513,12 @@ async def get_schedule_info(
"""
config = await get_schedule(db)
last_log = await import_log_repo.get_last_log(db)
last_run_at = last_log["timestamp"] if last_log else None
last_run_errors: bool | None = (last_log["errors"] is not None) if last_log else None
last_run_at = None
if last_log is not None:
from datetime import datetime
last_run_at = datetime.fromtimestamp(last_log.timestamp, tz=UTC).isoformat()
last_run_errors: bool | None = (last_log.errors is not None) if last_log else None
return ScheduleInfo(
config=config,
next_run_at=next_run_at,
@@ -574,9 +574,7 @@ async def list_import_logs(
Returns:
:class:`~app.models.blocklist.ImportLogListResponse`.
"""
items, total = await import_log_repo.list_logs(
db, source_id=source_id, page=page, page_size=page_size
)
items, total = await import_log_repo.list_logs(db, source_id=source_id, page=page, page_size=page_size)
return ImportLogListResponse(
items=[ImportLogEntry.model_validate(i) for i in items],

View File

@@ -13,8 +13,6 @@ import re
import tempfile
from pathlib import Path
from app.utils.logging_compat import get_logger
from app.exceptions import (
ConfigWriteError,
FilterAlreadyExistsError,
@@ -27,6 +25,7 @@ from app.exceptions import (
)
from app.models.config import (
AssignFilterRequest,
FilterConfig,
FilterConfigUpdate,
FilterCreateRequest,
FilterUpdateRequest,
@@ -46,6 +45,7 @@ from app.utils.config_file_utils import (
set_jail_local_key_sync,
)
from app.utils.jail_socket import reload_all
from app.utils.logging_compat import get_logger
from app.utils.regex_validator import RegexTimeoutError, validate_regex_pattern
log = get_logger(__name__)
@@ -54,6 +54,7 @@ log = get_logger(__name__)
# Internal wrappers for shared config helpers.
# ---------------------------------------------------------------------------
def _parse_jails_sync(config_dir: Path) -> tuple[dict[str, dict[str, str]], Path]:
return _config_file_parse_jails_sync(config_dir)
@@ -85,6 +86,7 @@ def _resolve_filter(raw_filter: str, jail_name: str, mode: str) -> str:
result = result.replace("%(mode)s", mode)
return result
# ---------------------------------------------------------------------------
# Internal helpers imported from the shared config helper module.
# ---------------------------------------------------------------------------
@@ -366,7 +368,7 @@ async def list_filters(
)
log.info("filters_listed", total=len(filters), active=sum(1 for f in filters if f.active))
return DomainFilterList(filters=filters, total=len(filters))
return DomainFilterList(items=filters, total=len(filters))
async def get_filter(
@@ -428,7 +430,7 @@ async def get_filter(
else:
raise FilterNotFoundError(base_name)
content, has_local, source_path = await run_blocking( _read)
content, has_local, source_path = await run_blocking(_read)
cfg = conffile_parser.parse_filter_file(content, name=base_name, filename=f"{base_name}.conf")
@@ -524,7 +526,7 @@ async def update_filter(
content = conffile_parser.serialize_filter_config(merged)
filter_d = Path(config_dir) / "filter.d"
await run_blocking( _write_filter_local_sync, filter_d, base_name, content)
await run_blocking(_write_filter_local_sync, filter_d, base_name, content)
if do_reload:
try:
@@ -580,7 +582,7 @@ async def create_filter(
if conf_path.is_file() or local_path.is_file():
raise FilterAlreadyExistsError(req.name)
await run_blocking( _check_not_exists)
await run_blocking(_check_not_exists)
# Validate regex patterns.
patterns: list[str] = list(req.failregex) + list(req.ignoreregex)
@@ -598,7 +600,7 @@ async def create_filter(
)
content = conffile_parser.serialize_filter_config(cfg)
await run_blocking( _write_filter_local_sync, filter_d, req.name, content)
await run_blocking(_write_filter_local_sync, filter_d, req.name, content)
if do_reload:
try:
@@ -663,7 +665,7 @@ async def delete_filter(
log.info("filter_local_deleted", filter=base_name, path=str(local_path))
await run_blocking( _delete)
await run_blocking(_delete)
async def assign_filter_to_jail(
@@ -713,9 +715,10 @@ async def assign_filter_to_jail(
if not conf_exists and not local_exists:
raise FilterNotFoundError(req.filter_name)
await run_blocking( _check_filter)
await run_blocking(_check_filter)
await run_blocking(set_jail_local_key_sync,
await run_blocking(
set_jail_local_key_sync,
Path(config_dir),
jail_name,
"filter",

View File

@@ -21,10 +21,10 @@ import time
from typing import TYPE_CHECKING
import aiohttp
from app.utils.logging_compat import get_logger
from app.models.geo import GeoInfo
from app.repositories import geo_cache_repo
from app.utils.logging_compat import get_logger
if TYPE_CHECKING:
import collections.abc
@@ -40,14 +40,10 @@ log = get_logger(__name__)
# ---------------------------------------------------------------------------
#: ip-api.com single-IP lookup endpoint (HTTP only on the free tier).
_API_URL: str = (
"http://ip-api.com/json/{ip}?fields=status,message,country,countryCode,org,as"
)
_API_URL: str = "http://ip-api.com/json/{ip}?fields=status,message,country,countryCode,org,as"
#: ip-api.com batch endpoint — accepts up to 100 IPs per POST.
_BATCH_API_URL: str = (
"http://ip-api.com/batch?fields=status,message,country,countryCode,org,as,query"
)
_BATCH_API_URL: str = "http://ip-api.com/batch?fields=status,message,country,countryCode,org,as,query"
#: Maximum IPs per batch request (ip-api.com hard limit is 100).
_BATCH_SIZE: int = 100
@@ -217,9 +213,7 @@ class GeoCache:
await self.clear_neg_cache()
geo_map = await self.lookup_batch(unresolved, http_session, db=db)
resolved_count = sum(
1 for info in geo_map.values() if info.country_code is not None
)
resolved_count = sum(1 for info in geo_map.values() if info.country_code is not None)
log.info(
"geo_re_resolve_complete",
@@ -398,7 +392,7 @@ class GeoCache:
asn=result.asn,
org=result.org,
)
except (OSError) as exc:
except OSError as exc:
log.warning("geo_persist_failed", ip=ip, error=type(exc).__name__)
log.debug("geo_lookup_success_mmdb", ip=ip, country=result.country_code)
return result
@@ -412,7 +406,7 @@ class GeoCache:
if db is not None:
try:
await geo_cache_repo.upsert_neg_entry_and_commit(db=db, ip=ip)
except (OSError) as exc:
except OSError as exc:
log.warning("geo_persist_neg_failed", ip=ip, error=type(exc).__name__)
return GeoInfo(country_code=None, country_name=None, asn=None, org=None)
@@ -439,7 +433,7 @@ class GeoCache:
asn=result.asn,
org=result.org,
)
except (OSError) as exc:
except OSError as exc:
log.warning("geo_persist_failed", ip=ip, error=type(exc).__name__)
log.debug("geo_lookup_success_http", ip=ip, country=result.country_code, asn=result.asn)
return result
@@ -448,7 +442,7 @@ class GeoCache:
ip=ip,
message=data.get("message", "unknown"),
)
except (TimeoutError, aiohttp.ClientError, ValueError) as exc:
except (TimeoutError, aiohttp.ClientError, ValueError, OSError) as exc:
log.warning(
"geo_lookup_http_request_failed",
ip=ip,
@@ -585,7 +579,7 @@ class GeoCache:
if db is not None and pos_rows:
try:
await geo_cache_repo.bulk_upsert_entries_and_commit(db, pos_rows)
except (OSError) as exc:
except OSError as exc:
log.warning(
"geo_batch_persist_mmdb_failed",
count=len(pos_rows),
@@ -604,7 +598,7 @@ class GeoCache:
if db is not None and neg_ips:
try:
await geo_cache_repo.bulk_upsert_neg_entries_and_commit(db, neg_ips)
except (OSError) as exc:
except OSError as exc:
log.warning(
"geo_batch_persist_neg_failed",
count=len(neg_ips),
@@ -637,9 +631,7 @@ class GeoCache:
# If every IP in the chunk came back with country_code=None and the
# batch wasn't tiny, that almost certainly means the whole request
# was rejected (connection reset / 429). Retry after a back-off.
all_failed = all(
info.country_code is None for info in chunk_result.values()
)
all_failed = all(info.country_code is None for info in chunk_result.values())
if not all_failed or attempt >= _BATCH_MAX_RETRIES:
break
backoff = _BATCH_DELAY * (2 ** (attempt + 1))
@@ -659,9 +651,7 @@ class GeoCache:
await self._store(ip, info)
geo_result[ip] = info
if db is not None:
pos_rows.append(
(ip, info.country_code, info.country_name, info.asn, info.org)
)
pos_rows.append((ip, info.country_code, info.country_name, info.asn, info.org))
else:
# HTTP failed — record as negative cache.
async with self._cache_lock:
@@ -677,7 +667,7 @@ class GeoCache:
pos_rows,
neg_ips,
)
except (OSError) as exc:
except OSError as exc:
log.warning(
"geo_batch_persist_failed",
positive_count=len(pos_rows),
@@ -724,7 +714,7 @@ class GeoCache:
log.warning("geo_batch_non_200", status=resp.status, count=len(ips))
return fallback
data: list[dict[str, object]] = await resp.json(content_type=None)
except (TimeoutError, aiohttp.ClientError, ValueError) as exc:
except (TimeoutError, aiohttp.ClientError, ValueError, OSError) as exc:
log.warning(
"geo_batch_request_failed",
count=len(ips),
@@ -836,7 +826,7 @@ class GeoCache:
try:
await geo_cache_repo.bulk_upsert_entries_and_commit(db, rows)
except (OSError) as exc:
except OSError as exc:
log.warning("geo_flush_dirty_failed", error=type(exc).__name__)
# Re-add to dirty so they are retried on the next flush cycle.
self._dirty.update(to_flush)

View File

@@ -61,17 +61,20 @@ def normalise_ip(address: str) -> str:
IPv4-mapped IPv6 addresses (e.g. ``::ffff:192.168.1.1``) are converted
to their IPv4 equivalent (``192.168.1.1``).
Plain IPv4 addresses are returned unchanged.
Non-IP strings (e.g. ``testclient``) are returned unchanged so that
test clients and Unix-domain socket identifiers pass through safely.
Args:
address: A valid IP address string.
address: An IP address string or other identifier.
Returns:
Normalised IP address string.
Raises:
ValueError: If *address* is not a valid IP address.
Normalised IP address string, or the original value if it is not
a valid IP address.
"""
ip = ipaddress.ip_address(address)
try:
ip = ipaddress.ip_address(address)
except ValueError:
return address
if isinstance(ip, ipaddress.IPv6Address) and ip.ipv4_mapped:
return str(ip.ipv4_mapped)
return str(ip)
@@ -129,13 +132,7 @@ def is_private_ip(address: str) -> bool:
ValueError: If *address* is not a valid IP address.
"""
ip = ipaddress.ip_address(address)
return (
ip.is_private
or ip.is_loopback
or ip.is_link_local
or ip.is_multicast
or ip.is_reserved
)
return ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_multicast or ip.is_reserved
async def validate_blocklist_url(url: str) -> None:
@@ -165,9 +162,7 @@ async def validate_blocklist_url(url: str) -> None:
raise ValueError(f"Invalid URL format: {exc}") from exc
if parsed.scheme not in ("http", "https"):
raise ValueError(
f"Invalid scheme '{parsed.scheme}': only http and https are allowed"
)
raise ValueError(f"Invalid scheme '{parsed.scheme}': only http and https are allowed")
if not parsed.hostname:
raise ValueError("URL has no hostname")
@@ -201,14 +196,9 @@ async def validate_blocklist_url(url: str) -> None:
# connection time, and host mode is never used in production.
if is_private_ip(ip_str):
import os
if (
os.getenv("BANGUI_LOG_LEVEL") == "debug"
and ipaddress.ip_address(ip_str).is_loopback
):
if os.getenv("BANGUI_LOG_LEVEL") == "debug" and ipaddress.ip_address(ip_str).is_loopback:
continue
raise ValueError(
f"Hostname '{hostname}' resolves to private/reserved IP: {ip_str}"
)
raise ValueError(f"Hostname '{hostname}' resolves to private/reserved IP: {ip_str}")
except ipaddress.AddressValueError as exc:
raise ValueError(f"Invalid IP address: {ip_str}") from exc

View File

@@ -26,6 +26,19 @@ class _CompatLogger:
if v is not None:
stdlib_kwargs[k] = v
if kwargs:
# Several keys are reserved in LogRecord; rename them to avoid KeyError.
reserved_renames = {
"message": "log_message",
"name": "log_name",
"filename": "log_filename",
"funcName": "log_funcName",
"lineno": "log_lineno",
"module": "log_module",
"pathname": "log_pathname",
}
for old_key, new_key in reserved_renames.items():
if old_key in kwargs:
kwargs[new_key] = kwargs.pop(old_key)
stdlib_kwargs["extra"] = kwargs
self._logger.log(level, event, **stdlib_kwargs)
@@ -50,7 +63,7 @@ class _CompatLogger:
def exception(self, event: str, **kwargs: Any) -> None:
self._log(logging.ERROR, event, exc_info=True, **kwargs)
def bind(self, **kwargs: Any) -> "_CompatLogger":
def bind(self, **kwargs: Any) -> _CompatLogger:
"""Return a new logger with bound context (no-op for stdlib)."""
return self

View File

@@ -46,6 +46,7 @@ import time
from typing import Any
import aiosqlite
from app.utils.logging_compat import get_logger
log = get_logger(__name__)
@@ -133,12 +134,10 @@ async def acquire_scheduler_lock(db: aiosqlite.Connection) -> bool:
await db.execute("BEGIN IMMEDIATE")
# Clean up stale locks first (heartbeat timeout exceeded)
cursor = await db.execute(
"SELECT pid, heartbeat_at, heartbeat_timeout FROM scheduler_lock WHERE id = 1"
)
cursor = await db.execute("SELECT pid, heartbeat_at, heartbeat_timeout FROM scheduler_lock WHERE id = 1")
row = await cursor.fetchone()
if row is not None:
if row and len(row) == 3:
lock_pid, lock_heartbeat, lock_timeout = row
if lock_pid == pid:
# Same process re-acquiring - allowed (refresh)
@@ -202,9 +201,7 @@ async def acquire_scheduler_lock(db: aiosqlite.Connection) -> bool:
return False
except Exception as e:
raise RuntimeError(
f"Failed to acquire scheduler lock due to database error: {e}"
) from e
raise RuntimeError(f"Failed to acquire scheduler lock due to database error: {e}") from e
async def release_scheduler_lock(db: aiosqlite.Connection) -> None:
@@ -372,9 +369,7 @@ async def get_lock_health(db: aiosqlite.Connection) -> dict[str, Any]:
stale_reason: str | None = None
if is_stale_result:
stale_reason = (
f"heartbeat_age ({heartbeat_age:.1f}s) > timeout ({info['heartbeat_timeout']:.1f}s)"
)
stale_reason = f"heartbeat_age ({heartbeat_age:.1f}s) > timeout ({info['heartbeat_timeout']:.1f}s)"
return {
"has_lock": True,