refactor(ban_service): extract _bans_by_country_load_data helper

Break up long function into focused helper. Load data logic separate from aggregation.
This commit is contained in:
2026-05-03 17:00:34 +02:00
parent 5058a50143
commit 2df029f7e8
8 changed files with 458 additions and 321 deletions

View File

@@ -3386,6 +3386,64 @@ When user-supplied URLs are fetched by the backend, validate them before making
- `async validate_blocklist_url(url: AnyHttpUrl) → None`: Async DNS resolution + private IP check. - `async validate_blocklist_url(url: AnyHttpUrl) → None`: Async DNS resolution + private IP check.
- Service layer calls `await validate_blocklist_url(url)` before persisting; router catches `ValueError` and returns 400. - Service layer calls `await validate_blocklist_url(url)` before persisting; router catches `ValueError` and returns 400.
### 17.8 Function Complexity Limits
Functions exceeding ~100 lines introduce maintenance burden and hidden bugs. Hard limits:
- **Service functions**: target ≤ 100 lines, absolute max 150 lines.
- **Utility functions**: target ≤ 50 lines, absolute max 80 lines.
- **Router handlers**: target ≤ 40 lines, absolute max 60 lines.
When a function grows beyond its target:
1. **Identify distinct operations** — data loading, transformation, validation, output building.
2. **Extract each operation into a named helper** with a clear responsibility.
3. **Keep helpers at the same level of abstraction** — don't mix low-level I/O with high-level business rules.
Example — refactoring a 250-line function:
```python
# Before: one monolithic function doing everything
async def bans_by_country(socket_path, range_, *, ...):
# 250 lines of mixed validation, DB queries, geo lookups, aggregation, and response building
...
# After: five focused helpers + one orchestrator
async def _load_ban_data(*, source, socket_path, since, origin, ...):
"""Step 1: Query per-IP ban counts from the right source.""" ...
async def _resolve_geo(unique_ips, *, http_session, geo_cache_lookup, ...):
"""Step 2: Resolve geo info from cache or enricher.""" ...
async def _load_companion_rows(*, source, country_code, geo_map, ...):
"""Step 3: Load companion ban rows, optionally filtered by country.""" ...
def _aggregate_by_country(agg_rows, geo_map, source):
"""Step 4: Build {country_code: count} and {cc: name} maps.""" ...
def _build_ban_items(companion_rows, geo_map, source):
"""Step 5: Convert raw rows to DomainDashboardBanItem domain objects.""" ...
async def bans_by_country(socket_path, range_, *, ...):
agg_rows, total, unique_ips = await _load_ban_data(...)
geo_map = await _resolve_geo(unique_ips, ...)
companion_rows, _ = await _load_companion_rows(...)
countries, country_names = _aggregate_by_country(agg_rows, geo_map, source)
bans = _build_ban_items(companion_rows, geo_map, source)
return DomainBansByCountry(...)
```
**Why this works**:
- Each helper is independently testable.
- Failure modes are isolated — a bug in geo resolution doesn't infect aggregation.
- Code review becomes line-based rather than block-based.
- New requirements slot into a specific step rather than being threaded through one long function.
**Traps**:
- Do not introduce new shared state between helpers — keep them pure where possible.
- Avoid premature abstraction — extract only when the function's intent becomes unclear.
- Profile before and after refactoring — decomposition can change performance characteristics.
## 18. Quick Reference — Do / Don't ## 18. Quick Reference — Do / Don't
| Do | Don't | | Do | Don't |

View File

@@ -115,7 +115,30 @@ When adding a new response model to `backend/app/models/`:
--- ---
## 8. Related Documents ## 8. TypedDict for Error Metadata
Error response metadata uses `ErrorMetadata` (a `TypedDict` with `total=False`) instead of generic `dict[str, str | int | float | bool | None]`. This enables type-safe field access in exception handlers and type checkers can verify correct field usage.
```python
# BAD — generic dict, no type narrowing
def get_error_metadata(self) -> dict[str, str | int | float | bool | None]:
return {"jail_name": self.name}
# GOOD — TypedDict, type checker knows exact fields
def get_error_metadata(self) -> ErrorMetadata:
return {"jail_name": self.name}
```
When accessing error metadata in exception handlers, the type checker can now verify which keys are present:
```python
metadata = exc.get_error_metadata()
jail_name = metadata["jail_name"] # type checker verifies "jail_name" exists
```
`ErrorMetadata` is defined in `backend/app/models/response.py` and imported via `TYPE_CHECKING` blocks in `exceptions.py` and `main.py` to avoid circular dependencies at runtime.
## 9. Related Documents
- [Architekture.md](Architekture.md) — system architecture and data flow - [Architekture.md](Architekture.md) — system architecture and data flow
- [Backend-Development.md](Backend-Development.md) — Python coding conventions, Pydantic usage - [Backend-Development.md](Backend-Development.md) — Python coding conventions, Pydantic usage

