Fix HIGH priority issues: unbounded queries, rate limiting, health checks
Issue #3 - Unbounded Query Results (OOM): - get_all_archived_history() now uses keyset pagination with bounded max_rows (50k default) - Added 'id' field to records from get_archived_history() and get_archived_history_keyset() - Protocol signature updated with page_size, max_rows, last_ban_id params Issue #7 - Docker Health Check Fails: - Added curl to Dockerfile.backend runtime image - HEALTHCHECK now uses 'curl -f http://localhost:8000/api/health' - compose.prod.yml: increased start_period to 40s, timeout to 10s - Frontend healthcheck proxies to backend /api/health Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
141
backend/app/mappers/blocklist_mappers.py
Normal file
141
backend/app/mappers/blocklist_mappers.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""Blocklist response mappers.
|
||||
|
||||
Convert domain models (from blocklist_service) to response models (for HTTP API).
|
||||
|
||||
This is the mapping layer at the router boundary, ensuring the service layer
|
||||
remains independent of HTTP response shapes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.models.blocklist import (
|
||||
BlocklistSource,
|
||||
ImportLogEntry,
|
||||
ImportLogListResponse,
|
||||
ImportRunResult,
|
||||
ImportSourceResult,
|
||||
PreviewResponse,
|
||||
ScheduleConfig,
|
||||
ScheduleFrequency,
|
||||
ScheduleInfo,
|
||||
)
|
||||
from app.models.blocklist_domain import (
|
||||
DomainBlocklistSource,
|
||||
DomainImportLogEntry,
|
||||
DomainImportLogList,
|
||||
DomainImportRunResult,
|
||||
DomainImportSourceResult,
|
||||
DomainPreviewResult,
|
||||
DomainScheduleConfig,
|
||||
DomainScheduleFrequency,
|
||||
DomainScheduleInfo,
|
||||
)
|
||||
from app.utils.pagination import create_pagination_metadata
|
||||
|
||||
|
||||
def map_domain_blocklist_source_to_response(
|
||||
domain: DomainBlocklistSource,
|
||||
) -> BlocklistSource:
|
||||
"""Convert domain blocklist source to response model."""
|
||||
return BlocklistSource(
|
||||
id=domain.id,
|
||||
name=domain.name,
|
||||
url=domain.url,
|
||||
enabled=domain.enabled,
|
||||
created_at=domain.created_at,
|
||||
updated_at=domain.updated_at,
|
||||
)
|
||||
|
||||
|
||||
def map_domain_import_log_entry_to_response(
|
||||
domain: DomainImportLogEntry,
|
||||
) -> ImportLogEntry:
|
||||
"""Convert domain import log entry to response model."""
|
||||
return ImportLogEntry(
|
||||
id=domain.id,
|
||||
source_id=domain.source_id,
|
||||
source_url=domain.source_url,
|
||||
timestamp=domain.timestamp,
|
||||
ips_imported=domain.ips_imported,
|
||||
ips_skipped=domain.ips_skipped,
|
||||
errors=domain.errors,
|
||||
)
|
||||
|
||||
|
||||
def map_domain_import_log_list_to_response(
|
||||
domain_list: DomainImportLogList,
|
||||
) -> ImportLogListResponse:
|
||||
"""Convert domain import log list to response model."""
|
||||
return ImportLogListResponse(
|
||||
items=[map_domain_import_log_entry_to_response(i) for i in domain_list.items],
|
||||
pagination=create_pagination_metadata(
|
||||
domain_list.total, domain_list.page, domain_list.page_size
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def map_domain_schedule_frequency_to_response(
|
||||
domain: DomainScheduleFrequency,
|
||||
) -> ScheduleFrequency:
|
||||
"""Convert domain schedule frequency to response model."""
|
||||
return ScheduleFrequency(domain.value)
|
||||
|
||||
|
||||
def map_domain_schedule_config_to_response(
|
||||
domain: DomainScheduleConfig,
|
||||
) -> ScheduleConfig:
|
||||
"""Convert domain schedule config to response model."""
|
||||
return ScheduleConfig(
|
||||
frequency=map_domain_schedule_frequency_to_response(domain.frequency),
|
||||
interval_hours=domain.interval_hours,
|
||||
hour=domain.hour,
|
||||
minute=domain.minute,
|
||||
day_of_week=domain.day_of_week,
|
||||
)
|
||||
|
||||
|
||||
def map_domain_schedule_info_to_response(domain: DomainScheduleInfo) -> ScheduleInfo:
|
||||
"""Convert domain schedule info to response model."""
|
||||
return ScheduleInfo(
|
||||
config=map_domain_schedule_config_to_response(domain.config),
|
||||
next_run_at=domain.next_run_at,
|
||||
last_run_at=domain.last_run_at,
|
||||
last_run_errors=domain.last_run_errors,
|
||||
)
|
||||
|
||||
|
||||
def map_domain_preview_result_to_response(domain: DomainPreviewResult) -> PreviewResponse:
|
||||
"""Convert domain preview result to response model."""
|
||||
return PreviewResponse(
|
||||
entries=domain.entries,
|
||||
total_lines=domain.total_lines,
|
||||
valid_count=domain.valid_count,
|
||||
skipped_count=domain.skipped_count,
|
||||
)
|
||||
|
||||
|
||||
def map_domain_import_source_result_to_response(
|
||||
domain: DomainImportSourceResult,
|
||||
) -> ImportSourceResult:
|
||||
"""Convert domain import source result to response model."""
|
||||
return ImportSourceResult(
|
||||
source_id=domain.source_id,
|
||||
source_url=domain.source_url,
|
||||
ips_imported=domain.ips_imported,
|
||||
ips_skipped=domain.ips_skipped,
|
||||
error=domain.error,
|
||||
)
|
||||
|
||||
|
||||
def map_domain_import_run_result_to_response(
|
||||
domain: DomainImportRunResult,
|
||||
) -> ImportRunResult:
|
||||
"""Convert domain import run result to response model."""
|
||||
return ImportRunResult(
|
||||
results=[
|
||||
map_domain_import_source_result_to_response(r) for r in domain.results
|
||||
],
|
||||
total_imported=domain.total_imported,
|
||||
total_skipped=domain.total_skipped,
|
||||
errors_count=domain.errors_count,
|
||||
)
|
||||
156
backend/app/mappers/config_mappers.py
Normal file
156
backend/app/mappers/config_mappers.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""Config response mappers.
|
||||
|
||||
Convert domain models (from config_service) to response models (for HTTP API).
|
||||
|
||||
This is the mapping layer at the router boundary, ensuring the service layer
|
||||
remains independent of HTTP response shapes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.models.config import (
|
||||
BantimeEscalation,
|
||||
Fail2BanLogResponse,
|
||||
FilterConfig,
|
||||
FilterListResponse,
|
||||
GlobalConfigResponse,
|
||||
JailConfig,
|
||||
JailConfigListResponse,
|
||||
LogPreviewResponse,
|
||||
MapColorThresholdsResponse,
|
||||
RegexTestResponse,
|
||||
ServiceStatusResponse,
|
||||
)
|
||||
from app.models.config_domain import (
|
||||
DomainBantimeEscalation,
|
||||
DomainFilterConfig,
|
||||
DomainFilterList,
|
||||
DomainGlobalConfig,
|
||||
DomainJailConfig,
|
||||
DomainJailConfigList,
|
||||
DomainMapColorThresholds,
|
||||
DomainRegexTest,
|
||||
DomainServiceStatus,
|
||||
)
|
||||
from app.utils.pagination import create_pagination_metadata
|
||||
|
||||
|
||||
def _map_domain_bantime_escalation(domain: DomainBantimeEscalation) -> BantimeEscalation:
|
||||
"""Convert domain bantime escalation to response model."""
|
||||
return BantimeEscalation(
|
||||
increment=domain.increment,
|
||||
factor=domain.factor,
|
||||
formula=domain.formula,
|
||||
multipliers=domain.multipliers,
|
||||
max_time=domain.max_time,
|
||||
rnd_time=domain.rnd_time,
|
||||
overall_jails=domain.overall_jails,
|
||||
)
|
||||
|
||||
|
||||
def map_domain_jail_config_to_response(domain: DomainJailConfig) -> JailConfig:
|
||||
"""Convert domain jail config to response model."""
|
||||
return JailConfig(
|
||||
name=domain.name,
|
||||
ban_time=domain.ban_time,
|
||||
max_retry=domain.max_retry,
|
||||
find_time=domain.find_time,
|
||||
fail_regex=domain.fail_regex,
|
||||
ignore_regex=domain.ignore_regex,
|
||||
log_paths=domain.log_paths,
|
||||
date_pattern=domain.date_pattern,
|
||||
log_encoding=domain.log_encoding,
|
||||
backend=domain.backend,
|
||||
use_dns=domain.use_dns,
|
||||
prefregex=domain.prefregex,
|
||||
actions=domain.actions,
|
||||
bantime_escalation=(
|
||||
_map_domain_bantime_escalation(domain.bantime_escalation)
|
||||
if domain.bantime_escalation
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def map_domain_jail_config_list_to_response(
|
||||
domain_list: DomainJailConfigList,
|
||||
) -> JailConfigListResponse:
|
||||
"""Convert domain jail config list to response model."""
|
||||
return JailConfigListResponse(
|
||||
items=[map_domain_jail_config_to_response(c) for c in domain_list.items],
|
||||
total=domain_list.total,
|
||||
)
|
||||
|
||||
|
||||
def map_domain_global_config_to_response(domain: DomainGlobalConfig) -> GlobalConfigResponse:
|
||||
"""Convert domain global config to response model."""
|
||||
return GlobalConfigResponse(
|
||||
log_level=domain.log_level,
|
||||
log_target=domain.log_target,
|
||||
db_purge_age=domain.db_purge_age,
|
||||
db_max_matches=domain.db_max_matches,
|
||||
)
|
||||
|
||||
|
||||
def map_domain_service_status_to_response(
|
||||
domain: DomainServiceStatus,
|
||||
) -> ServiceStatusResponse:
|
||||
"""Convert domain service status to response model."""
|
||||
return ServiceStatusResponse(
|
||||
online=domain.online,
|
||||
version=domain.version or "",
|
||||
jail_count=domain.jail_count,
|
||||
total_bans=domain.total_bans,
|
||||
total_failures=domain.total_failures,
|
||||
log_level=domain.log_level or "UNKNOWN",
|
||||
log_target=domain.log_target or "UNKNOWN",
|
||||
)
|
||||
|
||||
|
||||
def map_domain_map_color_thresholds_to_response(
|
||||
domain: DomainMapColorThresholds,
|
||||
) -> MapColorThresholdsResponse:
|
||||
"""Convert domain map color thresholds to response model."""
|
||||
return MapColorThresholdsResponse(
|
||||
threshold_high=domain.threshold_high,
|
||||
threshold_medium=domain.threshold_medium,
|
||||
threshold_low=domain.threshold_low,
|
||||
)
|
||||
|
||||
|
||||
def map_domain_regex_test_to_response(domain: DomainRegexTest) -> RegexTestResponse:
|
||||
"""Convert domain regex test to response model."""
|
||||
return RegexTestResponse(
|
||||
matched=domain.matched,
|
||||
groups=domain.groups,
|
||||
error=domain.error,
|
||||
)
|
||||
|
||||
|
||||
def map_domain_filter_config_to_response(domain: DomainFilterConfig) -> FilterConfig:
|
||||
"""Convert domain filter config to response model."""
|
||||
return FilterConfig(
|
||||
name=domain.name,
|
||||
filename=domain.filename,
|
||||
before=domain.before,
|
||||
after=domain.after,
|
||||
variables=domain.variables or {},
|
||||
prefregex=domain.prefregex,
|
||||
failregex=domain.failregex or [],
|
||||
ignoreregex=domain.ignoreregex or [],
|
||||
maxlines=domain.maxlines,
|
||||
datepattern=domain.datepattern,
|
||||
journalmatch=domain.journalmatch,
|
||||
active=domain.active,
|
||||
used_by_jails=domain.used_by_jails or [],
|
||||
source_file=domain.source_file,
|
||||
has_local_override=domain.has_local_override,
|
||||
)
|
||||
|
||||
|
||||
def map_domain_filter_list_to_response(domain_list: DomainFilterList) -> FilterListResponse:
|
||||
"""Convert domain filter list to response model."""
|
||||
return FilterListResponse(
|
||||
items=[map_domain_filter_config_to_response(f) for f in domain_list.items],
|
||||
total=domain_list.total,
|
||||
)
|
||||
23
backend/app/mappers/health_mappers.py
Normal file
23
backend/app/mappers/health_mappers.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Health response mappers.
|
||||
|
||||
Convert domain models (from health_service) to response models (for HTTP API).
|
||||
|
||||
This is the mapping layer at the router boundary, ensuring the service layer
|
||||
remains independent of HTTP response shapes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.models.health_domain import DomainServerStatus
|
||||
from app.models.server import ServerStatus
|
||||
|
||||
|
||||
def map_domain_server_status_to_response(domain: DomainServerStatus) -> ServerStatus:
|
||||
"""Convert domain server status to response model."""
|
||||
return ServerStatus(
|
||||
online=domain.online,
|
||||
version=domain.version,
|
||||
active_jails=domain.active_jails,
|
||||
total_bans=domain.total_bans,
|
||||
total_failures=domain.total_failures,
|
||||
)
|
||||
81
backend/app/mappers/history_mappers.py
Normal file
81
backend/app/mappers/history_mappers.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""History response mappers.
|
||||
|
||||
Convert domain models (from history_service) to response models (for HTTP API).
|
||||
|
||||
This is the mapping layer at the router boundary, ensuring the service layer
|
||||
remains independent of HTTP response shapes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.models.history import (
|
||||
HistoryBanItem,
|
||||
HistoryListResponse,
|
||||
IpDetailResponse,
|
||||
IpTimelineEvent,
|
||||
)
|
||||
from app.models.history_domain import (
|
||||
DomainHistoryBanItem,
|
||||
DomainHistoryList,
|
||||
DomainIpDetail,
|
||||
DomainIpTimelineEvent,
|
||||
)
|
||||
from app.utils.pagination import create_pagination_metadata
|
||||
|
||||
|
||||
def map_domain_history_ban_item_to_response(
|
||||
domain: DomainHistoryBanItem,
|
||||
) -> HistoryBanItem:
|
||||
"""Convert domain history ban item to response model."""
|
||||
return HistoryBanItem(
|
||||
ip=domain.ip,
|
||||
jail=domain.jail,
|
||||
banned_at=domain.banned_at,
|
||||
ban_count=domain.ban_count,
|
||||
failures=domain.failures,
|
||||
matches=domain.matches or [],
|
||||
country_code=domain.country_code,
|
||||
country_name=domain.country_name,
|
||||
asn=domain.asn,
|
||||
org=domain.org,
|
||||
)
|
||||
|
||||
|
||||
def map_domain_history_list_to_response(domain: DomainHistoryList) -> HistoryListResponse:
|
||||
"""Convert domain history list to response model."""
|
||||
return HistoryListResponse(
|
||||
items=[map_domain_history_ban_item_to_response(i) for i in domain.items],
|
||||
pagination=create_pagination_metadata(
|
||||
domain.total, domain.page, domain.page_size
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def map_domain_ip_timeline_event_to_response(
|
||||
domain: DomainIpTimelineEvent,
|
||||
) -> IpTimelineEvent:
|
||||
"""Convert domain IP timeline event to response model."""
|
||||
return IpTimelineEvent(
|
||||
jail=domain.jail,
|
||||
banned_at=domain.banned_at,
|
||||
ban_count=domain.ban_count,
|
||||
failures=domain.failures,
|
||||
matches=domain.matches or [],
|
||||
)
|
||||
|
||||
|
||||
def map_domain_ip_detail_to_response(domain: DomainIpDetail) -> IpDetailResponse:
|
||||
"""Convert domain IP detail to response model."""
|
||||
return IpDetailResponse(
|
||||
ip=domain.ip,
|
||||
total_bans=domain.total_bans,
|
||||
total_failures=domain.total_failures,
|
||||
last_ban_at=domain.last_ban_at,
|
||||
country_code=domain.country_code,
|
||||
country_name=domain.country_name,
|
||||
asn=domain.asn,
|
||||
org=domain.org,
|
||||
timeline=[
|
||||
map_domain_ip_timeline_event_to_response(t) for t in (domain.timeline or [])
|
||||
],
|
||||
)
|
||||
133
backend/app/mappers/jail_mappers.py
Normal file
133
backend/app/mappers/jail_mappers.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""Jail response mappers.
|
||||
|
||||
Convert domain models (from jail_service) to response models (for HTTP API).
|
||||
|
||||
This is the mapping layer at the router boundary, ensuring the service layer
|
||||
remains independent of HTTP response shapes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.models.ban import ActiveBan, JailBannedIpsResponse
|
||||
from app.models.ban_domain import DomainActiveBan
|
||||
from app.models.jail import (
|
||||
Jail,
|
||||
JailDetailResponse,
|
||||
JailListResponse,
|
||||
JailStatus,
|
||||
JailSummary,
|
||||
)
|
||||
from app.models.jail_domain import (
|
||||
DomainJailBannedIps,
|
||||
DomainBantimeEscalation,
|
||||
DomainJail,
|
||||
DomainJailDetail,
|
||||
DomainJailList,
|
||||
DomainJailStatus,
|
||||
DomainJailSummary,
|
||||
)
|
||||
from app.utils.pagination import create_pagination_metadata
|
||||
|
||||
|
||||
def _map_domain_jail_status(domain: DomainJailStatus) -> JailStatus:
|
||||
"""Convert domain jail status to response model."""
|
||||
return JailStatus(
|
||||
currently_banned=domain.currently_banned,
|
||||
total_banned=domain.total_banned,
|
||||
currently_failed=domain.currently_failed,
|
||||
total_failed=domain.total_failed,
|
||||
)
|
||||
|
||||
|
||||
def _map_domain_bantime_escalation(domain: DomainBantimeEscalation) -> object:
|
||||
"""Convert domain bantime escalation to response model."""
|
||||
from app.models.config import BantimeEscalation
|
||||
|
||||
return BantimeEscalation(
|
||||
increment=domain.increment,
|
||||
factor=domain.factor,
|
||||
formula=domain.formula,
|
||||
multipliers=domain.multipliers,
|
||||
max_time=domain.max_time,
|
||||
rnd_time=domain.rnd_time,
|
||||
overall_jails=domain.overall_jails,
|
||||
)
|
||||
|
||||
|
||||
def map_domain_jail_summary_to_response(domain: DomainJailSummary) -> JailSummary:
|
||||
"""Convert domain jail summary to response model."""
|
||||
return JailSummary(
|
||||
name=domain.name,
|
||||
enabled=domain.enabled,
|
||||
running=domain.running,
|
||||
idle=domain.idle,
|
||||
backend=domain.backend,
|
||||
find_time=domain.find_time,
|
||||
ban_time=domain.ban_time,
|
||||
max_retry=domain.max_retry,
|
||||
status=_map_domain_jail_status(domain.status) if domain.status else None,
|
||||
)
|
||||
|
||||
|
||||
def map_domain_jail_list_to_response(domain_list: DomainJailList) -> JailListResponse:
|
||||
"""Convert domain jail list to response model."""
|
||||
return JailListResponse(
|
||||
items=[map_domain_jail_summary_to_response(j) for j in domain_list.items],
|
||||
total=domain_list.total,
|
||||
)
|
||||
|
||||
|
||||
def map_domain_jail_to_response(domain: DomainJail) -> Jail:
|
||||
"""Convert domain jail to response model."""
|
||||
return Jail(
|
||||
name=domain.name,
|
||||
enabled=domain.enabled,
|
||||
running=domain.running,
|
||||
idle=domain.idle,
|
||||
backend=domain.backend,
|
||||
log_paths=domain.log_paths,
|
||||
fail_regex=domain.fail_regex,
|
||||
ignore_regex=domain.ignore_regex,
|
||||
ignore_ips=domain.ignore_ips,
|
||||
date_pattern=domain.date_pattern,
|
||||
log_encoding=domain.log_encoding,
|
||||
find_time=domain.find_time,
|
||||
ban_time=domain.ban_time,
|
||||
max_retry=domain.max_retry,
|
||||
actions=domain.actions,
|
||||
bantime_escalation=(
|
||||
_map_domain_bantime_escalation(domain.bantime_escalation)
|
||||
if domain.bantime_escalation
|
||||
else None
|
||||
),
|
||||
status=_map_domain_jail_status(domain.status) if domain.status else None,
|
||||
)
|
||||
|
||||
|
||||
def map_domain_jail_detail_to_response(domain: DomainJailDetail) -> JailDetailResponse:
|
||||
"""Convert domain jail detail to response model."""
|
||||
return JailDetailResponse(
|
||||
jail=map_domain_jail_to_response(domain.jail),
|
||||
ignore_list=domain.ignore_list,
|
||||
ignore_self=domain.ignore_self,
|
||||
)
|
||||
|
||||
|
||||
def map_domain_jail_banned_ips_to_response(
|
||||
domain: DomainJailBannedIps,
|
||||
) -> JailBannedIpsResponse:
|
||||
"""Convert domain jail banned IPs to response model."""
|
||||
return JailBannedIpsResponse(
|
||||
items=[
|
||||
ActiveBan(
|
||||
ip=ban.ip,
|
||||
jail=ban.jail,
|
||||
banned_at=ban.banned_at,
|
||||
expires_at=ban.expires_at,
|
||||
ban_count=ban.ban_count,
|
||||
country=ban.country,
|
||||
)
|
||||
for ban in domain.items
|
||||
],
|
||||
pagination=create_pagination_metadata(domain.total, domain.page, domain.page_size),
|
||||
)
|
||||
37
backend/app/mappers/server_mappers.py
Normal file
37
backend/app/mappers/server_mappers.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""Server response mappers.
|
||||
|
||||
Convert domain models (from server_service) to response models (for HTTP API).
|
||||
|
||||
This is the mapping layer at the router boundary, ensuring the service layer
|
||||
remains independent of HTTP response shapes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.models.server import ServerSettings, ServerSettingsResponse, ServerSettingsUpdate
|
||||
from app.models.server_domain import DomainServerSettings, DomainServerSettingsResult
|
||||
from app.utils.pagination import create_pagination_metadata
|
||||
|
||||
|
||||
def map_domain_server_settings_to_response(
|
||||
domain_settings: DomainServerSettings,
|
||||
) -> ServerSettings:
|
||||
"""Convert domain server settings to response model."""
|
||||
return ServerSettings(
|
||||
log_level=domain_settings.log_level,
|
||||
log_target=domain_settings.log_target,
|
||||
syslog_socket=domain_settings.syslog_socket,
|
||||
db_path=domain_settings.db_path,
|
||||
db_purge_age=domain_settings.db_purge_age,
|
||||
db_max_matches=domain_settings.db_max_matches,
|
||||
)
|
||||
|
||||
|
||||
def map_domain_server_settings_result_to_response(
|
||||
domain_result: DomainServerSettingsResult,
|
||||
) -> ServerSettingsResponse:
|
||||
"""Convert domain server settings result to response model."""
|
||||
return ServerSettingsResponse(
|
||||
settings=map_domain_server_settings_to_response(domain_result.settings),
|
||||
warnings=domain_result.warnings,
|
||||
)
|
||||
108
backend/app/models/blocklist_domain.py
Normal file
108
backend/app/models/blocklist_domain.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""Blocklist domain models.
|
||||
|
||||
Internal domain-focused models used by blocklist_service. These represent the
|
||||
business domain layer and are independent of HTTP response shapes.
|
||||
|
||||
Response models are defined in `app.models.blocklist` and mappers convert domain
|
||||
models to response models at the router boundary.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DomainBlocklistSource:
|
||||
"""Blocklist source definition (domain model)."""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
url: str
|
||||
enabled: bool
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DomainImportLogEntry:
|
||||
"""A single blocklist import run record (domain model)."""
|
||||
|
||||
id: int
|
||||
source_id: int | None
|
||||
source_url: str
|
||||
timestamp: str
|
||||
ips_imported: int
|
||||
ips_skipped: int
|
||||
errors: str | None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DomainImportLogList:
|
||||
"""Paginated list of import log entries (domain model)."""
|
||||
|
||||
items: list[DomainImportLogEntry]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
|
||||
|
||||
class DomainScheduleFrequency(StrEnum):
|
||||
"""Available import schedule frequency presets (domain model)."""
|
||||
|
||||
hourly = "hourly"
|
||||
daily = "daily"
|
||||
weekly = "weekly"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DomainScheduleConfig:
|
||||
"""Import schedule configuration (domain model)."""
|
||||
|
||||
frequency: DomainScheduleFrequency
|
||||
interval_hours: int = 24
|
||||
hour: int = 3
|
||||
minute: int = 0
|
||||
day_of_week: int = 0
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DomainScheduleInfo:
|
||||
"""Current schedule configuration with runtime metadata (domain model)."""
|
||||
|
||||
config: DomainScheduleConfig
|
||||
next_run_at: str | None = None
|
||||
last_run_at: str | None = None
|
||||
last_run_errors: bool | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DomainPreviewResult:
|
||||
"""Result of previewing a blocklist URL (domain model)."""
|
||||
|
||||
entries: list[str]
|
||||
total_lines: int
|
||||
valid_count: int
|
||||
skipped_count: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DomainImportSourceResult:
|
||||
"""Result of importing a single blocklist source (domain model)."""
|
||||
|
||||
source_id: int | None
|
||||
source_url: str
|
||||
ips_imported: int
|
||||
ips_skipped: int
|
||||
error: str | None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DomainImportRunResult:
|
||||
"""Aggregated result from a full import run (domain model)."""
|
||||
|
||||
results: list[DomainImportSourceResult]
|
||||
total_imported: int
|
||||
total_skipped: int
|
||||
errors_count: int
|
||||
130
backend/app/models/config_domain.py
Normal file
130
backend/app/models/config_domain.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""Config domain models.
|
||||
|
||||
Internal domain-focused models used by config_service. These represent the
|
||||
business domain layer and are independent of HTTP response shapes.
|
||||
|
||||
Response models are defined in `app.models.config` and mappers convert domain
|
||||
models to response models at the router boundary.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
|
||||
DNSMode = Literal["yes", "warn", "no", "raw"]
|
||||
LogEncoding = Literal["auto", "ascii", "utf-8", "UTF-8", "latin-1"]
|
||||
BackendType = Literal["auto", "polling", "pyinotify", "systemd", "gamin"]
|
||||
LogLevel = Literal["CRITICAL", "ERROR", "WARNING", "NOTICE", "INFO", "DEBUG"]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DomainBantimeEscalation:
|
||||
"""Incremental ban-time escalation configuration (domain model)."""
|
||||
|
||||
increment: bool = False
|
||||
factor: float | None = None
|
||||
formula: str | None = None
|
||||
multipliers: str | None = None
|
||||
max_time: int | None = None
|
||||
rnd_time: int | None = None
|
||||
overall_jails: bool = False
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DomainJailConfig:
|
||||
"""Configuration snapshot of a single jail (domain model)."""
|
||||
|
||||
name: str
|
||||
ban_time: int
|
||||
max_retry: int
|
||||
find_time: int
|
||||
fail_regex: list[str]
|
||||
ignore_regex: list[str]
|
||||
log_paths: list[str]
|
||||
actions: list[str]
|
||||
date_pattern: str | None = None
|
||||
log_encoding: LogEncoding = "UTF-8"
|
||||
backend: BackendType = "polling"
|
||||
use_dns: DNSMode = "warn"
|
||||
prefregex: str = ""
|
||||
bantime_escalation: DomainBantimeEscalation | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DomainJailConfigList:
|
||||
"""List of jail configurations (domain model)."""
|
||||
|
||||
items: list[DomainJailConfig]
|
||||
total: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DomainGlobalConfig:
|
||||
"""Global fail2ban settings (domain model)."""
|
||||
|
||||
log_level: LogLevel
|
||||
log_target: str
|
||||
db_purge_age: int
|
||||
db_max_matches: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DomainServiceStatus:
|
||||
"""Fail2ban service health status (domain model)."""
|
||||
|
||||
online: bool
|
||||
version: str | None = None
|
||||
jail_count: int = 0
|
||||
total_bans: int = 0
|
||||
total_failures: int = 0
|
||||
log_level: str | None = None
|
||||
log_target: str | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DomainMapColorThresholds:
|
||||
"""Map color threshold configuration (domain model)."""
|
||||
|
||||
threshold_high: int
|
||||
threshold_medium: int
|
||||
threshold_low: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DomainRegexTest:
|
||||
"""Result of a regex test (domain model)."""
|
||||
|
||||
matched: bool
|
||||
groups: list[str]
|
||||
error: str | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DomainFilterConfig:
|
||||
"""Structured representation of a filter.d/*.conf file (domain model)."""
|
||||
|
||||
name: str
|
||||
filename: str
|
||||
before: str | None = None
|
||||
after: str | None = None
|
||||
variables: dict[str, str] | None = None
|
||||
prefregex: str | None = None
|
||||
failregex: list[str] | None = None
|
||||
ignoreregex: list[str] | None = None
|
||||
maxlines: int | None = None
|
||||
datepattern: str | None = None
|
||||
journalmatch: str | None = None
|
||||
active: bool = False
|
||||
used_by_jails: list[str] | None = None
|
||||
source_file: str = ""
|
||||
has_local_override: bool = False
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DomainFilterList:
|
||||
"""List of filter configurations (domain model)."""
|
||||
|
||||
items: list[DomainFilterConfig]
|
||||
total: int
|
||||
23
backend/app/models/health_domain.py
Normal file
23
backend/app/models/health_domain.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Health domain models.
|
||||
|
||||
Internal domain-focused models used by health_service. These represent the
|
||||
business domain layer and are independent of HTTP response shapes.
|
||||
|
||||
Response models are defined in `app.models.config` and mappers convert domain
|
||||
models to response models at the router boundary.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DomainServerStatus:
|
||||
"""Cached fail2ban server health snapshot (domain model)."""
|
||||
|
||||
online: bool
|
||||
version: str | None = None
|
||||
active_jails: int = 0
|
||||
total_bans: int = 0
|
||||
total_failures: int = 0
|
||||
64
backend/app/models/history_domain.py
Normal file
64
backend/app/models/history_domain.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""History domain models.
|
||||
|
||||
Internal domain-focused models used by history_service. These represent the
|
||||
business domain layer and are independent of HTTP response shapes.
|
||||
|
||||
Response models are defined in `app.models.history` and mappers convert domain
|
||||
models to response models at the router boundary.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DomainHistoryBanItem:
|
||||
"""A single row in the history ban-list table (domain model)."""
|
||||
|
||||
ip: str
|
||||
jail: str
|
||||
banned_at: str
|
||||
ban_count: int
|
||||
failures: int = 0
|
||||
matches: list[str] | None = None
|
||||
country_code: str | None = None
|
||||
country_name: str | None = None
|
||||
asn: str | None = None
|
||||
org: str | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DomainHistoryList:
|
||||
"""Paginated history ban-list (domain model)."""
|
||||
|
||||
items: list[DomainHistoryBanItem]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DomainIpTimelineEvent:
|
||||
"""A single ban event in a per-IP timeline (domain model)."""
|
||||
|
||||
jail: str
|
||||
banned_at: str
|
||||
ban_count: int
|
||||
failures: int = 0
|
||||
matches: list[str] | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DomainIpDetail:
|
||||
"""Full historical record for a single IP address (domain model)."""
|
||||
|
||||
ip: str
|
||||
total_bans: int
|
||||
total_failures: int
|
||||
last_ban_at: str | None = None
|
||||
country_code: str | None = None
|
||||
country_name: str | None = None
|
||||
asn: str | None = None
|
||||
org: str | None = None
|
||||
timeline: list[DomainIpTimelineEvent] | None = None
|
||||
112
backend/app/models/jail_domain.py
Normal file
112
backend/app/models/jail_domain.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""Jail domain models.
|
||||
|
||||
Internal domain-focused models used by jail_service. These represent the
|
||||
business domain layer and are independent of HTTP response shapes.
|
||||
|
||||
Response models are defined in `app.models.jail` and mappers convert domain
|
||||
models to response models at the router boundary.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DomainJailStatus:
|
||||
"""Runtime metrics for a single jail (domain model)."""
|
||||
|
||||
currently_banned: int
|
||||
total_banned: int
|
||||
currently_failed: int
|
||||
total_failed: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DomainBantimeEscalation:
|
||||
"""Incremental ban-time escalation configuration (domain model)."""
|
||||
|
||||
increment: bool = False
|
||||
factor: float | None = None
|
||||
formula: str | None = None
|
||||
multipliers: str | None = None
|
||||
max_time: int | None = None
|
||||
rnd_time: int | None = None
|
||||
overall_jails: bool = False
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DomainJailSummary:
|
||||
"""Lightweight jail entry for the overview list (domain model)."""
|
||||
|
||||
name: str
|
||||
enabled: bool
|
||||
running: bool
|
||||
idle: bool
|
||||
backend: str
|
||||
find_time: int
|
||||
ban_time: int
|
||||
max_retry: int
|
||||
status: DomainJailStatus | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DomainJailList:
|
||||
"""List of active jails (domain model)."""
|
||||
|
||||
items: list[DomainJailSummary]
|
||||
total: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DomainJail:
|
||||
"""Full jail configuration (domain model)."""
|
||||
|
||||
name: str
|
||||
enabled: bool
|
||||
running: bool
|
||||
idle: bool
|
||||
backend: str
|
||||
log_paths: list[str]
|
||||
fail_regex: list[str]
|
||||
ignore_regex: list[str]
|
||||
ignore_ips: list[str]
|
||||
find_time: int
|
||||
ban_time: int
|
||||
max_retry: int
|
||||
actions: list[str]
|
||||
date_pattern: str | None = None
|
||||
log_encoding: str = "UTF-8"
|
||||
bantime_escalation: DomainBantimeEscalation | None = None
|
||||
status: DomainJailStatus | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DomainActiveBan:
|
||||
"""A currently active ban entry from a jail (domain model)."""
|
||||
|
||||
ip: str
|
||||
jail: str
|
||||
banned_at: str | None = None
|
||||
expires_at: str | None = None
|
||||
ban_count: int = 1
|
||||
country: str | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DomainJailBannedIps:
|
||||
"""Paginated list of currently banned IPs for a jail (domain model)."""
|
||||
|
||||
items: list[DomainActiveBan]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DomainJailDetail:
|
||||
"""Full jail with supplemental metadata (domain model)."""
|
||||
|
||||
jail: DomainJail
|
||||
ignore_list: list[str]
|
||||
ignore_self: bool
|
||||
32
backend/app/models/server_domain.py
Normal file
32
backend/app/models/server_domain.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""Server domain models.
|
||||
|
||||
Internal domain-focused models used by server_service. These represent the
|
||||
business domain layer and are independent of HTTP response shapes.
|
||||
|
||||
Response models are defined in `app.models.server` and mappers convert domain
|
||||
models to response models at the router boundary.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DomainServerSettings:
|
||||
"""Fail2ban server-level settings (domain model)."""
|
||||
|
||||
log_level: str
|
||||
log_target: str
|
||||
db_path: str
|
||||
db_purge_age: int
|
||||
db_max_matches: int
|
||||
syslog_socket: str | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DomainServerSettingsResult:
|
||||
"""Server settings with warnings (domain model)."""
|
||||
|
||||
settings: DomainServerSettings
|
||||
warnings: dict[str, bool]
|
||||
@@ -107,7 +107,7 @@ async def get_archived_history(
|
||||
total = int(row[0]) if row is not None and row[0] is not None else 0
|
||||
|
||||
async with db.execute(
|
||||
"SELECT jail, ip, timeofban, bancount, data, action "
|
||||
"SELECT id, jail, ip, timeofban, bancount, data, action "
|
||||
"FROM history_archive "
|
||||
f"{where_sql} "
|
||||
"ORDER BY timeofban DESC LIMIT ? OFFSET ?",
|
||||
@@ -117,12 +117,13 @@ async def get_archived_history(
|
||||
|
||||
records = [
|
||||
{
|
||||
"jail": str(r[0]),
|
||||
"ip": str(r[1]),
|
||||
"timeofban": int(r[2]),
|
||||
"bancount": int(r[3]),
|
||||
"data": str(r[4]),
|
||||
"action": str(r[5]),
|
||||
"id": int(r[0]),
|
||||
"jail": str(r[1]),
|
||||
"ip": str(r[2]),
|
||||
"timeofban": int(r[3]),
|
||||
"bancount": int(r[4]),
|
||||
"data": str(r[5]),
|
||||
"action": str(r[6]),
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
@@ -137,29 +138,59 @@ async def get_all_archived_history(
|
||||
ip_filter: str | list[str] | None = None,
|
||||
origin: BanOrigin | None = None,
|
||||
action: str | None = None,
|
||||
page_size: int = 1000,
|
||||
max_rows: int = 50_000,
|
||||
last_ban_id: int | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Return all archived history rows for the given filters."""
|
||||
page: int = 1
|
||||
page_size: int = 500
|
||||
"""Return archived history rows for the given filters, bounded to *max_rows*.
|
||||
|
||||
Uses keyset pagination internally for constant-time performance regardless
|
||||
of how deep into the result set we go. The caller must provide *last_ban_id*
|
||||
from the previous call to continue pagination; ``None`` starts fresh.
|
||||
|
||||
Args:
|
||||
page_size: Number of rows to fetch per internal batch (default 1000).
|
||||
max_rows: Hard cap on total rows returned (default 50 000). When
|
||||
reached the function returns even if more rows exist. Pass ``0``
|
||||
to request zero rows (useful for count-only callers).
|
||||
last_ban_id: Cursor from the previous call. ``None`` for the first
|
||||
call — the result set will start from the newest row.
|
||||
"""
|
||||
if max_rows <= 0:
|
||||
return []
|
||||
|
||||
all_rows: list[dict[str, Any]] = []
|
||||
current_last_ban_id: int | None = last_ban_id
|
||||
|
||||
while True:
|
||||
rows, total = await get_archived_history(
|
||||
batch, has_more = await get_archived_history_keyset(
|
||||
db=db,
|
||||
since=since,
|
||||
jail=jail,
|
||||
ip_filter=ip_filter,
|
||||
origin=origin,
|
||||
action=action,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
last_ban_id=current_last_ban_id,
|
||||
)
|
||||
all_rows.extend(rows)
|
||||
if len(rows) < page_size:
|
||||
if not batch:
|
||||
break
|
||||
all_rows.extend(batch)
|
||||
if len(all_rows) >= max_rows:
|
||||
break
|
||||
if not has_more:
|
||||
break
|
||||
# Use the id of the last row in the batch as the next cursor.
|
||||
# Rows are ordered id DESC, so the last row has the smallest id
|
||||
# seen in this batch and is the correct keyset anchor.
|
||||
last_row = batch[-1]
|
||||
current_last_ban_id = last_row.get("id")
|
||||
if current_last_ban_id is None:
|
||||
# Fallback: determine id from the WHERE clause of the previous query.
|
||||
# If we somehow cannot determine the id, stop to avoid an infinite loop.
|
||||
break
|
||||
page += 1
|
||||
|
||||
return all_rows
|
||||
return all_rows[:max_rows]
|
||||
|
||||
|
||||
async def purge_archived_history(db: aiosqlite.Connection, age_seconds: int) -> int:
|
||||
@@ -266,6 +297,7 @@ async def get_archived_history_keyset(
|
||||
|
||||
records = [
|
||||
{
|
||||
"id": int(r[0]),
|
||||
"jail": str(r[1]),
|
||||
"ip": str(r[2]),
|
||||
"timeofban": int(r[3]),
|
||||
|
||||
@@ -292,6 +292,9 @@ class HistoryArchiveRepository(Protocol):
|
||||
ip_filter: str | list[str] | None = None,
|
||||
origin: BanOrigin | None = None,
|
||||
action: str | None = None,
|
||||
page_size: int = 1000,
|
||||
max_rows: int = 50_000,
|
||||
last_ban_id: int | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
...
|
||||
|
||||
|
||||
@@ -34,6 +34,7 @@ from app.dependencies import (
|
||||
SettingsDep,
|
||||
)
|
||||
from app.exceptions import BadRequestError, BlocklistSourceNotFoundError
|
||||
from app.mappers import blocklist_mappers
|
||||
from app.models.blocklist import (
|
||||
BlocklistListResponse,
|
||||
BlocklistSource,
|
||||
@@ -370,6 +371,7 @@ async def preview_blocklist(
|
||||
raise BlocklistSourceNotFoundError(source_id)
|
||||
|
||||
try:
|
||||
return await blocklist_service.preview_source(source.url, http_session)
|
||||
domain_result = await blocklist_service.preview_source(source.url, http_session)
|
||||
return blocklist_mappers.map_domain_preview_result_to_response(domain_result)
|
||||
except ValueError as exc:
|
||||
raise BadRequestError(f"Could not fetch blocklist: {exc}") from exc
|
||||
|
||||
@@ -27,6 +27,7 @@ from app.models.config import (
|
||||
RegexTestResponse,
|
||||
ServiceStatusResponse,
|
||||
)
|
||||
from app.mappers import config_mappers
|
||||
from app.services import (
|
||||
config_service,
|
||||
jail_service,
|
||||
@@ -94,7 +95,8 @@ async def get_global_config(
|
||||
Raises:
|
||||
HTTPException: 502 when fail2ban is unreachable.
|
||||
"""
|
||||
return await config_service.get_global_config(socket_path)
|
||||
domain_result = await config_service.get_global_config(socket_path)
|
||||
return config_mappers.map_domain_global_config_to_response(domain_result)
|
||||
|
||||
|
||||
@router.put(
|
||||
@@ -400,7 +402,8 @@ async def get_service_status(
|
||||
"""
|
||||
from app.services import health_service
|
||||
|
||||
return await health_service.get_service_status(
|
||||
domain_result = await health_service.get_service_status(
|
||||
socket_path,
|
||||
probe_fn=health_service.probe,
|
||||
)
|
||||
return config_mappers.map_domain_service_status_to_response(domain_result)
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Annotated
|
||||
from fastapi import APIRouter, Path, Query, Request, status
|
||||
|
||||
from app.dependencies import AuthDep, Fail2BanConfigDirDep, Fail2BanSocketDep
|
||||
from app.mappers import config_mappers
|
||||
from app.models.config import (
|
||||
FilterConfig,
|
||||
FilterCreateRequest,
|
||||
@@ -50,10 +51,10 @@ async def list_filters(
|
||||
:class:`~app.models.config.FilterListResponse` with all discovered
|
||||
filters.
|
||||
"""
|
||||
result = await filter_config_service.list_filters(config_dir, socket_path)
|
||||
domain_result = await filter_config_service.list_filters(config_dir, socket_path)
|
||||
# Sort: active first (by name), then inactive (by name).
|
||||
result.filters.sort(key=lambda f: (not f.active, f.name.lower()))
|
||||
return result
|
||||
domain_result.items.sort(key=lambda f: (not f.active, f.name.lower()))
|
||||
return config_mappers.map_domain_filter_list_to_response(domain_result)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ from app.dependencies import (
|
||||
HttpSessionDep,
|
||||
)
|
||||
from app.exceptions import HistoryNotFoundError
|
||||
from app.mappers import history_mappers
|
||||
from app.models._common import TimeRange
|
||||
from app.models.ban import BanOrigin
|
||||
from app.models.history import HistoryListResponse, IpDetailResponse
|
||||
@@ -99,7 +100,7 @@ async def get_history(
|
||||
and the total matching count.
|
||||
"""
|
||||
|
||||
return await history_service.list_history(
|
||||
domain_result = await history_service.list_history(
|
||||
socket_path,
|
||||
range_=range,
|
||||
jail=jail,
|
||||
@@ -112,6 +113,7 @@ async def get_history(
|
||||
db=history_ctx.db,
|
||||
fail2ban_metadata_service=fail2ban_metadata_service,
|
||||
)
|
||||
return history_mappers.map_domain_history_list_to_response(domain_result)
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -136,7 +138,7 @@ async def get_history_archive(
|
||||
page_size: int = Query(default=DEFAULT_PAGE_SIZE, ge=1, le=500, description="Items per page (max 500)."),
|
||||
) -> HistoryListResponse:
|
||||
|
||||
return await history_service.list_history(
|
||||
domain_result = await history_service.list_history(
|
||||
socket_path,
|
||||
range_=range,
|
||||
jail=jail,
|
||||
@@ -148,6 +150,7 @@ async def get_history_archive(
|
||||
db=history_ctx.db,
|
||||
fail2ban_metadata_service=fail2ban_metadata_service,
|
||||
)
|
||||
return history_mappers.map_domain_history_list_to_response(domain_result)
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -182,14 +185,14 @@ async def get_ip_history(
|
||||
HTTPException: 404 if the IP has no history in the database.
|
||||
"""
|
||||
|
||||
detail: IpDetailResponse | None = await history_service.get_ip_detail(
|
||||
domain_result = await history_service.get_ip_detail(
|
||||
socket_path,
|
||||
ip,
|
||||
http_session=http_session,
|
||||
fail2ban_metadata_service=fail2ban_metadata_service,
|
||||
)
|
||||
|
||||
if detail is None:
|
||||
if domain_result is None:
|
||||
raise HistoryNotFoundError(ip)
|
||||
|
||||
return detail
|
||||
return history_mappers.map_domain_ip_detail_to_response(domain_result)
|
||||
|
||||
@@ -15,6 +15,7 @@ from app.dependencies import (
|
||||
PendingRecoveryDep,
|
||||
)
|
||||
from app.exceptions import BadRequestError
|
||||
from app.mappers import config_mappers
|
||||
from app.models.config import (
|
||||
ActivateJailRequest,
|
||||
AddLogPathRequest,
|
||||
@@ -68,7 +69,8 @@ async def get_jail_configs(
|
||||
Returns:
|
||||
:class:`~app.models.config.JailConfigListResponse`.
|
||||
"""
|
||||
return await config_service.list_jail_configs(socket_path)
|
||||
domain_result = await config_service.list_jail_configs(socket_path)
|
||||
return config_mappers.map_domain_jail_config_list_to_response(domain_result)
|
||||
|
||||
|
||||
|
||||
@@ -150,7 +152,8 @@ async def get_jail_config(
|
||||
HTTPException: 404 when the jail does not exist.
|
||||
HTTPException: 502 when fail2ban is unreachable.
|
||||
"""
|
||||
return await config_service.get_jail_config(socket_path, name)
|
||||
domain_result = await config_service.get_jail_config(socket_path, name)
|
||||
return config_mappers.map_domain_jail_config_to_response(domain_result)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ from app.dependencies import (
|
||||
JailServiceStateDep,
|
||||
)
|
||||
from app.exceptions import BadRequestError
|
||||
from app.mappers import jail_mappers
|
||||
from app.models.ban import JailBannedIpsResponse
|
||||
from app.models.jail import (
|
||||
IgnoreIpRequest,
|
||||
@@ -76,7 +77,8 @@ async def get_jails(
|
||||
Returns:
|
||||
:class:`~app.models.jail.JailListResponse` with all active jails.
|
||||
"""
|
||||
return await jail_service.list_jails(socket_path, state)
|
||||
domain_result = await jail_service.list_jails(socket_path, state)
|
||||
return jail_mappers.map_domain_jail_list_to_response(domain_result)
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -106,16 +108,16 @@ async def get_jail(
|
||||
HTTPException: 404 when the jail does not exist.
|
||||
HTTPException: 502 when fail2ban is unreachable.
|
||||
"""
|
||||
jail, ignore_list, ignore_self = await asyncio.gather(
|
||||
jail_detail, ignore_list, ignore_self = await asyncio.gather(
|
||||
jail_service.get_jail(socket_path, name),
|
||||
jail_service.get_ignore_list(socket_path, name),
|
||||
jail_service.get_ignore_self(socket_path, name),
|
||||
)
|
||||
return JailDetailResponse(
|
||||
jail=jail,
|
||||
ignore_list=ignore_list,
|
||||
ignore_self=ignore_self,
|
||||
# Merge ignore_list and ignore_self from dedicated service calls
|
||||
jail_detail_with_ignore = jail_detail.model_copy(
|
||||
update={"ignore_list": ignore_list, "ignore_self": ignore_self}
|
||||
)
|
||||
return jail_mappers.map_domain_jail_detail_to_response(jail_detail_with_ignore)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -474,7 +476,7 @@ async def get_jail_banned_ips(
|
||||
if not (1 <= page_size <= 100):
|
||||
raise BadRequestError("page_size must be between 1 and 100.")
|
||||
|
||||
return await jail_service.get_jail_banned_ips(
|
||||
domain_result = await jail_service.get_jail_banned_ips(
|
||||
socket_path=socket_path,
|
||||
jail_name=name,
|
||||
page=page,
|
||||
@@ -484,3 +486,4 @@ async def get_jail_banned_ips(
|
||||
http_session=http_session,
|
||||
app_db=ban_ctx.db,
|
||||
)
|
||||
return jail_mappers.map_domain_jail_banned_ips_to_response(domain_result)
|
||||
|
||||
@@ -13,6 +13,7 @@ from __future__ import annotations
|
||||
from fastapi import APIRouter, Request, status
|
||||
|
||||
from app.dependencies import AuthDep, Fail2BanSocketDep
|
||||
from app.mappers import server_mappers
|
||||
from app.models.server import ServerSettingsResponse, ServerSettingsUpdate
|
||||
from app.services import server_service
|
||||
|
||||
@@ -49,7 +50,8 @@ async def get_server_settings(
|
||||
Raises:
|
||||
HTTPException: 502 when fail2ban is unreachable.
|
||||
"""
|
||||
return await server_service.get_settings(socket_path)
|
||||
domain_result = await server_service.get_settings(socket_path)
|
||||
return server_mappers.map_domain_server_settings_result_to_response(domain_result)
|
||||
|
||||
|
||||
@router.put(
|
||||
|
||||
@@ -29,19 +29,17 @@ if TYPE_CHECKING:
|
||||
from app.exceptions import ConfigOperationError, ConfigValidationError, JailNotFoundError
|
||||
from app.models.config import (
|
||||
AddLogPathRequest,
|
||||
BantimeEscalation,
|
||||
GlobalConfigResponse,
|
||||
GlobalConfigUpdate,
|
||||
JailConfig,
|
||||
JailConfigListResponse,
|
||||
JailConfigResponse,
|
||||
JailConfigUpdate,
|
||||
LogPreviewRequest,
|
||||
LogPreviewResponse,
|
||||
MapColorThresholdsResponse,
|
||||
MapColorThresholdsUpdate,
|
||||
RegexTestRequest,
|
||||
RegexTestResponse,
|
||||
)
|
||||
from app.models.config_domain import (
|
||||
DomainBantimeEscalation,
|
||||
DomainGlobalConfig,
|
||||
DomainJailConfig,
|
||||
DomainJailConfigList,
|
||||
)
|
||||
from app.services.log_service import preview_log as util_preview_log
|
||||
from app.services.log_service import test_regex as util_test_regex
|
||||
@@ -120,7 +118,7 @@ def _validate_regex(pattern: str) -> str | None:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def get_jail_config(socket_path: str, name: str) -> JailConfigResponse:
|
||||
async def get_jail_config(socket_path: str, name: str) -> DomainJailConfig:
|
||||
"""Return the editable configuration for a single jail.
|
||||
|
||||
Args:
|
||||
@@ -128,7 +126,7 @@ async def get_jail_config(socket_path: str, name: str) -> JailConfigResponse:
|
||||
name: Jail name.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.config.JailConfigResponse`.
|
||||
:class:`~app.models.config_domain.DomainJailConfig`.
|
||||
|
||||
Raises:
|
||||
JailNotFoundError: If *name* is not a known jail.
|
||||
@@ -164,7 +162,7 @@ async def get_jail_config(socket_path: str, name: str) -> JailConfigResponse:
|
||||
bt_rndtime_raw: str | int | None = await _safe_get_typed(client, ["get", name, "bantime.rndtime"], None)
|
||||
bt_overalljails_raw: bool = await _safe_get_typed(client, ["get", name, "bantime.overalljails"], False)
|
||||
|
||||
bantime_escalation = BantimeEscalation(
|
||||
bantime_escalation = DomainBantimeEscalation(
|
||||
increment=bool(bt_increment_raw),
|
||||
factor=float(bt_factor_raw) if bt_factor_raw is not None else None,
|
||||
formula=str(bt_formula_raw) if bt_formula_raw else None,
|
||||
@@ -174,7 +172,7 @@ async def get_jail_config(socket_path: str, name: str) -> JailConfigResponse:
|
||||
overall_jails=bool(bt_overalljails_raw),
|
||||
)
|
||||
|
||||
jail_cfg = JailConfig(
|
||||
jail_cfg = DomainJailConfig(
|
||||
name=name,
|
||||
ban_time=int(bantime_raw or 600),
|
||||
find_time=int(findtime_raw or 600),
|
||||
@@ -192,17 +190,17 @@ async def get_jail_config(socket_path: str, name: str) -> JailConfigResponse:
|
||||
)
|
||||
|
||||
log.info("jail_config_fetched", jail=name)
|
||||
return JailConfigResponse(jail=jail_cfg)
|
||||
return jail_cfg
|
||||
|
||||
|
||||
async def list_jail_configs(socket_path: str) -> JailConfigListResponse:
|
||||
async def list_jail_configs(socket_path: str) -> DomainJailConfigList:
|
||||
"""Return configuration for all active jails.
|
||||
|
||||
Args:
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.config.JailConfigListResponse`.
|
||||
:class:`~app.models.config_domain.DomainJailConfigList`.
|
||||
|
||||
Raises:
|
||||
~app.utils.fail2ban_client.Fail2BanConnectionError: Socket unreachable.
|
||||
@@ -218,16 +216,15 @@ async def list_jail_configs(socket_path: str) -> JailConfigListResponse:
|
||||
)
|
||||
|
||||
if not jail_names:
|
||||
return JailConfigListResponse(items=[], total=0)
|
||||
return DomainJailConfigList(items=[], total=0)
|
||||
|
||||
responses: list[JailConfigResponse] = await asyncio.gather(
|
||||
jail_configs: list[DomainJailConfig] = await asyncio.gather(
|
||||
*[get_jail_config(socket_path, name) for name in jail_names],
|
||||
return_exceptions=False,
|
||||
)
|
||||
|
||||
jails = [r.jail for r in responses]
|
||||
log.info("jail_configs_listed", count=len(jails))
|
||||
return JailConfigListResponse(items=jails, total=len(jails))
|
||||
log.info("jail_configs_listed", count=len(jail_configs))
|
||||
return DomainJailConfigList(items=jail_configs, total=len(jail_configs))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -379,14 +376,14 @@ async def _replace_regex_list(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def get_global_config(socket_path: str) -> GlobalConfigResponse:
|
||||
async def get_global_config(socket_path: str) -> DomainGlobalConfig:
|
||||
"""Return fail2ban global configuration settings.
|
||||
|
||||
Args:
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.config.GlobalConfigResponse`.
|
||||
:class:`~app.models.config_domain.DomainGlobalConfig`.
|
||||
|
||||
Raises:
|
||||
~app.utils.fail2ban_client.Fail2BanConnectionError: Socket unreachable.
|
||||
@@ -405,7 +402,7 @@ async def get_global_config(socket_path: str) -> GlobalConfigResponse:
|
||||
_safe_get_typed(client, ["get", "dbmaxmatches"], 10),
|
||||
)
|
||||
|
||||
return GlobalConfigResponse(
|
||||
return DomainGlobalConfig(
|
||||
log_level=str(log_level_raw or "INFO").upper(),
|
||||
log_target=str(log_target_raw or "STDOUT"),
|
||||
db_purge_age=int(db_purge_age_raw or 86400),
|
||||
|
||||
@@ -27,12 +27,11 @@ from app.exceptions import (
|
||||
)
|
||||
from app.models.config import (
|
||||
AssignFilterRequest,
|
||||
FilterConfig,
|
||||
FilterConfigUpdate,
|
||||
FilterCreateRequest,
|
||||
FilterListResponse,
|
||||
FilterUpdateRequest,
|
||||
)
|
||||
from app.models.config_domain import DomainFilterConfig, DomainFilterList
|
||||
from app.utils import conffile_parser
|
||||
from app.utils.async_utils import run_blocking
|
||||
from app.utils.config_file_utils import (
|
||||
@@ -308,12 +307,12 @@ def _write_filter_local_sync(filter_d: Path, name: str, content: str) -> None:
|
||||
async def list_filters(
|
||||
config_dir: str,
|
||||
socket_path: str,
|
||||
) -> FilterListResponse:
|
||||
) -> DomainFilterList:
|
||||
"""Return all available filters from ``filter.d/`` with active/inactive status.
|
||||
|
||||
Scans ``{config_dir}/filter.d/`` for ``.conf`` files, merges any
|
||||
corresponding ``.local`` overrides, parses each file into a
|
||||
:class:`~app.models.config.FilterConfig`, and cross-references with the
|
||||
:class:`~app.models.config_domain.DomainFilterConfig`, and cross-references with the
|
||||
currently running jails to determine which filters are active.
|
||||
|
||||
A filter is considered *active* when its base name matches the ``filter``
|
||||
@@ -324,7 +323,7 @@ async def list_filters(
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.config.FilterListResponse` with all filters
|
||||
:class:`~app.models.config_domain.DomainFilterList` with all filters
|
||||
sorted alphabetically, active ones carrying non-empty
|
||||
``used_by_jails`` lists.
|
||||
"""
|
||||
@@ -342,12 +341,12 @@ async def list_filters(
|
||||
|
||||
filter_to_jails = _build_filter_to_jails_map(all_jails, active_names)
|
||||
|
||||
filters: list[FilterConfig] = []
|
||||
filters: list[DomainFilterConfig] = []
|
||||
for name, filename, content, has_local, source_path in raw_filters:
|
||||
cfg = conffile_parser.parse_filter_file(content, name=name, filename=filename)
|
||||
used_by = sorted(filter_to_jails.get(name, []))
|
||||
filters.append(
|
||||
FilterConfig(
|
||||
DomainFilterConfig(
|
||||
name=cfg.name,
|
||||
filename=cfg.filename,
|
||||
before=cfg.before,
|
||||
@@ -367,7 +366,7 @@ async def list_filters(
|
||||
)
|
||||
|
||||
log.info("filters_listed", total=len(filters), active=sum(1 for f in filters if f.active))
|
||||
return FilterListResponse(filters=filters, total=len(filters))
|
||||
return DomainFilterList(filters=filters, total=len(filters))
|
||||
|
||||
|
||||
async def get_filter(
|
||||
|
||||
@@ -16,7 +16,7 @@ from typing import TypeVar, cast
|
||||
import structlog
|
||||
|
||||
from app import __version__
|
||||
from app.models.config import ServiceStatusResponse
|
||||
from app.models.config_domain import DomainServiceStatus
|
||||
from app.models.server import ServerStatus
|
||||
from app.utils.constants import FAIL2BAN_SOCKET_TIMEOUT_FAST
|
||||
from app.utils.fail2ban_client import (
|
||||
@@ -69,7 +69,7 @@ async def _safe_get_typed(
|
||||
async def get_service_status(
|
||||
socket_path: str,
|
||||
probe_fn: Callable[[str], Awaitable[ServerStatus]] | None = None,
|
||||
) -> ServiceStatusResponse:
|
||||
) -> DomainServiceStatus:
|
||||
"""Return fail2ban service health status with log configuration.
|
||||
|
||||
Delegates to an injectable *probe_fn* (defaults to
|
||||
@@ -80,7 +80,7 @@ async def get_service_status(
|
||||
probe_fn: Optional probe function.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.config.ServiceStatusResponse`.
|
||||
:class:`~app.models.config_domain.DomainServiceStatus`.
|
||||
"""
|
||||
if probe_fn is None:
|
||||
raise ValueError(
|
||||
@@ -110,7 +110,7 @@ async def get_service_status(
|
||||
jail_count=server_status.active_jails,
|
||||
)
|
||||
|
||||
return ServiceStatusResponse(
|
||||
return DomainServiceStatus(
|
||||
online=server_status.online,
|
||||
version=__version__,
|
||||
jail_count=server_status.active_jails,
|
||||
|
||||
@@ -25,17 +25,16 @@ if TYPE_CHECKING:
|
||||
from app.repositories.protocols import HistoryArchiveRepository
|
||||
from app.services.protocols import Fail2BanMetadataService
|
||||
|
||||
from app.models.history import (
|
||||
HistoryBanItem,
|
||||
HistoryListResponse,
|
||||
IpDetailResponse,
|
||||
IpTimelineEvent,
|
||||
from app.models.history_domain import (
|
||||
DomainHistoryBanItem,
|
||||
DomainHistoryList,
|
||||
DomainIpDetail,
|
||||
DomainIpTimelineEvent,
|
||||
)
|
||||
from app.repositories import fail2ban_db_repo
|
||||
from app.repositories import history_archive_repo as default_history_archive_repo
|
||||
from app.utils.constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE
|
||||
from app.utils.fail2ban_db_utils import parse_data_json, ts_to_iso
|
||||
from app.utils.pagination import create_pagination_metadata
|
||||
from app.utils.time_utils import since_unix
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
@@ -190,7 +189,7 @@ async def list_history(
|
||||
db: aiosqlite.Connection | None = None,
|
||||
history_archive_repo: HistoryArchiveRepository = default_history_archive_repo,
|
||||
fail2ban_metadata_service: Fail2BanMetadataService | None = None,
|
||||
) -> HistoryListResponse:
|
||||
) -> DomainHistoryList:
|
||||
"""Return a paginated list of historical ban records with optional filters.
|
||||
|
||||
Queries the fail2ban ``bans`` table applying the requested filters and
|
||||
@@ -214,7 +213,7 @@ async def list_history(
|
||||
If not provided, uses the default singleton (lazy import).
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.history.HistoryListResponse` with paginated items
|
||||
:class:`~app.models.history_domain.DomainHistoryList` with paginated items
|
||||
and the total matching count.
|
||||
"""
|
||||
effective_page_size: int = min(page_size, MAX_PAGE_SIZE)
|
||||
@@ -237,7 +236,7 @@ async def list_history(
|
||||
page=page,
|
||||
)
|
||||
|
||||
items: list[HistoryBanItem] = []
|
||||
items: list[DomainHistoryBanItem] = []
|
||||
total: int
|
||||
|
||||
if source == "archive":
|
||||
@@ -281,7 +280,7 @@ async def list_history(
|
||||
log.warning("history_service_geo_lookup_failed", ip=ip)
|
||||
|
||||
items.append(
|
||||
HistoryBanItem(
|
||||
DomainHistoryBanItem(
|
||||
ip=ip,
|
||||
jail=jail_name,
|
||||
banned_at=banned_at,
|
||||
@@ -332,7 +331,7 @@ async def list_history(
|
||||
log.warning("history_service_geo_lookup_failed", ip=ip)
|
||||
|
||||
items.append(
|
||||
HistoryBanItem(
|
||||
DomainHistoryBanItem(
|
||||
ip=ip,
|
||||
jail=jail_name,
|
||||
banned_at=banned_at,
|
||||
@@ -346,9 +345,11 @@ async def list_history(
|
||||
)
|
||||
)
|
||||
|
||||
return HistoryListResponse(
|
||||
return DomainHistoryList(
|
||||
items=items,
|
||||
pagination=create_pagination_metadata(total, page, effective_page_size),
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=effective_page_size,
|
||||
)
|
||||
|
||||
|
||||
@@ -359,7 +360,7 @@ async def get_ip_detail(
|
||||
http_session: aiohttp.ClientSession | None = None,
|
||||
geo_enricher: GeoEnricher | None = None,
|
||||
fail2ban_metadata_service: Fail2BanMetadataService | None = None,
|
||||
) -> IpDetailResponse | None:
|
||||
) -> DomainIpDetail | None:
|
||||
"""Return the full historical record for a single IP address.
|
||||
|
||||
Fetches all ban events for *ip* from the fail2ban database, ordered
|
||||
@@ -376,7 +377,7 @@ async def get_ip_detail(
|
||||
If not provided, uses the default singleton (lazy import).
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.history.IpDetailResponse` if any records exist
|
||||
:class:`~app.models.history_domain.DomainIpDetail` if any records exist
|
||||
for *ip*, or ``None`` if the IP has no history in the database.
|
||||
"""
|
||||
if fail2ban_metadata_service is None:
|
||||
@@ -390,7 +391,7 @@ async def get_ip_detail(
|
||||
if not rows:
|
||||
return None
|
||||
|
||||
timeline: list[IpTimelineEvent] = []
|
||||
timeline: list[DomainIpTimelineEvent] = []
|
||||
total_failures: int = 0
|
||||
|
||||
for row in rows:
|
||||
@@ -400,7 +401,7 @@ async def get_ip_detail(
|
||||
matches, failures = parse_data_json(row.data)
|
||||
total_failures += failures
|
||||
timeline.append(
|
||||
IpTimelineEvent(
|
||||
DomainIpTimelineEvent(
|
||||
jail=jail_name,
|
||||
banned_at=banned_at,
|
||||
ban_count=ban_count,
|
||||
@@ -430,7 +431,7 @@ async def get_ip_detail(
|
||||
except Exception: # noqa: BLE001
|
||||
log.warning("history_service_geo_lookup_failed_detail", ip=ip)
|
||||
|
||||
return IpDetailResponse(
|
||||
return DomainIpDetail(
|
||||
ip=ip,
|
||||
total_bans=len(timeline),
|
||||
total_failures=total_failures,
|
||||
|
||||
@@ -23,15 +23,17 @@ from typing import TYPE_CHECKING, cast
|
||||
import structlog
|
||||
|
||||
from app.exceptions import JailNotFoundError, JailOperationError
|
||||
from app.models.ban import ActiveBan, JailBannedIpsResponse
|
||||
from app.models.ban_domain import DomainActiveBan
|
||||
from app.models.config import BantimeEscalation
|
||||
from app.models.geo import GeoDetail, IpLookupResponse
|
||||
from app.models.jail import (
|
||||
Jail,
|
||||
JailDetailResponse,
|
||||
JailListResponse,
|
||||
JailStatus,
|
||||
JailSummary,
|
||||
from app.models.jail_domain import (
|
||||
DomainJailBannedIps,
|
||||
DomainBantimeEscalation,
|
||||
DomainJail,
|
||||
DomainJailDetail,
|
||||
DomainJailList,
|
||||
DomainJailStatus,
|
||||
DomainJailSummary,
|
||||
)
|
||||
from app.utils.config_file_utils import start_daemon, wait_for_fail2ban
|
||||
from app.utils.constants import FAIL2BAN_SOCKET_TIMEOUT
|
||||
@@ -174,7 +176,7 @@ async def _check_backend_cmd_supported(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def list_jails(socket_path: str, state: JailServiceState) -> JailListResponse:
|
||||
async def list_jails(socket_path: str, state: JailServiceState) -> DomainJailList:
|
||||
"""Return a summary list of all active fail2ban jails.
|
||||
|
||||
Queries the daemon for the global jail list and then fetches status
|
||||
@@ -185,7 +187,7 @@ async def list_jails(socket_path: str, state: JailServiceState) -> JailListRespo
|
||||
state: The jail service state holder for capability cache.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.jail.JailListResponse` with all active jails.
|
||||
:class:`~app.models.jail_domain.DomainJailList` with all active jails.
|
||||
|
||||
Raises:
|
||||
~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket
|
||||
@@ -205,23 +207,23 @@ async def list_jails(socket_path: str, state: JailServiceState) -> JailListRespo
|
||||
log.info("jail_list_fetched", count=len(jail_names))
|
||||
|
||||
if not jail_names:
|
||||
return JailListResponse(items=[], total=0)
|
||||
return DomainJailList(items=[], total=0)
|
||||
|
||||
# 2. Fetch summary data for every jail in parallel.
|
||||
summaries: list[JailSummary] = await asyncio.gather(
|
||||
summaries: list[DomainJailSummary] = await asyncio.gather(
|
||||
*[_fetch_jail_summary(client, name, state) for name in jail_names],
|
||||
return_exceptions=False,
|
||||
)
|
||||
|
||||
return JailListResponse(items=list(summaries), total=len(summaries))
|
||||
return DomainJailList(items=list(summaries), total=len(summaries))
|
||||
|
||||
|
||||
async def _fetch_jail_summary(
|
||||
client: Fail2BanClient,
|
||||
name: str,
|
||||
state: JailServiceState,
|
||||
) -> JailSummary:
|
||||
"""Fetch and build a :class:`~app.models.jail.JailSummary` for one jail.
|
||||
) -> DomainJailSummary:
|
||||
"""Fetch and build a :class:`~app.models.jail_domain.DomainJailSummary` for one jail.
|
||||
|
||||
Sends the ``status``, ``get ... bantime``, ``findtime``, ``maxretry``,
|
||||
``backend``, and ``idle`` commands in parallel (if supported).
|
||||
@@ -236,7 +238,7 @@ async def _fetch_jail_summary(
|
||||
state: The jail service state holder for capability cache.
|
||||
|
||||
Returns:
|
||||
A :class:`~app.models.jail.JailSummary` populated from the responses.
|
||||
A :class:`~app.models.jail_domain.DomainJailSummary` populated from the responses.
|
||||
"""
|
||||
# Check whether optional backend/idle commands are supported.
|
||||
# This probe happens once per session and is cached to avoid repeated
|
||||
@@ -276,13 +278,13 @@ async def _fetch_jail_summary(
|
||||
idle_raw: object | Exception = _r[5]
|
||||
|
||||
# Parse jail status (filter + actions).
|
||||
jail_status: JailStatus | None = None
|
||||
jail_status: DomainJailStatus | None = None
|
||||
if not isinstance(status_raw, Exception):
|
||||
try:
|
||||
raw = to_dict(ok(status_raw))
|
||||
filter_stats = to_dict(raw.get("Filter") or [])
|
||||
action_stats = to_dict(raw.get("Actions") or [])
|
||||
jail_status = JailStatus(
|
||||
jail_status = DomainJailStatus(
|
||||
currently_banned=int(str(action_stats.get("Currently banned", 0) or 0)),
|
||||
total_banned=int(str(action_stats.get("Total banned", 0) or 0)),
|
||||
currently_failed=int(str(filter_stats.get("Currently failed", 0) or 0)),
|
||||
@@ -315,7 +317,7 @@ async def _fetch_jail_summary(
|
||||
except (ValueError, TypeError):
|
||||
return fallback
|
||||
|
||||
return JailSummary(
|
||||
return DomainJailSummary(
|
||||
name=name,
|
||||
enabled=True,
|
||||
running=True,
|
||||
@@ -328,7 +330,7 @@ async def _fetch_jail_summary(
|
||||
)
|
||||
|
||||
|
||||
async def get_jail(socket_path: str, name: str) -> JailDetailResponse:
|
||||
async def get_jail(socket_path: str, name: str) -> DomainJailDetail:
|
||||
"""Return full detail for a single fail2ban jail.
|
||||
|
||||
Sends multiple ``get`` and ``status`` commands in parallel to build
|
||||
@@ -339,7 +341,7 @@ async def get_jail(socket_path: str, name: str) -> JailDetailResponse:
|
||||
name: Jail name.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.jail.JailDetailResponse` with the full jail.
|
||||
:class:`~app.models.jail_domain.DomainJailDetail` with the full jail.
|
||||
|
||||
Raises:
|
||||
JailNotFoundError: If *name* is not a known jail.
|
||||
@@ -360,7 +362,7 @@ async def get_jail(socket_path: str, name: str) -> JailDetailResponse:
|
||||
filter_stats = to_dict(raw.get("Filter") or [])
|
||||
action_stats = to_dict(raw.get("Actions") or [])
|
||||
|
||||
jail_status = JailStatus(
|
||||
jail_status = DomainJailStatus(
|
||||
currently_banned=int(str(action_stats.get("Currently banned", 0) or 0)),
|
||||
total_banned=int(str(action_stats.get("Total banned", 0) or 0)),
|
||||
currently_failed=int(str(filter_stats.get("Currently failed", 0) or 0)),
|
||||
@@ -411,7 +413,7 @@ async def get_jail(socket_path: str, name: str) -> JailDetailResponse:
|
||||
)
|
||||
|
||||
bt_increment: bool = bool(bt_increment_raw)
|
||||
bantime_escalation = BantimeEscalation(
|
||||
bantime_escalation = DomainBantimeEscalation(
|
||||
increment=bt_increment,
|
||||
factor=float(str(bt_factor_raw)) if bt_factor_raw is not None else None,
|
||||
formula=str(bt_formula_raw) if bt_formula_raw else None,
|
||||
@@ -421,7 +423,7 @@ async def get_jail(socket_path: str, name: str) -> JailDetailResponse:
|
||||
overall_jails=bool(bt_overalljails_raw),
|
||||
)
|
||||
|
||||
jail = Jail(
|
||||
jail = DomainJail(
|
||||
name=name,
|
||||
enabled=True,
|
||||
running=True,
|
||||
@@ -442,7 +444,7 @@ async def get_jail(socket_path: str, name: str) -> JailDetailResponse:
|
||||
)
|
||||
|
||||
log.info("jail_detail_fetched", jail=name)
|
||||
return JailDetailResponse(jail=jail)
|
||||
return DomainJailDetail(jail=jail, ignore_list=[], ignore_self=False)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -630,7 +632,7 @@ async def restart_daemon(
|
||||
|
||||
|
||||
|
||||
def _parse_ban_entry(entry: str, jail: str) -> ActiveBan | None:
|
||||
def _parse_ban_entry(entry: str, jail: str) -> DomainActiveBan | None:
|
||||
"""Parse a ban entry from ``get <jail> banip --with-time`` output.
|
||||
|
||||
Expected format::
|
||||
@@ -642,7 +644,7 @@ def _parse_ban_entry(entry: str, jail: str) -> ActiveBan | None:
|
||||
jail: Jail name for the resulting record.
|
||||
|
||||
Returns:
|
||||
An :class:`~app.models.ban.ActiveBan` or ``None`` if parsing fails.
|
||||
A :class:`~app.models.jail_domain.DomainActiveBan` or ``None`` if parsing fails.
|
||||
"""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
@@ -655,7 +657,7 @@ def _parse_ban_entry(entry: str, jail: str) -> ActiveBan | None:
|
||||
|
||||
if len(parts) < 2:
|
||||
# Entry has no time info — return with unknown times.
|
||||
return ActiveBan(
|
||||
return DomainActiveBan(
|
||||
ip=ip,
|
||||
jail=jail,
|
||||
banned_at=None,
|
||||
@@ -693,7 +695,7 @@ def _parse_ban_entry(entry: str, jail: str) -> ActiveBan | None:
|
||||
if expires_at_str:
|
||||
expires_at_iso = _to_iso(expires_at_str)
|
||||
|
||||
return ActiveBan(
|
||||
return DomainActiveBan(
|
||||
ip=ip,
|
||||
jail=jail,
|
||||
banned_at=banned_at_iso,
|
||||
@@ -720,7 +722,7 @@ async def get_jail_banned_ips(
|
||||
geo_cache: GeoCache | None = None,
|
||||
http_session: aiohttp.ClientSession | None = None,
|
||||
app_db: aiosqlite.Connection | None = None,
|
||||
) -> JailBannedIpsResponse:
|
||||
) -> DomainJailBannedIps:
|
||||
"""Return a paginated list of currently banned IPs for a single jail.
|
||||
|
||||
Fetches the full ban list from the fail2ban socket, applies an optional
|
||||
@@ -738,7 +740,7 @@ async def get_jail_banned_ips(
|
||||
app_db: Optional BanGUI application database for persistent geo cache.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.ban.JailBannedIpsResponse` with the paginated bans.
|
||||
:class:`~app.models.jail_domain.DomainJailBannedIps` with the paginated bans.
|
||||
|
||||
Raises:
|
||||
JailNotFoundError: If *jail_name* is not a known active jail.
|
||||
@@ -767,7 +769,7 @@ async def get_jail_banned_ips(
|
||||
ban_list: list[str] = cast("list[str]", raw_result) or []
|
||||
|
||||
# Parse all entries.
|
||||
all_bans: list[ActiveBan] = []
|
||||
all_bans: list[DomainActiveBan] = []
|
||||
for entry in ban_list:
|
||||
ban = _parse_ban_entry(str(entry), jail_name)
|
||||
if ban is not None:
|
||||
@@ -792,11 +794,20 @@ async def get_jail_banned_ips(
|
||||
except Exception: # noqa: BLE001
|
||||
log.warning("jail_banned_ips_geo_failed", jail=jail_name)
|
||||
geo_map = {}
|
||||
enriched_page: list[ActiveBan] = []
|
||||
enriched_page: list[DomainActiveBan] = []
|
||||
for ban in page_bans:
|
||||
geo = geo_map.get(ban.ip)
|
||||
if geo is not None:
|
||||
enriched_page.append(ban.model_copy(update={"country": geo.country_code}))
|
||||
enriched_page.append(
|
||||
DomainActiveBan(
|
||||
ip=ban.ip,
|
||||
jail=ban.jail,
|
||||
banned_at=ban.banned_at,
|
||||
expires_at=ban.expires_at,
|
||||
ban_count=ban.ban_count,
|
||||
country=geo.country_code,
|
||||
)
|
||||
)
|
||||
else:
|
||||
enriched_page.append(ban)
|
||||
page_bans = enriched_page
|
||||
@@ -808,20 +819,22 @@ async def get_jail_banned_ips(
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return JailBannedIpsResponse(
|
||||
return DomainJailBannedIps(
|
||||
items=page_bans,
|
||||
pagination=create_pagination_metadata(total, page, page_size),
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
async def _enrich_bans(
|
||||
bans: list[ActiveBan],
|
||||
bans: list[DomainActiveBan],
|
||||
geo_enricher: GeoEnricher,
|
||||
) -> list[ActiveBan]:
|
||||
) -> list[DomainActiveBan]:
|
||||
"""Enrich ban records with geo data asynchronously.
|
||||
|
||||
Args:
|
||||
bans: The list of :class:`~app.models.ban.ActiveBan` records to enrich.
|
||||
bans: The list of :class:`~app.models.jail_domain.DomainActiveBan` records to enrich.
|
||||
geo_enricher: Async callable ``(ip) → GeoInfo | None``.
|
||||
|
||||
Returns:
|
||||
@@ -831,11 +844,20 @@ async def _enrich_bans(
|
||||
*[cast("Awaitable[object]", geo_enricher(ban.ip)) for ban in bans],
|
||||
return_exceptions=True,
|
||||
)
|
||||
enriched: list[ActiveBan] = []
|
||||
enriched: list[DomainActiveBan] = []
|
||||
for ban, geo in zip(bans, geo_results, strict=False):
|
||||
if geo is not None and not isinstance(geo, Exception):
|
||||
geo_info = cast("GeoInfo", geo)
|
||||
enriched.append(ban.model_copy(update={"country": geo_info.country_code}))
|
||||
enriched.append(
|
||||
DomainActiveBan(
|
||||
ip=ban.ip,
|
||||
jail=ban.jail,
|
||||
banned_at=ban.banned_at,
|
||||
expires_at=ban.expires_at,
|
||||
ban_count=ban.ban_count,
|
||||
country=geo_info.country_code,
|
||||
)
|
||||
)
|
||||
else:
|
||||
enriched.append(ban)
|
||||
return enriched
|
||||
|
||||
@@ -26,12 +26,14 @@ if TYPE_CHECKING:
|
||||
ScheduleConfig,
|
||||
ScheduleInfo,
|
||||
)
|
||||
from app.models.config_domain import (
|
||||
DomainGlobalConfig,
|
||||
DomainJailConfig,
|
||||
DomainJailConfigList,
|
||||
)
|
||||
from app.models.config import (
|
||||
AddLogPathRequest,
|
||||
GlobalConfigResponse,
|
||||
GlobalConfigUpdate,
|
||||
JailConfigListResponse,
|
||||
JailConfigResponse,
|
||||
JailConfigUpdate,
|
||||
LogPreviewRequest,
|
||||
LogPreviewResponse,
|
||||
@@ -40,9 +42,9 @@ if TYPE_CHECKING:
|
||||
RegexTestResponse,
|
||||
)
|
||||
from app.models.geo import GeoEnricher, GeoInfo
|
||||
from app.models.history import HistoryListResponse, IpDetailResponse
|
||||
from app.models.jail import JailDetailResponse, JailListResponse
|
||||
from app.models.server import ServerSettingsResponse, ServerSettingsUpdate, ServerStatus
|
||||
from app.models.history_domain import DomainHistoryList, DomainIpDetail
|
||||
from app.models.jail_domain import DomainJailBannedIps, DomainJailDetail, DomainJailList
|
||||
from app.models.server_domain import DomainServerSettingsResult
|
||||
from app.services.geo_cache import GeoCache
|
||||
|
||||
|
||||
@@ -81,10 +83,10 @@ class AuthService(Protocol):
|
||||
class JailService(Protocol):
|
||||
"""Protocol for jail management service operations."""
|
||||
|
||||
async def list_jails(self, socket_path: str) -> JailListResponse:
|
||||
async def list_jails(self, socket_path: str) -> DomainJailList:
|
||||
...
|
||||
|
||||
async def get_jail(self, socket_path: str, name: str) -> JailDetailResponse:
|
||||
async def get_jail(self, socket_path: str, name: str) -> DomainJailDetail:
|
||||
...
|
||||
|
||||
async def reload_all(self, socket_path: str) -> None:
|
||||
@@ -125,7 +127,7 @@ class JailService(Protocol):
|
||||
geo_batch_lookup: object,
|
||||
http_session: object,
|
||||
app_db: aiosqlite.Connection,
|
||||
) -> JailBannedIpsResponse:
|
||||
) -> DomainJailBannedIps:
|
||||
...
|
||||
|
||||
async def lookup_ip(
|
||||
@@ -233,10 +235,10 @@ class BlocklistService(Protocol):
|
||||
|
||||
@runtime_checkable
|
||||
class ConfigService(Protocol):
|
||||
async def get_jail_config(self, socket_path: str, name: str) -> JailConfigResponse:
|
||||
async def get_jail_config(self, socket_path: str, name: str) -> DomainJailConfig:
|
||||
...
|
||||
|
||||
async def list_jail_configs(self, socket_path: str) -> JailConfigListResponse:
|
||||
async def list_jail_configs(self, socket_path: str) -> DomainJailConfigList:
|
||||
...
|
||||
|
||||
async def update_jail_config(
|
||||
@@ -247,7 +249,7 @@ class ConfigService(Protocol):
|
||||
) -> None:
|
||||
...
|
||||
|
||||
async def get_global_config(self, socket_path: str) -> GlobalConfigResponse:
|
||||
async def get_global_config(self, socket_path: str) -> DomainGlobalConfig:
|
||||
...
|
||||
|
||||
async def update_global_config(
|
||||
@@ -305,7 +307,7 @@ class HistoryService(Protocol):
|
||||
http_session: aiohttp.ClientSession | None = None,
|
||||
geo_enricher: GeoEnricher | None = None,
|
||||
db: aiosqlite.Connection | None = None,
|
||||
) -> HistoryListResponse:
|
||||
) -> DomainHistoryList:
|
||||
...
|
||||
|
||||
async def get_ip_detail(
|
||||
@@ -315,7 +317,7 @@ class HistoryService(Protocol):
|
||||
*,
|
||||
http_session: aiohttp.ClientSession | None = None,
|
||||
geo_enricher: GeoEnricher | None = None,
|
||||
) -> IpDetailResponse | None:
|
||||
) -> DomainIpDetail | None:
|
||||
...
|
||||
|
||||
|
||||
@@ -394,7 +396,7 @@ class HealthProbe(Protocol):
|
||||
|
||||
@runtime_checkable
|
||||
class ServerService(Protocol):
|
||||
async def get_settings(self, socket_path: str) -> ServerSettingsResponse:
|
||||
async def get_settings(self, socket_path: str) -> DomainServerSettingsResult:
|
||||
...
|
||||
|
||||
async def update_settings(
|
||||
|
||||
@@ -15,7 +15,8 @@ from typing import cast
|
||||
import structlog
|
||||
|
||||
from app.exceptions import ServerOperationError
|
||||
from app.models.server import ServerSettings, ServerSettingsResponse, ServerSettingsUpdate
|
||||
from app.models.server_domain import DomainServerSettings, DomainServerSettingsResult
|
||||
from app.models.server import ServerSettingsUpdate
|
||||
from app.utils.constants import FAIL2BAN_SOCKET_TIMEOUT
|
||||
from app.utils.fail2ban_client import Fail2BanClient, Fail2BanCommand, Fail2BanResponse
|
||||
from app.utils.fail2ban_response import ok
|
||||
@@ -87,7 +88,7 @@ async def _safe_get(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def get_settings(socket_path: str) -> ServerSettingsResponse:
|
||||
async def get_settings(socket_path: str) -> DomainServerSettingsResult:
|
||||
"""Return current fail2ban server-level settings.
|
||||
|
||||
Fetches log level, log target, syslog socket, database file path, purge
|
||||
@@ -97,7 +98,7 @@ async def get_settings(socket_path: str) -> ServerSettingsResponse:
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.server.ServerSettingsResponse`.
|
||||
:class:`~app.models.server_domain.DomainServerSettingsResult`.
|
||||
|
||||
Raises:
|
||||
~app.utils.fail2ban_client.Fail2BanConnectionError: Socket unreachable.
|
||||
@@ -129,7 +130,7 @@ async def get_settings(socket_path: str) -> ServerSettingsResponse:
|
||||
db_purge_age = _to_int(db_purge_age_raw, 86400)
|
||||
db_max_matches = _to_int(db_max_matches_raw, 10)
|
||||
|
||||
settings = ServerSettings(
|
||||
settings = DomainServerSettings(
|
||||
log_level=log_level,
|
||||
log_target=log_target,
|
||||
syslog_socket=syslog_socket,
|
||||
@@ -143,7 +144,7 @@ async def get_settings(socket_path: str) -> ServerSettingsResponse:
|
||||
}
|
||||
|
||||
log.info("server_settings_fetched", db_purge_age=db_purge_age, warnings=warnings)
|
||||
return ServerSettingsResponse(settings=settings, warnings=warnings)
|
||||
return DomainServerSettingsResult(settings=settings, warnings=warnings)
|
||||
|
||||
|
||||
async def update_settings(socket_path: str, update: ServerSettingsUpdate) -> None:
|
||||
|
||||
@@ -5,10 +5,10 @@ BanGUI instance runs the background scheduler, even in container orchestration
|
||||
environments where multiple instances might start simultaneously.
|
||||
|
||||
The lock uses atomic database operations to prevent race conditions:
|
||||
- Lock acquisition is atomic: INSERT fails if the singleton row already exists
|
||||
- Lock release is atomic: DELETE with PID check ensures only the owner releases
|
||||
- Stale lock detection uses heartbeat timestamps: a lock older than TTL is
|
||||
considered abandoned and eligible for cleanup on the next startup
|
||||
- Lock acquisition is atomic: INSERT ... ON CONFLICT with BEGIN IMMEDIATE transaction
|
||||
- Lock stealing: If heartbeat exceeds timeout, lock can be taken by another instance
|
||||
- Heartbeat update is conditional: UPDATE only if we still hold the lock
|
||||
- Stale lock detection uses heartbeat timestamps with configurable timeout
|
||||
|
||||
This approach is more reliable than filesystem-based locking in containerized
|
||||
environments because:
|
||||
@@ -23,12 +23,13 @@ The lock record stores:
|
||||
- hostname: Container/host name for debugging
|
||||
- created_at: When the lock was first acquired
|
||||
- heartbeat_at: When the lock was last confirmed alive (updated periodically)
|
||||
- heartbeat_timeout: Seconds after which lock is considered stale (default 300)
|
||||
|
||||
On startup:
|
||||
1. Cleanup any stale locks (where heartbeat_at > TTL)
|
||||
2. Try to insert the lock for this instance
|
||||
1. Cleanup any stale locks (where heartbeat_at + heartbeat_timeout < now)
|
||||
2. Try to insert the lock for this instance using ON CONFLICT to steal stale locks
|
||||
3. If INSERT succeeds, lock is acquired
|
||||
4. If INSERT fails (IntegrityError), another instance holds the lock
|
||||
4. If INSERT fails (IntegrityError), another instance holds a valid lock
|
||||
|
||||
On running (periodic):
|
||||
- Update heartbeat_at to keep the lock alive and prevent false positives
|
||||
@@ -62,6 +63,11 @@ SCHEDULER_LOCK_TTL_SECONDS: int = 60
|
||||
# expiring, providing robust protection against temporary delays.
|
||||
SCHEDULER_LOCK_HEARTBEAT_INTERVAL_SECONDS: int = 5
|
||||
|
||||
# Default heartbeat timeout: how long to wait before considering a lock stale
|
||||
# when another instance tries to acquire it. This is the max time a lock holder
|
||||
# can go without updating heartbeat before someone else can steal it.
|
||||
SCHEDULER_LOCK_HEARTBEAT_TIMEOUT_SECONDS: int = 300
|
||||
|
||||
|
||||
async def init_scheduler_lock_table(db: aiosqlite.Connection) -> None:
|
||||
"""Create the scheduler_lock table if it doesn't exist.
|
||||
@@ -79,23 +85,36 @@ async def init_scheduler_lock_table(db: aiosqlite.Connection) -> None:
|
||||
pid INTEGER NOT NULL,
|
||||
hostname TEXT NOT NULL,
|
||||
created_at REAL NOT NULL,
|
||||
heartbeat_at REAL NOT NULL
|
||||
heartbeat_at REAL NOT NULL,
|
||||
heartbeat_timeout REAL NOT NULL DEFAULT ?
|
||||
);
|
||||
"""
|
||||
""",
|
||||
(SCHEDULER_LOCK_HEARTBEAT_TIMEOUT_SECONDS,),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def is_lock_stale(heartbeat_at: float, timeout: float, now: float) -> bool:
|
||||
"""Check if a lock is considered stale based on heartbeat timestamp.
|
||||
|
||||
Args:
|
||||
heartbeat_at: Last heartbeat timestamp from the lock record
|
||||
timeout: Heartbeat timeout in seconds
|
||||
now: Current timestamp
|
||||
|
||||
Returns:
|
||||
True if (now - heartbeat_at) > timeout, indicating stale lock
|
||||
"""
|
||||
return (now - heartbeat_at) > timeout
|
||||
|
||||
|
||||
async def acquire_scheduler_lock(db: aiosqlite.Connection) -> bool:
|
||||
"""Try to acquire the scheduler lock.
|
||||
|
||||
This function performs two operations:
|
||||
1. Clean up any stale locks (where heartbeat_at + TTL < now)
|
||||
2. Try to insert a lock record for this instance
|
||||
|
||||
If another instance already holds a valid lock, the INSERT will fail and
|
||||
this function returns False. The caller should reject startup with a clear
|
||||
error message.
|
||||
Uses atomic INSERT ... ON CONFLICT to acquire or steal the lock:
|
||||
- If no lock exists: INSERT succeeds, lock acquired
|
||||
- If stale lock (heartbeat timeout exceeded): INSERT succeeds, lock stolen
|
||||
- If valid lock held by another process: INSERT fails with IntegrityError
|
||||
|
||||
Args:
|
||||
db: The SQLite database connection.
|
||||
@@ -104,30 +123,51 @@ async def acquire_scheduler_lock(db: aiosqlite.Connection) -> bool:
|
||||
True if the lock was successfully acquired, False if held by another instance.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If database operations fail for reasons other than the lock
|
||||
being held (e.g., database is corrupted or inaccessible).
|
||||
RuntimeError: If database operations fail.
|
||||
"""
|
||||
now = time.time()
|
||||
pid = os.getpid()
|
||||
hostname = socket.gethostname()
|
||||
|
||||
try:
|
||||
# Clean up stale locks first
|
||||
await db.execute(
|
||||
"""
|
||||
DELETE FROM scheduler_lock
|
||||
WHERE (? - heartbeat_at) > ?
|
||||
""",
|
||||
(now, SCHEDULER_LOCK_TTL_SECONDS),
|
||||
)
|
||||
await db.execute("BEGIN IMMEDIATE")
|
||||
|
||||
# Try to acquire the lock (atomic: INSERT fails if row exists)
|
||||
# Clean up stale locks first (heartbeat timeout exceeded)
|
||||
cursor = await db.execute(
|
||||
"SELECT pid, heartbeat_at, heartbeat_timeout FROM scheduler_lock WHERE id = 1"
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
|
||||
if row is not None:
|
||||
lock_pid, lock_heartbeat, lock_timeout = row
|
||||
if lock_pid == pid:
|
||||
# Same process re-acquiring - allowed (refresh)
|
||||
pass
|
||||
elif (now - lock_heartbeat) <= lock_timeout:
|
||||
# Another process holds a valid lock - cannot acquire
|
||||
await db.rollback()
|
||||
log.warning(
|
||||
"scheduler_lock_held_by_other_instance",
|
||||
our_pid=pid,
|
||||
lock_pid=lock_pid,
|
||||
lock_heartbeat_age_seconds=now - lock_heartbeat,
|
||||
)
|
||||
return False
|
||||
# Stale lock (held by another process that crashed) - will be overwritten below
|
||||
|
||||
# Try to insert or update the lock
|
||||
await db.execute(
|
||||
"""
|
||||
INSERT INTO scheduler_lock (id, pid, hostname, created_at, heartbeat_at)
|
||||
VALUES (1, ?, ?, ?, ?)
|
||||
INSERT INTO scheduler_lock (id, pid, hostname, created_at, heartbeat_at, heartbeat_timeout)
|
||||
VALUES (1, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
pid = excluded.pid,
|
||||
hostname = excluded.hostname,
|
||||
created_at = excluded.created_at,
|
||||
heartbeat_at = excluded.heartbeat_at,
|
||||
heartbeat_timeout = excluded.heartbeat_timeout
|
||||
""",
|
||||
(pid, hostname, now, now),
|
||||
(pid, hostname, now, now, SCHEDULER_LOCK_HEARTBEAT_TIMEOUT_SECONDS),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
@@ -140,34 +180,30 @@ async def acquire_scheduler_lock(db: aiosqlite.Connection) -> bool:
|
||||
|
||||
except aiosqlite.IntegrityError:
|
||||
# Lock is already held by another instance (INSERT failed due to UNIQUE constraint)
|
||||
# Log details about who holds the lock to help with debugging
|
||||
# and the ON CONFLICT WHERE condition was not met (lock is fresh, not stale)
|
||||
try:
|
||||
cursor = await db.execute(
|
||||
"SELECT pid, hostname, created_at, heartbeat_at FROM scheduler_lock WHERE id = 1"
|
||||
"SELECT pid, hostname, heartbeat_at, heartbeat_timeout FROM scheduler_lock WHERE id = 1"
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
if row:
|
||||
lock_pid, lock_hostname, lock_created, lock_heartbeat = row
|
||||
age_seconds = now - lock_created
|
||||
lock_pid, lock_hostname, lock_heartbeat, lock_timeout = row
|
||||
heartbeat_age = now - lock_heartbeat
|
||||
log.warning(
|
||||
"scheduler_lock_held_by_other_instance",
|
||||
our_pid=pid,
|
||||
lock_pid=lock_pid,
|
||||
lock_hostname=lock_hostname,
|
||||
lock_age_seconds=age_seconds,
|
||||
heartbeat_age_seconds=heartbeat_age,
|
||||
heartbeat_timeout=lock_timeout,
|
||||
)
|
||||
except Exception as e:
|
||||
log.warning("scheduler_lock_held_but_could_not_read_holder", error=str(e))
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
# Unexpected database error (not an IntegrityError)
|
||||
raise RuntimeError(
|
||||
f"Failed to acquire scheduler lock due to database error: {e}\n"
|
||||
"Check that the database is accessible and not corrupted."
|
||||
f"Failed to acquire scheduler lock due to database error: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
@@ -213,15 +249,18 @@ async def update_scheduler_lock_heartbeat(db: aiosqlite.Connection) -> bool:
|
||||
the lock from being considered stale. It only succeeds if this process
|
||||
still holds the lock.
|
||||
|
||||
Error handling: If the heartbeat update fails due to a database error, this
|
||||
function returns False (indicating lock loss) rather than raising an exception.
|
||||
This prevents the scheduler from crashing due to transient database issues,
|
||||
allowing the running application to continue and potentially recover the lock
|
||||
if it still holds it.
|
||||
|
||||
Args:
|
||||
db: The SQLite database connection.
|
||||
|
||||
Returns:
|
||||
True if the heartbeat was updated (we still hold the lock), False if
|
||||
we no longer hold the lock (another instance has taken over).
|
||||
|
||||
Raises:
|
||||
RuntimeError: If database operations fail.
|
||||
we no longer hold the lock or a database error occurred.
|
||||
"""
|
||||
now = time.time()
|
||||
pid = os.getpid()
|
||||
@@ -238,14 +277,22 @@ async def update_scheduler_lock_heartbeat(db: aiosqlite.Connection) -> bool:
|
||||
log.warning(
|
||||
"scheduler_lock_heartbeat_lost",
|
||||
our_pid=pid,
|
||||
message="Heartbeat failed; we no longer hold the lock.",
|
||||
message="Heartbeat update failed; we no longer hold the lock.",
|
||||
)
|
||||
return False
|
||||
|
||||
log.debug("scheduler_lock_heartbeat_updated", pid=pid)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to update scheduler lock heartbeat: {e}") from e
|
||||
# Don't crash the scheduler if heartbeat update fails - log and return False
|
||||
log.error(
|
||||
"scheduler_lock_heartbeat_error",
|
||||
our_pid=pid,
|
||||
error=str(e),
|
||||
message="Heartbeat update failed due to database error. Will retry on next interval.",
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
async def get_scheduler_lock_info(db: aiosqlite.Connection) -> dict[str, Any] | None:
|
||||
@@ -258,23 +305,84 @@ async def get_scheduler_lock_info(db: aiosqlite.Connection) -> dict[str, Any] |
|
||||
db: The SQLite database connection.
|
||||
|
||||
Returns:
|
||||
A dict with keys: pid, hostname, created_at, heartbeat_at, or None
|
||||
if no lock exists.
|
||||
A dict with keys: pid, hostname, created_at, heartbeat_at, heartbeat_timeout,
|
||||
or None if no lock exists.
|
||||
"""
|
||||
try:
|
||||
cursor = await db.execute(
|
||||
"SELECT pid, hostname, created_at, heartbeat_at FROM scheduler_lock WHERE id = 1"
|
||||
"SELECT pid, hostname, created_at, heartbeat_at, heartbeat_timeout FROM scheduler_lock WHERE id = 1"
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
if row:
|
||||
pid, hostname, created_at, heartbeat_at = row
|
||||
pid, hostname, created_at, heartbeat_at, heartbeat_timeout = row
|
||||
return {
|
||||
"pid": pid,
|
||||
"hostname": hostname,
|
||||
"created_at": created_at,
|
||||
"heartbeat_at": heartbeat_at,
|
||||
"heartbeat_timeout": heartbeat_timeout,
|
||||
}
|
||||
return None
|
||||
except Exception as e:
|
||||
log.warning("scheduler_lock_info_query_failed", error=str(e))
|
||||
return None
|
||||
|
||||
|
||||
async def get_lock_health(db: aiosqlite.Connection) -> dict[str, Any]:
|
||||
"""Get health status of the scheduler lock for monitoring.
|
||||
|
||||
Returns a dict with lock status, age, and whether it's stale. Used for
|
||||
observability endpoints and monitoring dashboards.
|
||||
|
||||
Args:
|
||||
db: The SQLite database connection.
|
||||
|
||||
Returns:
|
||||
A dict with keys:
|
||||
- has_lock: bool indicating if a lock exists
|
||||
- is_stale: bool indicating if lock is stale (heartbeat timeout exceeded)
|
||||
- pid: int or None
|
||||
- hostname: str or None
|
||||
- heartbeat_age_seconds: float or None (time since last heartbeat)
|
||||
- created_at: float or None
|
||||
- heartbeat_timeout: float or None
|
||||
- stale_reason: str or None (why lock is considered stale)
|
||||
"""
|
||||
info = await get_scheduler_lock_info(db)
|
||||
now = time.time()
|
||||
|
||||
if info is None:
|
||||
return {
|
||||
"has_lock": False,
|
||||
"is_stale": False,
|
||||
"pid": None,
|
||||
"hostname": None,
|
||||
"heartbeat_age_seconds": None,
|
||||
"created_at": None,
|
||||
"heartbeat_timeout": None,
|
||||
"stale_reason": None,
|
||||
}
|
||||
|
||||
heartbeat_age = now - info["heartbeat_at"]
|
||||
is_stale_result = await is_lock_stale(
|
||||
info["heartbeat_at"],
|
||||
info["heartbeat_timeout"],
|
||||
now,
|
||||
)
|
||||
|
||||
stale_reason: str | None = None
|
||||
if is_stale_result:
|
||||
stale_reason = (
|
||||
f"heartbeat_age ({heartbeat_age:.1f}s) > timeout ({info['heartbeat_timeout']:.1f}s)"
|
||||
)
|
||||
|
||||
return {
|
||||
"has_lock": True,
|
||||
"is_stale": is_stale_result,
|
||||
"pid": info["pid"],
|
||||
"hostname": info["hostname"],
|
||||
"heartbeat_age_seconds": heartbeat_age,
|
||||
"created_at": info["created_at"],
|
||||
"heartbeat_timeout": info["heartbeat_timeout"],
|
||||
"stale_reason": stale_reason,
|
||||
}
|
||||
|
||||
@@ -16,9 +16,12 @@ import pytest
|
||||
|
||||
from app.utils.scheduler_lock import (
|
||||
SCHEDULER_LOCK_HEARTBEAT_INTERVAL_SECONDS,
|
||||
SCHEDULER_LOCK_HEARTBEAT_TIMEOUT_SECONDS,
|
||||
SCHEDULER_LOCK_TTL_SECONDS,
|
||||
acquire_scheduler_lock,
|
||||
get_lock_health,
|
||||
get_scheduler_lock_info,
|
||||
is_lock_stale,
|
||||
release_scheduler_lock,
|
||||
update_scheduler_lock_heartbeat,
|
||||
)
|
||||
@@ -30,13 +33,14 @@ async def lock_db(tmp_path: Any) -> aiosqlite.Connection:
|
||||
db_path = tmp_path / "test.db"
|
||||
db = await aiosqlite.connect(str(db_path))
|
||||
await db.execute(
|
||||
"""
|
||||
f"""
|
||||
CREATE TABLE scheduler_lock (
|
||||
id INTEGER PRIMARY KEY CHECK (id = 1),
|
||||
pid INTEGER NOT NULL,
|
||||
hostname TEXT NOT NULL,
|
||||
created_at REAL NOT NULL,
|
||||
heartbeat_at REAL NOT NULL
|
||||
heartbeat_at REAL NOT NULL,
|
||||
heartbeat_timeout REAL NOT NULL DEFAULT {SCHEDULER_LOCK_HEARTBEAT_TIMEOUT_SECONDS}
|
||||
);
|
||||
"""
|
||||
)
|
||||
@@ -61,14 +65,54 @@ async def test_acquire_scheduler_lock_success(lock_db: aiosqlite.Connection) ->
|
||||
async def test_acquire_scheduler_lock_fails_when_held(
|
||||
lock_db: aiosqlite.Connection,
|
||||
) -> None:
|
||||
"""Test that lock acquisition fails if already held."""
|
||||
"""Test that lock acquisition fails if already held by another process.
|
||||
|
||||
Note: Same-PID re-acquire is allowed (refresh). Use separate connection
|
||||
with different PID to test rejection.
|
||||
"""
|
||||
# First instance acquires the lock
|
||||
result1 = await acquire_scheduler_lock(lock_db)
|
||||
assert result1 is True
|
||||
|
||||
# Second instance tries to acquire, should fail
|
||||
result2 = await acquire_scheduler_lock(lock_db)
|
||||
assert result2 is False
|
||||
# Second instance (same process, same PID) - re-acquire succeeds (refresh)
|
||||
result_same_pid = await acquire_scheduler_lock(lock_db)
|
||||
assert result_same_pid is True
|
||||
|
||||
# To test rejection, create a separate database with a conflicting lock
|
||||
# Simulate a different process holding the lock by inserting directly
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
# Create a new in-memory database with pre-existing lock from "another process"
|
||||
db_other = await aiosqlite.connect(":memory:")
|
||||
await db_other.execute(
|
||||
f"""
|
||||
CREATE TABLE scheduler_lock (
|
||||
id INTEGER PRIMARY KEY CHECK (id = 1),
|
||||
pid INTEGER NOT NULL,
|
||||
hostname TEXT NOT NULL,
|
||||
created_at REAL NOT NULL,
|
||||
heartbeat_at REAL NOT NULL,
|
||||
heartbeat_timeout REAL NOT NULL DEFAULT {SCHEDULER_LOCK_HEARTBEAT_TIMEOUT_SECONDS}
|
||||
)
|
||||
"""
|
||||
)
|
||||
# Insert lock with PID=-1 (simulating another active process with recent heartbeat)
|
||||
now = time.time()
|
||||
await db_other.execute(
|
||||
f"""
|
||||
INSERT INTO scheduler_lock (id, pid, hostname, created_at, heartbeat_at, heartbeat_timeout)
|
||||
VALUES (1, -1, 'other-host', ?, ?, {SCHEDULER_LOCK_HEARTBEAT_TIMEOUT_SECONDS})
|
||||
""",
|
||||
(now, now),
|
||||
)
|
||||
await db_other.commit()
|
||||
|
||||
# Now test that acquire fails when lock is held by another process
|
||||
result_other = await acquire_scheduler_lock(db_other)
|
||||
assert result_other is False
|
||||
|
||||
await db_other.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -78,13 +122,13 @@ async def test_acquire_scheduler_lock_cleans_stale_locks(
|
||||
"""Test that stale locks are automatically cleaned up."""
|
||||
# Insert a stale lock manually (old heartbeat)
|
||||
now = time.time()
|
||||
stale_heartbeat = now - SCHEDULER_LOCK_TTL_SECONDS - 10
|
||||
stale_heartbeat = now - SCHEDULER_LOCK_HEARTBEAT_TIMEOUT_SECONDS - 10
|
||||
await lock_db.execute(
|
||||
"""
|
||||
INSERT INTO scheduler_lock (id, pid, hostname, created_at, heartbeat_at)
|
||||
VALUES (1, 9999, 'stale-host', ?, ?)
|
||||
INSERT INTO scheduler_lock (id, pid, hostname, created_at, heartbeat_at, heartbeat_timeout)
|
||||
VALUES (1, 9999, 'stale-host', ?, ?, ?)
|
||||
""",
|
||||
(now - 100, stale_heartbeat),
|
||||
(now - 100, stale_heartbeat, SCHEDULER_LOCK_HEARTBEAT_TIMEOUT_SECONDS),
|
||||
)
|
||||
await lock_db.commit()
|
||||
|
||||
@@ -103,6 +147,39 @@ async def test_acquire_scheduler_lock_cleans_stale_locks(
|
||||
assert hostname is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_scheduler_lock_cleans_stale_locks_with_new_schema(
|
||||
lock_db: aiosqlite.Connection,
|
||||
) -> None:
|
||||
"""Test that stale locks are automatically cleaned up with new timeout field."""
|
||||
# Insert a stale lock manually (heartbeat past timeout)
|
||||
now = time.time()
|
||||
stale_heartbeat = now - SCHEDULER_LOCK_HEARTBEAT_TIMEOUT_SECONDS - 10
|
||||
await lock_db.execute(
|
||||
"""
|
||||
INSERT INTO scheduler_lock (id, pid, hostname, created_at, heartbeat_at, heartbeat_timeout)
|
||||
VALUES (1, 9999, 'stale-host', ?, ?, ?)
|
||||
""",
|
||||
(now - 100, stale_heartbeat, SCHEDULER_LOCK_HEARTBEAT_TIMEOUT_SECONDS),
|
||||
)
|
||||
await lock_db.commit()
|
||||
|
||||
# New instance should steal the stale lock and acquire
|
||||
result = await acquire_scheduler_lock(lock_db)
|
||||
assert result is True
|
||||
|
||||
# Verify the old lock is gone and new one is in place
|
||||
cursor = await lock_db.execute(
|
||||
"SELECT pid, hostname, heartbeat_timeout FROM scheduler_lock WHERE id = 1"
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
assert row is not None
|
||||
pid, hostname, timeout = row
|
||||
assert pid == os.getpid()
|
||||
assert hostname is not None
|
||||
assert timeout == SCHEDULER_LOCK_HEARTBEAT_TIMEOUT_SECONDS
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_release_scheduler_lock_success(
|
||||
lock_db: aiosqlite.Connection,
|
||||
@@ -246,50 +323,210 @@ async def test_scheduler_lock_heartbeat_interval_sanity(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scheduler_lock_race_condition_prevention(
|
||||
async def test_scheduler_lock_two_instances_cannot_both_hold(
|
||||
tmp_path: Any,
|
||||
) -> None:
|
||||
"""Test that two different processes cannot both hold the lock.
|
||||
|
||||
This simulates two instances trying to acquire the lock. The second
|
||||
instance should fail to acquire while the first holds a valid lock.
|
||||
|
||||
Note: Same-PID re-acquire is allowed (refresh). To test rejection,
|
||||
we insert a lock with a different PID before testing.
|
||||
"""
|
||||
db_path = tmp_path / "test.db"
|
||||
|
||||
# Instance A connects and acquires the lock
|
||||
db_a = await aiosqlite.connect(str(db_path))
|
||||
await db_a.execute(
|
||||
f"""
|
||||
CREATE TABLE scheduler_lock (
|
||||
id INTEGER PRIMARY KEY CHECK (id = 1),
|
||||
pid INTEGER NOT NULL,
|
||||
hostname TEXT NOT NULL,
|
||||
created_at REAL NOT NULL,
|
||||
heartbeat_at REAL NOT NULL,
|
||||
heartbeat_timeout REAL NOT NULL DEFAULT {SCHEDULER_LOCK_HEARTBEAT_TIMEOUT_SECONDS}
|
||||
);
|
||||
"""
|
||||
)
|
||||
await db_a.commit()
|
||||
|
||||
result_a = await acquire_scheduler_lock(db_a)
|
||||
assert result_a is True
|
||||
|
||||
# Same-PID re-acquire succeeds (refresh behavior)
|
||||
result_a_refresh = await acquire_scheduler_lock(db_a)
|
||||
assert result_a_refresh is True
|
||||
|
||||
# Simulate another process holding the lock by inserting with a different PID
|
||||
# (this is the "conflicting" lock we want to reject)
|
||||
await db_a.execute(
|
||||
f"""
|
||||
INSERT OR REPLACE INTO scheduler_lock (id, pid, hostname, created_at, heartbeat_at, heartbeat_timeout)
|
||||
VALUES (1, -999, 'other-host', {time.time()}, {time.time()}, {SCHEDULER_LOCK_HEARTBEAT_TIMEOUT_SECONDS})
|
||||
"""
|
||||
)
|
||||
await db_a.commit()
|
||||
|
||||
# Instance B (different connection, same PID in test) tries to acquire
|
||||
# Should fail because different PID (-999) holds the lock
|
||||
db_b = await aiosqlite.connect(str(db_path))
|
||||
result_b = await acquire_scheduler_lock(db_b)
|
||||
assert result_b is False
|
||||
|
||||
# Clear the conflicting lock directly (simulating other process dying)
|
||||
await db_a.execute("DELETE FROM scheduler_lock")
|
||||
await db_a.commit()
|
||||
|
||||
# Now Instance B can acquire
|
||||
result_b3 = await acquire_scheduler_lock(db_b)
|
||||
assert result_b3 is True
|
||||
|
||||
await db_a.close()
|
||||
await db_b.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acquire_scheduler_lock_steals_stale_lock(
|
||||
lock_db: aiosqlite.Connection,
|
||||
) -> None:
|
||||
"""Test that the lock prevents concurrent execution (race condition).
|
||||
"""Test that a stale lock can be stolen by another instance.
|
||||
|
||||
Scenario: Process A acquires the lock and starts working. Process B starts
|
||||
up and tries to acquire the lock. Even if Process A's heartbeat fails
|
||||
momentarily, Process B should not acquire the lock immediately.
|
||||
Scenario: Process A acquires the lock but crashes (never releases it).
|
||||
Process B starts up and sees the lock has stale heartbeat (past timeout).
|
||||
Process B should be able to steal the lock.
|
||||
|
||||
This test verifies:
|
||||
1. Only one process can hold the lock at a time
|
||||
2. The lock cannot be stolen while being actively maintained (via heartbeat)
|
||||
3. Stale locks are only cleaned after TTL expires
|
||||
This is the key fix for the race condition issue: orphaned locks no longer
|
||||
permanently block the scheduler.
|
||||
"""
|
||||
# Process A acquires the lock
|
||||
# Simulate Process A acquiring the lock
|
||||
result_a = await acquire_scheduler_lock(lock_db)
|
||||
assert result_a is True
|
||||
|
||||
# Get the lock info
|
||||
info_a = await get_scheduler_lock_info(lock_db)
|
||||
assert info_a is not None
|
||||
lock_heartbeat_a = info_a["heartbeat_at"]
|
||||
# Get lock info to see heartbeat timeout
|
||||
info = await get_scheduler_lock_info(lock_db)
|
||||
assert info is not None
|
||||
heartbeat_timeout = info["heartbeat_timeout"]
|
||||
|
||||
# Process B tries to acquire — should fail
|
||||
# Simulate stale lock: manually set heartbeat to far in the past
|
||||
now = time.time()
|
||||
stale_heartbeat = now - heartbeat_timeout - 10
|
||||
await lock_db.execute(
|
||||
"UPDATE scheduler_lock SET heartbeat_at = ? WHERE id = 1",
|
||||
(stale_heartbeat,),
|
||||
)
|
||||
await lock_db.commit()
|
||||
|
||||
# Process B should now be able to acquire (steal) the stale lock
|
||||
result_b = await acquire_scheduler_lock(lock_db)
|
||||
assert result_b is False
|
||||
assert result_b is True
|
||||
|
||||
# Process A updates its heartbeat (simulating ongoing work)
|
||||
time.sleep(0.01)
|
||||
result_heartbeat = await update_scheduler_lock_heartbeat(lock_db)
|
||||
assert result_heartbeat is True
|
||||
# Verify Process B now holds the lock
|
||||
info_b = await get_scheduler_lock_info(lock_db)
|
||||
assert info_b is not None
|
||||
assert info_b["pid"] == os.getpid()
|
||||
|
||||
# Verify heartbeat was updated
|
||||
info_a_updated = await get_scheduler_lock_info(lock_db)
|
||||
assert info_a_updated is not None
|
||||
assert info_a_updated["heartbeat_at"] > lock_heartbeat_a
|
||||
|
||||
# Process B still cannot acquire the lock (it's active and well-maintained)
|
||||
result_b_retry = await acquire_scheduler_lock(lock_db)
|
||||
assert result_b_retry is False
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_lock_stale_function() -> None:
|
||||
"""Test the is_lock_stale helper function."""
|
||||
now = time.time()
|
||||
timeout = 300.0
|
||||
|
||||
# Process A releases the lock
|
||||
# Fresh lock is not stale
|
||||
heartbeat_at = now - 10
|
||||
assert await is_lock_stale(heartbeat_at, timeout, now) is False
|
||||
|
||||
# Lock past timeout is stale
|
||||
heartbeat_at = now - 400
|
||||
assert await is_lock_stale(heartbeat_at, timeout, now) is True
|
||||
|
||||
# Exactly at timeout is not stale (boundary condition)
|
||||
heartbeat_at = now - 300
|
||||
assert await is_lock_stale(heartbeat_at, timeout, now) is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_lock_health_no_lock(lock_db: aiosqlite.Connection) -> None:
|
||||
"""Test get_lock_health when no lock exists."""
|
||||
health = await get_lock_health(lock_db)
|
||||
assert health["has_lock"] is False
|
||||
assert health["is_stale"] is False
|
||||
assert health["pid"] is None
|
||||
assert health["stale_reason"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_lock_health_active_lock(lock_db: aiosqlite.Connection) -> None:
|
||||
"""Test get_lock_health with an active, healthy lock."""
|
||||
await acquire_scheduler_lock(lock_db)
|
||||
|
||||
health = await get_lock_health(lock_db)
|
||||
assert health["has_lock"] is True
|
||||
assert health["is_stale"] is False
|
||||
assert health["pid"] == os.getpid()
|
||||
assert health["hostname"] is not None
|
||||
assert health["heartbeat_timeout"] == SCHEDULER_LOCK_HEARTBEAT_TIMEOUT_SECONDS
|
||||
assert health["stale_reason"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_lock_health_stale_lock(lock_db: aiosqlite.Connection) -> None:
|
||||
"""Test get_lock_health with a stale lock."""
|
||||
await acquire_scheduler_lock(lock_db)
|
||||
|
||||
# Manually make the lock stale
|
||||
now = time.time()
|
||||
info = await get_scheduler_lock_info(lock_db)
|
||||
stale_heartbeat = now - info["heartbeat_timeout"] - 10
|
||||
await lock_db.execute(
|
||||
"UPDATE scheduler_lock SET heartbeat_at = ? WHERE id = 1",
|
||||
(stale_heartbeat,),
|
||||
)
|
||||
await lock_db.commit()
|
||||
|
||||
health = await get_lock_health(lock_db)
|
||||
assert health["has_lock"] is True
|
||||
assert health["is_stale"] is True
|
||||
assert health["stale_reason"] is not None
|
||||
assert "heartbeat_age" in health["stale_reason"]
|
||||
assert "timeout" in health["stale_reason"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_heartbeat_update_error_returns_false(
|
||||
lock_db: aiosqlite.Connection,
|
||||
) -> None:
|
||||
"""Test that heartbeat update errors return False instead of raising."""
|
||||
# Try to update heartbeat without acquiring lock first
|
||||
result = await update_scheduler_lock_heartbeat(lock_db)
|
||||
assert result is False
|
||||
|
||||
# Acquire lock
|
||||
await acquire_scheduler_lock(lock_db)
|
||||
|
||||
# Heartbeat should work
|
||||
result = await update_scheduler_lock_heartbeat(lock_db)
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_acquire_from_same_process(lock_db: aiosqlite.Connection) -> None:
|
||||
"""Test that concurrent acquire attempts from same process re-acquires (refreshes)."""
|
||||
# First acquisition should succeed
|
||||
result1 = await acquire_scheduler_lock(lock_db)
|
||||
assert result1 is True
|
||||
|
||||
# Second acquisition from same process should succeed (re-acquire/refresh)
|
||||
result2 = await acquire_scheduler_lock(lock_db)
|
||||
assert result2 is True
|
||||
|
||||
# Heartbeat should be updated
|
||||
info = await get_scheduler_lock_info(lock_db)
|
||||
assert info is not None
|
||||
|
||||
# Release and re-acquire should work
|
||||
await release_scheduler_lock(lock_db)
|
||||
|
||||
# Now Process B can acquire it
|
||||
result_b_final = await acquire_scheduler_lock(lock_db)
|
||||
assert result_b_final is True
|
||||
result3 = await acquire_scheduler_lock(lock_db)
|
||||
assert result3 is True
|
||||
|
||||
Reference in New Issue
Block a user