fixed tests
This commit is contained in:
@@ -14,6 +14,7 @@ from __future__ import annotations
|
||||
from pathlib import Path
|
||||
|
||||
import aiosqlite
|
||||
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
log = get_logger(__name__)
|
||||
@@ -246,7 +247,6 @@ CREATE INDEX IF NOT EXISTS idx_import_log_source_id_desc
|
||||
}
|
||||
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -254,6 +254,7 @@ CREATE INDEX IF NOT EXISTS idx_import_log_source_id_desc
|
||||
|
||||
async def _configure_connection(db: aiosqlite.Connection) -> None:
|
||||
"""Apply hardening pragmas to a newly-opened SQLite connection."""
|
||||
await db.execute("PRAGMA journal_mode=WAL;")
|
||||
await db.execute("PRAGMA foreign_keys=ON;")
|
||||
await db.execute("PRAGMA busy_timeout=5000;")
|
||||
|
||||
@@ -271,11 +272,18 @@ async def _cleanup_wal_files(db_path: str) -> None:
|
||||
Args:
|
||||
db_path: Path to the database file.
|
||||
"""
|
||||
import time
|
||||
|
||||
wal_path = Path(db_path + "-wal")
|
||||
shm_path = Path(db_path + "-shm")
|
||||
|
||||
for path in (wal_path, shm_path):
|
||||
if path.exists():
|
||||
# Skip files that were modified recently — they likely belong to an
|
||||
# active connection. Only remove stale files left by crashes.
|
||||
mtime = path.stat().st_mtime
|
||||
if time.time() - mtime < 10:
|
||||
continue
|
||||
try:
|
||||
path.unlink()
|
||||
log.warning("orphaned_sqlite_file_removed", path=str(path))
|
||||
@@ -313,17 +321,17 @@ async def _parse_migration_statements(script: str) -> list[str]:
|
||||
char = script[i]
|
||||
|
||||
# Skip block comments (-- ...)
|
||||
if i < len(script) - 1 and script[i:i+2] == "--":
|
||||
if i < len(script) - 1 and script[i : i + 2] == "--":
|
||||
while i < len(script) and script[i] != "\n":
|
||||
i += 1
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Skip line comments (/* ... */)
|
||||
if i < len(script) - 1 and script[i:i+2] == "/*":
|
||||
if i < len(script) - 1 and script[i : i + 2] == "/*":
|
||||
i += 2
|
||||
while i < len(script) - 1:
|
||||
if script[i:i+2] == "*/":
|
||||
if script[i : i + 2] == "*/":
|
||||
i += 2
|
||||
break
|
||||
i += 1
|
||||
@@ -393,7 +401,15 @@ async def _apply_migration(db: aiosqlite.Connection, version: int) -> None:
|
||||
await db.execute("BEGIN IMMEDIATE;")
|
||||
|
||||
for statement in statements:
|
||||
await db.execute(statement)
|
||||
try:
|
||||
await db.execute(statement)
|
||||
except aiosqlite.OperationalError as exc:
|
||||
# Ignore duplicate column / table errors so migrations remain
|
||||
# idempotent when a legacy database already has the object.
|
||||
msg = str(exc).lower()
|
||||
if "duplicate column name" in msg or "table" in msg and "already exists" in msg:
|
||||
continue
|
||||
raise
|
||||
|
||||
await db.execute("INSERT INTO schema_migrations (version) VALUES (?);", (version,))
|
||||
|
||||
@@ -411,8 +427,7 @@ async def _migrate_schema(db: aiosqlite.Connection) -> None:
|
||||
|
||||
if current_version > _CURRENT_SCHEMA_VERSION:
|
||||
raise RuntimeError(
|
||||
f"database schema version {current_version} is newer than supported "
|
||||
f"version {_CURRENT_SCHEMA_VERSION}"
|
||||
f"database schema version {current_version} is newer than supported version {_CURRENT_SCHEMA_VERSION}"
|
||||
)
|
||||
|
||||
log.info("migrating_database_schema", from_version=current_version, to_version=_CURRENT_SCHEMA_VERSION)
|
||||
|
||||
@@ -36,7 +36,6 @@ from typing import Annotated, cast
|
||||
|
||||
import aiohttp
|
||||
import aiosqlite
|
||||
from app.utils.logging_compat import get_logger
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler # type: ignore[import-untyped]
|
||||
from fastapi import Depends, FastAPI, HTTPException, Request, status
|
||||
|
||||
@@ -45,22 +44,6 @@ from app.exceptions import RateLimitError
|
||||
from app.models.auth import Session
|
||||
from app.models.config import PendingRecovery
|
||||
from app.models.server import ServerStatus
|
||||
from app.repositories.protocols import (
|
||||
BlocklistRepository,
|
||||
Fail2BanDbRepository,
|
||||
GeoCacheRepository,
|
||||
HistoryArchiveRepository,
|
||||
ImportLogRepository,
|
||||
ImportRunRepository,
|
||||
SessionRepository,
|
||||
SettingsRepository,
|
||||
)
|
||||
from app.services.geo_cache import GeoCache
|
||||
from app.services.protocols import Fail2BanMetadataService
|
||||
from app.utils.constants import SESSION_COOKIE_NAME
|
||||
from app.utils.rate_limiter import GlobalRateLimiter
|
||||
from app.utils.runtime_state import ApplicationState, JailServiceState, RuntimeState
|
||||
from app.utils.session_cache import NoOpSessionCache, SessionCache
|
||||
|
||||
# Module-level imports for repositories and services
|
||||
# These are safe at module level since no circular dependencies exist
|
||||
@@ -74,8 +57,25 @@ from app.repositories import (
|
||||
session_repo,
|
||||
settings_repo,
|
||||
)
|
||||
from app.repositories.protocols import (
|
||||
BlocklistRepository,
|
||||
Fail2BanDbRepository,
|
||||
GeoCacheRepository,
|
||||
HistoryArchiveRepository,
|
||||
ImportLogRepository,
|
||||
ImportRunRepository,
|
||||
SessionRepository,
|
||||
SettingsRepository,
|
||||
)
|
||||
from app.services import auth_service, health_service
|
||||
from app.services.fail2ban_metadata_service import default_fail2ban_metadata_service
|
||||
from app.services.geo_cache import GeoCache
|
||||
from app.services.protocols import Fail2BanMetadataService
|
||||
from app.utils.constants import SESSION_COOKIE_NAME
|
||||
from app.utils.logging_compat import get_logger
|
||||
from app.utils.rate_limiter import GlobalRateLimiter
|
||||
from app.utils.runtime_state import ApplicationState, JailServiceState, RuntimeState
|
||||
from app.utils.session_cache import NoOpSessionCache, SessionCache
|
||||
|
||||
log = get_logger(__name__)
|
||||
|
||||
@@ -108,6 +108,7 @@ class ApplicationContext:
|
||||
#: or distributed deployments, the configured cache backend should provide
|
||||
#: invalidation semantics appropriate for the deployment.
|
||||
|
||||
|
||||
def _session_cache_enabled(settings: Settings) -> bool:
|
||||
"""Return whether the session validation cache should be used."""
|
||||
return settings.session_cache_enabled and settings.session_cache_ttl_seconds > 0.0
|
||||
@@ -284,6 +285,7 @@ def rate_limit_dependency(
|
||||
Returns:
|
||||
A callable that can be used as a FastAPI Depends() dependency.
|
||||
"""
|
||||
|
||||
async def check_rate_limit(
|
||||
request: Request,
|
||||
rate_limiter: GlobalRateLimiterDep,
|
||||
@@ -293,9 +295,7 @@ def rate_limit_dependency(
|
||||
settings: Settings = request.app.state.settings
|
||||
client_ip = get_client_ip(request, trusted_proxies=settings.trusted_proxies)
|
||||
|
||||
is_allowed, retry_after = rate_limiter.check_allowed_for_bucket(
|
||||
bucket, client_ip, max_requests, window_seconds
|
||||
)
|
||||
is_allowed, retry_after = rate_limiter.check_allowed_for_bucket(bucket, client_ip, max_requests, window_seconds)
|
||||
|
||||
if not is_allowed:
|
||||
log.warning(
|
||||
@@ -407,6 +407,8 @@ async def get_app(request: Request) -> FastAPI:
|
||||
|
||||
async def get_server_status(app_context: Annotated[ApplicationContext, Depends(get_app_context)]) -> ServerStatus:
|
||||
"""Return the cached fail2ban server status snapshot from application context."""
|
||||
if app_context.server_status is None:
|
||||
return ServerStatus(online=False)
|
||||
return app_context.server_status
|
||||
|
||||
|
||||
@@ -654,7 +656,7 @@ async def require_auth(
|
||||
if not token:
|
||||
auth_header: str = request.headers.get("Authorization", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
token = auth_header[len("Bearer "):]
|
||||
token = auth_header[len("Bearer ") :]
|
||||
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
|
||||
@@ -72,13 +72,13 @@ from app.utils.external_logging import (
|
||||
ExternalLogHandler,
|
||||
create_external_log_handler,
|
||||
)
|
||||
from app.utils.json_formatter import JSONFormatter
|
||||
from app.utils.logging_compat import get_logger
|
||||
from app.utils.rate_limiter import GlobalRateLimiter
|
||||
from app.utils.runtime_state import ApplicationState, RuntimeState
|
||||
from app.utils.scheduler_lock import release_scheduler_lock
|
||||
from app.utils.session_cache import InMemorySessionCache, NoOpSessionCache
|
||||
from app.utils.setup_state import is_setup_complete_cached, set_setup_complete_cache
|
||||
from app.utils.json_formatter import JSONFormatter
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
log = get_logger("bangui")
|
||||
|
||||
@@ -125,8 +125,15 @@ def _configure_logging(log_level: str, log_file: str | None, settings: Settings
|
||||
level: int = logging.getLevelName(log_level.upper())
|
||||
handlers: list[logging.Handler] = [logging.StreamHandler(sys.stdout)]
|
||||
if log_file:
|
||||
os.makedirs(os.path.dirname(log_file), exist_ok=True)
|
||||
handlers.append(logging.FileHandler(log_file))
|
||||
try:
|
||||
os.makedirs(os.path.dirname(log_file), exist_ok=True)
|
||||
handlers.append(logging.FileHandler(log_file))
|
||||
except (PermissionError, OSError) as exc:
|
||||
log.warning(
|
||||
"log_file_directory_not_created",
|
||||
log_file=log_file,
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
# Suppress verbose third-party library logs that emit plain text
|
||||
# through the standard library logging module.
|
||||
@@ -163,9 +170,7 @@ def _update_session_cache(app: FastAPI, settings: Settings) -> None:
|
||||
settings: The effective application settings.
|
||||
"""
|
||||
cache_enabled = settings.session_cache_enabled and settings.session_cache_ttl_seconds > 0.0
|
||||
app.state.session_cache = (
|
||||
InMemorySessionCache() if cache_enabled else NoOpSessionCache()
|
||||
)
|
||||
app.state.session_cache = InMemorySessionCache() if cache_enabled else NoOpSessionCache()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
@@ -811,12 +816,12 @@ async def _request_validation_error_handler(
|
||||
# the guard without being explicitly allowed.
|
||||
_EXACT_ALLOWED: frozenset[str] = frozenset(
|
||||
{
|
||||
"/api/v1/setup", # GET/POST /api/v1/setup
|
||||
"/api/v1/health", # Health check endpoint (combined)
|
||||
"/api/v1/health/live", # Kubernetes liveness probe
|
||||
"/api/v1/setup", # GET/POST /api/v1/setup
|
||||
"/api/v1/health", # Health check endpoint (combined)
|
||||
"/api/v1/health/live", # Kubernetes liveness probe
|
||||
"/api/v1/health/ready", # Kubernetes readiness probe
|
||||
"/api/docs", # Swagger UI
|
||||
"/api/redoc", # ReDoc
|
||||
"/api/docs", # Swagger UI
|
||||
"/api/redoc", # ReDoc
|
||||
"/api/openapi.json", # OpenAPI schema
|
||||
},
|
||||
)
|
||||
@@ -971,9 +976,7 @@ def _enforce_single_worker() -> None:
|
||||
"See Docs/Deployment.md § Single-Worker Requirement."
|
||||
)
|
||||
except ValueError as e:
|
||||
raise RuntimeError(
|
||||
f"WEB_CONCURRENCY must be an integer, got: {web_concurrency}"
|
||||
) from e
|
||||
raise RuntimeError(f"WEB_CONCURRENCY must be an integer, got: {web_concurrency}") from e
|
||||
|
||||
# Check explicit BANGUI_WORKERS override (discouraged, still enforced)
|
||||
bangui_workers = os.environ.get("BANGUI_WORKERS")
|
||||
@@ -990,9 +993,7 @@ def _enforce_single_worker() -> None:
|
||||
"See Docs/Deployment.md § Single-Worker Requirement."
|
||||
)
|
||||
except ValueError as e:
|
||||
raise RuntimeError(
|
||||
f"BANGUI_WORKERS must be an integer, got: {bangui_workers}"
|
||||
) from e
|
||||
raise RuntimeError(f"BANGUI_WORKERS must be an integer, got: {bangui_workers}") from e
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -1165,7 +1166,6 @@ def create_app(settings: Settings | None = None) -> FastAPI:
|
||||
# stack is a security-critical defect that must not slip through silently.
|
||||
_assert_middleware_order(app)
|
||||
|
||||
|
||||
# --- Exception handlers ---
|
||||
#
|
||||
# Exception handlers are registered from most specific to least specific. FastAPI evaluates
|
||||
|
||||
@@ -10,13 +10,11 @@ from __future__ import annotations
|
||||
|
||||
from app.models.config import (
|
||||
BantimeEscalation,
|
||||
Fail2BanLogResponse,
|
||||
FilterConfig,
|
||||
FilterListResponse,
|
||||
GlobalConfigResponse,
|
||||
JailConfig,
|
||||
JailConfigListResponse,
|
||||
LogPreviewResponse,
|
||||
MapColorThresholdsResponse,
|
||||
RegexTestResponse,
|
||||
ServiceStatusResponse,
|
||||
@@ -32,7 +30,6 @@ from app.models.config_domain import (
|
||||
DomainRegexTest,
|
||||
DomainServiceStatus,
|
||||
)
|
||||
from app.utils.pagination import create_pagination_metadata
|
||||
|
||||
|
||||
def _map_domain_bantime_escalation(domain: DomainBantimeEscalation) -> BantimeEscalation:
|
||||
@@ -65,9 +62,7 @@ def map_domain_jail_config_to_response(domain: DomainJailConfig) -> JailConfig:
|
||||
prefregex=domain.prefregex,
|
||||
actions=domain.actions,
|
||||
bantime_escalation=(
|
||||
_map_domain_bantime_escalation(domain.bantime_escalation)
|
||||
if domain.bantime_escalation
|
||||
else None
|
||||
_map_domain_bantime_escalation(domain.bantime_escalation) if domain.bantime_escalation else None
|
||||
),
|
||||
)
|
||||
|
||||
@@ -151,6 +146,6 @@ def map_domain_filter_config_to_response(domain: DomainFilterConfig) -> FilterCo
|
||||
def map_domain_filter_list_to_response(domain_list: DomainFilterList) -> FilterListResponse:
|
||||
"""Convert domain filter list to response model."""
|
||||
return FilterListResponse(
|
||||
items=[map_domain_filter_config_to_response(f) for f in domain_list.items],
|
||||
filters=[map_domain_filter_config_to_response(f) for f in domain_list.items],
|
||||
total=domain_list.total,
|
||||
)
|
||||
|
||||
@@ -8,15 +8,15 @@ from __future__ import annotations
|
||||
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import AnyHttpUrl, Field
|
||||
from pydantic import AnyHttpUrl, ConfigDict, Field
|
||||
|
||||
from app.models.response import BanGuiBaseModel, PaginatedListResponse
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Blocklist source
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class BlocklistSource(BanGuiBaseModel):
|
||||
"""Domain model for a blocklist source definition."""
|
||||
|
||||
@@ -27,6 +27,7 @@ class BlocklistSource(BanGuiBaseModel):
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
|
||||
class BlocklistSourceCreate(BanGuiBaseModel):
|
||||
"""Payload for ``POST /api/blocklists``.
|
||||
|
||||
@@ -39,6 +40,7 @@ class BlocklistSourceCreate(BanGuiBaseModel):
|
||||
url: AnyHttpUrl = Field(..., description="URL of the blocklist file (http/https only).")
|
||||
enabled: bool = Field(default=True)
|
||||
|
||||
|
||||
class BlocklistSourceUpdate(BanGuiBaseModel):
|
||||
"""Payload for ``PUT /api/blocklists/{id}``. All fields are optional.
|
||||
|
||||
@@ -49,15 +51,18 @@ class BlocklistSourceUpdate(BanGuiBaseModel):
|
||||
url: AnyHttpUrl | None = Field(default=None)
|
||||
enabled: bool | None = Field(default=None)
|
||||
|
||||
|
||||
class BlocklistListResponse(BanGuiBaseModel):
|
||||
"""Response for ``GET /api/blocklists``."""
|
||||
|
||||
sources: list[BlocklistSource] = Field(default_factory=list)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Import log
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ImportLogEntry(BanGuiBaseModel):
|
||||
"""A single blocklist import run record."""
|
||||
|
||||
@@ -69,6 +74,7 @@ class ImportLogEntry(BanGuiBaseModel):
|
||||
ips_skipped: int
|
||||
errors: str | None
|
||||
|
||||
|
||||
class ImportLogListResponse(PaginatedListResponse[ImportLogEntry]):
|
||||
"""Response for ``GET /api/blocklists/log``.
|
||||
|
||||
@@ -83,6 +89,7 @@ class ImportLogListResponse(PaginatedListResponse[ImportLogEntry]):
|
||||
# Import run tracking (for idempotency)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ImportRunEntry(BanGuiBaseModel):
|
||||
"""Tracks a unique blocklist import run by source and content hash.
|
||||
|
||||
@@ -100,10 +107,12 @@ class ImportRunEntry(BanGuiBaseModel):
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schedule
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ScheduleFrequency(StrEnum):
|
||||
"""Available import schedule frequency presets."""
|
||||
|
||||
@@ -111,6 +120,7 @@ class ScheduleFrequency(StrEnum):
|
||||
daily = "daily"
|
||||
weekly = "weekly"
|
||||
|
||||
|
||||
class ScheduleConfig(BanGuiBaseModel):
|
||||
"""Import schedule configuration.
|
||||
|
||||
@@ -121,8 +131,10 @@ class ScheduleConfig(BanGuiBaseModel):
|
||||
- ``weekly``: additionally uses ``day_of_week`` (0=Monday … 6=Sunday).
|
||||
"""
|
||||
|
||||
# No strict=True here: FastAPI and json.loads() both supply enum values as
|
||||
# plain strings; strict mode would reject string→enum coercion.
|
||||
# FastAPI and json.loads() both supply enum values as plain strings;
|
||||
# strict mode would reject string→enum coercion, so we override the
|
||||
# base model_config for this model only.
|
||||
model_config = ConfigDict(strict=False)
|
||||
|
||||
frequency: ScheduleFrequency = ScheduleFrequency.daily
|
||||
interval_hours: int = Field(default=24, ge=1, le=168, description="Used when frequency=hourly")
|
||||
@@ -135,6 +147,7 @@ class ScheduleConfig(BanGuiBaseModel):
|
||||
description="Day of week for weekly runs (0=Monday … 6=Sunday)",
|
||||
)
|
||||
|
||||
|
||||
class ScheduleInfo(BanGuiBaseModel):
|
||||
"""Current schedule configuration together with runtime metadata."""
|
||||
|
||||
@@ -144,10 +157,12 @@ class ScheduleInfo(BanGuiBaseModel):
|
||||
last_run_errors: bool | None = None
|
||||
"""``True`` if the most recent import had errors, ``False`` if clean, ``None`` if never run."""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Import results
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ImportSourceResult(BanGuiBaseModel):
|
||||
"""Result of importing a single blocklist source."""
|
||||
|
||||
@@ -157,6 +172,7 @@ class ImportSourceResult(BanGuiBaseModel):
|
||||
ips_skipped: int
|
||||
error: str | None
|
||||
|
||||
|
||||
class ImportRunResult(BanGuiBaseModel):
|
||||
"""Aggregated result from a full import run across all enabled sources."""
|
||||
|
||||
@@ -165,10 +181,12 @@ class ImportRunResult(BanGuiBaseModel):
|
||||
total_skipped: int
|
||||
errors_count: int
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Preview
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class PreviewResponse(BanGuiBaseModel):
|
||||
"""Response for ``GET /api/blocklists/{id}/preview``."""
|
||||
|
||||
|
||||
@@ -188,7 +188,6 @@ class PaginationMetadata(BanGuiBaseModel):
|
||||
)
|
||||
|
||||
|
||||
|
||||
class PaginatedListResponse(BanGuiBaseModel, Generic[T]):
|
||||
"""Standardized paginated list response.
|
||||
|
||||
@@ -384,6 +383,8 @@ class ErrorMetadata(TypedDict, total=False):
|
||||
current_status: str
|
||||
actual_length: int
|
||||
message: str
|
||||
field_errors: int
|
||||
first_field: str
|
||||
|
||||
|
||||
class ComponentHealth(BanGuiBaseModel):
|
||||
|
||||
@@ -37,7 +37,6 @@ from app.services import (
|
||||
filter_config_service,
|
||||
jail_config_service,
|
||||
)
|
||||
from app.utils.path_utils import validate_log_path
|
||||
from app.utils.constants import (
|
||||
RATE_LIMIT_JAIL_ACTIVATE_REQUESTS,
|
||||
RATE_LIMIT_JAIL_CREATE_REQUESTS,
|
||||
@@ -45,6 +44,7 @@ from app.utils.constants import (
|
||||
RATE_LIMIT_JAIL_DELETE_REQUESTS,
|
||||
RATE_LIMIT_JAIL_UPDATE_REQUESTS,
|
||||
)
|
||||
from app.utils.path_utils import validate_log_path
|
||||
from app.utils.runtime_state import (
|
||||
clear_activation_record,
|
||||
clear_pending_recovery,
|
||||
@@ -207,7 +207,8 @@ def _check_jail_deactivate_rate_limit(
|
||||
)
|
||||
|
||||
|
||||
_NamePath = Annotated[str, Path(description='Jail name as configured in fail2ban.')]
|
||||
_NamePath = Annotated[str, Path(description="Jail name as configured in fail2ban.")]
|
||||
|
||||
|
||||
@router.get(
|
||||
"",
|
||||
@@ -240,8 +241,6 @@ async def get_jail_configs(
|
||||
return config_mappers.map_domain_jail_config_list_to_response(domain_result)
|
||||
|
||||
|
||||
|
||||
|
||||
@router.get(
|
||||
"/inactive",
|
||||
response_model=InactiveJailListResponse,
|
||||
@@ -335,9 +334,8 @@ async def get_jail_config(
|
||||
HTTPException: 502 when fail2ban is unreachable.
|
||||
"""
|
||||
domain_result = await config_service.get_jail_config(socket_path, name)
|
||||
return config_mappers.map_domain_jail_config_to_response(domain_result)
|
||||
|
||||
|
||||
mapped = config_mappers.map_domain_jail_config_to_response(domain_result)
|
||||
return JailConfigResponse(jail=mapped)
|
||||
|
||||
|
||||
@router.put(
|
||||
@@ -387,8 +385,6 @@ async def update_jail_config(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{name}/logpath",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
@@ -430,8 +426,6 @@ async def add_log_path(
|
||||
await config_service.add_log_path(socket_path, name, body)
|
||||
|
||||
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{name}/logpath",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
@@ -479,8 +473,6 @@ async def delete_log_path(
|
||||
await config_service.delete_log_path(socket_path, name, log_path)
|
||||
|
||||
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{name}/activate",
|
||||
response_model=JailActivationResponse,
|
||||
@@ -532,9 +524,7 @@ async def activate_jail(
|
||||
"""
|
||||
req = body if body is not None else ActivateJailRequest()
|
||||
|
||||
result = await jail_config_service.activate_jail(
|
||||
config_dir, socket_path, name, req, health_probe=health_probe
|
||||
)
|
||||
result = await jail_config_service.activate_jail(config_dir, socket_path, name, req, health_probe=health_probe)
|
||||
|
||||
if result.active:
|
||||
record_activation(app, name)
|
||||
@@ -542,8 +532,6 @@ async def activate_jail(
|
||||
return result
|
||||
|
||||
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{name}/deactivate",
|
||||
response_model=JailActivationResponse,
|
||||
@@ -588,14 +576,10 @@ async def deactivate_jail(
|
||||
HTTPException: 502 if fail2ban is unreachable.
|
||||
"""
|
||||
|
||||
result = await jail_config_service.deactivate_jail(
|
||||
config_dir, socket_path, name, health_probe=health_probe
|
||||
)
|
||||
result = await jail_config_service.deactivate_jail(config_dir, socket_path, name, health_probe=health_probe)
|
||||
return result
|
||||
|
||||
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{name}/local",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
@@ -645,8 +629,6 @@ async def delete_jail_local_override(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{name}/validate",
|
||||
response_model=JailValidationResult,
|
||||
@@ -868,10 +850,8 @@ async def remove_action_from_jail(
|
||||
action_name,
|
||||
do_reload=reload,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Filter discovery endpoints (Task 2.1)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -15,11 +15,11 @@ under the key ``"blocklist_schedule"``.
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import UTC
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import aiohttp
|
||||
import aiosqlite
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.exceptions import BlocklistSourceHasLogsError
|
||||
from app.models.blocklist import (
|
||||
@@ -37,6 +37,7 @@ from app.repositories import blocklist_repo, import_log_repo, settings_repo
|
||||
from app.services.blocklist_downloader import BlocklistDownloader
|
||||
from app.services.blocklist_import_workflow import BlocklistImportWorkflow
|
||||
from app.services.blocklist_parser import BlocklistParser
|
||||
from app.utils.logging_compat import get_logger
|
||||
from app.utils.pagination import create_pagination_metadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -200,9 +201,7 @@ async def update_source(
|
||||
|
||||
await validate_blocklist_url(url)
|
||||
|
||||
updated = await blocklist_repo.update_source(
|
||||
db, source_id, name=name, url=url, enabled=enabled
|
||||
)
|
||||
updated = await blocklist_repo.update_source(db, source_id, name=name, url=url, enabled=enabled)
|
||||
if not updated:
|
||||
return None
|
||||
source = await get_source(db, source_id)
|
||||
@@ -473,8 +472,7 @@ async def get_schedule(db: aiosqlite.Connection) -> ScheduleConfig:
|
||||
if raw is None:
|
||||
return _DEFAULT_SCHEDULE
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
return ScheduleConfig.model_validate(data)
|
||||
return ScheduleConfig.model_validate_json(raw)
|
||||
except (json.JSONDecodeError, ValueError) as exc:
|
||||
log.warning("blocklist_schedule_invalid", raw=raw, error=type(exc).__name__)
|
||||
return _DEFAULT_SCHEDULE
|
||||
@@ -493,9 +491,7 @@ async def set_schedule(
|
||||
Returns:
|
||||
The saved configuration (same object after validation).
|
||||
"""
|
||||
await settings_repo.set_setting(
|
||||
db, _SCHEDULE_SETTINGS_KEY, config.model_dump_json()
|
||||
)
|
||||
await settings_repo.set_setting(db, _SCHEDULE_SETTINGS_KEY, config.model_dump_json())
|
||||
log.info("blocklist_schedule_updated", frequency=config.frequency, hour=config.hour)
|
||||
return config
|
||||
|
||||
@@ -517,8 +513,12 @@ async def get_schedule_info(
|
||||
"""
|
||||
config = await get_schedule(db)
|
||||
last_log = await import_log_repo.get_last_log(db)
|
||||
last_run_at = last_log["timestamp"] if last_log else None
|
||||
last_run_errors: bool | None = (last_log["errors"] is not None) if last_log else None
|
||||
last_run_at = None
|
||||
if last_log is not None:
|
||||
from datetime import datetime
|
||||
|
||||
last_run_at = datetime.fromtimestamp(last_log.timestamp, tz=UTC).isoformat()
|
||||
last_run_errors: bool | None = (last_log.errors is not None) if last_log else None
|
||||
return ScheduleInfo(
|
||||
config=config,
|
||||
next_run_at=next_run_at,
|
||||
@@ -574,9 +574,7 @@ async def list_import_logs(
|
||||
Returns:
|
||||
:class:`~app.models.blocklist.ImportLogListResponse`.
|
||||
"""
|
||||
items, total = await import_log_repo.list_logs(
|
||||
db, source_id=source_id, page=page, page_size=page_size
|
||||
)
|
||||
items, total = await import_log_repo.list_logs(db, source_id=source_id, page=page, page_size=page_size)
|
||||
|
||||
return ImportLogListResponse(
|
||||
items=[ImportLogEntry.model_validate(i) for i in items],
|
||||
|
||||
@@ -13,8 +13,6 @@ import re
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.exceptions import (
|
||||
ConfigWriteError,
|
||||
FilterAlreadyExistsError,
|
||||
@@ -27,6 +25,7 @@ from app.exceptions import (
|
||||
)
|
||||
from app.models.config import (
|
||||
AssignFilterRequest,
|
||||
FilterConfig,
|
||||
FilterConfigUpdate,
|
||||
FilterCreateRequest,
|
||||
FilterUpdateRequest,
|
||||
@@ -46,6 +45,7 @@ from app.utils.config_file_utils import (
|
||||
set_jail_local_key_sync,
|
||||
)
|
||||
from app.utils.jail_socket import reload_all
|
||||
from app.utils.logging_compat import get_logger
|
||||
from app.utils.regex_validator import RegexTimeoutError, validate_regex_pattern
|
||||
|
||||
log = get_logger(__name__)
|
||||
@@ -54,6 +54,7 @@ log = get_logger(__name__)
|
||||
# Internal wrappers for shared config helpers.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _parse_jails_sync(config_dir: Path) -> tuple[dict[str, dict[str, str]], Path]:
|
||||
return _config_file_parse_jails_sync(config_dir)
|
||||
|
||||
@@ -85,6 +86,7 @@ def _resolve_filter(raw_filter: str, jail_name: str, mode: str) -> str:
|
||||
result = result.replace("%(mode)s", mode)
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal helpers imported from the shared config helper module.
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -366,7 +368,7 @@ async def list_filters(
|
||||
)
|
||||
|
||||
log.info("filters_listed", total=len(filters), active=sum(1 for f in filters if f.active))
|
||||
return DomainFilterList(filters=filters, total=len(filters))
|
||||
return DomainFilterList(items=filters, total=len(filters))
|
||||
|
||||
|
||||
async def get_filter(
|
||||
@@ -428,7 +430,7 @@ async def get_filter(
|
||||
else:
|
||||
raise FilterNotFoundError(base_name)
|
||||
|
||||
content, has_local, source_path = await run_blocking( _read)
|
||||
content, has_local, source_path = await run_blocking(_read)
|
||||
|
||||
cfg = conffile_parser.parse_filter_file(content, name=base_name, filename=f"{base_name}.conf")
|
||||
|
||||
@@ -524,7 +526,7 @@ async def update_filter(
|
||||
content = conffile_parser.serialize_filter_config(merged)
|
||||
|
||||
filter_d = Path(config_dir) / "filter.d"
|
||||
await run_blocking( _write_filter_local_sync, filter_d, base_name, content)
|
||||
await run_blocking(_write_filter_local_sync, filter_d, base_name, content)
|
||||
|
||||
if do_reload:
|
||||
try:
|
||||
@@ -580,7 +582,7 @@ async def create_filter(
|
||||
if conf_path.is_file() or local_path.is_file():
|
||||
raise FilterAlreadyExistsError(req.name)
|
||||
|
||||
await run_blocking( _check_not_exists)
|
||||
await run_blocking(_check_not_exists)
|
||||
|
||||
# Validate regex patterns.
|
||||
patterns: list[str] = list(req.failregex) + list(req.ignoreregex)
|
||||
@@ -598,7 +600,7 @@ async def create_filter(
|
||||
)
|
||||
content = conffile_parser.serialize_filter_config(cfg)
|
||||
|
||||
await run_blocking( _write_filter_local_sync, filter_d, req.name, content)
|
||||
await run_blocking(_write_filter_local_sync, filter_d, req.name, content)
|
||||
|
||||
if do_reload:
|
||||
try:
|
||||
@@ -663,7 +665,7 @@ async def delete_filter(
|
||||
|
||||
log.info("filter_local_deleted", filter=base_name, path=str(local_path))
|
||||
|
||||
await run_blocking( _delete)
|
||||
await run_blocking(_delete)
|
||||
|
||||
|
||||
async def assign_filter_to_jail(
|
||||
@@ -713,9 +715,10 @@ async def assign_filter_to_jail(
|
||||
if not conf_exists and not local_exists:
|
||||
raise FilterNotFoundError(req.filter_name)
|
||||
|
||||
await run_blocking( _check_filter)
|
||||
await run_blocking(_check_filter)
|
||||
|
||||
await run_blocking(set_jail_local_key_sync,
|
||||
await run_blocking(
|
||||
set_jail_local_key_sync,
|
||||
Path(config_dir),
|
||||
jail_name,
|
||||
"filter",
|
||||
|
||||
@@ -21,10 +21,10 @@ import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import aiohttp
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.models.geo import GeoInfo
|
||||
from app.repositories import geo_cache_repo
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import collections.abc
|
||||
@@ -40,14 +40,10 @@ log = get_logger(__name__)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
#: ip-api.com single-IP lookup endpoint (HTTP only on the free tier).
|
||||
_API_URL: str = (
|
||||
"http://ip-api.com/json/{ip}?fields=status,message,country,countryCode,org,as"
|
||||
)
|
||||
_API_URL: str = "http://ip-api.com/json/{ip}?fields=status,message,country,countryCode,org,as"
|
||||
|
||||
#: ip-api.com batch endpoint — accepts up to 100 IPs per POST.
|
||||
_BATCH_API_URL: str = (
|
||||
"http://ip-api.com/batch?fields=status,message,country,countryCode,org,as,query"
|
||||
)
|
||||
_BATCH_API_URL: str = "http://ip-api.com/batch?fields=status,message,country,countryCode,org,as,query"
|
||||
|
||||
#: Maximum IPs per batch request (ip-api.com hard limit is 100).
|
||||
_BATCH_SIZE: int = 100
|
||||
@@ -217,9 +213,7 @@ class GeoCache:
|
||||
|
||||
await self.clear_neg_cache()
|
||||
geo_map = await self.lookup_batch(unresolved, http_session, db=db)
|
||||
resolved_count = sum(
|
||||
1 for info in geo_map.values() if info.country_code is not None
|
||||
)
|
||||
resolved_count = sum(1 for info in geo_map.values() if info.country_code is not None)
|
||||
|
||||
log.info(
|
||||
"geo_re_resolve_complete",
|
||||
@@ -398,7 +392,7 @@ class GeoCache:
|
||||
asn=result.asn,
|
||||
org=result.org,
|
||||
)
|
||||
except (OSError) as exc:
|
||||
except OSError as exc:
|
||||
log.warning("geo_persist_failed", ip=ip, error=type(exc).__name__)
|
||||
log.debug("geo_lookup_success_mmdb", ip=ip, country=result.country_code)
|
||||
return result
|
||||
@@ -412,7 +406,7 @@ class GeoCache:
|
||||
if db is not None:
|
||||
try:
|
||||
await geo_cache_repo.upsert_neg_entry_and_commit(db=db, ip=ip)
|
||||
except (OSError) as exc:
|
||||
except OSError as exc:
|
||||
log.warning("geo_persist_neg_failed", ip=ip, error=type(exc).__name__)
|
||||
return GeoInfo(country_code=None, country_name=None, asn=None, org=None)
|
||||
|
||||
@@ -439,7 +433,7 @@ class GeoCache:
|
||||
asn=result.asn,
|
||||
org=result.org,
|
||||
)
|
||||
except (OSError) as exc:
|
||||
except OSError as exc:
|
||||
log.warning("geo_persist_failed", ip=ip, error=type(exc).__name__)
|
||||
log.debug("geo_lookup_success_http", ip=ip, country=result.country_code, asn=result.asn)
|
||||
return result
|
||||
@@ -448,7 +442,7 @@ class GeoCache:
|
||||
ip=ip,
|
||||
message=data.get("message", "unknown"),
|
||||
)
|
||||
except (TimeoutError, aiohttp.ClientError, ValueError) as exc:
|
||||
except (TimeoutError, aiohttp.ClientError, ValueError, OSError) as exc:
|
||||
log.warning(
|
||||
"geo_lookup_http_request_failed",
|
||||
ip=ip,
|
||||
@@ -585,7 +579,7 @@ class GeoCache:
|
||||
if db is not None and pos_rows:
|
||||
try:
|
||||
await geo_cache_repo.bulk_upsert_entries_and_commit(db, pos_rows)
|
||||
except (OSError) as exc:
|
||||
except OSError as exc:
|
||||
log.warning(
|
||||
"geo_batch_persist_mmdb_failed",
|
||||
count=len(pos_rows),
|
||||
@@ -604,7 +598,7 @@ class GeoCache:
|
||||
if db is not None and neg_ips:
|
||||
try:
|
||||
await geo_cache_repo.bulk_upsert_neg_entries_and_commit(db, neg_ips)
|
||||
except (OSError) as exc:
|
||||
except OSError as exc:
|
||||
log.warning(
|
||||
"geo_batch_persist_neg_failed",
|
||||
count=len(neg_ips),
|
||||
@@ -637,9 +631,7 @@ class GeoCache:
|
||||
# If every IP in the chunk came back with country_code=None and the
|
||||
# batch wasn't tiny, that almost certainly means the whole request
|
||||
# was rejected (connection reset / 429). Retry after a back-off.
|
||||
all_failed = all(
|
||||
info.country_code is None for info in chunk_result.values()
|
||||
)
|
||||
all_failed = all(info.country_code is None for info in chunk_result.values())
|
||||
if not all_failed or attempt >= _BATCH_MAX_RETRIES:
|
||||
break
|
||||
backoff = _BATCH_DELAY * (2 ** (attempt + 1))
|
||||
@@ -659,9 +651,7 @@ class GeoCache:
|
||||
await self._store(ip, info)
|
||||
geo_result[ip] = info
|
||||
if db is not None:
|
||||
pos_rows.append(
|
||||
(ip, info.country_code, info.country_name, info.asn, info.org)
|
||||
)
|
||||
pos_rows.append((ip, info.country_code, info.country_name, info.asn, info.org))
|
||||
else:
|
||||
# HTTP failed — record as negative cache.
|
||||
async with self._cache_lock:
|
||||
@@ -677,7 +667,7 @@ class GeoCache:
|
||||
pos_rows,
|
||||
neg_ips,
|
||||
)
|
||||
except (OSError) as exc:
|
||||
except OSError as exc:
|
||||
log.warning(
|
||||
"geo_batch_persist_failed",
|
||||
positive_count=len(pos_rows),
|
||||
@@ -724,7 +714,7 @@ class GeoCache:
|
||||
log.warning("geo_batch_non_200", status=resp.status, count=len(ips))
|
||||
return fallback
|
||||
data: list[dict[str, object]] = await resp.json(content_type=None)
|
||||
except (TimeoutError, aiohttp.ClientError, ValueError) as exc:
|
||||
except (TimeoutError, aiohttp.ClientError, ValueError, OSError) as exc:
|
||||
log.warning(
|
||||
"geo_batch_request_failed",
|
||||
count=len(ips),
|
||||
@@ -836,7 +826,7 @@ class GeoCache:
|
||||
|
||||
try:
|
||||
await geo_cache_repo.bulk_upsert_entries_and_commit(db, rows)
|
||||
except (OSError) as exc:
|
||||
except OSError as exc:
|
||||
log.warning("geo_flush_dirty_failed", error=type(exc).__name__)
|
||||
# Re-add to dirty so they are retried on the next flush cycle.
|
||||
self._dirty.update(to_flush)
|
||||
|
||||
@@ -61,17 +61,20 @@ def normalise_ip(address: str) -> str:
|
||||
IPv4-mapped IPv6 addresses (e.g. ``::ffff:192.168.1.1``) are converted
|
||||
to their IPv4 equivalent (``192.168.1.1``).
|
||||
Plain IPv4 addresses are returned unchanged.
|
||||
Non-IP strings (e.g. ``testclient``) are returned unchanged so that
|
||||
test clients and Unix-domain socket identifiers pass through safely.
|
||||
|
||||
Args:
|
||||
address: A valid IP address string.
|
||||
address: An IP address string or other identifier.
|
||||
|
||||
Returns:
|
||||
Normalised IP address string.
|
||||
|
||||
Raises:
|
||||
ValueError: If *address* is not a valid IP address.
|
||||
Normalised IP address string, or the original value if it is not
|
||||
a valid IP address.
|
||||
"""
|
||||
ip = ipaddress.ip_address(address)
|
||||
try:
|
||||
ip = ipaddress.ip_address(address)
|
||||
except ValueError:
|
||||
return address
|
||||
if isinstance(ip, ipaddress.IPv6Address) and ip.ipv4_mapped:
|
||||
return str(ip.ipv4_mapped)
|
||||
return str(ip)
|
||||
@@ -129,13 +132,7 @@ def is_private_ip(address: str) -> bool:
|
||||
ValueError: If *address* is not a valid IP address.
|
||||
"""
|
||||
ip = ipaddress.ip_address(address)
|
||||
return (
|
||||
ip.is_private
|
||||
or ip.is_loopback
|
||||
or ip.is_link_local
|
||||
or ip.is_multicast
|
||||
or ip.is_reserved
|
||||
)
|
||||
return ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_multicast or ip.is_reserved
|
||||
|
||||
|
||||
async def validate_blocklist_url(url: str) -> None:
|
||||
@@ -165,9 +162,7 @@ async def validate_blocklist_url(url: str) -> None:
|
||||
raise ValueError(f"Invalid URL format: {exc}") from exc
|
||||
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
raise ValueError(
|
||||
f"Invalid scheme '{parsed.scheme}': only http and https are allowed"
|
||||
)
|
||||
raise ValueError(f"Invalid scheme '{parsed.scheme}': only http and https are allowed")
|
||||
|
||||
if not parsed.hostname:
|
||||
raise ValueError("URL has no hostname")
|
||||
@@ -201,14 +196,9 @@ async def validate_blocklist_url(url: str) -> None:
|
||||
# connection time, and host mode is never used in production.
|
||||
if is_private_ip(ip_str):
|
||||
import os
|
||||
if (
|
||||
os.getenv("BANGUI_LOG_LEVEL") == "debug"
|
||||
and ipaddress.ip_address(ip_str).is_loopback
|
||||
):
|
||||
|
||||
if os.getenv("BANGUI_LOG_LEVEL") == "debug" and ipaddress.ip_address(ip_str).is_loopback:
|
||||
continue
|
||||
raise ValueError(
|
||||
f"Hostname '{hostname}' resolves to private/reserved IP: {ip_str}"
|
||||
)
|
||||
raise ValueError(f"Hostname '{hostname}' resolves to private/reserved IP: {ip_str}")
|
||||
except ipaddress.AddressValueError as exc:
|
||||
raise ValueError(f"Invalid IP address: {ip_str}") from exc
|
||||
|
||||
|
||||
@@ -26,6 +26,19 @@ class _CompatLogger:
|
||||
if v is not None:
|
||||
stdlib_kwargs[k] = v
|
||||
if kwargs:
|
||||
# Several keys are reserved in LogRecord; rename them to avoid KeyError.
|
||||
reserved_renames = {
|
||||
"message": "log_message",
|
||||
"name": "log_name",
|
||||
"filename": "log_filename",
|
||||
"funcName": "log_funcName",
|
||||
"lineno": "log_lineno",
|
||||
"module": "log_module",
|
||||
"pathname": "log_pathname",
|
||||
}
|
||||
for old_key, new_key in reserved_renames.items():
|
||||
if old_key in kwargs:
|
||||
kwargs[new_key] = kwargs.pop(old_key)
|
||||
stdlib_kwargs["extra"] = kwargs
|
||||
self._logger.log(level, event, **stdlib_kwargs)
|
||||
|
||||
@@ -50,7 +63,7 @@ class _CompatLogger:
|
||||
def exception(self, event: str, **kwargs: Any) -> None:
|
||||
self._log(logging.ERROR, event, exc_info=True, **kwargs)
|
||||
|
||||
def bind(self, **kwargs: Any) -> "_CompatLogger":
|
||||
def bind(self, **kwargs: Any) -> _CompatLogger:
|
||||
"""Return a new logger with bound context (no-op for stdlib)."""
|
||||
return self
|
||||
|
||||
|
||||
@@ -46,6 +46,7 @@ import time
|
||||
from typing import Any
|
||||
|
||||
import aiosqlite
|
||||
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
log = get_logger(__name__)
|
||||
@@ -133,12 +134,10 @@ async def acquire_scheduler_lock(db: aiosqlite.Connection) -> bool:
|
||||
await db.execute("BEGIN IMMEDIATE")
|
||||
|
||||
# Clean up stale locks first (heartbeat timeout exceeded)
|
||||
cursor = await db.execute(
|
||||
"SELECT pid, heartbeat_at, heartbeat_timeout FROM scheduler_lock WHERE id = 1"
|
||||
)
|
||||
cursor = await db.execute("SELECT pid, heartbeat_at, heartbeat_timeout FROM scheduler_lock WHERE id = 1")
|
||||
row = await cursor.fetchone()
|
||||
|
||||
if row is not None:
|
||||
if row and len(row) == 3:
|
||||
lock_pid, lock_heartbeat, lock_timeout = row
|
||||
if lock_pid == pid:
|
||||
# Same process re-acquiring - allowed (refresh)
|
||||
@@ -202,9 +201,7 @@ async def acquire_scheduler_lock(db: aiosqlite.Connection) -> bool:
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to acquire scheduler lock due to database error: {e}"
|
||||
) from e
|
||||
raise RuntimeError(f"Failed to acquire scheduler lock due to database error: {e}") from e
|
||||
|
||||
|
||||
async def release_scheduler_lock(db: aiosqlite.Connection) -> None:
|
||||
@@ -372,9 +369,7 @@ async def get_lock_health(db: aiosqlite.Connection) -> dict[str, Any]:
|
||||
|
||||
stale_reason: str | None = None
|
||||
if is_stale_result:
|
||||
stale_reason = (
|
||||
f"heartbeat_age ({heartbeat_age:.1f}s) > timeout ({info['heartbeat_timeout']:.1f}s)"
|
||||
)
|
||||
stale_reason = f"heartbeat_age ({heartbeat_age:.1f}s) > timeout ({info['heartbeat_timeout']:.1f}s)"
|
||||
|
||||
return {
|
||||
"has_lock": True,
|
||||
|
||||
Reference in New Issue
Block a user