View File

@@ -1,95 +1,3 @@
### Issue #23: MEDIUM - Missing Default Configuration Documentation
**Where found**:
- `.env.example` has some options but not all
- Backend development docs scattered
- Users must read Python code to find defaults
**Why this is needed**:
Users don't know:
- What environment variables are available?
- What are the defaults?
- What values are valid?
**Goal**:
Create comprehensive configuration reference documentation.
**What to do**:
1. Create `Docs/CONFIGURATION.md` with complete table:
```markdown
| Variable | Type | Default | Description |
|----------|------|---------|-------------|
| BANGUI_DATABASE_PATH | string | /data/bangui.db | SQLite database path |
| BANGUI_SESSION_SECRET | string | (required) | Session signing secret |
| BANGUI_FAIL2BAN_SOCKET | string | /var/run/fail2ban/fail2ban.sock | fail2ban socket |
```
2. Document each option with valid values and constraints
3. Organize by section (database, security, performance, etc.)
4. Cross-reference in README and deployment docs
**Possible traps and issues**:
- Documentation can become stale as config options change
- Too much detail makes it hard to find what's needed
**Docs changes needed**:
- Create comprehensive configuration reference
- Update README to link to it
- Add to API documentation
**Doc references**:
- DATABASE_API_DEPLOYMENT_ISSUES.md - Issue "5.1, 5.5, 5.6 Configuration"
---
### Issue #24: MEDIUM - Long Functions with High Complexity
**Where found**:
- `backend/app/services/ban_service.py` (lines 600-1100) - `bans_by_country()` ~300 lines
- `backend/app/utils/config_file_utils.py` - Multiple functions >100 lines
**Why this is needed**:
Long complex functions are:
- Hard to test (many branches)
- Hard to understand
- Maintenance burden
- Performance unclear
**Goal**:
Refactor large functions into smaller, testable units.
**What to do**:
1. Identify functions >100 lines
2. Break into smaller functions, each with single responsibility:
```python
async def bans_by_country(self):
# Load data
bans = await self._load_bans_paginated()
# Aggregate
by_country = self._aggregate_by_country(bans)
# Enrich
enriched = await self._enrich_with_geo(by_country)
# Sort and return
return self._sort_by_count(enriched)
```
3. Each smaller function is testable independently
4. Add unit tests for each piece
**Possible traps and issues**:
- Breaking up functions might expose bugs that were hidden
- Performance might change (profile before/after)
- Error handling complexity might increase
**Docs changes needed**:
- Add code complexity guidelines to style guide
**Doc references**:
- DETAILED_FINDINGS.md - Issue #19 "Long Functions"
---
### Issue #25: MEDIUM - Incomplete Type Hints in Error Handling ### Issue #25: MEDIUM - Incomplete Type Hints in Error Handling
**Where found**: **Where found**:

View File

