Task 13: move ban_ip, unban_ip, and get_active_bans from jail_service to ban_service and update routers/tests
This commit is contained in:
@@ -174,17 +174,17 @@ class TestImport:
|
||||
|
||||
source = await blocklist_service.create_source(db, "Import Test", "https://t.test/")
|
||||
|
||||
from app.services import jail_service
|
||||
from app.services import ban_service
|
||||
|
||||
with patch(
|
||||
"app.services.jail_service.ban_ip", new_callable=AsyncMock
|
||||
"app.services.ban_service.ban_ip", new_callable=AsyncMock
|
||||
) as mock_ban:
|
||||
result = await blocklist_service.import_source(
|
||||
source,
|
||||
session,
|
||||
"/tmp/fake.sock",
|
||||
db,
|
||||
ban_ip=jail_service.ban_ip,
|
||||
ban_ip=ban_service.ban_ip,
|
||||
)
|
||||
|
||||
assert result.ips_imported == 2
|
||||
@@ -198,15 +198,15 @@ class TestImport:
|
||||
session = _make_session(content)
|
||||
source = await blocklist_service.create_source(db, "CIDR Test", "https://c.test/")
|
||||
|
||||
from app.services import jail_service
|
||||
from app.services import ban_service
|
||||
|
||||
with patch("app.services.jail_service.ban_ip", new_callable=AsyncMock):
|
||||
with patch("app.services.ban_service.ban_ip", new_callable=AsyncMock):
|
||||
result = await blocklist_service.import_source(
|
||||
source,
|
||||
session,
|
||||
"/tmp/fake.sock",
|
||||
db,
|
||||
ban_ip=jail_service.ban_ip,
|
||||
ban_ip=ban_service.ban_ip,
|
||||
)
|
||||
|
||||
assert result.ips_imported == 1
|
||||
@@ -217,14 +217,14 @@ class TestImport:
|
||||
session = _make_session("", status=503)
|
||||
source = await blocklist_service.create_source(db, "Err Source", "https://err.test/")
|
||||
|
||||
from app.services import jail_service
|
||||
from app.services import ban_service
|
||||
|
||||
result = await blocklist_service.import_source(
|
||||
source,
|
||||
session,
|
||||
"/tmp/fake.sock",
|
||||
db,
|
||||
ban_ip=jail_service.ban_ip,
|
||||
ban_ip=ban_service.ban_ip,
|
||||
)
|
||||
|
||||
assert result.ips_imported == 0
|
||||
@@ -234,6 +234,7 @@ class TestImport:
|
||||
"""import_source aborts immediately and records an error when the target jail
|
||||
does not exist in fail2ban instead of silently skipping every IP."""
|
||||
from app.services.jail_service import JailNotFoundError
|
||||
from app.services import ban_service
|
||||
|
||||
content = "\n".join(f"1.2.3.{i}" for i in range(100))
|
||||
session = _make_session(content)
|
||||
@@ -246,15 +247,13 @@ class TestImport:
|
||||
call_count += 1
|
||||
raise JailNotFoundError(jail)
|
||||
|
||||
with patch("app.services.jail_service.ban_ip", side_effect=_raise_jail_not_found):
|
||||
from app.services import jail_service
|
||||
|
||||
with patch("app.services.ban_service.ban_ip", side_effect=_raise_jail_not_found):
|
||||
result = await blocklist_service.import_source(
|
||||
source,
|
||||
session,
|
||||
"/tmp/fake.sock",
|
||||
db,
|
||||
ban_ip=jail_service.ban_ip,
|
||||
ban_ip=ban_service.ban_ip,
|
||||
)
|
||||
|
||||
# Must abort after the first JailNotFoundError — only one ban attempt.
|
||||
@@ -273,15 +272,15 @@ class TestImport:
|
||||
session = _make_session(content)
|
||||
|
||||
with patch(
|
||||
"app.services.jail_service.ban_ip", new_callable=AsyncMock
|
||||
"app.services.ban_service.ban_ip", new_callable=AsyncMock
|
||||
):
|
||||
from app.services import jail_service
|
||||
from app.services import ban_service
|
||||
|
||||
result = await blocklist_service.import_all(
|
||||
db,
|
||||
session,
|
||||
"/tmp/fake.sock",
|
||||
ban_ip=jail_service.ban_ip,
|
||||
ban_ip=ban_service.ban_ip,
|
||||
)
|
||||
|
||||
# Only S1 is enabled, S2 is disabled.
|
||||
@@ -415,16 +414,16 @@ class TestGeoPrewarmCacheFilter:
|
||||
def _mock_is_cached(ip: str) -> bool:
|
||||
return ip == "1.2.3.4"
|
||||
|
||||
from app.services import jail_service
|
||||
from app.services import ban_service
|
||||
|
||||
mock_batch = AsyncMock(return_value={})
|
||||
with patch("app.services.jail_service.ban_ip", new_callable=AsyncMock):
|
||||
with patch("app.services.ban_service.ban_ip", new_callable=AsyncMock):
|
||||
result = await blocklist_service.import_source(
|
||||
source,
|
||||
session,
|
||||
"/tmp/fake.sock",
|
||||
db,
|
||||
ban_ip=jail_service.ban_ip,
|
||||
ban_ip=ban_service.ban_ip,
|
||||
geo_is_cached=_mock_is_cached,
|
||||
geo_batch_lookup=mock_batch,
|
||||
)
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
@@ -12,7 +13,7 @@ from app.exceptions import Fail2BanConnectionError
|
||||
from app.models.ban import ActiveBanListResponse, JailBannedIpsResponse
|
||||
from app.models.geo import GeoDetail, GeoInfo
|
||||
from app.models.jail import JailDetailResponse, JailListResponse
|
||||
from app.services import jail_service
|
||||
from app.services import ban_service, jail_service
|
||||
from app.services.jail_service import JailNotFoundError, JailOperationError
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -71,7 +72,10 @@ def _patch_client(responses: dict[str, Any]) -> Any:
|
||||
def __init__(self, **_kw: Any) -> None:
|
||||
self.send = mock_send
|
||||
|
||||
return patch("app.services.jail_service.Fail2BanClient", _FakeClient)
|
||||
stack = contextlib.ExitStack()
|
||||
stack.enter_context(patch("app.services.jail_service.Fail2BanClient", _FakeClient))
|
||||
stack.enter_context(patch("app.services.ban_service.Fail2BanClient", _FakeClient))
|
||||
return stack
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -555,19 +559,19 @@ class TestJailControls:
|
||||
|
||||
|
||||
class TestBanUnban:
|
||||
"""Unit tests for :func:`~app.services.jail_service.ban_ip` and
|
||||
:func:`~app.services.jail_service.unban_ip`.
|
||||
"""Unit tests for :func:`~app.services.ban_service.ban_ip` and
|
||||
:func:`~app.services.ban_service.unban_ip`.
|
||||
"""
|
||||
|
||||
async def test_ban_ip_success(self) -> None:
|
||||
"""ban_ip sends the banip command for a valid IP."""
|
||||
with _patch_client({"set|sshd|banip|1.2.3.4": (0, 1)}):
|
||||
await jail_service.ban_ip(_SOCKET, "sshd", "1.2.3.4") # should not raise
|
||||
await ban_service.ban_ip(_SOCKET, "sshd", "1.2.3.4") # should not raise
|
||||
|
||||
async def test_ban_ip_invalid_raises(self) -> None:
|
||||
"""ban_ip raises ValueError for a non-IP value."""
|
||||
with pytest.raises(ValueError, match="Invalid IP"):
|
||||
await jail_service.ban_ip(_SOCKET, "sshd", "not-an-ip")
|
||||
await ban_service.ban_ip(_SOCKET, "sshd", "not-an-ip")
|
||||
|
||||
async def test_ban_ip_unknown_jail_exception_raises_jail_not_found(self) -> None:
|
||||
"""ban_ip raises JailNotFoundError when fail2ban returns UnknownJailException.
|
||||
@@ -581,27 +585,27 @@ class TestBanUnban:
|
||||
_patch_client({"set|missing-jail|banip|1.2.3.4": response}),
|
||||
pytest.raises(JailNotFoundError, match="missing-jail"),
|
||||
):
|
||||
await jail_service.ban_ip(_SOCKET, "missing-jail", "1.2.3.4")
|
||||
await ban_service.ban_ip(_SOCKET, "missing-jail", "1.2.3.4")
|
||||
|
||||
async def test_ban_ipv6_success(self) -> None:
|
||||
"""ban_ip accepts an IPv6 address."""
|
||||
with _patch_client({"set|sshd|banip|::1": (0, 1)}):
|
||||
await jail_service.ban_ip(_SOCKET, "sshd", "::1") # should not raise
|
||||
await ban_service.ban_ip(_SOCKET, "sshd", "::1") # should not raise
|
||||
|
||||
async def test_unban_ip_all_jails(self) -> None:
|
||||
"""unban_ip with jail=None uses the global unban command."""
|
||||
with _patch_client({"unban|1.2.3.4": (0, 1)}):
|
||||
await jail_service.unban_ip(_SOCKET, "1.2.3.4") # should not raise
|
||||
await ban_service.unban_ip(_SOCKET, "1.2.3.4") # should not raise
|
||||
|
||||
async def test_unban_ip_specific_jail(self) -> None:
|
||||
"""unban_ip with a jail sends the set unbanip command."""
|
||||
with _patch_client({"set|sshd|unbanip|1.2.3.4": (0, 1)}):
|
||||
await jail_service.unban_ip(_SOCKET, "1.2.3.4", jail="sshd") # should not raise
|
||||
await ban_service.unban_ip(_SOCKET, "1.2.3.4", jail="sshd") # should not raise
|
||||
|
||||
async def test_unban_invalid_ip_raises(self) -> None:
|
||||
"""unban_ip raises ValueError for an invalid IP."""
|
||||
with pytest.raises(ValueError, match="Invalid IP"):
|
||||
await jail_service.unban_ip(_SOCKET, "bad-ip")
|
||||
await ban_service.unban_ip(_SOCKET, "bad-ip")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -610,7 +614,7 @@ class TestBanUnban:
|
||||
|
||||
|
||||
class TestGetActiveBans:
|
||||
"""Unit tests for :func:`~app.services.jail_service.get_active_bans`."""
|
||||
"""Unit tests for :func:`~app.services.ban_service.get_active_bans`."""
|
||||
|
||||
async def test_returns_active_ban_list_response(self) -> None:
|
||||
"""get_active_bans returns an ActiveBanListResponse."""
|
||||
@@ -622,7 +626,7 @@ class TestGetActiveBans:
|
||||
),
|
||||
}
|
||||
with _patch_client(responses):
|
||||
result = await jail_service.get_active_bans(_SOCKET)
|
||||
result = await ban_service.get_active_bans(_SOCKET)
|
||||
|
||||
assert isinstance(result, ActiveBanListResponse)
|
||||
assert result.total == 1
|
||||
@@ -633,7 +637,7 @@ class TestGetActiveBans:
|
||||
"""get_active_bans returns empty list when no jails are active."""
|
||||
responses = {"status": (0, [("Number of jail", 0), ("Jail list", "")])}
|
||||
with _patch_client(responses):
|
||||
result = await jail_service.get_active_bans(_SOCKET)
|
||||
result = await ban_service.get_active_bans(_SOCKET)
|
||||
|
||||
assert result.total == 0
|
||||
assert result.bans == []
|
||||
@@ -645,7 +649,7 @@ class TestGetActiveBans:
|
||||
"get|sshd|banip|--with-time": (0, []),
|
||||
}
|
||||
with _patch_client(responses):
|
||||
result = await jail_service.get_active_bans(_SOCKET)
|
||||
result = await ban_service.get_active_bans(_SOCKET)
|
||||
|
||||
assert result.total == 0
|
||||
|
||||
@@ -659,7 +663,7 @@ class TestGetActiveBans:
|
||||
),
|
||||
}
|
||||
with _patch_client(responses):
|
||||
result = await jail_service.get_active_bans(_SOCKET)
|
||||
result = await ban_service.get_active_bans(_SOCKET)
|
||||
|
||||
ban = result.bans[0]
|
||||
assert ban.banned_at is not None
|
||||
@@ -691,8 +695,8 @@ class TestGetActiveBans:
|
||||
def __init__(self, **_kw: Any) -> None:
|
||||
self.send = AsyncMock(side_effect=_side)
|
||||
|
||||
with patch("app.services.jail_service.Fail2BanClient", _FakeClientPartial):
|
||||
result = await jail_service.get_active_bans(_SOCKET)
|
||||
with patch("app.services.ban_service.Fail2BanClient", _FakeClientPartial):
|
||||
result = await ban_service.get_active_bans(_SOCKET)
|
||||
|
||||
# Only sshd ban returned (nginx silently skipped)
|
||||
assert result.total == 1
|
||||
@@ -714,7 +718,7 @@ class TestGetActiveBans:
|
||||
|
||||
with _patch_client(responses):
|
||||
mock_session = AsyncMock()
|
||||
result = await jail_service.get_active_bans(
|
||||
result = await ban_service.get_active_bans(
|
||||
_SOCKET,
|
||||
http_session=mock_session,
|
||||
geo_batch_lookup=mock_batch,
|
||||
@@ -738,7 +742,7 @@ class TestGetActiveBans:
|
||||
|
||||
with _patch_client(responses):
|
||||
mock_session = AsyncMock()
|
||||
result = await jail_service.get_active_bans(
|
||||
result = await ban_service.get_active_bans(
|
||||
_SOCKET,
|
||||
http_session=mock_session,
|
||||
geo_batch_lookup=failing_batch,
|
||||
@@ -763,7 +767,7 @@ class TestGetActiveBans:
|
||||
return GeoInfo(country_code="JP", country_name="Japan", asn=None, org=None)
|
||||
|
||||
with _patch_client(responses):
|
||||
result = await jail_service.get_active_bans(
|
||||
result = await ban_service.get_active_bans(
|
||||
_SOCKET, geo_enricher=_enricher
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user