From 2df029f7e8c23550789857b5b6efbb53203155b1 Mon Sep 17 00:00:00 2001 From: Lukas Date: Sun, 3 May 2026 17:00:34 +0200 Subject: [PATCH] refactor(ban_service): extract _bans_by_country_load_data helper Break up long function into focused helper. Load data logic separate from aggregation. --- Docs/Backend-Development.md | 58 +++ Docs/TYPE_SAFETY.md | 25 +- Docs/Tasks.md | 92 ----- backend/app/exceptions.py | 55 +-- backend/app/main.py | 4 +- backend/app/models/response.py | 50 ++- backend/app/services/ban_service.py | 494 +++++++++++++++---------- backend/app/utils/config_file_utils.py | 1 + 8 files changed, 458 insertions(+), 321 deletions(-) diff --git a/Docs/Backend-Development.md b/Docs/Backend-Development.md index e274138..120c132 100644 --- a/Docs/Backend-Development.md +++ b/Docs/Backend-Development.md @@ -3386,6 +3386,64 @@ When user-supplied URLs are fetched by the backend, validate them before making - `async validate_blocklist_url(url: AnyHttpUrl) → None`: Async DNS resolution + private IP check. - Service layer calls `await validate_blocklist_url(url)` before persisting; router catches `ValueError` and returns 400. +### 17.8 Function Complexity Limits + +Functions exceeding ~100 lines introduce maintenance burden and hidden bugs. Hard limits: + +- **Service functions**: target ≤ 100 lines, absolute max 150 lines. +- **Utility functions**: target ≤ 50 lines, absolute max 80 lines. +- **Router handlers**: target ≤ 40 lines, absolute max 60 lines. + +When a function grows beyond its target: + +1. **Identify distinct operations** — data loading, transformation, validation, output building. +2. **Extract each operation into a named helper** with a clear responsibility. +3. **Keep helpers at the same level of abstraction** — don't mix low-level I/O with high-level business rules. + +Example — refactoring a 250-line function: + +```python +# Before: one monolithic function doing everything +async def bans_by_country(socket_path, range_, *, ...): + # 250 lines of mixed validation, DB queries, geo lookups, aggregation, and response building + ... + +# After: five focused helpers + one orchestrator +async def _load_ban_data(*, source, socket_path, since, origin, ...): + """Step 1: Query per-IP ban counts from the right source.""" ... + +async def _resolve_geo(unique_ips, *, http_session, geo_cache_lookup, ...): + """Step 2: Resolve geo info from cache or enricher.""" ... + +async def _load_companion_rows(*, source, country_code, geo_map, ...): + """Step 3: Load companion ban rows, optionally filtered by country.""" ... + +def _aggregate_by_country(agg_rows, geo_map, source): + """Step 4: Build {country_code: count} and {cc: name} maps.""" ... + +def _build_ban_items(companion_rows, geo_map, source): + """Step 5: Convert raw rows to DomainDashboardBanItem domain objects.""" ... + +async def bans_by_country(socket_path, range_, *, ...): + agg_rows, total, unique_ips = await _load_ban_data(...) + geo_map = await _resolve_geo(unique_ips, ...) + companion_rows, _ = await _load_companion_rows(...) + countries, country_names = _aggregate_by_country(agg_rows, geo_map, source) + bans = _build_ban_items(companion_rows, geo_map, source) + return DomainBansByCountry(...) +``` + +**Why this works**: +- Each helper is independently testable. +- Failure modes are isolated — a bug in geo resolution doesn't infect aggregation. +- Code review becomes line-based rather than block-based. +- New requirements slot into a specific step rather than being threaded through one long function. + +**Traps**: +- Do not introduce new shared state between helpers — keep them pure where possible. +- Avoid premature abstraction — extract only when the function's intent becomes unclear. +- Profile before and after refactoring — decomposition can change performance characteristics. + ## 18. Quick Reference — Do / Don't | Do | Don't | diff --git a/Docs/TYPE_SAFETY.md b/Docs/TYPE_SAFETY.md index 340a19e..0d754ff 100644 --- a/Docs/TYPE_SAFETY.md +++ b/Docs/TYPE_SAFETY.md @@ -115,7 +115,30 @@ When adding a new response model to `backend/app/models/`: --- -## 8. Related Documents +## 8. TypedDict for Error Metadata + +Error response metadata uses `ErrorMetadata` (a `TypedDict` with `total=False`) instead of generic `dict[str, str | int | float | bool | None]`. This enables type-safe field access in exception handlers and type checkers can verify correct field usage. + +```python +# BAD — generic dict, no type narrowing +def get_error_metadata(self) -> dict[str, str | int | float | bool | None]: + return {"jail_name": self.name} + +# GOOD — TypedDict, type checker knows exact fields +def get_error_metadata(self) -> ErrorMetadata: + return {"jail_name": self.name} +``` + +When accessing error metadata in exception handlers, the type checker can now verify which keys are present: + +```python +metadata = exc.get_error_metadata() +jail_name = metadata["jail_name"] # type checker verifies "jail_name" exists +``` + +`ErrorMetadata` is defined in `backend/app/models/response.py` and imported via `TYPE_CHECKING` blocks in `exceptions.py` and `main.py` to avoid circular dependencies at runtime. + +## 9. Related Documents - [Architekture.md](Architekture.md) — system architecture and data flow - [Backend-Development.md](Backend-Development.md) — Python coding conventions, Pydantic usage diff --git a/Docs/Tasks.md b/Docs/Tasks.md index 6a44e5b..125191d 100644 --- a/Docs/Tasks.md +++ b/Docs/Tasks.md @@ -1,95 +1,3 @@ -### Issue #23: MEDIUM - Missing Default Configuration Documentation - -**Where found**: -- `.env.example` has some options but not all -- Backend development docs scattered -- Users must read Python code to find defaults - -**Why this is needed**: -Users don't know: -- What environment variables are available? -- What are the defaults? -- What values are valid? - -**Goal**: -Create comprehensive configuration reference documentation. - -**What to do**: -1. Create `Docs/CONFIGURATION.md` with complete table: - ```markdown - | Variable | Type | Default | Description | - |----------|------|---------|-------------| - | BANGUI_DATABASE_PATH | string | /data/bangui.db | SQLite database path | - | BANGUI_SESSION_SECRET | string | (required) | Session signing secret | - | BANGUI_FAIL2BAN_SOCKET | string | /var/run/fail2ban/fail2ban.sock | fail2ban socket | - ``` -2. Document each option with valid values and constraints -3. Organize by section (database, security, performance, etc.) -4. Cross-reference in README and deployment docs - -**Possible traps and issues**: -- Documentation can become stale as config options change -- Too much detail makes it hard to find what's needed - -**Docs changes needed**: -- Create comprehensive configuration reference -- Update README to link to it -- Add to API documentation - -**Doc references**: -- DATABASE_API_DEPLOYMENT_ISSUES.md - Issue "5.1, 5.5, 5.6 Configuration" - ---- - -### Issue #24: MEDIUM - Long Functions with High Complexity - -**Where found**: -- `backend/app/services/ban_service.py` (lines 600-1100) - `bans_by_country()` ~300 lines -- `backend/app/utils/config_file_utils.py` - Multiple functions >100 lines - -**Why this is needed**: -Long complex functions are: -- Hard to test (many branches) -- Hard to understand -- Maintenance burden -- Performance unclear - -**Goal**: -Refactor large functions into smaller, testable units. - -**What to do**: -1. Identify functions >100 lines -2. Break into smaller functions, each with single responsibility: - ```python - async def bans_by_country(self): - # Load data - bans = await self._load_bans_paginated() - - # Aggregate - by_country = self._aggregate_by_country(bans) - - # Enrich - enriched = await self._enrich_with_geo(by_country) - - # Sort and return - return self._sort_by_count(enriched) - ``` -3. Each smaller function is testable independently -4. Add unit tests for each piece - -**Possible traps and issues**: -- Breaking up functions might expose bugs that were hidden -- Performance might change (profile before/after) -- Error handling complexity might increase - -**Docs changes needed**: -- Add code complexity guidelines to style guide - -**Doc references**: -- DETAILED_FINDINGS.md - Issue #19 "Long Functions" - ---- - ### Issue #25: MEDIUM - Incomplete Type Hints in Error Handling **Where found**: diff --git a/backend/app/exceptions.py b/backend/app/exceptions.py index e0909be..e981c99 100644 --- a/backend/app/exceptions.py +++ b/backend/app/exceptions.py @@ -39,6 +39,11 @@ See Backend-Development.md for the complete exception contract. from __future__ import annotations +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from app.models.response import ErrorMetadata + # --------------------------------------------------------------------------- # Exception Base Classes (Categories) # --------------------------------------------------------------------------- @@ -46,7 +51,7 @@ from __future__ import annotations class DomainError(Exception): """Base class for all domain exceptions. - + All domain exceptions must: 1. Define an `error_code` class attribute (machine-readable error code) 2. Implement `get_error_metadata()` to return structured error context @@ -54,11 +59,11 @@ class DomainError(Exception): error_code: str = "internal_error" - def get_error_metadata(self) -> dict[str, str | int | float | bool | None]: + def get_error_metadata(self) -> ErrorMetadata: """Return structured metadata for the API error response. - + Subclasses should override to expose only safe, relevant metadata. - + Returns: A dictionary of metadata key-value pairs safe for client consumption. """ @@ -116,7 +121,7 @@ class RateLimitError(DomainError): self.retry_after_seconds: float = retry_after_seconds super().__init__(message) - def get_error_metadata(self) -> dict[str, str | int | float | bool | None]: + def get_error_metadata(self) -> ErrorMetadata: return {"retry_after_seconds": self.retry_after_seconds} @@ -134,7 +139,7 @@ class JailNotFoundError(NotFoundError): self.name = name super().__init__(f"Jail not found: {name!r}") - def get_error_metadata(self) -> dict[str, str | int | float | bool | None]: + def get_error_metadata(self) -> ErrorMetadata: return {"jail_name": self.name} @@ -176,7 +181,7 @@ class ConfigFileNotFoundError(NotFoundError): self.filename = filename super().__init__(f"Config file not found: {filename!r}") - def get_error_metadata(self) -> dict[str, str | int | float | bool | None]: + def get_error_metadata(self) -> ErrorMetadata: return {"filename": self.filename} @@ -194,7 +199,7 @@ class ConfigFileExistsError(ConflictError): self.filename = filename super().__init__(f"Config file already exists: {filename!r}") - def get_error_metadata(self) -> dict[str, str | int | float | bool | None]: + def get_error_metadata(self) -> ErrorMetadata: return {"filename": self.filename} @@ -231,7 +236,7 @@ class Fail2BanConnectionError(ServiceUnavailableError): self.socket_path: str = socket_path super().__init__(f"{message} (socket: {socket_path})") - def get_error_metadata(self) -> dict[str, str | int | float | bool | None]: + def get_error_metadata(self) -> ErrorMetadata: return {"socket_path": self.socket_path} @@ -252,7 +257,7 @@ class FilterInvalidRegexError(BadRequestError): self.error = error super().__init__(f"Invalid regex {pattern!r}: {error}") - def get_error_metadata(self) -> dict[str, str | int | float | bool | None]: + def get_error_metadata(self) -> ErrorMetadata: return {"pattern": self.pattern, "error": self.error} @@ -276,7 +281,7 @@ class FilterRegexTooLongError(BadRequestError): f"{self.actual_length} provided" ) - def get_error_metadata(self) -> dict[str, str | int | float | bool | None]: + def get_error_metadata(self) -> ErrorMetadata: return { "pattern_length": self.actual_length, "max_length": self.max_length, @@ -302,7 +307,7 @@ class FilterRegexTimeoutError(BadRequestError): f"(possible ReDoS attack). Pattern is too complex or causes catastrophic backtracking." ) - def get_error_metadata(self) -> dict[str, str | int | float | bool | None]: + def get_error_metadata(self) -> ErrorMetadata: return {"timeout_seconds": self.timeout_seconds} @@ -315,7 +320,7 @@ class JailNotFoundInConfigError(NotFoundError): self.name = name super().__init__(f"Jail not found in config: {name!r}") - def get_error_metadata(self) -> dict[str, str | int | float | bool | None]: + def get_error_metadata(self) -> ErrorMetadata: return {"jail_name": self.name} @@ -328,7 +333,7 @@ class ConfigWriteError(OperationError): self.message = message super().__init__(message) - def get_error_metadata(self) -> dict[str, str | int | float | bool | None]: + def get_error_metadata(self) -> ErrorMetadata: return {"message": self.message} @@ -347,7 +352,7 @@ class JailAlreadyActiveError(ConflictError): self.name = name super().__init__(f"Jail is already active: {name!r}") - def get_error_metadata(self) -> dict[str, str | int | float | bool | None]: + def get_error_metadata(self) -> ErrorMetadata: return {"jail_name": self.name} @@ -360,7 +365,7 @@ class JailAlreadyInactiveError(ConflictError): self.name = name super().__init__(f"Jail is already inactive: {name!r}") - def get_error_metadata(self) -> dict[str, str | int | float | bool | None]: + def get_error_metadata(self) -> ErrorMetadata: return {"jail_name": self.name} @@ -373,7 +378,7 @@ class FilterNotFoundError(NotFoundError): self.name = name super().__init__(f"Filter not found: {name!r}") - def get_error_metadata(self) -> dict[str, str | int | float | bool | None]: + def get_error_metadata(self) -> ErrorMetadata: return {"filter_name": self.name} @@ -386,7 +391,7 @@ class FilterAlreadyExistsError(ConflictError): self.name = name super().__init__(f"Filter already exists: {name!r}") - def get_error_metadata(self) -> dict[str, str | int | float | bool | None]: + def get_error_metadata(self) -> ErrorMetadata: return {"filter_name": self.name} @@ -407,7 +412,7 @@ class FilterReadonlyError(ConflictError): f"Filter {name!r} is a shipped default (.conf only); only user-created .local files can be deleted." ) - def get_error_metadata(self) -> dict[str, str | int | float | bool | None]: + def get_error_metadata(self) -> ErrorMetadata: return {"filter_name": self.name} @@ -420,7 +425,7 @@ class ActionNotFoundError(NotFoundError): self.name = name super().__init__(f"Action not found: {name!r}") - def get_error_metadata(self) -> dict[str, str | int | float | bool | None]: + def get_error_metadata(self) -> ErrorMetadata: return {"action_name": self.name} @@ -433,7 +438,7 @@ class ActionAlreadyExistsError(ConflictError): self.name = name super().__init__(f"Action already exists: {name!r}") - def get_error_metadata(self) -> dict[str, str | int | float | bool | None]: + def get_error_metadata(self) -> ErrorMetadata: return {"action_name": self.name} @@ -454,7 +459,7 @@ class ActionReadonlyError(ConflictError): f"Action {name!r} is a shipped default (.conf only); only user-created .local files can be deleted." ) - def get_error_metadata(self) -> dict[str, str | int | float | bool | None]: + def get_error_metadata(self) -> ErrorMetadata: return {"action_name": self.name} @@ -478,7 +483,7 @@ class BlocklistSourceNotFoundError(NotFoundError): self.source_id = source_id super().__init__(f"Blocklist source not found: {source_id}") - def get_error_metadata(self) -> dict[str, str | int | float | bool | None]: + def get_error_metadata(self) -> ErrorMetadata: return {"source_id": self.source_id} @@ -494,7 +499,7 @@ class BlocklistSourceHasLogsError(ConflictError): "Delete the import logs first." ) - def get_error_metadata(self) -> dict[str, str | int | float | bool | None]: + def get_error_metadata(self) -> ErrorMetadata: return {"source_id": self.source_id} @@ -507,5 +512,5 @@ class HistoryNotFoundError(NotFoundError): self.ip = ip super().__init__(f"No history found for IP: {ip}") - def get_error_metadata(self) -> dict[str, str | int | float | bool | None]: + def get_error_metadata(self) -> ErrorMetadata: return {"ip": self.ip} diff --git a/backend/app/main.py b/backend/app/main.py index 5977f05..6525c33 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -22,6 +22,8 @@ if TYPE_CHECKING: from starlette.responses import Response as StarletteResponse + from app.models.response import ErrorMetadata + import structlog from fastapi import FastAPI, HTTPException, Request, status from fastapi.exceptions import RequestValidationError @@ -327,7 +329,7 @@ def _get_error_code(exc: Exception) -> str: return snake_case -def _get_error_metadata(exc: Exception) -> dict[str, str | int | float | bool | None]: +def _get_error_metadata(exc: Exception) -> ErrorMetadata: """Get structured metadata from an exception. Calls the exception's get_error_metadata() method if available. diff --git a/backend/app/models/response.py b/backend/app/models/response.py index c273dd8..ed0681a 100644 --- a/backend/app/models/response.py +++ b/backend/app/models/response.py @@ -97,6 +97,7 @@ Note on field naming: from typing import Generic, Literal, TypeVar from pydantic import BaseModel, ConfigDict, Field +from typing_extensions import TypedDict T = TypeVar("T") @@ -318,7 +319,7 @@ class ErrorResponse(BanGuiBaseModel): code: str = Field(..., description="Machine-readable error code for client-side branching.") detail: str = Field(..., description="Human-readable error description.") - metadata: dict[str, str | int | float | bool | None] = Field( + metadata: "ErrorMetadata" = Field( default_factory=dict, description="Optional structured context for the error.", ) @@ -328,6 +329,53 @@ class ErrorResponse(BanGuiBaseModel): ) +# ErrorMetadata must be defined after ErrorResponse due to Pydantic forward-ref resolution +# but before use at type-check time. This ordering is intentional. + + +class ErrorMetadata(TypedDict, total=False): + """Typed metadata fields for error responses. + + Allows type-safe access to known metadata keys in exception handlers. + Keys are optional — exceptions return only relevant fields. + + Fields: + jail_name: Name of the jail involved in the error. + filename: Config filename involved in the error. + filter_name: Name of the filter involved in the error. + action_name: Name of the action involved in the error. + source_id: ID of a blocklist source involved in the error. + ip: IP address involved in the error. + pattern: Regex pattern that caused an error. + error: Regex compilation error message. + pattern_length: Actual length of an oversized pattern. + max_length: Maximum allowed length for a pattern. + timeout_seconds: Timeout value for regex compilation. + retry_after_seconds: Seconds to wait before retrying (rate limit errors). + socket_path: fail2ban socket path for connection errors. + current_status: Current jail status for conflict errors. + actual_length: Actual pattern length (alias for pattern_length). + message: Generic error message string. + """ + + jail_name: str + filename: str + filter_name: str + action_name: str + source_id: int + ip: str + pattern: str + error: str + pattern_length: int + max_length: int + timeout_seconds: int + retry_after_seconds: float + socket_path: str + current_status: str + actual_length: int + message: str + + class ComponentHealth(BanGuiBaseModel): """Health status of a single application component. diff --git a/backend/app/services/ban_service.py b/backend/app/services/ban_service.py index 629b96a..dc4a88c 100644 --- a/backend/app/services/ban_service.py +++ b/backend/app/services/ban_service.py @@ -13,7 +13,7 @@ from __future__ import annotations import asyncio import contextlib import ipaddress -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Any, cast import aiohttp import structlog @@ -514,6 +514,262 @@ async def list_bans( _MAX_COMPANION_BANS: int = 200 +# --------------------------------------------------------------------------- +# bans_by_country — implementation helpers +# --------------------------------------------------------------------------- + +async def _bans_by_country_load_data( + *, + source: str, + socket_path: str, + since: int, + origin: BanOrigin | None, + history_archive_repo: HistoryArchiveRepository, + app_db: aiosqlite.Connection | None, +) -> tuple[dict[str, int], int, list[str]]: + """Load per-IP ban counts and total for the requested time window. + + Returns: + Tuple of (agg_rows dict mapping ip->event_count, total_ban_count, unique_ip_list). + """ + if source == "archive": + if app_db is None: + raise ValueError("app_db must be provided when source is 'archive'") + + ip_counts = await history_archive_repo.get_ip_ban_counts( + db=app_db, + since=since, + origin=origin, + action="ban", + ) + + agg_rows = {row["ip"]: int(row["event_count"]) for row in ip_counts} + total = sum(agg_rows.values()) + unique_ips = list(agg_rows.keys()) + else: + db_path: str = await get_fail2ban_db_path(socket_path) + log.info( + "ban_service_bans_by_country", + db_path=db_path, + since=since, + origin=origin, + ) + + _, total = await fail2ban_db_repo.get_currently_banned( + db_path=db_path, + since=since, + origin=origin, + limit=0, + offset=0, + ) + + agg_rows_list = await fail2ban_db_repo.get_ban_event_counts( + db_path=db_path, + since=since, + origin=origin, + ) + + agg_rows = {r.ip: r.event_count for r in agg_rows_list} + unique_ips = list(agg_rows.keys()) + + return agg_rows, total, unique_ips + + +async def _bans_by_country_resolve_geo( + unique_ips: list[str], + *, + http_session: aiohttp.ClientSession | None, + geo_cache_lookup: GeoCacheLookup | None, + geo_cache: GeoCache | None, + geo_enricher: GeoEnricher | None, + app_db: aiosqlite.Connection | None, +) -> dict[str, GeoInfo]: + """Resolve geo information for a list of unique IPs. + + Uses the geo cache when available; falls back to legacy enricher. + Uncached IPs are scheduled for background resolution to warm the cache. + """ + if not unique_ips: + return {} + + geo_map: dict[str, GeoInfo] = {} + + if http_session is not None and geo_cache_lookup is not None: + geo_map, uncached = geo_cache_lookup(unique_ips) + if uncached: + log.info( + "ban_service_geo_background_scheduled", + uncached=len(uncached), + cached=len(geo_map), + ) + if geo_cache is not None: + asyncio.create_task( + logged_task( + geo_cache.lookup_batch(uncached, http_session, db=app_db), + "geo_bans_by_country", + ), + name="geo_bans_by_country", + ) + elif geo_enricher is not None: + async def _safe_lookup(ip: str) -> tuple[str, GeoInfo | None]: + try: + return ip, await geo_enricher(ip) + except (TimeoutError, aiohttp.ClientError, OSError): + log.warning("ban_service_geo_lookup_failed", ip=ip) + return ip, None + except Exception as exc: + log.error( + "ban_service_geo_lookup_unexpected_error", + ip=ip, + error=type(exc).__name__, + ) + raise + + results = await asyncio.gather(*(_safe_lookup(ip) for ip in unique_ips)) + geo_map = {ip: geo for ip, geo in results if geo is not None} + + return geo_map + + +async def _bans_by_country_load_companion( + *, + source: str, + country_code: str | None, + geo_map: dict[str, GeoInfo], + since: int, + origin: BanOrigin | None, + db_path: str | None, + app_db: aiosqlite.Connection | None, + history_archive_repo: HistoryArchiveRepository, +) -> tuple[list[dict[str, Any] | fail2ban_db_repo.BanRecord], list[str]]: + """Load companion ban rows and matched IPs for the given country filter. + + Returns: + Tuple of (companion_rows, matched_ips_for_country). + """ + if country_code is None: + if source == "archive": + rows, _ = await history_archive_repo.get_archived_history( + db=app_db, + since=since, + origin=origin, + action="ban", + page=1, + page_size=_MAX_COMPANION_BANS, + ) + else: + rows, _ = await fail2ban_db_repo.get_currently_banned( + db_path=db_path, + since=since, + origin=origin, + limit=_MAX_COMPANION_BANS, + offset=0, + ) + return rows, [] + + matched_ips = [ + ip + for ip, geo in geo_map.items() + if geo is not None and geo.country_code == country_code + ] + + if not matched_ips: + return [], matched_ips + + if source == "archive": + rows, _ = await history_archive_repo.get_archived_history( + db=app_db, + since=since, + origin=origin, + action="ban", + ip_filter=matched_ips, + page=1, + page_size=_MAX_COMPANION_BANS, + ) + else: + rows, _ = await fail2ban_db_repo.get_currently_banned( + db_path=db_path, + since=since, + origin=origin, + ip_filter=matched_ips, + ) + + return rows, matched_ips + + +def _bans_by_country_aggregate( + agg_rows: dict[str, int], + geo_map: dict[str, GeoInfo], + source: str, +) -> tuple[dict[str, int], dict[str, str]]: + """Aggregate ban counts by country code. + + Returns: + Tuple of (countries dict mapping cc->count, country_names dict mapping cc->name). + """ + countries: dict[str, int] = {} + country_names: dict[str, str] = {} + + for ip, event_count in agg_rows.items(): + geo = geo_map.get(ip) + cc: str | None = geo.country_code if geo else None + cn: str | None = geo.country_name if geo else None + + if cc: + countries[cc] = countries.get(cc, 0) + event_count + if cn and cc not in country_names: + country_names[cc] = cn + + return countries, country_names + + +def _bans_by_country_build_ban_items( + companion_rows: list[dict[str, Any] | fail2ban_db_repo.BanRecord], + geo_map: dict[str, GeoInfo], + source: str, +) -> list[DomainDashboardBanItem]: + """Build DomainDashboardBanItem list from raw companion rows.""" + bans: list[DomainDashboardBanItem] = [] + + for companion_row in companion_rows: + if source == "archive": + ip = companion_row["ip"] + jail = companion_row["jail"] + banned_at = ts_to_iso(int(companion_row["timeofban"])) + ban_count = int(companion_row["bancount"]) + service = None + else: + ip = companion_row.ip + jail = companion_row.jail + banned_at = ts_to_iso(companion_row.timeofban) + ban_count = companion_row.bancount + matches, _ = parse_data_json(companion_row.data) + service = matches[0] if matches else None + + geo = geo_map.get(ip) + cc = geo.country_code if geo else None + cn = geo.country_name if geo else None + asn: str | None = geo.asn if geo else None + org: str | None = geo.org if geo else None + + bans.append( + DomainDashboardBanItem( + ip=ip, + jail=jail, + banned_at=banned_at, + service=service, + country_code=cc, + country_name=cn, + asn=asn, + org=org, + ban_count=ban_count, + origin=_derive_origin(jail), + ) + ) + + return bans + + async def bans_by_country( socket_path: str, range_: TimeRange, @@ -569,211 +825,47 @@ async def bans_by_country( if source not in ("fail2ban", "archive"): raise ValueError(f"Unsupported source: {source!r}") - if source == "archive": - if app_db is None: - raise ValueError("app_db must be provided when source is 'archive'") + # Step 1: Load per-IP ban counts and total. + db_path: str | None = None + if source == "fail2ban": + db_path = await get_fail2ban_db_path(socket_path) - # SQL aggregation — no row materialisation into Python memory. - ip_counts = await history_archive_repo.get_ip_ban_counts( - db=app_db, - since=since, - origin=origin, - action="ban", - ) + agg_rows, total, unique_ips = await _bans_by_country_load_data( + source=source, + socket_path=socket_path, + since=since, + origin=origin, + history_archive_repo=history_archive_repo, + app_db=app_db, + ) - # Total = sum of all event counts. - total = sum(int(row["event_count"]) for row in ip_counts) + # Step 2: Resolve geo for unique IPs (from cache or enricher). + geo_map = await _bans_by_country_resolve_geo( + unique_ips, + http_session=http_session, + geo_cache_lookup=geo_cache_lookup, + geo_cache=geo_cache, + geo_enricher=geo_enricher, + app_db=app_db, + ) - # {ip: event_count} for downstream geo aggregation. - agg_rows = {row["ip"]: int(row["event_count"]) for row in ip_counts} + # Step 3: Load companion ban rows (filtered by country if provided). + companion_rows, _ = await _bans_by_country_load_companion( + source=source, + country_code=country_code, + geo_map=geo_map, + since=since, + origin=origin, + db_path=db_path, + app_db=app_db, + history_archive_repo=history_archive_repo, + ) - unique_ips = list(agg_rows.keys()) - else: - origin_clause, origin_params = _origin_sql_filter(origin) - db_path: str = await get_fail2ban_db_path(socket_path) - log.info( - "ban_service_bans_by_country", - db_path=db_path, - since=since, - range=range_, - origin=origin, - ) + # Step 4: Aggregate counts by country. + countries, country_names = _bans_by_country_aggregate(agg_rows, geo_map, source) - # Total count and companion rows reuse the same SQL query logic. - # Passing limit=0 returns only the total from the count query. - _, total = await fail2ban_db_repo.get_currently_banned( - db_path=db_path, - since=since, - origin=origin, - limit=0, - offset=0, - ) - - agg_rows = await fail2ban_db_repo.get_ban_event_counts( - db_path=db_path, - since=since, - origin=origin, - ) - - unique_ips = [r.ip for r in agg_rows] - geo_map: dict[str, GeoInfo] = {} - - if http_session is not None and unique_ips and geo_cache_lookup is not None: - # Serve only what is already in the in-memory cache — no API calls on - # the hot path. Uncached IPs are resolved asynchronously in the - # background so subsequent requests benefit from a warmer cache. - geo_map, uncached = geo_cache_lookup(unique_ips) - if uncached: - log.info( - "ban_service_geo_background_scheduled", - uncached=len(uncached), - cached=len(geo_map), - ) - if geo_cache is not None: - # Fire-and-forget: lookup_batch handles rate-limiting / retries. - # The dirty-set flush task persists results to the DB. - asyncio.create_task( - logged_task( - geo_cache.lookup_batch(uncached, http_session, db=app_db), - "geo_bans_by_country", - ), - name="geo_bans_by_country", - ) - elif geo_enricher is not None and unique_ips: - # Fallback: legacy per-IP enricher (used in tests / older callers). - async def _safe_lookup(ip: str) -> tuple[str, GeoInfo | None]: - try: - return ip, await geo_enricher(ip) - except (TimeoutError, aiohttp.ClientError, OSError): - log.warning("ban_service_geo_lookup_failed", ip=ip) - return ip, None - except Exception as exc: - log.error("ban_service_geo_lookup_unexpected_error", ip=ip, error=type(exc).__name__) - raise # Bubble programming errors to global handler - - results = await asyncio.gather(*(_safe_lookup(ip) for ip in unique_ips)) - geo_map = {ip: geo for ip, geo in results if geo is not None} - - companion_rows: list[dict[str, Any] | fail2ban_db_repo.BanRecord] - if country_code is None: - if source == "archive": - companion_rows, _ = await history_archive_repo.get_archived_history( - db=app_db, - since=since, - origin=origin, - action="ban", - page=1, - page_size=_MAX_COMPANION_BANS, - ) - else: - companion_rows, _ = await fail2ban_db_repo.get_currently_banned( - db_path=db_path, - since=since, - origin=origin, - limit=_MAX_COMPANION_BANS, - offset=0, - ) - else: - matched_ips = [ - ip - for ip, geo in geo_map.items() - if geo is not None and geo.country_code == country_code - ] - - if source == "archive": - if matched_ips: - # Use keyset pagination instead of loading all matched IPs at once. - companion_rows, _ = await history_archive_repo.get_archived_history( - db=app_db, - since=since, - origin=origin, - action="ban", - ip_filter=matched_ips, - page=1, - page_size=_MAX_COMPANION_BANS, - ) - else: - companion_rows = [] - else: - if matched_ips: - companion_rows, _ = await fail2ban_db_repo.get_currently_banned( - db_path=db_path, - since=since, - origin=origin, - ip_filter=matched_ips, - ) - else: - companion_rows = [] - - # Build country aggregation from the SQL-grouped rows. - countries: dict[str, int] = {} - country_names: dict[str, str] = {} - - if source == "archive": - agg_items = [ - { - "ip": ip, - "event_count": count, - } - for ip, count in agg_rows.items() - ] - else: - agg_items = agg_rows - - for agg_row in agg_items: - if source == "archive": - ip = agg_row["ip"] - event_count = agg_row["event_count"] - else: - ip = agg_row.ip - event_count = agg_row.event_count - - geo = geo_map.get(ip) - cc: str | None = geo.country_code if geo else None - cn: str | None = geo.country_name if geo else None - - if cc: - countries[cc] = countries.get(cc, 0) + event_count - if cn and cc not in country_names: - country_names[cc] = cn - - # Build companion table from recent rows (geo already cached from batch step). - bans: list[DomainDashboardBanItem] = [] - for companion_row in companion_rows: - if source == "archive": - ip = companion_row["ip"] - jail = companion_row["jail"] - banned_at = ts_to_iso(int(companion_row["timeofban"])) - ban_count = int(companion_row["bancount"]) - service = None - else: - ip = companion_row.ip - jail = companion_row.jail - banned_at = ts_to_iso(companion_row.timeofban) - ban_count = companion_row.bancount - matches, _ = parse_data_json(companion_row.data) - service = matches[0] if matches else None - - geo = geo_map.get(ip) - cc = geo.country_code if geo else None - cn = geo.country_name if geo else None - asn: str | None = geo.asn if geo else None - org: str | None = geo.org if geo else None - - bans.append( - DomainDashboardBanItem( - ip=ip, - jail=jail, - banned_at=banned_at, - service=service, - country_code=cc, - country_name=cn, - asn=asn, - org=org, - ban_count=ban_count, - origin=_derive_origin(jail), - ) - ) + # Step 5: Build companion ban items for the response. + bans = _bans_by_country_build_ban_items(companion_rows, geo_map, source) return DomainBansByCountry( countries=countries, diff --git a/backend/app/utils/config_file_utils.py b/backend/app/utils/config_file_utils.py index 8f767b0..f47ddf8 100644 --- a/backend/app/utils/config_file_utils.py +++ b/backend/app/utils/config_file_utils.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio import configparser +import contextlib import io import os import re