@@ -39,6 +39,11 @@ See Backend-Development.md for the complete exception contract.
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from app.models.response import ErrorMetadata
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Exception Base Classes (Categories) # Exception Base Classes (Categories)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -54,7 +59,7 @@ class DomainError(Exception):
error_code: str = "internal_error" error_code: str = "internal_error"
def get_error_metadata(self) -> dict[str, str | int | float | bool | None]: def get_error_metadata(self) -> ErrorMetadata:
"""Return structured metadata for the API error response. """Return structured metadata for the API error response.
Subclasses should override to expose only safe, relevant metadata. Subclasses should override to expose only safe, relevant metadata.
@@ -116,7 +121,7 @@ class RateLimitError(DomainError):
self.retry_after_seconds: float = retry_after_seconds self.retry_after_seconds: float = retry_after_seconds
super().__init__(message) super().__init__(message)
def get_error_metadata(self) -> dict[str, str | int | float | bool | None]: def get_error_metadata(self) -> ErrorMetadata:
return {"retry_after_seconds": self.retry_after_seconds} return {"retry_after_seconds": self.retry_after_seconds}
@@ -134,7 +139,7 @@ class JailNotFoundError(NotFoundError):
self.name = name self.name = name
super().__init__(f"Jail not found: {name!r}") super().__init__(f"Jail not found: {name!r}")
def get_error_metadata(self) -> dict[str, str | int | float | bool | None]: def get_error_metadata(self) -> ErrorMetadata:
return {"jail_name": self.name} return {"jail_name": self.name}
@@ -176,7 +181,7 @@ class ConfigFileNotFoundError(NotFoundError):
self.filename = filename self.filename = filename
super().__init__(f"Config file not found: {filename!r}") super().__init__(f"Config file not found: {filename!r}")
def get_error_metadata(self) -> dict[str, str | int | float | bool | None]: def get_error_metadata(self) -> ErrorMetadata:
return {"filename": self.filename} return {"filename": self.filename}
@@ -194,7 +199,7 @@ class ConfigFileExistsError(ConflictError):
self.filename = filename self.filename = filename
super().__init__(f"Config file already exists: {filename!r}") super().__init__(f"Config file already exists: {filename!r}")
def get_error_metadata(self) -> dict[str, str | int | float | bool | None]: def get_error_metadata(self) -> ErrorMetadata:
return {"filename": self.filename} return {"filename": self.filename}
@@ -231,7 +236,7 @@ class Fail2BanConnectionError(ServiceUnavailableError):
self.socket_path: str = socket_path self.socket_path: str = socket_path
super().__init__(f"{message} (socket: {socket_path})") super().__init__(f"{message} (socket: {socket_path})")
def get_error_metadata(self) -> dict[str, str | int | float | bool | None]: def get_error_metadata(self) -> ErrorMetadata:
return {"socket_path": self.socket_path} return {"socket_path": self.socket_path}
@@ -252,7 +257,7 @@ class FilterInvalidRegexError(BadRequestError):
self.error = error self.error = error
super().__init__(f"Invalid regex {pattern!r}: {error}") super().__init__(f"Invalid regex {pattern!r}: {error}")
def get_error_metadata(self) -> dict[str, str | int | float | bool | None]: def get_error_metadata(self) -> ErrorMetadata:
return {"pattern": self.pattern, "error": self.error} return {"pattern": self.pattern, "error": self.error}
@@ -276,7 +281,7 @@ class FilterRegexTooLongError(BadRequestError):
f"{self.actual_length} provided" f"{self.actual_length} provided"
) )
def get_error_metadata(self) -> dict[str, str | int | float | bool | None]: def get_error_metadata(self) -> ErrorMetadata:
return { return {
"pattern_length": self.actual_length, "pattern_length": self.actual_length,
"max_length": self.max_length, "max_length": self.max_length,
@@ -302,7 +307,7 @@ class FilterRegexTimeoutError(BadRequestError):
f"(possible ReDoS attack). Pattern is too complex or causes catastrophic backtracking." f"(possible ReDoS attack). Pattern is too complex or causes catastrophic backtracking."
) )
def get_error_metadata(self) -> dict[str, str | int | float | bool | None]: def get_error_metadata(self) -> ErrorMetadata:
return {"timeout_seconds": self.timeout_seconds} return {"timeout_seconds": self.timeout_seconds}
@@ -315,7 +320,7 @@ class JailNotFoundInConfigError(NotFoundError):
self.name = name self.name = name
super().__init__(f"Jail not found in config: {name!r}") super().__init__(f"Jail not found in config: {name!r}")
def get_error_metadata(self) -> dict[str, str | int | float | bool | None]: def get_error_metadata(self) -> ErrorMetadata:
return {"jail_name": self.name} return {"jail_name": self.name}
@@ -328,7 +333,7 @@ class ConfigWriteError(OperationError):
self.message = message self.message = message
super().__init__(message) super().__init__(message)
def get_error_metadata(self) -> dict[str, str | int | float | bool | None]: def get_error_metadata(self) -> ErrorMetadata:
return {"message": self.message} return {"message": self.message}
@@ -347,7 +352,7 @@ class JailAlreadyActiveError(ConflictError):
self.name = name self.name = name
super().__init__(f"Jail is already active: {name!r}") super().__init__(f"Jail is already active: {name!r}")
def get_error_metadata(self) -> dict[str, str | int | float | bool | None]: def get_error_metadata(self) -> ErrorMetadata:
return {"jail_name": self.name} return {"jail_name": self.name}
@@ -360,7 +365,7 @@ class JailAlreadyInactiveError(ConflictError):
self.name = name self.name = name
super().__init__(f"Jail is already inactive: {name!r}") super().__init__(f"Jail is already inactive: {name!r}")
def get_error_metadata(self) -> dict[str, str | int | float | bool | None]: def get_error_metadata(self) -> ErrorMetadata:
return {"jail_name": self.name} return {"jail_name": self.name}
@@ -373,7 +378,7 @@ class FilterNotFoundError(NotFoundError):
self.name = name self.name = name
super().__init__(f"Filter not found: {name!r}") super().__init__(f"Filter not found: {name!r}")
def get_error_metadata(self) -> dict[str, str | int | float | bool | None]: def get_error_metadata(self) -> ErrorMetadata:
return {"filter_name": self.name} return {"filter_name": self.name}
@@ -386,7 +391,7 @@ class FilterAlreadyExistsError(ConflictError):
self.name = name self.name = name
super().__init__(f"Filter already exists: {name!r}") super().__init__(f"Filter already exists: {name!r}")
def get_error_metadata(self) -> dict[str, str | int | float | bool | None]: def get_error_metadata(self) -> ErrorMetadata:
return {"filter_name": self.name} return {"filter_name": self.name}
@@ -407,7 +412,7 @@ class FilterReadonlyError(ConflictError):
f"Filter {name!r} is a shipped default (.conf only); only user-created .local files can be deleted." f"Filter {name!r} is a shipped default (.conf only); only user-created .local files can be deleted."
) )
def get_error_metadata(self) -> dict[str, str | int | float | bool | None]: def get_error_metadata(self) -> ErrorMetadata:
return {"filter_name": self.name} return {"filter_name": self.name}
@@ -420,7 +425,7 @@ class ActionNotFoundError(NotFoundError):
self.name = name self.name = name
super().__init__(f"Action not found: {name!r}") super().__init__(f"Action not found: {name!r}")
def get_error_metadata(self) -> dict[str, str | int | float | bool | None]: def get_error_metadata(self) -> ErrorMetadata:
return {"action_name": self.name} return {"action_name": self.name}
@@ -433,7 +438,7 @@ class ActionAlreadyExistsError(ConflictError):
self.name = name self.name = name
super().__init__(f"Action already exists: {name!r}") super().__init__(f"Action already exists: {name!r}")
def get_error_metadata(self) -> dict[str, str | int | float | bool | None]: def get_error_metadata(self) -> ErrorMetadata:
return {"action_name": self.name} return {"action_name": self.name}
@@ -454,7 +459,7 @@ class ActionReadonlyError(ConflictError):
f"Action {name!r} is a shipped default (.conf only); only user-created .local files can be deleted." f"Action {name!r} is a shipped default (.conf only); only user-created .local files can be deleted."
) )
def get_error_metadata(self) -> dict[str, str | int | float | bool | None]: def get_error_metadata(self) -> ErrorMetadata:
return {"action_name": self.name} return {"action_name": self.name}
@@ -478,7 +483,7 @@ class BlocklistSourceNotFoundError(NotFoundError):
self.source_id = source_id self.source_id = source_id
super().__init__(f"Blocklist source not found: {source_id}") super().__init__(f"Blocklist source not found: {source_id}")
def get_error_metadata(self) -> dict[str, str | int | float | bool | None]: def get_error_metadata(self) -> ErrorMetadata:
return {"source_id": self.source_id} return {"source_id": self.source_id}
@@ -494,7 +499,7 @@ class BlocklistSourceHasLogsError(ConflictError):
"Delete the import logs first." "Delete the import logs first."
) )
def get_error_metadata(self) -> dict[str, str | int | float | bool | None]: def get_error_metadata(self) -> ErrorMetadata:
return {"source_id": self.source_id} return {"source_id": self.source_id}
@@ -507,5 +512,5 @@ class HistoryNotFoundError(NotFoundError):
self.ip = ip self.ip = ip
super().__init__(f"No history found for IP: {ip}") super().__init__(f"No history found for IP: {ip}")
def get_error_metadata(self) -> dict[str, str | int | float | bool | None]: def get_error_metadata(self) -> ErrorMetadata:
return {"ip": self.ip} return {"ip": self.ip}

