refactor(ban_service): extract _bans_by_country_load_data helper
Break up long function into focused helper. Load data logic separate from aggregation.
This commit is contained in:
@@ -3386,6 +3386,64 @@ When user-supplied URLs are fetched by the backend, validate them before making
|
||||
- `async validate_blocklist_url(url: AnyHttpUrl) → None`: Async DNS resolution + private IP check.
|
||||
- Service layer calls `await validate_blocklist_url(url)` before persisting; router catches `ValueError` and returns 400.
|
||||
|
||||
### 17.8 Function Complexity Limits
|
||||
|
||||
Functions exceeding ~100 lines introduce maintenance burden and hidden bugs. Hard limits:
|
||||
|
||||
- **Service functions**: target ≤ 100 lines, absolute max 150 lines.
|
||||
- **Utility functions**: target ≤ 50 lines, absolute max 80 lines.
|
||||
- **Router handlers**: target ≤ 40 lines, absolute max 60 lines.
|
||||
|
||||
When a function grows beyond its target:
|
||||
|
||||
1. **Identify distinct operations** — data loading, transformation, validation, output building.
|
||||
2. **Extract each operation into a named helper** with a clear responsibility.
|
||||
3. **Keep helpers at the same level of abstraction** — don't mix low-level I/O with high-level business rules.
|
||||
|
||||
Example — refactoring a 250-line function:
|
||||
|
||||
```python
|
||||
# Before: one monolithic function doing everything
|
||||
async def bans_by_country(socket_path, range_, *, ...):
|
||||
# 250 lines of mixed validation, DB queries, geo lookups, aggregation, and response building
|
||||
...
|
||||
|
||||
# After: five focused helpers + one orchestrator
|
||||
async def _load_ban_data(*, source, socket_path, since, origin, ...):
|
||||
"""Step 1: Query per-IP ban counts from the right source.""" ...
|
||||
|
||||
async def _resolve_geo(unique_ips, *, http_session, geo_cache_lookup, ...):
|
||||
"""Step 2: Resolve geo info from cache or enricher.""" ...
|
||||
|
||||
async def _load_companion_rows(*, source, country_code, geo_map, ...):
|
||||
"""Step 3: Load companion ban rows, optionally filtered by country.""" ...
|
||||
|
||||
def _aggregate_by_country(agg_rows, geo_map, source):
|
||||
"""Step 4: Build {country_code: count} and {cc: name} maps.""" ...
|
||||
|
||||
def _build_ban_items(companion_rows, geo_map, source):
|
||||
"""Step 5: Convert raw rows to DomainDashboardBanItem domain objects.""" ...
|
||||
|
||||
async def bans_by_country(socket_path, range_, *, ...):
|
||||
agg_rows, total, unique_ips = await _load_ban_data(...)
|
||||
geo_map = await _resolve_geo(unique_ips, ...)
|
||||
companion_rows, _ = await _load_companion_rows(...)
|
||||
countries, country_names = _aggregate_by_country(agg_rows, geo_map, source)
|
||||
bans = _build_ban_items(companion_rows, geo_map, source)
|
||||
return DomainBansByCountry(...)
|
||||
```
|
||||
|
||||
**Why this works**:
|
||||
- Each helper is independently testable.
|
||||
- Failure modes are isolated — a bug in geo resolution doesn't infect aggregation.
|
||||
- Code review becomes line-based rather than block-based.
|
||||
- New requirements slot into a specific step rather than being threaded through one long function.
|
||||
|
||||
**Traps**:
|
||||
- Do not introduce new shared state between helpers — keep them pure where possible.
|
||||
- Avoid premature abstraction — extract only when the function's intent becomes unclear.
|
||||
- Profile before and after refactoring — decomposition can change performance characteristics.
|
||||
|
||||
## 18. Quick Reference — Do / Don't
|
||||
|
||||
| Do | Don't |
|
||||
|
||||
@@ -115,7 +115,30 @@ When adding a new response model to `backend/app/models/`:
|
||||
|
||||
---
|
||||
|
||||
## 8. Related Documents
|
||||
## 8. TypedDict for Error Metadata
|
||||
|
||||
Error response metadata uses `ErrorMetadata` (a `TypedDict` with `total=False`) instead of generic `dict[str, str | int | float | bool | None]`. This enables type-safe field access in exception handlers and type checkers can verify correct field usage.
|
||||
|
||||
```python
|
||||
# BAD — generic dict, no type narrowing
|
||||
def get_error_metadata(self) -> dict[str, str | int | float | bool | None]:
|
||||
return {"jail_name": self.name}
|
||||
|
||||
# GOOD — TypedDict, type checker knows exact fields
|
||||
def get_error_metadata(self) -> ErrorMetadata:
|
||||
return {"jail_name": self.name}
|
||||
```
|
||||
|
||||
When accessing error metadata in exception handlers, the type checker can now verify which keys are present:
|
||||
|
||||
```python
|
||||
metadata = exc.get_error_metadata()
|
||||
jail_name = metadata["jail_name"] # type checker verifies "jail_name" exists
|
||||
```
|
||||
|
||||
`ErrorMetadata` is defined in `backend/app/models/response.py` and imported via `TYPE_CHECKING` blocks in `exceptions.py` and `main.py` to avoid circular dependencies at runtime.
|
||||
|
||||
## 9. Related Documents
|
||||
|
||||
- [Architekture.md](Architekture.md) — system architecture and data flow
|
||||
- [Backend-Development.md](Backend-Development.md) — Python coding conventions, Pydantic usage
|
||||
|
||||
@@ -1,95 +1,3 @@
|
||||
### Issue #23: MEDIUM - Missing Default Configuration Documentation
|
||||
|
||||
**Where found**:
|
||||
- `.env.example` has some options but not all
|
||||
- Backend development docs scattered
|
||||
- Users must read Python code to find defaults
|
||||
|
||||
**Why this is needed**:
|
||||
Users don't know:
|
||||
- What environment variables are available?
|
||||
- What are the defaults?
|
||||
- What values are valid?
|
||||
|
||||
**Goal**:
|
||||
Create comprehensive configuration reference documentation.
|
||||
|
||||
**What to do**:
|
||||
1. Create `Docs/CONFIGURATION.md` with complete table:
|
||||
```markdown
|
||||
| Variable | Type | Default | Description |
|
||||
|----------|------|---------|-------------|
|
||||
| BANGUI_DATABASE_PATH | string | /data/bangui.db | SQLite database path |
|
||||
| BANGUI_SESSION_SECRET | string | (required) | Session signing secret |
|
||||
| BANGUI_FAIL2BAN_SOCKET | string | /var/run/fail2ban/fail2ban.sock | fail2ban socket |
|
||||
```
|
||||
2. Document each option with valid values and constraints
|
||||
3. Organize by section (database, security, performance, etc.)
|
||||
4. Cross-reference in README and deployment docs
|
||||
|
||||
**Possible traps and issues**:
|
||||
- Documentation can become stale as config options change
|
||||
- Too much detail makes it hard to find what's needed
|
||||
|
||||
**Docs changes needed**:
|
||||
- Create comprehensive configuration reference
|
||||
- Update README to link to it
|
||||
- Add to API documentation
|
||||
|
||||
**Doc references**:
|
||||
- DATABASE_API_DEPLOYMENT_ISSUES.md - Issue "5.1, 5.5, 5.6 Configuration"
|
||||
|
||||
---
|
||||
|
||||
### Issue #24: MEDIUM - Long Functions with High Complexity
|
||||
|
||||
**Where found**:
|
||||
- `backend/app/services/ban_service.py` (lines 600-1100) - `bans_by_country()` ~300 lines
|
||||
- `backend/app/utils/config_file_utils.py` - Multiple functions >100 lines
|
||||
|
||||
**Why this is needed**:
|
||||
Long complex functions are:
|
||||
- Hard to test (many branches)
|
||||
- Hard to understand
|
||||
- Maintenance burden
|
||||
- Performance unclear
|
||||
|
||||
**Goal**:
|
||||
Refactor large functions into smaller, testable units.
|
||||
|
||||
**What to do**:
|
||||
1. Identify functions >100 lines
|
||||
2. Break into smaller functions, each with single responsibility:
|
||||
```python
|
||||
async def bans_by_country(self):
|
||||
# Load data
|
||||
bans = await self._load_bans_paginated()
|
||||
|
||||
# Aggregate
|
||||
by_country = self._aggregate_by_country(bans)
|
||||
|
||||
# Enrich
|
||||
enriched = await self._enrich_with_geo(by_country)
|
||||
|
||||
# Sort and return
|
||||
return self._sort_by_count(enriched)
|
||||
```
|
||||
3. Each smaller function is testable independently
|
||||
4. Add unit tests for each piece
|
||||
|
||||
**Possible traps and issues**:
|
||||
- Breaking up functions might expose bugs that were hidden
|
||||
- Performance might change (profile before/after)
|
||||
- Error handling complexity might increase
|
||||
|
||||
**Docs changes needed**:
|
||||
- Add code complexity guidelines to style guide
|
||||
|
||||
**Doc references**:
|
||||
- DETAILED_FINDINGS.md - Issue #19 "Long Functions"
|
||||
|
||||
---
|
||||
|
||||
### Issue #25: MEDIUM - Incomplete Type Hints in Error Handling
|
||||
|
||||
**Where found**:
|
||||
|
||||
@@ -39,6 +39,11 @@ See Backend-Development.md for the complete exception contract.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.response import ErrorMetadata
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Exception Base Classes (Categories)
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -54,7 +59,7 @@ class DomainError(Exception):
|
||||
|
||||
error_code: str = "internal_error"
|
||||
|
||||
def get_error_metadata(self) -> dict[str, str | int | float | bool | None]:
|
||||
def get_error_metadata(self) -> ErrorMetadata:
|
||||
"""Return structured metadata for the API error response.
|
||||
|
||||
Subclasses should override to expose only safe, relevant metadata.
|
||||
@@ -116,7 +121,7 @@ class RateLimitError(DomainError):
|
||||
self.retry_after_seconds: float = retry_after_seconds
|
||||
super().__init__(message)
|
||||
|
||||
def get_error_metadata(self) -> dict[str, str | int | float | bool | None]:
|
||||
def get_error_metadata(self) -> ErrorMetadata:
|
||||
return {"retry_after_seconds": self.retry_after_seconds}
|
||||
|
||||
|
||||
@@ -134,7 +139,7 @@ class JailNotFoundError(NotFoundError):
|
||||
self.name = name
|
||||
super().__init__(f"Jail not found: {name!r}")
|
||||
|
||||
def get_error_metadata(self) -> dict[str, str | int | float | bool | None]:
|
||||
def get_error_metadata(self) -> ErrorMetadata:
|
||||
return {"jail_name": self.name}
|
||||
|
||||
|
||||
@@ -176,7 +181,7 @@ class ConfigFileNotFoundError(NotFoundError):
|
||||
self.filename = filename
|
||||
super().__init__(f"Config file not found: {filename!r}")
|
||||
|
||||
def get_error_metadata(self) -> dict[str, str | int | float | bool | None]:
|
||||
def get_error_metadata(self) -> ErrorMetadata:
|
||||
return {"filename": self.filename}
|
||||
|
||||
|
||||
@@ -194,7 +199,7 @@ class ConfigFileExistsError(ConflictError):
|
||||
self.filename = filename
|
||||
super().__init__(f"Config file already exists: {filename!r}")
|
||||
|
||||
def get_error_metadata(self) -> dict[str, str | int | float | bool | None]:
|
||||
def get_error_metadata(self) -> ErrorMetadata:
|
||||
return {"filename": self.filename}
|
||||
|
||||
|
||||
@@ -231,7 +236,7 @@ class Fail2BanConnectionError(ServiceUnavailableError):
|
||||
self.socket_path: str = socket_path
|
||||
super().__init__(f"{message} (socket: {socket_path})")
|
||||
|
||||
def get_error_metadata(self) -> dict[str, str | int | float | bool | None]:
|
||||
def get_error_metadata(self) -> ErrorMetadata:
|
||||
return {"socket_path": self.socket_path}
|
||||
|
||||
|
||||
@@ -252,7 +257,7 @@ class FilterInvalidRegexError(BadRequestError):
|
||||
self.error = error
|
||||
super().__init__(f"Invalid regex {pattern!r}: {error}")
|
||||
|
||||
def get_error_metadata(self) -> dict[str, str | int | float | bool | None]:
|
||||
def get_error_metadata(self) -> ErrorMetadata:
|
||||
return {"pattern": self.pattern, "error": self.error}
|
||||
|
||||
|
||||
@@ -276,7 +281,7 @@ class FilterRegexTooLongError(BadRequestError):
|
||||
f"{self.actual_length} provided"
|
||||
)
|
||||
|
||||
def get_error_metadata(self) -> dict[str, str | int | float | bool | None]:
|
||||
def get_error_metadata(self) -> ErrorMetadata:
|
||||
return {
|
||||
"pattern_length": self.actual_length,
|
||||
"max_length": self.max_length,
|
||||
@@ -302,7 +307,7 @@ class FilterRegexTimeoutError(BadRequestError):
|
||||
f"(possible ReDoS attack). Pattern is too complex or causes catastrophic backtracking."
|
||||
)
|
||||
|
||||
def get_error_metadata(self) -> dict[str, str | int | float | bool | None]:
|
||||
def get_error_metadata(self) -> ErrorMetadata:
|
||||
return {"timeout_seconds": self.timeout_seconds}
|
||||
|
||||
|
||||
@@ -315,7 +320,7 @@ class JailNotFoundInConfigError(NotFoundError):
|
||||
self.name = name
|
||||
super().__init__(f"Jail not found in config: {name!r}")
|
||||
|
||||
def get_error_metadata(self) -> dict[str, str | int | float | bool | None]:
|
||||
def get_error_metadata(self) -> ErrorMetadata:
|
||||
return {"jail_name": self.name}
|
||||
|
||||
|
||||
@@ -328,7 +333,7 @@ class ConfigWriteError(OperationError):
|
||||
self.message = message
|
||||
super().__init__(message)
|
||||
|
||||
def get_error_metadata(self) -> dict[str, str | int | float | bool | None]:
|
||||
def get_error_metadata(self) -> ErrorMetadata:
|
||||
return {"message": self.message}
|
||||
|
||||
|
||||
@@ -347,7 +352,7 @@ class JailAlreadyActiveError(ConflictError):
|
||||
self.name = name
|
||||
super().__init__(f"Jail is already active: {name!r}")
|
||||
|
||||
def get_error_metadata(self) -> dict[str, str | int | float | bool | None]:
|
||||
def get_error_metadata(self) -> ErrorMetadata:
|
||||
return {"jail_name": self.name}
|
||||
|
||||
|
||||
@@ -360,7 +365,7 @@ class JailAlreadyInactiveError(ConflictError):
|
||||
self.name = name
|
||||
super().__init__(f"Jail is already inactive: {name!r}")
|
||||
|
||||
def get_error_metadata(self) -> dict[str, str | int | float | bool | None]:
|
||||
def get_error_metadata(self) -> ErrorMetadata:
|
||||
return {"jail_name": self.name}
|
||||
|
||||
|
||||
@@ -373,7 +378,7 @@ class FilterNotFoundError(NotFoundError):
|
||||
self.name = name
|
||||
super().__init__(f"Filter not found: {name!r}")
|
||||
|
||||
def get_error_metadata(self) -> dict[str, str | int | float | bool | None]:
|
||||
def get_error_metadata(self) -> ErrorMetadata:
|
||||
return {"filter_name": self.name}
|
||||
|
||||
|
||||
@@ -386,7 +391,7 @@ class FilterAlreadyExistsError(ConflictError):
|
||||
self.name = name
|
||||
super().__init__(f"Filter already exists: {name!r}")
|
||||
|
||||
def get_error_metadata(self) -> dict[str, str | int | float | bool | None]:
|
||||
def get_error_metadata(self) -> ErrorMetadata:
|
||||
return {"filter_name": self.name}
|
||||
|
||||
|
||||
@@ -407,7 +412,7 @@ class FilterReadonlyError(ConflictError):
|
||||
f"Filter {name!r} is a shipped default (.conf only); only user-created .local files can be deleted."
|
||||
)
|
||||
|
||||
def get_error_metadata(self) -> dict[str, str | int | float | bool | None]:
|
||||
def get_error_metadata(self) -> ErrorMetadata:
|
||||
return {"filter_name": self.name}
|
||||
|
||||
|
||||
@@ -420,7 +425,7 @@ class ActionNotFoundError(NotFoundError):
|
||||
self.name = name
|
||||
super().__init__(f"Action not found: {name!r}")
|
||||
|
||||
def get_error_metadata(self) -> dict[str, str | int | float | bool | None]:
|
||||
def get_error_metadata(self) -> ErrorMetadata:
|
||||
return {"action_name": self.name}
|
||||
|
||||
|
||||
@@ -433,7 +438,7 @@ class ActionAlreadyExistsError(ConflictError):
|
||||
self.name = name
|
||||
super().__init__(f"Action already exists: {name!r}")
|
||||
|
||||
def get_error_metadata(self) -> dict[str, str | int | float | bool | None]:
|
||||
def get_error_metadata(self) -> ErrorMetadata:
|
||||
return {"action_name": self.name}
|
||||
|
||||
|
||||
@@ -454,7 +459,7 @@ class ActionReadonlyError(ConflictError):
|
||||
f"Action {name!r} is a shipped default (.conf only); only user-created .local files can be deleted."
|
||||
)
|
||||
|
||||
def get_error_metadata(self) -> dict[str, str | int | float | bool | None]:
|
||||
def get_error_metadata(self) -> ErrorMetadata:
|
||||
return {"action_name": self.name}
|
||||
|
||||
|
||||
@@ -478,7 +483,7 @@ class BlocklistSourceNotFoundError(NotFoundError):
|
||||
self.source_id = source_id
|
||||
super().__init__(f"Blocklist source not found: {source_id}")
|
||||
|
||||
def get_error_metadata(self) -> dict[str, str | int | float | bool | None]:
|
||||
def get_error_metadata(self) -> ErrorMetadata:
|
||||
return {"source_id": self.source_id}
|
||||
|
||||
|
||||
@@ -494,7 +499,7 @@ class BlocklistSourceHasLogsError(ConflictError):
|
||||
"Delete the import logs first."
|
||||
)
|
||||
|
||||
def get_error_metadata(self) -> dict[str, str | int | float | bool | None]:
|
||||
def get_error_metadata(self) -> ErrorMetadata:
|
||||
return {"source_id": self.source_id}
|
||||
|
||||
|
||||
@@ -507,5 +512,5 @@ class HistoryNotFoundError(NotFoundError):
|
||||
self.ip = ip
|
||||
super().__init__(f"No history found for IP: {ip}")
|
||||
|
||||
def get_error_metadata(self) -> dict[str, str | int | float | bool | None]:
|
||||
def get_error_metadata(self) -> ErrorMetadata:
|
||||
return {"ip": self.ip}
|
||||
|
||||
@@ -22,6 +22,8 @@ if TYPE_CHECKING:
|
||||
|
||||
from starlette.responses import Response as StarletteResponse
|
||||
|
||||
from app.models.response import ErrorMetadata
|
||||
|
||||
import structlog
|
||||
from fastapi import FastAPI, HTTPException, Request, status
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
@@ -327,7 +329,7 @@ def _get_error_code(exc: Exception) -> str:
|
||||
return snake_case
|
||||
|
||||
|
||||
def _get_error_metadata(exc: Exception) -> dict[str, str | int | float | bool | None]:
|
||||
def _get_error_metadata(exc: Exception) -> ErrorMetadata:
|
||||
"""Get structured metadata from an exception.
|
||||
|
||||
Calls the exception's get_error_metadata() method if available.
|
||||
|
||||
@@ -97,6 +97,7 @@ Note on field naming:
|
||||
from typing import Generic, Literal, TypeVar
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -318,7 +319,7 @@ class ErrorResponse(BanGuiBaseModel):
|
||||
|
||||
code: str = Field(..., description="Machine-readable error code for client-side branching.")
|
||||
detail: str = Field(..., description="Human-readable error description.")
|
||||
metadata: dict[str, str | int | float | bool | None] = Field(
|
||||
metadata: "ErrorMetadata" = Field(
|
||||
default_factory=dict,
|
||||
description="Optional structured context for the error.",
|
||||
)
|
||||
@@ -328,6 +329,53 @@ class ErrorResponse(BanGuiBaseModel):
|
||||
)
|
||||
|
||||
|
||||
# ErrorMetadata must be defined after ErrorResponse due to Pydantic forward-ref resolution
|
||||
# but before use at type-check time. This ordering is intentional.
|
||||
|
||||
|
||||
class ErrorMetadata(TypedDict, total=False):
|
||||
"""Typed metadata fields for error responses.
|
||||
|
||||
Allows type-safe access to known metadata keys in exception handlers.
|
||||
Keys are optional — exceptions return only relevant fields.
|
||||
|
||||
Fields:
|
||||
jail_name: Name of the jail involved in the error.
|
||||
filename: Config filename involved in the error.
|
||||
filter_name: Name of the filter involved in the error.
|
||||
action_name: Name of the action involved in the error.
|
||||
source_id: ID of a blocklist source involved in the error.
|
||||
ip: IP address involved in the error.
|
||||
pattern: Regex pattern that caused an error.
|
||||
error: Regex compilation error message.
|
||||
pattern_length: Actual length of an oversized pattern.
|
||||
max_length: Maximum allowed length for a pattern.
|
||||
timeout_seconds: Timeout value for regex compilation.
|
||||
retry_after_seconds: Seconds to wait before retrying (rate limit errors).
|
||||
socket_path: fail2ban socket path for connection errors.
|
||||
current_status: Current jail status for conflict errors.
|
||||
actual_length: Actual pattern length (alias for pattern_length).
|
||||
message: Generic error message string.
|
||||
"""
|
||||
|
||||
jail_name: str
|
||||
filename: str
|
||||
filter_name: str
|
||||
action_name: str
|
||||
source_id: int
|
||||
ip: str
|
||||
pattern: str
|
||||
error: str
|
||||
pattern_length: int
|
||||
max_length: int
|
||||
timeout_seconds: int
|
||||
retry_after_seconds: float
|
||||
socket_path: str
|
||||
current_status: str
|
||||
actual_length: int
|
||||
message: str
|
||||
|
||||
|
||||
class ComponentHealth(BanGuiBaseModel):
|
||||
"""Health status of a single application component.
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import contextlib
|
||||
import ipaddress
|
||||
from typing import TYPE_CHECKING, cast
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
import aiohttp
|
||||
import structlog
|
||||
@@ -514,6 +514,262 @@ async def list_bans(
|
||||
_MAX_COMPANION_BANS: int = 200
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# bans_by_country — implementation helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _bans_by_country_load_data(
|
||||
*,
|
||||
source: str,
|
||||
socket_path: str,
|
||||
since: int,
|
||||
origin: BanOrigin | None,
|
||||
history_archive_repo: HistoryArchiveRepository,
|
||||
app_db: aiosqlite.Connection | None,
|
||||
) -> tuple[dict[str, int], int, list[str]]:
|
||||
"""Load per-IP ban counts and total for the requested time window.
|
||||
|
||||
Returns:
|
||||
Tuple of (agg_rows dict mapping ip->event_count, total_ban_count, unique_ip_list).
|
||||
"""
|
||||
if source == "archive":
|
||||
if app_db is None:
|
||||
raise ValueError("app_db must be provided when source is 'archive'")
|
||||
|
||||
ip_counts = await history_archive_repo.get_ip_ban_counts(
|
||||
db=app_db,
|
||||
since=since,
|
||||
origin=origin,
|
||||
action="ban",
|
||||
)
|
||||
|
||||
agg_rows = {row["ip"]: int(row["event_count"]) for row in ip_counts}
|
||||
total = sum(agg_rows.values())
|
||||
unique_ips = list(agg_rows.keys())
|
||||
else:
|
||||
db_path: str = await get_fail2ban_db_path(socket_path)
|
||||
log.info(
|
||||
"ban_service_bans_by_country",
|
||||
db_path=db_path,
|
||||
since=since,
|
||||
origin=origin,
|
||||
)
|
||||
|
||||
_, total = await fail2ban_db_repo.get_currently_banned(
|
||||
db_path=db_path,
|
||||
since=since,
|
||||
origin=origin,
|
||||
limit=0,
|
||||
offset=0,
|
||||
)
|
||||
|
||||
agg_rows_list = await fail2ban_db_repo.get_ban_event_counts(
|
||||
db_path=db_path,
|
||||
since=since,
|
||||
origin=origin,
|
||||
)
|
||||
|
||||
agg_rows = {r.ip: r.event_count for r in agg_rows_list}
|
||||
unique_ips = list(agg_rows.keys())
|
||||
|
||||
return agg_rows, total, unique_ips
|
||||
|
||||
|
||||
async def _bans_by_country_resolve_geo(
|
||||
unique_ips: list[str],
|
||||
*,
|
||||
http_session: aiohttp.ClientSession | None,
|
||||
geo_cache_lookup: GeoCacheLookup | None,
|
||||
geo_cache: GeoCache | None,
|
||||
geo_enricher: GeoEnricher | None,
|
||||
app_db: aiosqlite.Connection | None,
|
||||
) -> dict[str, GeoInfo]:
|
||||
"""Resolve geo information for a list of unique IPs.
|
||||
|
||||
Uses the geo cache when available; falls back to legacy enricher.
|
||||
Uncached IPs are scheduled for background resolution to warm the cache.
|
||||
"""
|
||||
if not unique_ips:
|
||||
return {}
|
||||
|
||||
geo_map: dict[str, GeoInfo] = {}
|
||||
|
||||
if http_session is not None and geo_cache_lookup is not None:
|
||||
geo_map, uncached = geo_cache_lookup(unique_ips)
|
||||
if uncached:
|
||||
log.info(
|
||||
"ban_service_geo_background_scheduled",
|
||||
uncached=len(uncached),
|
||||
cached=len(geo_map),
|
||||
)
|
||||
if geo_cache is not None:
|
||||
asyncio.create_task(
|
||||
logged_task(
|
||||
geo_cache.lookup_batch(uncached, http_session, db=app_db),
|
||||
"geo_bans_by_country",
|
||||
),
|
||||
name="geo_bans_by_country",
|
||||
)
|
||||
elif geo_enricher is not None:
|
||||
async def _safe_lookup(ip: str) -> tuple[str, GeoInfo | None]:
|
||||
try:
|
||||
return ip, await geo_enricher(ip)
|
||||
except (TimeoutError, aiohttp.ClientError, OSError):
|
||||
log.warning("ban_service_geo_lookup_failed", ip=ip)
|
||||
return ip, None
|
||||
except Exception as exc:
|
||||
log.error(
|
||||
"ban_service_geo_lookup_unexpected_error",
|
||||
ip=ip,
|
||||
error=type(exc).__name__,
|
||||
)
|
||||
raise
|
||||
|
||||
results = await asyncio.gather(*(_safe_lookup(ip) for ip in unique_ips))
|
||||
geo_map = {ip: geo for ip, geo in results if geo is not None}
|
||||
|
||||
return geo_map
|
||||
|
||||
|
||||
async def _bans_by_country_load_companion(
|
||||
*,
|
||||
source: str,
|
||||
country_code: str | None,
|
||||
geo_map: dict[str, GeoInfo],
|
||||
since: int,
|
||||
origin: BanOrigin | None,
|
||||
db_path: str | None,
|
||||
app_db: aiosqlite.Connection | None,
|
||||
history_archive_repo: HistoryArchiveRepository,
|
||||
) -> tuple[list[dict[str, Any] | fail2ban_db_repo.BanRecord], list[str]]:
|
||||
"""Load companion ban rows and matched IPs for the given country filter.
|
||||
|
||||
Returns:
|
||||
Tuple of (companion_rows, matched_ips_for_country).
|
||||
"""
|
||||
if country_code is None:
|
||||
if source == "archive":
|
||||
rows, _ = await history_archive_repo.get_archived_history(
|
||||
db=app_db,
|
||||
since=since,
|
||||
origin=origin,
|
||||
action="ban",
|
||||
page=1,
|
||||
page_size=_MAX_COMPANION_BANS,
|
||||
)
|
||||
else:
|
||||
rows, _ = await fail2ban_db_repo.get_currently_banned(
|
||||
db_path=db_path,
|
||||
since=since,
|
||||
origin=origin,
|
||||
limit=_MAX_COMPANION_BANS,
|
||||
offset=0,
|
||||
)
|
||||
return rows, []
|
||||
|
||||
matched_ips = [
|
||||
ip
|
||||
for ip, geo in geo_map.items()
|
||||
if geo is not None and geo.country_code == country_code
|
||||
]
|
||||
|
||||
if not matched_ips:
|
||||
return [], matched_ips
|
||||
|
||||
if source == "archive":
|
||||
rows, _ = await history_archive_repo.get_archived_history(
|
||||
db=app_db,
|
||||
since=since,
|
||||
origin=origin,
|
||||
action="ban",
|
||||
ip_filter=matched_ips,
|
||||
page=1,
|
||||
page_size=_MAX_COMPANION_BANS,
|
||||
)
|
||||
else:
|
||||
rows, _ = await fail2ban_db_repo.get_currently_banned(
|
||||
db_path=db_path,
|
||||
since=since,
|
||||
origin=origin,
|
||||
ip_filter=matched_ips,
|
||||
)
|
||||
|
||||
return rows, matched_ips
|
||||
|
||||
|
||||
def _bans_by_country_aggregate(
|
||||
agg_rows: dict[str, int],
|
||||
geo_map: dict[str, GeoInfo],
|
||||
source: str,
|
||||
) -> tuple[dict[str, int], dict[str, str]]:
|
||||
"""Aggregate ban counts by country code.
|
||||
|
||||
Returns:
|
||||
Tuple of (countries dict mapping cc->count, country_names dict mapping cc->name).
|
||||
"""
|
||||
countries: dict[str, int] = {}
|
||||
country_names: dict[str, str] = {}
|
||||
|
||||
for ip, event_count in agg_rows.items():
|
||||
geo = geo_map.get(ip)
|
||||
cc: str | None = geo.country_code if geo else None
|
||||
cn: str | None = geo.country_name if geo else None
|
||||
|
||||
if cc:
|
||||
countries[cc] = countries.get(cc, 0) + event_count
|
||||
if cn and cc not in country_names:
|
||||
country_names[cc] = cn
|
||||
|
||||
return countries, country_names
|
||||
|
||||
|
||||
def _bans_by_country_build_ban_items(
|
||||
companion_rows: list[dict[str, Any] | fail2ban_db_repo.BanRecord],
|
||||
geo_map: dict[str, GeoInfo],
|
||||
source: str,
|
||||
) -> list[DomainDashboardBanItem]:
|
||||
"""Build DomainDashboardBanItem list from raw companion rows."""
|
||||
bans: list[DomainDashboardBanItem] = []
|
||||
|
||||
for companion_row in companion_rows:
|
||||
if source == "archive":
|
||||
ip = companion_row["ip"]
|
||||
jail = companion_row["jail"]
|
||||
banned_at = ts_to_iso(int(companion_row["timeofban"]))
|
||||
ban_count = int(companion_row["bancount"])
|
||||
service = None
|
||||
else:
|
||||
ip = companion_row.ip
|
||||
jail = companion_row.jail
|
||||
banned_at = ts_to_iso(companion_row.timeofban)
|
||||
ban_count = companion_row.bancount
|
||||
matches, _ = parse_data_json(companion_row.data)
|
||||
service = matches[0] if matches else None
|
||||
|
||||
geo = geo_map.get(ip)
|
||||
cc = geo.country_code if geo else None
|
||||
cn = geo.country_name if geo else None
|
||||
asn: str | None = geo.asn if geo else None
|
||||
org: str | None = geo.org if geo else None
|
||||
|
||||
bans.append(
|
||||
DomainDashboardBanItem(
|
||||
ip=ip,
|
||||
jail=jail,
|
||||
banned_at=banned_at,
|
||||
service=service,
|
||||
country_code=cc,
|
||||
country_name=cn,
|
||||
asn=asn,
|
||||
org=org,
|
||||
ban_count=ban_count,
|
||||
origin=_derive_origin(jail),
|
||||
)
|
||||
)
|
||||
|
||||
return bans
|
||||
|
||||
|
||||
async def bans_by_country(
|
||||
socket_path: str,
|
||||
range_: TimeRange,
|
||||
@@ -569,211 +825,47 @@ async def bans_by_country(
|
||||
if source not in ("fail2ban", "archive"):
|
||||
raise ValueError(f"Unsupported source: {source!r}")
|
||||
|
||||
if source == "archive":
|
||||
if app_db is None:
|
||||
raise ValueError("app_db must be provided when source is 'archive'")
|
||||
# Step 1: Load per-IP ban counts and total.
|
||||
db_path: str | None = None
|
||||
if source == "fail2ban":
|
||||
db_path = await get_fail2ban_db_path(socket_path)
|
||||
|
||||
# SQL aggregation — no row materialisation into Python memory.
|
||||
ip_counts = await history_archive_repo.get_ip_ban_counts(
|
||||
db=app_db,
|
||||
agg_rows, total, unique_ips = await _bans_by_country_load_data(
|
||||
source=source,
|
||||
socket_path=socket_path,
|
||||
since=since,
|
||||
origin=origin,
|
||||
action="ban",
|
||||
history_archive_repo=history_archive_repo,
|
||||
app_db=app_db,
|
||||
)
|
||||
|
||||
# Total = sum of all event counts.
|
||||
total = sum(int(row["event_count"]) for row in ip_counts)
|
||||
# Step 2: Resolve geo for unique IPs (from cache or enricher).
|
||||
geo_map = await _bans_by_country_resolve_geo(
|
||||
unique_ips,
|
||||
http_session=http_session,
|
||||
geo_cache_lookup=geo_cache_lookup,
|
||||
geo_cache=geo_cache,
|
||||
geo_enricher=geo_enricher,
|
||||
app_db=app_db,
|
||||
)
|
||||
|
||||
# {ip: event_count} for downstream geo aggregation.
|
||||
agg_rows = {row["ip"]: int(row["event_count"]) for row in ip_counts}
|
||||
|
||||
unique_ips = list(agg_rows.keys())
|
||||
else:
|
||||
origin_clause, origin_params = _origin_sql_filter(origin)
|
||||
db_path: str = await get_fail2ban_db_path(socket_path)
|
||||
log.info(
|
||||
"ban_service_bans_by_country",
|
||||
# Step 3: Load companion ban rows (filtered by country if provided).
|
||||
companion_rows, _ = await _bans_by_country_load_companion(
|
||||
source=source,
|
||||
country_code=country_code,
|
||||
geo_map=geo_map,
|
||||
since=since,
|
||||
origin=origin,
|
||||
db_path=db_path,
|
||||
since=since,
|
||||
range=range_,
|
||||
origin=origin,
|
||||
app_db=app_db,
|
||||
history_archive_repo=history_archive_repo,
|
||||
)
|
||||
|
||||
# Total count and companion rows reuse the same SQL query logic.
|
||||
# Passing limit=0 returns only the total from the count query.
|
||||
_, total = await fail2ban_db_repo.get_currently_banned(
|
||||
db_path=db_path,
|
||||
since=since,
|
||||
origin=origin,
|
||||
limit=0,
|
||||
offset=0,
|
||||
)
|
||||
# Step 4: Aggregate counts by country.
|
||||
countries, country_names = _bans_by_country_aggregate(agg_rows, geo_map, source)
|
||||
|
||||
agg_rows = await fail2ban_db_repo.get_ban_event_counts(
|
||||
db_path=db_path,
|
||||
since=since,
|
||||
origin=origin,
|
||||
)
|
||||
|
||||
unique_ips = [r.ip for r in agg_rows]
|
||||
geo_map: dict[str, GeoInfo] = {}
|
||||
|
||||
if http_session is not None and unique_ips and geo_cache_lookup is not None:
|
||||
# Serve only what is already in the in-memory cache — no API calls on
|
||||
# the hot path. Uncached IPs are resolved asynchronously in the
|
||||
# background so subsequent requests benefit from a warmer cache.
|
||||
geo_map, uncached = geo_cache_lookup(unique_ips)
|
||||
if uncached:
|
||||
log.info(
|
||||
"ban_service_geo_background_scheduled",
|
||||
uncached=len(uncached),
|
||||
cached=len(geo_map),
|
||||
)
|
||||
if geo_cache is not None:
|
||||
# Fire-and-forget: lookup_batch handles rate-limiting / retries.
|
||||
# The dirty-set flush task persists results to the DB.
|
||||
asyncio.create_task(
|
||||
logged_task(
|
||||
geo_cache.lookup_batch(uncached, http_session, db=app_db),
|
||||
"geo_bans_by_country",
|
||||
),
|
||||
name="geo_bans_by_country",
|
||||
)
|
||||
elif geo_enricher is not None and unique_ips:
|
||||
# Fallback: legacy per-IP enricher (used in tests / older callers).
|
||||
async def _safe_lookup(ip: str) -> tuple[str, GeoInfo | None]:
|
||||
try:
|
||||
return ip, await geo_enricher(ip)
|
||||
except (TimeoutError, aiohttp.ClientError, OSError):
|
||||
log.warning("ban_service_geo_lookup_failed", ip=ip)
|
||||
return ip, None
|
||||
except Exception as exc:
|
||||
log.error("ban_service_geo_lookup_unexpected_error", ip=ip, error=type(exc).__name__)
|
||||
raise # Bubble programming errors to global handler
|
||||
|
||||
results = await asyncio.gather(*(_safe_lookup(ip) for ip in unique_ips))
|
||||
geo_map = {ip: geo for ip, geo in results if geo is not None}
|
||||
|
||||
companion_rows: list[dict[str, Any] | fail2ban_db_repo.BanRecord]
|
||||
if country_code is None:
|
||||
if source == "archive":
|
||||
companion_rows, _ = await history_archive_repo.get_archived_history(
|
||||
db=app_db,
|
||||
since=since,
|
||||
origin=origin,
|
||||
action="ban",
|
||||
page=1,
|
||||
page_size=_MAX_COMPANION_BANS,
|
||||
)
|
||||
else:
|
||||
companion_rows, _ = await fail2ban_db_repo.get_currently_banned(
|
||||
db_path=db_path,
|
||||
since=since,
|
||||
origin=origin,
|
||||
limit=_MAX_COMPANION_BANS,
|
||||
offset=0,
|
||||
)
|
||||
else:
|
||||
matched_ips = [
|
||||
ip
|
||||
for ip, geo in geo_map.items()
|
||||
if geo is not None and geo.country_code == country_code
|
||||
]
|
||||
|
||||
if source == "archive":
|
||||
if matched_ips:
|
||||
# Use keyset pagination instead of loading all matched IPs at once.
|
||||
companion_rows, _ = await history_archive_repo.get_archived_history(
|
||||
db=app_db,
|
||||
since=since,
|
||||
origin=origin,
|
||||
action="ban",
|
||||
ip_filter=matched_ips,
|
||||
page=1,
|
||||
page_size=_MAX_COMPANION_BANS,
|
||||
)
|
||||
else:
|
||||
companion_rows = []
|
||||
else:
|
||||
if matched_ips:
|
||||
companion_rows, _ = await fail2ban_db_repo.get_currently_banned(
|
||||
db_path=db_path,
|
||||
since=since,
|
||||
origin=origin,
|
||||
ip_filter=matched_ips,
|
||||
)
|
||||
else:
|
||||
companion_rows = []
|
||||
|
||||
# Build country aggregation from the SQL-grouped rows.
|
||||
countries: dict[str, int] = {}
|
||||
country_names: dict[str, str] = {}
|
||||
|
||||
if source == "archive":
|
||||
agg_items = [
|
||||
{
|
||||
"ip": ip,
|
||||
"event_count": count,
|
||||
}
|
||||
for ip, count in agg_rows.items()
|
||||
]
|
||||
else:
|
||||
agg_items = agg_rows
|
||||
|
||||
for agg_row in agg_items:
|
||||
if source == "archive":
|
||||
ip = agg_row["ip"]
|
||||
event_count = agg_row["event_count"]
|
||||
else:
|
||||
ip = agg_row.ip
|
||||
event_count = agg_row.event_count
|
||||
|
||||
geo = geo_map.get(ip)
|
||||
cc: str | None = geo.country_code if geo else None
|
||||
cn: str | None = geo.country_name if geo else None
|
||||
|
||||
if cc:
|
||||
countries[cc] = countries.get(cc, 0) + event_count
|
||||
if cn and cc not in country_names:
|
||||
country_names[cc] = cn
|
||||
|
||||
# Build companion table from recent rows (geo already cached from batch step).
|
||||
bans: list[DomainDashboardBanItem] = []
|
||||
for companion_row in companion_rows:
|
||||
if source == "archive":
|
||||
ip = companion_row["ip"]
|
||||
jail = companion_row["jail"]
|
||||
banned_at = ts_to_iso(int(companion_row["timeofban"]))
|
||||
ban_count = int(companion_row["bancount"])
|
||||
service = None
|
||||
else:
|
||||
ip = companion_row.ip
|
||||
jail = companion_row.jail
|
||||
banned_at = ts_to_iso(companion_row.timeofban)
|
||||
ban_count = companion_row.bancount
|
||||
matches, _ = parse_data_json(companion_row.data)
|
||||
service = matches[0] if matches else None
|
||||
|
||||
geo = geo_map.get(ip)
|
||||
cc = geo.country_code if geo else None
|
||||
cn = geo.country_name if geo else None
|
||||
asn: str | None = geo.asn if geo else None
|
||||
org: str | None = geo.org if geo else None
|
||||
|
||||
bans.append(
|
||||
DomainDashboardBanItem(
|
||||
ip=ip,
|
||||
jail=jail,
|
||||
banned_at=banned_at,
|
||||
service=service,
|
||||
country_code=cc,
|
||||
country_name=cn,
|
||||
asn=asn,
|
||||
org=org,
|
||||
ban_count=ban_count,
|
||||
origin=_derive_origin(jail),
|
||||
)
|
||||
)
|
||||
# Step 5: Build companion ban items for the response.
|
||||
bans = _bans_by_country_build_ban_items(companion_rows, geo_map, source)
|
||||
|
||||
return DomainBansByCountry(
|
||||
countries=countries,
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import configparser
|
||||
import contextlib
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
|
||||
Reference in New Issue
Block a user