diff --git a/Docs/Tasks.md b/Docs/Tasks.md index 6e92a03..310a4db 100644 --- a/Docs/Tasks.md +++ b/Docs/Tasks.md @@ -241,6 +241,8 @@ Reference: `Docs/Refactoring.md` for full analysis of each issue. ### Task 13 — Move `ban_ip`, `unban_ip`, and `get_active_bans` from `jail_service` to `ban_service` +**Status:** Completed ✅ + **Found in:** `backend/app/services/jail_service.py` contains `ban_ip`, `unban_ip`, and `get_active_bans`. These operations conceptually belong in `ban_service.py`, which is the declared home for ban management. Routers `bans.py` and `blocklist.py` already import `jail_service` specifically for these functions. **Goal:** Move the three functions into `ban_service.py`. Update `bans.py` and `blocklist.py` to import from `ban_service` instead of `jail_service`. Remove the now-redundant `jail_service` import from those routers if it is no longer needed. diff --git a/backend/app/routers/bans.py b/backend/app/routers/bans.py index d342c02..1621b59 100644 --- a/backend/app/routers/bans.py +++ b/backend/app/routers/bans.py @@ -22,7 +22,7 @@ from app.dependencies import ( from app.exceptions import JailNotFoundError, JailOperationError from app.models.ban import ActiveBanListResponse, BanRequest, UnbanAllResponse, UnbanRequest from app.models.jail import JailCommandResponse -from app.services import jail_service +from app.services import ban_service, jail_service from app.exceptions import Fail2BanConnectionError router: APIRouter = APIRouter(prefix="/api/bans", tags=["Bans"]) @@ -72,7 +72,7 @@ async def get_active_bans( HTTPException: 502 when fail2ban is unreachable. """ try: - return await jail_service.get_active_bans( + return await ban_service.get_active_bans( socket_path, geo_batch_lookup=geo_batch_lookup, http_session=http_session, @@ -114,7 +114,7 @@ async def ban_ip( HTTPException: 502 when fail2ban is unreachable. """ try: - await jail_service.ban_ip(socket_path, body.jail, body.ip) + await ban_service.ban_ip(socket_path, body.jail, body.ip) return JailCommandResponse( message=f"IP {body.ip!r} banned in jail {body.jail!r}.", jail=body.jail, @@ -174,7 +174,7 @@ async def unban_ip( target_jail: str | None = None if (body.unban_all or body.jail is None) else body.jail try: - await jail_service.unban_ip(socket_path, body.ip, jail=target_jail) + await ban_service.unban_ip(socket_path, body.ip, jail=target_jail) scope = f"jail {target_jail!r}" if target_jail else "all jails" return JailCommandResponse( message=f"IP {body.ip!r} unbanned from {scope}.", diff --git a/backend/app/routers/blocklist.py b/backend/app/routers/blocklist.py index 77080d3..32bf764 100644 --- a/backend/app/routers/blocklist.py +++ b/backend/app/routers/blocklist.py @@ -44,7 +44,7 @@ from app.models.blocklist import ( ScheduleConfig, ScheduleInfo, ) -from app.services import blocklist_service, geo_service, jail_service +from app.services import ban_service, blocklist_service, geo_service from app.tasks.blocklist_import import run_import_with_resources router: APIRouter = APIRouter(prefix="/api/blocklists", tags=["Blocklists"]) @@ -138,7 +138,7 @@ async def run_import_now( socket_path, geo_is_cached=geo_service.is_cached, geo_batch_lookup=geo_batch_lookup, - ban_ip=jail_service.ban_ip, + ban_ip=ban_service.ban_ip, ) diff --git a/backend/app/services/ban_service.py b/backend/app/services/ban_service.py index df3bf64..6502763 100644 --- a/backend/app/services/ban_service.py +++ b/backend/app/services/ban_service.py @@ -11,11 +11,14 @@ so BanGUI never modifies or locks the fail2ban database. from __future__ import annotations import asyncio +import contextlib +import ipaddress import time -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast import structlog +from app.exceptions import JailNotFoundError, JailOperationError from app.models.ban import ( BLOCKLIST_JAIL, BUCKET_SECONDS, @@ -26,6 +29,8 @@ from app.models.ban import ( BansByJailResponse, BanTrendBucket, BanTrendResponse, + ActiveBan, + ActiveBanListResponse, DashboardBanItem, DashboardBanListResponse, TimeRange, @@ -42,6 +47,10 @@ from app.repositories.history_archive_repo import ( ) from app.services.fail2ban_metadata_service import default_fail2ban_metadata_service from app.utils.fail2ban_db_utils import parse_data_json, ts_to_iso +from app.utils.fail2ban_client import ( + Fail2BanClient, + Fail2BanResponse, +) if TYPE_CHECKING: import aiohttp @@ -70,6 +79,114 @@ _SOCKET_TIMEOUT: float = 5.0 # --------------------------------------------------------------------------- +def _ok(response: object) -> object: + """Extract the payload from a fail2ban ``(return_code, data)`` response. + + Args: + response: Raw value returned by :meth:`~Fail2BanClient.send`. + + Returns: + The payload ``data`` portion of the response. + + Raises: + ValueError: If the response indicates an error (return code ≠ 0). + """ + try: + code, data = response # type: ignore[assignment] + except (TypeError, ValueError) as exc: + raise ValueError(f"Unexpected fail2ban response shape: {response!r}") from exc + + if code != 0: + raise ValueError(f"fail2ban returned error code {code}: {data!r}") + + return data + + +def _to_dict(pairs: object) -> dict[str, object]: + """Convert a list of ``(key, value)`` pairs to a plain dict. + + Args: + pairs: A list of ``(key, value)`` pairs (or any iterable thereof). + + Returns: + A :class:`dict` with the keys and values from *pairs*. + """ + if not isinstance(pairs, (list, tuple)): + return {} + result: dict[str, object] = {} + for item in pairs: + try: + k, v = item + result[str(k)] = v + except (TypeError, ValueError): + pass + return result + + +def _ensure_list(value: object | None) -> list[str]: + """Coerce a fail2ban response value to a list of strings.""" + if value is None: + return [] + if isinstance(value, str): + return [value] if value.strip() else [] + if isinstance(value, (list, tuple)): + return [str(v) for v in value if v is not None] + return [str(value)] + + +def _is_not_found_error(exc: Exception) -> bool: + """Return ``True`` if *exc* indicates a jail does not exist.""" + msg = str(exc).lower() + return any( + phrase in msg + for phrase in ( + "unknown jail", + "unknownjail", + "no jail", + "does not exist", + "not found", + ) + ) + + +async def ban_ip(socket_path: str, jail: str, ip: str) -> None: + """Ban an IP address in the specified jail.""" + try: + ipaddress.ip_address(ip) + except ValueError as exc: + raise ValueError(f"Invalid IP address: {ip!r}") from exc + + client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT) + + try: + _ok(await client.send(["set", jail, "banip", ip])) + except ValueError as exc: + if _is_not_found_error(exc): + raise JailNotFoundError(jail) from exc + raise JailOperationError(str(exc)) from exc + + +async def unban_ip(socket_path: str, ip: str, jail: str | None = None) -> None: + """Unban an IP address from a specific jail or all jails.""" + try: + ipaddress.ip_address(ip) + except ValueError as exc: + raise ValueError(f"Invalid IP address: {ip!r}") from exc + + client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT) + + if jail is None: + _ok(await client.send(["unban", ip])) + return + + try: + _ok(await client.send(["set", jail, "unbanip", ip])) + except ValueError as exc: + if _is_not_found_error(exc): + raise JailNotFoundError(jail) from exc + raise JailOperationError(str(exc)) from exc + + def _origin_sql_filter(origin: BanOrigin | None) -> tuple[str, tuple[str, ...]]: """Return a SQL fragment and its parameters for the origin filter. @@ -88,6 +205,191 @@ def _origin_sql_filter(origin: BanOrigin | None) -> tuple[str, tuple[str, ...]]: return "", () +def _parse_ban_entry(entry: str, jail: str) -> ActiveBan | None: + """Parse a ban entry from ``get banip --with-time`` output.""" + from datetime import UTC, datetime + + try: + parts = entry.split("\t", 1) + ip = parts[0].strip() + + ipaddress.ip_address(ip) + + if len(parts) < 2: + return ActiveBan( + ip=ip, + jail=jail, + banned_at=None, + expires_at=None, + ban_count=1, + country=None, + ) + + time_part = parts[1].strip() + plus_idx = time_part.find(" + ") + if plus_idx == -1: + banned_at_str = time_part.strip() + expires_at_str: str | None = None + else: + banned_at_str = time_part[:plus_idx].strip() + remainder = time_part[plus_idx + 3 :] + eq_idx = remainder.find(" = ") + expires_at_str = remainder[eq_idx + 3 :].strip() if eq_idx != -1 else None + + _date_fmt = "%Y-%m-%d %H:%M:%S" + + def _to_iso(ts: str) -> str: + dt = datetime.strptime(ts, _date_fmt).replace(tzinfo=UTC) + return dt.isoformat() + + banned_at_iso: str | None = None + expires_at_iso: str | None = None + + with contextlib.suppress(ValueError): + banned_at_iso = _to_iso(banned_at_str) + + with contextlib.suppress(ValueError): + if expires_at_str: + expires_at_iso = _to_iso(expires_at_str) + + return ActiveBan( + ip=ip, + jail=jail, + banned_at=banned_at_iso, + expires_at=expires_at_iso, + ban_count=1, + country=None, + ) + except (ValueError, IndexError, AttributeError) as exc: + log.debug("ban_entry_parse_error", entry=entry, jail=jail, error=str(exc)) + return None + + +async def _enrich_bans( + bans: list[ActiveBan], + geo_enricher: GeoEnricher, +) -> list[ActiveBan]: + """Enrich ban records with geo data asynchronously.""" + geo_results: list[object | Exception] = await asyncio.gather( + *[cast("Awaitable[object]", geo_enricher(ban.ip)) for ban in bans], + return_exceptions=True, + ) + enriched: list[ActiveBan] = [] + for ban, geo in zip(bans, geo_results, strict=False): + if geo is not None and not isinstance(geo, Exception): + geo_info = cast("GeoInfo", geo) + enriched.append(ban.model_copy(update={"country": geo_info.country_code})) + else: + enriched.append(ban) + return enriched + + +async def get_active_bans( + socket_path: str, + geo_batch_lookup: GeoBatchLookup | None = None, + geo_enricher: GeoEnricher | None = None, + http_session: aiohttp.ClientSession | None = None, + app_db: aiosqlite.Connection | None = None, +) -> ActiveBanListResponse: + """Return all currently banned IPs across every jail. + + For each jail the ``get banip --with-time`` command is used + to retrieve ban start and expiry times alongside the IP address. + + Geo enrichment strategy (highest priority first): + + 1. When *http_session* is provided the entire set of banned IPs is resolved + in a single :func:`~app.services.geo_service.lookup_batch` call (up to + 100 IPs per HTTP request). This is far more efficient than concurrent + per-IP lookups and stays within ip-api.com rate limits. + 2. When only *geo_enricher* is provided (legacy / test path) each IP is + resolved individually via the supplied async callable. + + Args: + socket_path: Path to the fail2ban Unix domain socket. + geo_enricher: Optional async callable ``(ip) -> GeoInfo | None``. + Used to enrich each ban entry with country and ASN data. + Ignored when *http_session* is provided. + http_session: Optional shared :class:`aiohttp.ClientSession`. When + provided, :func:`~app.services.geo_service.lookup_batch` is used + for efficient bulk geo resolution. + app_db: Optional BanGUI application database connection used to + persist newly resolved geo entries across restarts. Only + meaningful when *http_session* is provided. + + Returns: + :class:`~app.models.ban.ActiveBanListResponse` with all active bans. + + Raises: + ~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket + cannot be reached. + """ + + client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT) + + global_status = _to_dict(_ok(await client.send(["status"]))) + jail_list_raw: str = str(global_status.get("Jail list", "") or "").strip() + jail_names: list[str] = ( + [j.strip() for j in jail_list_raw.split(",") if j.strip()] + if jail_list_raw + else [] + ) + + if not jail_names: + return ActiveBanListResponse(bans=[], total=0) + + results: list[object | Exception] = await asyncio.gather( + *[client.send(["get", jn, "banip", "--with-time"]) for jn in jail_names], + return_exceptions=True, + ) + + bans: list[ActiveBan] = [] + for jail_name, raw_result in zip(jail_names, results, strict=False): + if isinstance(raw_result, Exception): + log.warning( + "active_bans_fetch_error", + jail=jail_name, + error=str(raw_result), + ) + continue + + try: + ban_list: list[str] = cast("list[str]", _ok(raw_result)) or [] + except (TypeError, ValueError) as exc: + log.warning( + "active_bans_parse_error", + jail=jail_name, + error=str(exc), + ) + continue + + for entry in ban_list: + ban = _parse_ban_entry(str(entry), jail_name) + if ban is not None: + bans.append(ban) + + if http_session is not None and bans and geo_batch_lookup is not None: + all_ips: list[str] = [ban.ip for ban in bans] + try: + geo_map = await geo_batch_lookup(all_ips, http_session, db=app_db) + except Exception: # noqa: BLE001 + log.warning("active_bans_batch_geo_failed") + geo_map = {} + enriched: list[ActiveBan] = [] + for ban in bans: + geo = geo_map.get(ban.ip) + if geo is not None: + enriched.append(ban.model_copy(update={"country": geo.country_code})) + else: + enriched.append(ban) + bans = enriched + elif geo_enricher is not None: + bans = await _enrich_bans(bans, geo_enricher) + + log.info("active_bans_fetched", total=len(bans)) + return ActiveBanListResponse(bans=bans, total=len(bans)) + + _TIME_RANGE_SLACK_SECONDS: int = 60 @@ -502,7 +804,7 @@ async def bans_by_country( country_names[cc] = cn # Build companion table from recent rows (geo already cached from batch step). - bans: list[DashboardBanItem] = [] + bans: list[ActiveBan] = [] for companion_row in companion_rows: if source == "archive": ip = companion_row["ip"] diff --git a/backend/app/services/jail_service.py b/backend/app/services/jail_service.py index 21cb6f6..3c45026 100644 --- a/backend/app/services/jail_service.py +++ b/backend/app/services/jail_service.py @@ -802,194 +802,6 @@ async def restart_daemon( ) -# --------------------------------------------------------------------------- -# Public API — Ban / Unban -# --------------------------------------------------------------------------- - - -async def ban_ip(socket_path: str, jail: str, ip: str) -> None: - """Ban an IP address in a specific fail2ban jail. - - The IP address is validated with :mod:`ipaddress` before the command - is sent to fail2ban. - - Args: - socket_path: Path to the fail2ban Unix domain socket. - jail: Jail in which to apply the ban. - ip: IP address to ban (IPv4 or IPv6). - - Raises: - ValueError: If *ip* is not a valid IP address. - JailNotFoundError: If *jail* is not a known jail. - JailOperationError: If fail2ban reports the operation failed. - ~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket - cannot be reached. - """ - # Validate the IP address before sending to avoid injection. - try: - ipaddress.ip_address(ip) - except ValueError as exc: - raise ValueError(f"Invalid IP address: {ip!r}") from exc - - client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT) - try: - _ok(await client.send(["set", jail, "banip", ip])) - log.info("ip_banned", ip=ip, jail=jail) - except ValueError as exc: - if _is_not_found_error(exc): - raise JailNotFoundError(jail) from exc - raise JailOperationError(str(exc)) from exc - - -async def unban_ip( - socket_path: str, - ip: str, - jail: str | None = None, -) -> None: - """Unban an IP address from one or all fail2ban jails. - - If *jail* is ``None``, the IP is unbanned from every jail using the - global ``unban`` command. Otherwise only the specified jail is - targeted. - - Args: - socket_path: Path to the fail2ban Unix domain socket. - ip: IP address to unban. - jail: Jail to unban from. ``None`` means all jails. - - Raises: - ValueError: If *ip* is not a valid IP address. - JailNotFoundError: If *jail* is specified but does not exist. - JailOperationError: If fail2ban reports the operation failed. - ~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket - cannot be reached. - """ - try: - ipaddress.ip_address(ip) - except ValueError as exc: - raise ValueError(f"Invalid IP address: {ip!r}") from exc - - client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT) - try: - if jail is None: - _ok(await client.send(["unban", ip])) - log.info("ip_unbanned_all_jails", ip=ip) - else: - _ok(await client.send(["set", jail, "unbanip", ip])) - log.info("ip_unbanned", ip=ip, jail=jail) - except ValueError as exc: - if _is_not_found_error(exc): - raise JailNotFoundError(jail or "") from exc - raise JailOperationError(str(exc)) from exc - - -async def get_active_bans( - socket_path: str, - geo_batch_lookup: GeoBatchLookup | None = None, - geo_enricher: GeoEnricher | None = None, - http_session: aiohttp.ClientSession | None = None, - app_db: aiosqlite.Connection | None = None, -) -> ActiveBanListResponse: - """Return all currently banned IPs across every jail. - - For each jail the ``get banip --with-time`` command is used - to retrieve ban start and expiry times alongside the IP address. - - Geo enrichment strategy (highest priority first): - - 1. When *http_session* is provided the entire set of banned IPs is resolved - in a single :func:`~app.services.geo_service.lookup_batch` call (up to - 100 IPs per HTTP request). This is far more efficient than concurrent - per-IP lookups and stays within ip-api.com rate limits. - 2. When only *geo_enricher* is provided (legacy / test path) each IP is - resolved individually via the supplied async callable. - - Args: - socket_path: Path to the fail2ban Unix domain socket. - geo_enricher: Optional async callable ``(ip) → GeoInfo | None`` - used to enrich each ban entry with country and ASN data. - Ignored when *http_session* is provided. - http_session: Optional shared :class:`aiohttp.ClientSession`. When - provided, :func:`~app.services.geo_service.lookup_batch` is used - for efficient bulk geo resolution. - app_db: Optional BanGUI application database connection used to - persist newly resolved geo entries across restarts. Only - meaningful when *http_session* is provided. - - Returns: - :class:`~app.models.ban.ActiveBanListResponse` with all active bans. - - Raises: - ~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket - cannot be reached. - """ - - client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT) - - # Fetch jail names. - global_status = _to_dict(_ok(await client.send(["status"]))) - jail_list_raw: str = str(global_status.get("Jail list", "") or "").strip() - jail_names: list[str] = ( - [j.strip() for j in jail_list_raw.split(",") if j.strip()] - if jail_list_raw - else [] - ) - - if not jail_names: - return ActiveBanListResponse(bans=[], total=0) - - # For each jail, fetch the ban list with time info in parallel. - results: list[object | Exception] = await asyncio.gather( - *[client.send(["get", jn, "banip", "--with-time"]) for jn in jail_names], - return_exceptions=True, - ) - - bans: list[ActiveBan] = [] - for jail_name, raw_result in zip(jail_names, results, strict=False): - if isinstance(raw_result, Exception): - log.warning( - "active_bans_fetch_error", - jail=jail_name, - error=str(raw_result), - ) - continue - - try: - ban_list: list[str] = cast("list[str]", _ok(raw_result)) or [] - except (TypeError, ValueError) as exc: - log.warning( - "active_bans_parse_error", - jail=jail_name, - error=str(exc), - ) - continue - - for entry in ban_list: - ban = _parse_ban_entry(str(entry), jail_name) - if ban is not None: - bans.append(ban) - - # Enrich with geo data — prefer batch lookup over per-IP enricher. - if http_session is not None and bans and geo_batch_lookup is not None: - all_ips: list[str] = [ban.ip for ban in bans] - try: - geo_map = await geo_batch_lookup(all_ips, http_session, db=app_db) - except Exception: # noqa: BLE001 - log.warning("active_bans_batch_geo_failed") - geo_map = {} - enriched: list[ActiveBan] = [] - for ban in bans: - geo = geo_map.get(ban.ip) - if geo is not None: - enriched.append(ban.model_copy(update={"country": geo.country_code})) - else: - enriched.append(ban) - bans = enriched - elif geo_enricher is not None: - bans = await _enrich_bans(bans, geo_enricher) - - log.info("active_bans_fetched", total=len(bans)) - return ActiveBanListResponse(bans=bans, total=len(bans)) def _parse_ban_entry(entry: str, jail: str) -> ActiveBan | None: diff --git a/backend/app/tasks/blocklist_import.py b/backend/app/tasks/blocklist_import.py index f2e1605..ae0f290 100644 --- a/backend/app/tasks/blocklist_import.py +++ b/backend/app/tasks/blocklist_import.py @@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any import structlog from app.db import open_db -from app.services import blocklist_service, jail_service +from app.services import ban_service, blocklist_service from app.utils.runtime_state import get_effective_settings if TYPE_CHECKING: @@ -55,7 +55,7 @@ async def _run_import_with_resources(settings: Settings, http_session: ClientSes db, http_session, socket_path, - ban_ip=jail_service.ban_ip, + ban_ip=ban_service.ban_ip, ) log.info( "blocklist_import_finished", diff --git a/backend/tests/test_routers/test_bans.py b/backend/tests/test_routers/test_bans.py index 2e2d6d9..e84d127 100644 --- a/backend/tests/test_routers/test_bans.py +++ b/backend/tests/test_routers/test_bans.py @@ -84,7 +84,7 @@ class TestGetActiveBans: total=1, ) with patch( - "app.routers.bans.jail_service.get_active_bans", + "app.routers.bans.ban_service.get_active_bans", AsyncMock(return_value=mock_response), ): resp = await bans_client.get("/api/bans/active") @@ -107,7 +107,7 @@ class TestGetActiveBans: """GET /api/bans/active returns empty list when no bans are active.""" mock_response = ActiveBanListResponse(bans=[], total=0) with patch( - "app.routers.bans.jail_service.get_active_bans", + "app.routers.bans.ban_service.get_active_bans", AsyncMock(return_value=mock_response), ): resp = await bans_client.get("/api/bans/active") @@ -132,7 +132,7 @@ class TestGetActiveBans: total=1, ) with patch( - "app.routers.bans.jail_service.get_active_bans", + "app.routers.bans.ban_service.get_active_bans", AsyncMock(return_value=mock_response), ): resp = await bans_client.get("/api/bans/active") @@ -156,7 +156,7 @@ class TestBanIp: async def test_201_on_success(self, bans_client: AsyncClient) -> None: """POST /api/bans returns 201 when the IP is banned.""" with patch( - "app.routers.bans.jail_service.ban_ip", + "app.routers.bans.ban_service.ban_ip", AsyncMock(return_value=None), ): resp = await bans_client.post( @@ -170,7 +170,7 @@ class TestBanIp: async def test_400_for_invalid_ip(self, bans_client: AsyncClient) -> None: """POST /api/bans returns 400 for an invalid IP address.""" with patch( - "app.routers.bans.jail_service.ban_ip", + "app.routers.bans.ban_service.ban_ip", AsyncMock(side_effect=ValueError("Invalid IP address: 'bad'")), ): resp = await bans_client.post( @@ -185,7 +185,7 @@ class TestBanIp: from app.services.jail_service import JailNotFoundError with patch( - "app.routers.bans.jail_service.ban_ip", + "app.routers.bans.ban_service.ban_ip", AsyncMock(side_effect=JailNotFoundError("ghost")), ): resp = await bans_client.post( @@ -215,7 +215,7 @@ class TestUnbanIp: async def test_200_unban_from_all(self, bans_client: AsyncClient) -> None: """DELETE /api/bans with unban_all=true unbans from all jails.""" with patch( - "app.routers.bans.jail_service.unban_ip", + "app.routers.bans.ban_service.unban_ip", AsyncMock(return_value=None), ): resp = await bans_client.request( @@ -230,7 +230,7 @@ class TestUnbanIp: async def test_200_unban_from_specific_jail(self, bans_client: AsyncClient) -> None: """DELETE /api/bans with a jail unbans from that jail only.""" with patch( - "app.routers.bans.jail_service.unban_ip", + "app.routers.bans.ban_service.unban_ip", AsyncMock(return_value=None), ): resp = await bans_client.request( @@ -245,7 +245,7 @@ class TestUnbanIp: async def test_400_for_invalid_ip(self, bans_client: AsyncClient) -> None: """DELETE /api/bans returns 400 for an invalid IP.""" with patch( - "app.routers.bans.jail_service.unban_ip", + "app.routers.bans.ban_service.unban_ip", AsyncMock(side_effect=ValueError("Invalid IP address: 'bad'")), ): resp = await bans_client.request( @@ -261,7 +261,7 @@ class TestUnbanIp: from app.services.jail_service import JailNotFoundError with patch( - "app.routers.bans.jail_service.unban_ip", + "app.routers.bans.ban_service.unban_ip", AsyncMock(side_effect=JailNotFoundError("ghost")), ): resp = await bans_client.request( diff --git a/backend/tests/test_services/test_blocklist_service.py b/backend/tests/test_services/test_blocklist_service.py index 11fa32d..b4f13c0 100644 --- a/backend/tests/test_services/test_blocklist_service.py +++ b/backend/tests/test_services/test_blocklist_service.py @@ -174,17 +174,17 @@ class TestImport: source = await blocklist_service.create_source(db, "Import Test", "https://t.test/") - from app.services import jail_service + from app.services import ban_service with patch( - "app.services.jail_service.ban_ip", new_callable=AsyncMock + "app.services.ban_service.ban_ip", new_callable=AsyncMock ) as mock_ban: result = await blocklist_service.import_source( source, session, "/tmp/fake.sock", db, - ban_ip=jail_service.ban_ip, + ban_ip=ban_service.ban_ip, ) assert result.ips_imported == 2 @@ -198,15 +198,15 @@ class TestImport: session = _make_session(content) source = await blocklist_service.create_source(db, "CIDR Test", "https://c.test/") - from app.services import jail_service + from app.services import ban_service - with patch("app.services.jail_service.ban_ip", new_callable=AsyncMock): + with patch("app.services.ban_service.ban_ip", new_callable=AsyncMock): result = await blocklist_service.import_source( source, session, "/tmp/fake.sock", db, - ban_ip=jail_service.ban_ip, + ban_ip=ban_service.ban_ip, ) assert result.ips_imported == 1 @@ -217,14 +217,14 @@ class TestImport: session = _make_session("", status=503) source = await blocklist_service.create_source(db, "Err Source", "https://err.test/") - from app.services import jail_service + from app.services import ban_service result = await blocklist_service.import_source( source, session, "/tmp/fake.sock", db, - ban_ip=jail_service.ban_ip, + ban_ip=ban_service.ban_ip, ) assert result.ips_imported == 0 @@ -234,6 +234,7 @@ class TestImport: """import_source aborts immediately and records an error when the target jail does not exist in fail2ban instead of silently skipping every IP.""" from app.services.jail_service import JailNotFoundError + from app.services import ban_service content = "\n".join(f"1.2.3.{i}" for i in range(100)) session = _make_session(content) @@ -246,15 +247,13 @@ class TestImport: call_count += 1 raise JailNotFoundError(jail) - with patch("app.services.jail_service.ban_ip", side_effect=_raise_jail_not_found): - from app.services import jail_service - + with patch("app.services.ban_service.ban_ip", side_effect=_raise_jail_not_found): result = await blocklist_service.import_source( source, session, "/tmp/fake.sock", db, - ban_ip=jail_service.ban_ip, + ban_ip=ban_service.ban_ip, ) # Must abort after the first JailNotFoundError — only one ban attempt. @@ -273,15 +272,15 @@ class TestImport: session = _make_session(content) with patch( - "app.services.jail_service.ban_ip", new_callable=AsyncMock + "app.services.ban_service.ban_ip", new_callable=AsyncMock ): - from app.services import jail_service + from app.services import ban_service result = await blocklist_service.import_all( db, session, "/tmp/fake.sock", - ban_ip=jail_service.ban_ip, + ban_ip=ban_service.ban_ip, ) # Only S1 is enabled, S2 is disabled. @@ -415,16 +414,16 @@ class TestGeoPrewarmCacheFilter: def _mock_is_cached(ip: str) -> bool: return ip == "1.2.3.4" - from app.services import jail_service + from app.services import ban_service mock_batch = AsyncMock(return_value={}) - with patch("app.services.jail_service.ban_ip", new_callable=AsyncMock): + with patch("app.services.ban_service.ban_ip", new_callable=AsyncMock): result = await blocklist_service.import_source( source, session, "/tmp/fake.sock", db, - ban_ip=jail_service.ban_ip, + ban_ip=ban_service.ban_ip, geo_is_cached=_mock_is_cached, geo_batch_lookup=mock_batch, ) diff --git a/backend/tests/test_services/test_jail_service.py b/backend/tests/test_services/test_jail_service.py index 643eb84..1f574b8 100644 --- a/backend/tests/test_services/test_jail_service.py +++ b/backend/tests/test_services/test_jail_service.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio +import contextlib from typing import Any from unittest.mock import AsyncMock, patch @@ -12,7 +13,7 @@ from app.exceptions import Fail2BanConnectionError from app.models.ban import ActiveBanListResponse, JailBannedIpsResponse from app.models.geo import GeoDetail, GeoInfo from app.models.jail import JailDetailResponse, JailListResponse -from app.services import jail_service +from app.services import ban_service, jail_service from app.services.jail_service import JailNotFoundError, JailOperationError # --------------------------------------------------------------------------- @@ -71,7 +72,10 @@ def _patch_client(responses: dict[str, Any]) -> Any: def __init__(self, **_kw: Any) -> None: self.send = mock_send - return patch("app.services.jail_service.Fail2BanClient", _FakeClient) + stack = contextlib.ExitStack() + stack.enter_context(patch("app.services.jail_service.Fail2BanClient", _FakeClient)) + stack.enter_context(patch("app.services.ban_service.Fail2BanClient", _FakeClient)) + return stack # --------------------------------------------------------------------------- @@ -555,19 +559,19 @@ class TestJailControls: class TestBanUnban: - """Unit tests for :func:`~app.services.jail_service.ban_ip` and - :func:`~app.services.jail_service.unban_ip`. + """Unit tests for :func:`~app.services.ban_service.ban_ip` and + :func:`~app.services.ban_service.unban_ip`. """ async def test_ban_ip_success(self) -> None: """ban_ip sends the banip command for a valid IP.""" with _patch_client({"set|sshd|banip|1.2.3.4": (0, 1)}): - await jail_service.ban_ip(_SOCKET, "sshd", "1.2.3.4") # should not raise + await ban_service.ban_ip(_SOCKET, "sshd", "1.2.3.4") # should not raise async def test_ban_ip_invalid_raises(self) -> None: """ban_ip raises ValueError for a non-IP value.""" with pytest.raises(ValueError, match="Invalid IP"): - await jail_service.ban_ip(_SOCKET, "sshd", "not-an-ip") + await ban_service.ban_ip(_SOCKET, "sshd", "not-an-ip") async def test_ban_ip_unknown_jail_exception_raises_jail_not_found(self) -> None: """ban_ip raises JailNotFoundError when fail2ban returns UnknownJailException. @@ -581,27 +585,27 @@ class TestBanUnban: _patch_client({"set|missing-jail|banip|1.2.3.4": response}), pytest.raises(JailNotFoundError, match="missing-jail"), ): - await jail_service.ban_ip(_SOCKET, "missing-jail", "1.2.3.4") + await ban_service.ban_ip(_SOCKET, "missing-jail", "1.2.3.4") async def test_ban_ipv6_success(self) -> None: """ban_ip accepts an IPv6 address.""" with _patch_client({"set|sshd|banip|::1": (0, 1)}): - await jail_service.ban_ip(_SOCKET, "sshd", "::1") # should not raise + await ban_service.ban_ip(_SOCKET, "sshd", "::1") # should not raise async def test_unban_ip_all_jails(self) -> None: """unban_ip with jail=None uses the global unban command.""" with _patch_client({"unban|1.2.3.4": (0, 1)}): - await jail_service.unban_ip(_SOCKET, "1.2.3.4") # should not raise + await ban_service.unban_ip(_SOCKET, "1.2.3.4") # should not raise async def test_unban_ip_specific_jail(self) -> None: """unban_ip with a jail sends the set unbanip command.""" with _patch_client({"set|sshd|unbanip|1.2.3.4": (0, 1)}): - await jail_service.unban_ip(_SOCKET, "1.2.3.4", jail="sshd") # should not raise + await ban_service.unban_ip(_SOCKET, "1.2.3.4", jail="sshd") # should not raise async def test_unban_invalid_ip_raises(self) -> None: """unban_ip raises ValueError for an invalid IP.""" with pytest.raises(ValueError, match="Invalid IP"): - await jail_service.unban_ip(_SOCKET, "bad-ip") + await ban_service.unban_ip(_SOCKET, "bad-ip") # --------------------------------------------------------------------------- @@ -610,7 +614,7 @@ class TestBanUnban: class TestGetActiveBans: - """Unit tests for :func:`~app.services.jail_service.get_active_bans`.""" + """Unit tests for :func:`~app.services.ban_service.get_active_bans`.""" async def test_returns_active_ban_list_response(self) -> None: """get_active_bans returns an ActiveBanListResponse.""" @@ -622,7 +626,7 @@ class TestGetActiveBans: ), } with _patch_client(responses): - result = await jail_service.get_active_bans(_SOCKET) + result = await ban_service.get_active_bans(_SOCKET) assert isinstance(result, ActiveBanListResponse) assert result.total == 1 @@ -633,7 +637,7 @@ class TestGetActiveBans: """get_active_bans returns empty list when no jails are active.""" responses = {"status": (0, [("Number of jail", 0), ("Jail list", "")])} with _patch_client(responses): - result = await jail_service.get_active_bans(_SOCKET) + result = await ban_service.get_active_bans(_SOCKET) assert result.total == 0 assert result.bans == [] @@ -645,7 +649,7 @@ class TestGetActiveBans: "get|sshd|banip|--with-time": (0, []), } with _patch_client(responses): - result = await jail_service.get_active_bans(_SOCKET) + result = await ban_service.get_active_bans(_SOCKET) assert result.total == 0 @@ -659,7 +663,7 @@ class TestGetActiveBans: ), } with _patch_client(responses): - result = await jail_service.get_active_bans(_SOCKET) + result = await ban_service.get_active_bans(_SOCKET) ban = result.bans[0] assert ban.banned_at is not None @@ -691,8 +695,8 @@ class TestGetActiveBans: def __init__(self, **_kw: Any) -> None: self.send = AsyncMock(side_effect=_side) - with patch("app.services.jail_service.Fail2BanClient", _FakeClientPartial): - result = await jail_service.get_active_bans(_SOCKET) + with patch("app.services.ban_service.Fail2BanClient", _FakeClientPartial): + result = await ban_service.get_active_bans(_SOCKET) # Only sshd ban returned (nginx silently skipped) assert result.total == 1 @@ -714,7 +718,7 @@ class TestGetActiveBans: with _patch_client(responses): mock_session = AsyncMock() - result = await jail_service.get_active_bans( + result = await ban_service.get_active_bans( _SOCKET, http_session=mock_session, geo_batch_lookup=mock_batch, @@ -738,7 +742,7 @@ class TestGetActiveBans: with _patch_client(responses): mock_session = AsyncMock() - result = await jail_service.get_active_bans( + result = await ban_service.get_active_bans( _SOCKET, http_session=mock_session, geo_batch_lookup=failing_batch, @@ -763,7 +767,7 @@ class TestGetActiveBans: return GeoInfo(country_code="JP", country_name="Japan", asn=None, org=None) with _patch_client(responses): - result = await jail_service.get_active_bans( + result = await ban_service.get_active_bans( _SOCKET, geo_enricher=_enricher )