View File

@@ -22,6 +22,8 @@ if TYPE_CHECKING:
from starlette.responses import Response as StarletteResponse from starlette.responses import Response as StarletteResponse
from app.models.response import ErrorMetadata
import structlog import structlog
from fastapi import FastAPI, HTTPException, Request, status from fastapi import FastAPI, HTTPException, Request, status
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
@@ -327,7 +329,7 @@ def _get_error_code(exc: Exception) -> str:
return snake_case return snake_case
def _get_error_metadata(exc: Exception) -> dict[str, str | int | float | bool | None]: def _get_error_metadata(exc: Exception) -> ErrorMetadata:
"""Get structured metadata from an exception. """Get structured metadata from an exception.
Calls the exception's get_error_metadata() method if available. Calls the exception's get_error_metadata() method if available.

View File

@@ -97,6 +97,7 @@ Note on field naming:
from typing import Generic, Literal, TypeVar from typing import Generic, Literal, TypeVar
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import TypedDict
T = TypeVar("T") T = TypeVar("T")
@@ -318,7 +319,7 @@ class ErrorResponse(BanGuiBaseModel):
code: str = Field(..., description="Machine-readable error code for client-side branching.") code: str = Field(..., description="Machine-readable error code for client-side branching.")
detail: str = Field(..., description="Human-readable error description.") detail: str = Field(..., description="Human-readable error description.")
metadata: dict[str, str | int | float | bool | None] = Field( metadata: "ErrorMetadata" = Field(
default_factory=dict, default_factory=dict,
description="Optional structured context for the error.", description="Optional structured context for the error.",
) )
@@ -328,6 +329,53 @@ class ErrorResponse(BanGuiBaseModel):
) )
# ErrorMetadata must be defined after ErrorResponse due to Pydantic forward-ref resolution
# but before use at type-check time. This ordering is intentional.
class ErrorMetadata(TypedDict, total=False):
"""Typed metadata fields for error responses.
Allows type-safe access to known metadata keys in exception handlers.
Keys are optional — exceptions return only relevant fields.
Fields:
jail_name: Name of the jail involved in the error.
filename: Config filename involved in the error.
filter_name: Name of the filter involved in the error.
action_name: Name of the action involved in the error.
source_id: ID of a blocklist source involved in the error.
ip: IP address involved in the error.
pattern: Regex pattern that caused an error.
error: Regex compilation error message.
pattern_length: Actual length of an oversized pattern.
max_length: Maximum allowed length for a pattern.
timeout_seconds: Timeout value for regex compilation.
retry_after_seconds: Seconds to wait before retrying (rate limit errors).
socket_path: fail2ban socket path for connection errors.
current_status: Current jail status for conflict errors.
actual_length: Actual pattern length (alias for pattern_length).
message: Generic error message string.
"""
jail_name: str
filename: str
filter_name: str
action_name: str
source_id: int
ip: str
pattern: str
error: str
pattern_length: int
max_length: int
timeout_seconds: int
retry_after_seconds: float
socket_path: str
current_status: str
actual_length: int
message: str
class ComponentHealth(BanGuiBaseModel): class ComponentHealth(BanGuiBaseModel):
"""Health status of a single application component. """Health status of a single application component.

View File

@@ -13,7 +13,7 @@ from __future__ import annotations
import asyncio import asyncio
import contextlib import contextlib
import ipaddress import ipaddress
from typing import TYPE_CHECKING, cast from typing import TYPE_CHECKING, Any, cast
import aiohttp import aiohttp
import structlog import structlog
@@ -514,6 +514,262 @@ async def list_bans(
_MAX_COMPANION_BANS: int = 200 _MAX_COMPANION_BANS: int = 200
# ---------------------------------------------------------------------------
# bans_by_country — implementation helpers
# ---------------------------------------------------------------------------
async def _bans_by_country_load_data(
*,
source: str,
socket_path: str,
since: int,
origin: BanOrigin | None,
history_archive_repo: HistoryArchiveRepository,
app_db: aiosqlite.Connection | None,
) -> tuple[dict[str, int], int, list[str]]:
"""Load per-IP ban counts and total for the requested time window.
Returns:
Tuple of (agg_rows dict mapping ip->event_count, total_ban_count, unique_ip_list).
"""
if source == "archive":
if app_db is None:
raise ValueError("app_db must be provided when source is 'archive'")
ip_counts = await history_archive_repo.get_ip_ban_counts(
db=app_db,
since=since,
origin=origin,
action="ban",
)
agg_rows = {row["ip"]: int(row["event_count"]) for row in ip_counts}
total = sum(agg_rows.values())
unique_ips = list(agg_rows.keys())
else:
db_path: str = await get_fail2ban_db_path(socket_path)
log.info(
"ban_service_bans_by_country",
db_path=db_path,
since=since,
origin=origin,
)
_, total = await fail2ban_db_repo.get_currently_banned(
db_path=db_path,
since=since,
origin=origin,
limit=0,
offset=0,
)
agg_rows_list = await fail2ban_db_repo.get_ban_event_counts(
db_path=db_path,
since=since,
origin=origin,
)
agg_rows = {r.ip: r.event_count for r in agg_rows_list}
unique_ips = list(agg_rows.keys())
return agg_rows, total, unique_ips
async def _bans_by_country_resolve_geo(
unique_ips: list[str],
*,
http_session: aiohttp.ClientSession | None,
geo_cache_lookup: GeoCacheLookup | None,
geo_cache: GeoCache | None,
geo_enricher: GeoEnricher | None,
app_db: aiosqlite.Connection | None,
) -> dict[str, GeoInfo]:
"""Resolve geo information for a list of unique IPs.
Uses the geo cache when available; falls back to legacy enricher.
Uncached IPs are scheduled for background resolution to warm the cache.
"""
if not unique_ips:
return {}
geo_map: dict[str, GeoInfo] = {}
if http_session is not None and geo_cache_lookup is not None:
geo_map, uncached = geo_cache_lookup(unique_ips)
if uncached:
log.info(
"ban_service_geo_background_scheduled",
uncached=len(uncached),
cached=len(geo_map),
)
if geo_cache is not None:
asyncio.create_task(
logged_task(
geo_cache.lookup_batch(uncached, http_session, db=app_db),
"geo_bans_by_country",
),
name="geo_bans_by_country",
)
elif geo_enricher is not None:
async def _safe_lookup(ip: str) -> tuple[str, GeoInfo | None]:
try:
return ip, await geo_enricher(ip)
except (TimeoutError, aiohttp.ClientError, OSError):
log.warning("ban_service_geo_lookup_failed", ip=ip)
return ip, None
except Exception as exc:
log.error(
"ban_service_geo_lookup_unexpected_error",
ip=ip,
error=type(exc).__name__,
)
raise
results = await asyncio.gather(*(_safe_lookup(ip) for ip in unique_ips))
geo_map = {ip: geo for ip, geo in results if geo is not None}
return geo_map
async def _bans_by_country_load_companion(
*,
source: str,
country_code: str | None,
geo_map: dict[str, GeoInfo],
since: int,
origin: BanOrigin | None,
db_path: str | None,
app_db: aiosqlite.Connection | None,
history_archive_repo: HistoryArchiveRepository,
) -> tuple[list[dict[str, Any] | fail2ban_db_repo.BanRecord], list[str]]:
"""Load companion ban rows and matched IPs for the given country filter.
Returns:
Tuple of (companion_rows, matched_ips_for_country).
"""
if country_code is None:
if source == "archive":
rows, _ = await history_archive_repo.get_archived_history(
db=app_db,
since=since,
origin=origin,
action="ban",
page=1,
page_size=_MAX_COMPANION_BANS,
)
else:
rows, _ = await fail2ban_db_repo.get_currently_banned(
db_path=db_path,
since=since,
origin=origin,
limit=_MAX_COMPANION_BANS,
offset=0,
)
return rows, []
matched_ips = [
ip
for ip, geo in geo_map.items()
if geo is not None and geo.country_code == country_code
]
if not matched_ips:
return [], matched_ips
if source == "archive":
rows, _ = await history_archive_repo.get_archived_history(
db=app_db,
since=since,
origin=origin,
action="ban",
ip_filter=matched_ips,
page=1,
page_size=_MAX_COMPANION_BANS,
)
else:
rows, _ = await fail2ban_db_repo.get_currently_banned(
db_path=db_path,
since=since,
origin=origin,
ip_filter=matched_ips,
)
return rows, matched_ips
def _bans_by_country_aggregate(
agg_rows: dict[str, int],
geo_map: dict[str, GeoInfo],
source: str,
) -> tuple[dict[str, int], dict[str, str]]:
"""Aggregate ban counts by country code.
Returns:
Tuple of (countries dict mapping cc->count, country_names dict mapping cc->name).
"""
countries: dict[str, int] = {}
country_names: dict[str, str] = {}
for ip, event_count in agg_rows.items():
geo = geo_map.get(ip)
cc: str | None = geo.country_code if geo else None
cn: str | None = geo.country_name if geo else None
if cc:
countries[cc] = countries.get(cc, 0) + event_count
if cn and cc not in country_names:
country_names[cc] = cn
return countries, country_names
def _bans_by_country_build_ban_items(
companion_rows: list[dict[str, Any] | fail2ban_db_repo.BanRecord],
geo_map: dict[str, GeoInfo],
source: str,
) -> list[DomainDashboardBanItem]:
"""Build DomainDashboardBanItem list from raw companion rows."""
bans: list[DomainDashboardBanItem] = []
for companion_row in companion_rows:
if source == "archive":
ip = companion_row["ip"]
jail = companion_row["jail"]
banned_at = ts_to_iso(int(companion_row["timeofban"]))
ban_count = int(companion_row["bancount"])
service = None
else:
ip = companion_row.ip
jail = companion_row.jail
banned_at = ts_to_iso(companion_row.timeofban)
ban_count = companion_row.bancount
matches, _ = parse_data_json(companion_row.data)
service = matches[0] if matches else None
geo = geo_map.get(ip)
cc = geo.country_code if geo else None
cn = geo.country_name if geo else None
asn: str | None = geo.asn if geo else None
org: str | None = geo.org if geo else None
bans.append(
DomainDashboardBanItem(
ip=ip,
jail=jail,
banned_at=banned_at,
service=service,
country_code=cc,
country_name=cn,
asn=asn,
org=org,
ban_count=ban_count,
origin=_derive_origin(jail),
)
)
return bans
async def bans_by_country( async def bans_by_country(
socket_path: str, socket_path: str,
range_: TimeRange, range_: TimeRange,
@@ -569,211 +825,47 @@ async def bans_by_country(
if source not in ("fail2ban", "archive"): if source not in ("fail2ban", "archive"):
raise ValueError(f"Unsupported source: {source!r}") raise ValueError(f"Unsupported source: {source!r}")
if source == "archive": # Step 1: Load per-IP ban counts and total.
if app_db is None: db_path: str | None = None
raise ValueError("app_db must be provided when source is 'archive'") if source == "fail2ban":
db_path = await get_fail2ban_db_path(socket_path)
# SQL aggregation — no row materialisation into Python memory. agg_rows, total, unique_ips = await _bans_by_country_load_data(
ip_counts = await history_archive_repo.get_ip_ban_counts( source=source,
db=app_db, socket_path=socket_path,
since=since, since=since,
origin=origin, origin=origin,
action="ban", history_archive_repo=history_archive_repo,
) app_db=app_db,
)
# Total = sum of all event counts. # Step 2: Resolve geo for unique IPs (from cache or enricher).
total = sum(int(row["event_count"]) for row in ip_counts) geo_map = await _bans_by_country_resolve_geo(
unique_ips,
http_session=http_session,
geo_cache_lookup=geo_cache_lookup,
geo_cache=geo_cache,
geo_enricher=geo_enricher,
app_db=app_db,
)
# {ip: event_count} for downstream geo aggregation. # Step 3: Load companion ban rows (filtered by country if provided).
agg_rows = {row["ip"]: int(row["event_count"]) for row in ip_counts} companion_rows, _ = await _bans_by_country_load_companion(
source=source,
country_code=country_code,
geo_map=geo_map,
since=since,
origin=origin,
db_path=db_path,
app_db=app_db,
history_archive_repo=history_archive_repo,
)
unique_ips = list(agg_rows.keys()) # Step 4: Aggregate counts by country.
else: countries, country_names = _bans_by_country_aggregate(agg_rows, geo_map, source)
origin_clause, origin_params = _origin_sql_filter(origin)
db_path: str = await get_fail2ban_db_path(socket_path)
log.info(
"ban_service_bans_by_country",
db_path=db_path,
since=since,
range=range_,
origin=origin,
)
# Total count and companion rows reuse the same SQL query logic. # Step 5: Build companion ban items for the response.
# Passing limit=0 returns only the total from the count query. bans = _bans_by_country_build_ban_items(companion_rows, geo_map, source)
_, total = await fail2ban_db_repo.get_currently_banned(
db_path=db_path,
since=since,
origin=origin,
limit=0,
offset=0,
)
agg_rows = await fail2ban_db_repo.get_ban_event_counts(
db_path=db_path,
since=since,
origin=origin,
)
unique_ips = [r.ip for r in agg_rows]
geo_map: dict[str, GeoInfo] = {}
if http_session is not None and unique_ips and geo_cache_lookup is not None:
# Serve only what is already in the in-memory cache — no API calls on
# the hot path. Uncached IPs are resolved asynchronously in the
# background so subsequent requests benefit from a warmer cache.
geo_map, uncached = geo_cache_lookup(unique_ips)
if uncached:
log.info(
"ban_service_geo_background_scheduled",
uncached=len(uncached),
cached=len(geo_map),
)
if geo_cache is not None:
# Fire-and-forget: lookup_batch handles rate-limiting / retries.
# The dirty-set flush task persists results to the DB.
asyncio.create_task(
logged_task(
geo_cache.lookup_batch(uncached, http_session, db=app_db),
"geo_bans_by_country",
),
name="geo_bans_by_country",
)
elif geo_enricher is not None and unique_ips:
# Fallback: legacy per-IP enricher (used in tests / older callers).
async def _safe_lookup(ip: str) -> tuple[str, GeoInfo | None]:
try:
return ip, await geo_enricher(ip)
except (TimeoutError, aiohttp.ClientError, OSError):
log.warning("ban_service_geo_lookup_failed", ip=ip)
return ip, None
except Exception as exc:
log.error("ban_service_geo_lookup_unexpected_error", ip=ip, error=type(exc).__name__)
raise # Bubble programming errors to global handler
results = await asyncio.gather(*(_safe_lookup(ip) for ip in unique_ips))
geo_map = {ip: geo for ip, geo in results if geo is not None}
companion_rows: list[dict[str, Any] | fail2ban_db_repo.BanRecord]
if country_code is None:
if source == "archive":
companion_rows, _ = await history_archive_repo.get_archived_history(
db=app_db,
since=since,
origin=origin,
action="ban",
page=1,
page_size=_MAX_COMPANION_BANS,
)
else:
companion_rows, _ = await fail2ban_db_repo.get_currently_banned(
db_path=db_path,
since=since,
origin=origin,
limit=_MAX_COMPANION_BANS,
offset=0,
)
else:
matched_ips = [
ip
for ip, geo in geo_map.items()
if geo is not None and geo.country_code == country_code
]
if source == "archive":
if matched_ips:
# Use keyset pagination instead of loading all matched IPs at once.
companion_rows, _ = await history_archive_repo.get_archived_history(
db=app_db,
since=since,
origin=origin,
action="ban",
ip_filter=matched_ips,
page=1,
page_size=_MAX_COMPANION_BANS,
)
else:
companion_rows = []
else:
if matched_ips:
companion_rows, _ = await fail2ban_db_repo.get_currently_banned(
db_path=db_path,
since=since,
origin=origin,
ip_filter=matched_ips,
)
else:
companion_rows = []
# Build country aggregation from the SQL-grouped rows.
countries: dict[str, int] = {}
country_names: dict[str, str] = {}
if source == "archive":
agg_items = [
{
"ip": ip,
"event_count": count,
}
for ip, count in agg_rows.items()
]
else:
agg_items = agg_rows
for agg_row in agg_items:
if source == "archive":
ip = agg_row["ip"]
event_count = agg_row["event_count"]
else:
ip = agg_row.ip
event_count = agg_row.event_count
geo = geo_map.get(ip)
cc: str | None = geo.country_code if geo else None
cn: str | None = geo.country_name if geo else None
if cc:
countries[cc] = countries.get(cc, 0) + event_count
if cn and cc not in country_names:
country_names[cc] = cn
# Build companion table from recent rows (geo already cached from batch step).
bans: list[DomainDashboardBanItem] = []
for companion_row in companion_rows:
if source == "archive":
ip = companion_row["ip"]
jail = companion_row["jail"]
banned_at = ts_to_iso(int(companion_row["timeofban"]))
ban_count = int(companion_row["bancount"])
service = None
else:
ip = companion_row.ip
jail = companion_row.jail
banned_at = ts_to_iso(companion_row.timeofban)
ban_count = companion_row.bancount
matches, _ = parse_data_json(companion_row.data)
service = matches[0] if matches else None
geo = geo_map.get(ip)
cc = geo.country_code if geo else None
cn = geo.country_name if geo else None
asn: str | None = geo.asn if geo else None
org: str | None = geo.org if geo else None
bans.append(
DomainDashboardBanItem(
ip=ip,
jail=jail,
banned_at=banned_at,
service=service,
country_code=cc,
country_name=cn,
asn=asn,
org=org,
ban_count=ban_count,
origin=_derive_origin(jail),
)
)
return DomainBansByCountry( return DomainBansByCountry(
countries=countries, countries=countries,

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio import asyncio
import configparser import configparser
import contextlib
import io import io
import os import os
import re import re