refactor: improve backend type safety and import organization

- Add TYPE_CHECKING guards for runtime-expensive imports (aiohttp, aiosqlite)
- Reorganize imports to follow PEP 8 conventions
- Convert TypeAlias to modern PEP 695 type syntax (where appropriate)
- Use Sequence/Mapping from collections.abc for type hints (covariant)
- Replace string literals with cast() for improved type inference
- Fix casting of Fail2BanResponse and TypedDict patterns
- Add IpLookupResult TypedDict for precise return type annotation
- Reformat overlong lines for readability (120 char limit)
- Add asyncio_mode and filterwarnings to pytest config
- Update test fixtures with improved type hints

This improves mypy type checking and makes type relationships explicit.
This commit is contained in:
2026-03-20 13:44:14 +01:00
parent bdcdd5d672
commit 1c0bac1353
30 changed files with 431 additions and 644 deletions

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
from collections.abc import Generator
from unittest.mock import patch
import pytest
@@ -157,12 +158,12 @@ class TestRequireAuthSessionCache:
"""In-memory session token cache inside ``require_auth``."""
@pytest.fixture(autouse=True)
def reset_cache(self) -> None: # type: ignore[misc]
def reset_cache(self) -> Generator[None, None, None]:
"""Flush the session cache before and after every test in this class."""
from app import dependencies
dependencies.clear_session_cache()
yield # type: ignore[misc]
yield
dependencies.clear_session_cache()
async def test_second_request_skips_db(self, client: AsyncClient) -> None:

View File

@@ -70,7 +70,7 @@ class TestGeoLookup:
async def test_200_with_geo_info(self, geo_client: AsyncClient) -> None:
"""GET /api/geo/lookup/{ip} returns 200 with enriched result."""
geo = GeoInfo(country_code="DE", country_name="Germany", asn="12345", org="Acme")
result = {
result: dict[str, object] = {
"ip": "1.2.3.4",
"currently_banned_in": ["sshd"],
"geo": geo,
@@ -92,7 +92,7 @@ class TestGeoLookup:
async def test_200_when_not_banned(self, geo_client: AsyncClient) -> None:
"""GET /api/geo/lookup/{ip} returns empty list when IP is not banned anywhere."""
result = {
result: dict[str, object] = {
"ip": "8.8.8.8",
"currently_banned_in": [],
"geo": GeoInfo(country_code="US", country_name="United States", asn=None, org=None),
@@ -108,7 +108,7 @@ class TestGeoLookup:
async def test_200_with_no_geo(self, geo_client: AsyncClient) -> None:
"""GET /api/geo/lookup/{ip} returns null geo when enricher fails."""
result = {
result: dict[str, object] = {
"ip": "1.2.3.4",
"currently_banned_in": [],
"geo": None,
@@ -144,7 +144,7 @@ class TestGeoLookup:
async def test_ipv6_address(self, geo_client: AsyncClient) -> None:
"""GET /api/geo/lookup/{ip} handles IPv6 addresses."""
result = {
result: dict[str, object] = {
"ip": "2001:db8::1",
"currently_banned_in": [],
"geo": None,

View File

@@ -12,6 +12,7 @@ from httpx import ASGITransport, AsyncClient
from app.config import Settings
from app.db import init_db
from app.main import create_app
from app.models.ban import JailBannedIpsResponse
from app.models.jail import Jail, JailDetailResponse, JailListResponse, JailStatus, JailSummary
# ---------------------------------------------------------------------------
@@ -801,17 +802,17 @@ class TestGetJailBannedIps:
def _mock_response(
self,
*,
items: list[dict] | None = None,
items: list[dict[str, str | None]] | None = None,
total: int = 2,
page: int = 1,
page_size: int = 25,
) -> "JailBannedIpsResponse": # type: ignore[name-defined]
) -> JailBannedIpsResponse:
from app.models.ban import ActiveBan, JailBannedIpsResponse
ban_items = (
[
ActiveBan(
ip=item.get("ip", "1.2.3.4"),
ip=item.get("ip") or "1.2.3.4",
jail="sshd",
banned_at=item.get("banned_at", "2025-01-01T10:00:00+00:00"),
expires_at=item.get("expires_at", "2025-01-01T10:10:00+00:00"),

View File

@@ -247,9 +247,9 @@ class TestSetupCompleteCaching:
assert not getattr(app.state, "_setup_complete_cached", False)
# First non-exempt request — middleware queries DB and sets the flag.
await client.post("/api/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]}) # type: ignore[call-overload]
await client.post("/api/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]})
assert app.state._setup_complete_cached is True # type: ignore[attr-defined]
assert app.state._setup_complete_cached is True
async def test_cached_path_skips_is_setup_complete(
self,
@@ -267,12 +267,12 @@ class TestSetupCompleteCaching:
# Do setup and warm the cache.
await client.post("/api/setup", json=_SETUP_PAYLOAD)
await client.post("/api/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]}) # type: ignore[call-overload]
assert app.state._setup_complete_cached is True # type: ignore[attr-defined]
await client.post("/api/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]})
assert app.state._setup_complete_cached is True
call_count = 0
async def _counting(db): # type: ignore[no-untyped-def]
async def _counting(db: aiosqlite.Connection) -> bool:
nonlocal call_count
call_count += 1
return True

View File

@@ -73,7 +73,7 @@ class TestCheckPasswordAsync:
auth_service._check_password("secret", hashed), # noqa: SLF001
auth_service._check_password("wrong", hashed), # noqa: SLF001
)
assert results == [True, False]
assert tuple(results) == (True, False)
class TestLogin:

View File

@@ -64,7 +64,7 @@ async def _create_f2b_db(path: str, rows: list[dict[str, Any]]) -> None:
@pytest.fixture
async def f2b_db_path(tmp_path: Path) -> str: # type: ignore[misc]
async def f2b_db_path(tmp_path: Path) -> str:
"""Return the path to a test fail2ban SQLite database with several bans."""
path = str(tmp_path / "fail2ban_test.sqlite3")
await _create_f2b_db(
@@ -103,7 +103,7 @@ async def f2b_db_path(tmp_path: Path) -> str: # type: ignore[misc]
@pytest.fixture
async def mixed_origin_db_path(tmp_path: Path) -> str: # type: ignore[misc]
async def mixed_origin_db_path(tmp_path: Path) -> str:
"""Return a database with bans from both blocklist-import and organic jails."""
path = str(tmp_path / "fail2ban_mixed_origin.sqlite3")
await _create_f2b_db(
@@ -136,7 +136,7 @@ async def mixed_origin_db_path(tmp_path: Path) -> str: # type: ignore[misc]
@pytest.fixture
async def empty_f2b_db_path(tmp_path: Path) -> str: # type: ignore[misc]
async def empty_f2b_db_path(tmp_path: Path) -> str:
"""Return the path to a fail2ban SQLite database with no ban records."""
path = str(tmp_path / "fail2ban_empty.sqlite3")
await _create_f2b_db(path, [])
@@ -632,13 +632,13 @@ class TestBansbyCountryBackground:
from app.services import geo_service
# Pre-populate the cache for all three IPs in the fixture.
geo_service._cache["10.0.0.1"] = geo_service.GeoInfo( # type: ignore[attr-defined]
geo_service._cache["10.0.0.1"] = geo_service.GeoInfo(
country_code="DE", country_name="Germany", asn=None, org=None
)
geo_service._cache["10.0.0.2"] = geo_service.GeoInfo( # type: ignore[attr-defined]
geo_service._cache["10.0.0.2"] = geo_service.GeoInfo(
country_code="US", country_name="United States", asn=None, org=None
)
geo_service._cache["10.0.0.3"] = geo_service.GeoInfo( # type: ignore[attr-defined]
geo_service._cache["10.0.0.3"] = geo_service.GeoInfo(
country_code="JP", country_name="Japan", asn=None, org=None
)

View File

@@ -114,13 +114,13 @@ async def _seed_f2b_db(path: str, n: int) -> list[str]:
@pytest.fixture(scope="module")
def event_loop_policy() -> None: # type: ignore[misc]
def event_loop_policy() -> None:
"""Use the default event loop policy for module-scoped fixtures."""
return None
@pytest.fixture(scope="module")
async def perf_db_path(tmp_path_factory: Any) -> str: # type: ignore[misc]
async def perf_db_path(tmp_path_factory: Any) -> str:
"""Return the path to a fail2ban DB seeded with 10 000 synthetic bans.
Module-scoped so the database is created only once for all perf tests.

View File

@@ -13,15 +13,19 @@ from app.services.config_file_service import (
JailNameError,
JailNotFoundInConfigError,
_build_inactive_jail,
_extract_action_base_name,
_extract_filter_base_name,
_ordered_config_files,
_parse_jails_sync,
_resolve_filter,
_safe_jail_name,
_validate_jail_config_sync,
_write_local_override_sync,
activate_jail,
deactivate_jail,
list_inactive_jails,
rollback_jail,
validate_jail_config,
)
# ---------------------------------------------------------------------------
@@ -292,9 +296,7 @@ class TestBuildInactiveJail:
def test_has_local_override_absent(self, tmp_path: Path) -> None:
"""has_local_override is False when no .local file exists."""
jail = _build_inactive_jail(
"sshd", {}, "/etc/fail2ban/jail.d/sshd.conf", config_dir=tmp_path
)
jail = _build_inactive_jail("sshd", {}, "/etc/fail2ban/jail.d/sshd.conf", config_dir=tmp_path)
assert jail.has_local_override is False
def test_has_local_override_present(self, tmp_path: Path) -> None:
@@ -302,9 +304,7 @@ class TestBuildInactiveJail:
local = tmp_path / "jail.d" / "sshd.local"
local.parent.mkdir(parents=True, exist_ok=True)
local.write_text("[sshd]\nenabled = false\n")
jail = _build_inactive_jail(
"sshd", {}, "/etc/fail2ban/jail.d/sshd.conf", config_dir=tmp_path
)
jail = _build_inactive_jail("sshd", {}, "/etc/fail2ban/jail.d/sshd.conf", config_dir=tmp_path)
assert jail.has_local_override is True
def test_has_local_override_no_config_dir(self) -> None:
@@ -363,9 +363,7 @@ class TestWriteLocalOverrideSync:
assert "2222" in content
def test_override_logpath_list(self, tmp_path: Path) -> None:
_write_local_override_sync(
tmp_path, "sshd", True, {"logpath": ["/var/log/auth.log", "/var/log/secure"]}
)
_write_local_override_sync(tmp_path, "sshd", True, {"logpath": ["/var/log/auth.log", "/var/log/secure"]})
content = (tmp_path / "jail.d" / "sshd.local").read_text()
assert "/var/log/auth.log" in content
assert "/var/log/secure" in content
@@ -447,9 +445,7 @@ class TestListInactiveJails:
assert "sshd" in names
assert "apache-auth" in names
async def test_has_local_override_true_when_local_file_exists(
self, tmp_path: Path
) -> None:
async def test_has_local_override_true_when_local_file_exists(self, tmp_path: Path) -> None:
"""has_local_override is True for a jail whose jail.d .local file exists."""
_write(tmp_path / "jail.conf", JAIL_CONF)
local = tmp_path / "jail.d" / "apache-auth.local"
@@ -463,9 +459,7 @@ class TestListInactiveJails:
jail = next(j for j in result.jails if j.name == "apache-auth")
assert jail.has_local_override is True
async def test_has_local_override_false_when_no_local_file(
self, tmp_path: Path
) -> None:
async def test_has_local_override_false_when_no_local_file(self, tmp_path: Path) -> None:
"""has_local_override is False when no jail.d .local file exists."""
_write(tmp_path / "jail.conf", JAIL_CONF)
with patch(
@@ -608,7 +602,8 @@ class TestActivateJail:
patch(
"app.services.config_file_service._get_active_jail_names",
new=AsyncMock(return_value=set()),
),pytest.raises(JailNotFoundInConfigError)
),
pytest.raises(JailNotFoundInConfigError),
):
await activate_jail(str(tmp_path), "/fake.sock", "nonexistent", req)
@@ -621,7 +616,8 @@ class TestActivateJail:
patch(
"app.services.config_file_service._get_active_jail_names",
new=AsyncMock(return_value={"sshd"}),
),pytest.raises(JailAlreadyActiveError)
),
pytest.raises(JailAlreadyActiveError),
):
await activate_jail(str(tmp_path), "/fake.sock", "sshd", req)
@@ -691,7 +687,8 @@ class TestDeactivateJail:
patch(
"app.services.config_file_service._get_active_jail_names",
new=AsyncMock(return_value={"sshd"}),
),pytest.raises(JailNotFoundInConfigError)
),
pytest.raises(JailNotFoundInConfigError),
):
await deactivate_jail(str(tmp_path), "/fake.sock", "nonexistent")
@@ -701,7 +698,8 @@ class TestDeactivateJail:
patch(
"app.services.config_file_service._get_active_jail_names",
new=AsyncMock(return_value=set()),
),pytest.raises(JailAlreadyInactiveError)
),
pytest.raises(JailAlreadyInactiveError),
):
await deactivate_jail(str(tmp_path), "/fake.sock", "apache-auth")
@@ -710,38 +708,6 @@ class TestDeactivateJail:
await deactivate_jail(str(tmp_path), "/fake.sock", "a/b")
# ---------------------------------------------------------------------------
# _extract_filter_base_name
# ---------------------------------------------------------------------------
class TestExtractFilterBaseName:
def test_simple_name(self) -> None:
from app.services.config_file_service import _extract_filter_base_name
assert _extract_filter_base_name("sshd") == "sshd"
def test_name_with_mode(self) -> None:
from app.services.config_file_service import _extract_filter_base_name
assert _extract_filter_base_name("sshd[mode=aggressive]") == "sshd"
def test_name_with_variable_mode(self) -> None:
from app.services.config_file_service import _extract_filter_base_name
assert _extract_filter_base_name("sshd[mode=%(mode)s]") == "sshd"
def test_whitespace_stripped(self) -> None:
from app.services.config_file_service import _extract_filter_base_name
assert _extract_filter_base_name(" nginx ") == "nginx"
def test_empty_string(self) -> None:
from app.services.config_file_service import _extract_filter_base_name
assert _extract_filter_base_name("") == ""
# ---------------------------------------------------------------------------
# _build_filter_to_jails_map
# ---------------------------------------------------------------------------
@@ -757,9 +723,7 @@ class TestBuildFilterToJailsMap:
def test_inactive_jail_not_included(self) -> None:
from app.services.config_file_service import _build_filter_to_jails_map
result = _build_filter_to_jails_map(
{"apache-auth": {"filter": "apache-auth"}}, set()
)
result = _build_filter_to_jails_map({"apache-auth": {"filter": "apache-auth"}}, set())
assert result == {}
def test_multiple_jails_sharing_filter(self) -> None:
@@ -775,9 +739,7 @@ class TestBuildFilterToJailsMap:
def test_mode_suffix_stripped(self) -> None:
from app.services.config_file_service import _build_filter_to_jails_map
result = _build_filter_to_jails_map(
{"sshd": {"filter": "sshd[mode=aggressive]"}}, {"sshd"}
)
result = _build_filter_to_jails_map({"sshd": {"filter": "sshd[mode=aggressive]"}}, {"sshd"})
assert "sshd" in result
def test_missing_filter_key_falls_back_to_jail_name(self) -> None:
@@ -988,10 +950,13 @@ class TestGetFilter:
async def test_raises_filter_not_found(self, tmp_path: Path) -> None:
from app.services.config_file_service import FilterNotFoundError, get_filter
with patch(
"app.services.config_file_service._get_active_jail_names",
new=AsyncMock(return_value=set()),
), pytest.raises(FilterNotFoundError):
with (
patch(
"app.services.config_file_service._get_active_jail_names",
new=AsyncMock(return_value=set()),
),
pytest.raises(FilterNotFoundError),
):
await get_filter(str(tmp_path), "/fake.sock", "nonexistent")
async def test_has_local_override_detected(self, tmp_path: Path) -> None:
@@ -1093,10 +1058,13 @@ class TestGetFilterLocalOnly:
async def test_raises_when_neither_conf_nor_local(self, tmp_path: Path) -> None:
from app.services.config_file_service import FilterNotFoundError, get_filter
with patch(
"app.services.config_file_service._get_active_jail_names",
new=AsyncMock(return_value=set()),
), pytest.raises(FilterNotFoundError):
with (
patch(
"app.services.config_file_service._get_active_jail_names",
new=AsyncMock(return_value=set()),
),
pytest.raises(FilterNotFoundError),
):
await get_filter(str(tmp_path), "/fake.sock", "nonexistent")
async def test_accepts_local_extension(self, tmp_path: Path) -> None:
@@ -1212,9 +1180,7 @@ class TestSetJailLocalKeySync:
jail_d = tmp_path / "jail.d"
jail_d.mkdir()
(jail_d / "sshd.local").write_text(
"[sshd]\nenabled = true\n"
)
(jail_d / "sshd.local").write_text("[sshd]\nenabled = true\n")
_set_jail_local_key_sync(tmp_path, "sshd", "filter", "newfilter")
@@ -1300,10 +1266,13 @@ class TestUpdateFilter:
from app.models.config import FilterUpdateRequest
from app.services.config_file_service import FilterNotFoundError, update_filter
with patch(
"app.services.config_file_service._get_active_jail_names",
new=AsyncMock(return_value=set()),
), pytest.raises(FilterNotFoundError):
with (
patch(
"app.services.config_file_service._get_active_jail_names",
new=AsyncMock(return_value=set()),
),
pytest.raises(FilterNotFoundError),
):
await update_filter(
str(tmp_path),
"/fake.sock",
@@ -1321,10 +1290,13 @@ class TestUpdateFilter:
filter_d = tmp_path / "filter.d"
_write(filter_d / "sshd.conf", _FILTER_CONF_WITH_REGEX)
with patch(
"app.services.config_file_service._get_active_jail_names",
new=AsyncMock(return_value=set()),
), pytest.raises(FilterInvalidRegexError):
with (
patch(
"app.services.config_file_service._get_active_jail_names",
new=AsyncMock(return_value=set()),
),
pytest.raises(FilterInvalidRegexError),
):
await update_filter(
str(tmp_path),
"/fake.sock",
@@ -1351,13 +1323,16 @@ class TestUpdateFilter:
filter_d = tmp_path / "filter.d"
_write(filter_d / "sshd.conf", _FILTER_CONF)
with patch(
"app.services.config_file_service._get_active_jail_names",
new=AsyncMock(return_value=set()),
), patch(
"app.services.config_file_service.jail_service.reload_all",
new=AsyncMock(),
) as mock_reload:
with (
patch(
"app.services.config_file_service._get_active_jail_names",
new=AsyncMock(return_value=set()),
),
patch(
"app.services.config_file_service.jail_service.reload_all",
new=AsyncMock(),
) as mock_reload,
):
await update_filter(
str(tmp_path),
"/fake.sock",
@@ -1405,10 +1380,13 @@ class TestCreateFilter:
filter_d = tmp_path / "filter.d"
_write(filter_d / "sshd.conf", _FILTER_CONF)
with patch(
"app.services.config_file_service._get_active_jail_names",
new=AsyncMock(return_value=set()),
), pytest.raises(FilterAlreadyExistsError):
with (
patch(
"app.services.config_file_service._get_active_jail_names",
new=AsyncMock(return_value=set()),
),
pytest.raises(FilterAlreadyExistsError),
):
await create_filter(
str(tmp_path),
"/fake.sock",
@@ -1422,10 +1400,13 @@ class TestCreateFilter:
filter_d = tmp_path / "filter.d"
_write(filter_d / "custom.local", "[Definition]\n")
with patch(
"app.services.config_file_service._get_active_jail_names",
new=AsyncMock(return_value=set()),
), pytest.raises(FilterAlreadyExistsError):
with (
patch(
"app.services.config_file_service._get_active_jail_names",
new=AsyncMock(return_value=set()),
),
pytest.raises(FilterAlreadyExistsError),
):
await create_filter(
str(tmp_path),
"/fake.sock",
@@ -1436,10 +1417,13 @@ class TestCreateFilter:
from app.models.config import FilterCreateRequest
from app.services.config_file_service import FilterInvalidRegexError, create_filter
with patch(
"app.services.config_file_service._get_active_jail_names",
new=AsyncMock(return_value=set()),
), pytest.raises(FilterInvalidRegexError):
with (
patch(
"app.services.config_file_service._get_active_jail_names",
new=AsyncMock(return_value=set()),
),
pytest.raises(FilterInvalidRegexError),
):
await create_filter(
str(tmp_path),
"/fake.sock",
@@ -1461,13 +1445,16 @@ class TestCreateFilter:
from app.models.config import FilterCreateRequest
from app.services.config_file_service import create_filter
with patch(
"app.services.config_file_service._get_active_jail_names",
new=AsyncMock(return_value=set()),
), patch(
"app.services.config_file_service.jail_service.reload_all",
new=AsyncMock(),
) as mock_reload:
with (
patch(
"app.services.config_file_service._get_active_jail_names",
new=AsyncMock(return_value=set()),
),
patch(
"app.services.config_file_service.jail_service.reload_all",
new=AsyncMock(),
) as mock_reload,
):
await create_filter(
str(tmp_path),
"/fake.sock",
@@ -1485,9 +1472,7 @@ class TestCreateFilter:
@pytest.mark.asyncio
class TestDeleteFilter:
async def test_deletes_local_file_when_conf_and_local_exist(
self, tmp_path: Path
) -> None:
async def test_deletes_local_file_when_conf_and_local_exist(self, tmp_path: Path) -> None:
from app.services.config_file_service import delete_filter
filter_d = tmp_path / "filter.d"
@@ -1524,9 +1509,7 @@ class TestDeleteFilter:
with pytest.raises(FilterNotFoundError):
await delete_filter(str(tmp_path), "nonexistent")
async def test_accepts_filter_name_error_for_invalid_name(
self, tmp_path: Path
) -> None:
async def test_accepts_filter_name_error_for_invalid_name(self, tmp_path: Path) -> None:
from app.services.config_file_service import FilterNameError, delete_filter
with pytest.raises(FilterNameError):
@@ -1607,9 +1590,7 @@ class TestAssignFilterToJail:
AssignFilterRequest(filter_name="sshd"),
)
async def test_raises_filter_name_error_for_invalid_filter(
self, tmp_path: Path
) -> None:
async def test_raises_filter_name_error_for_invalid_filter(self, tmp_path: Path) -> None:
from app.models.config import AssignFilterRequest
from app.services.config_file_service import FilterNameError, assign_filter_to_jail
@@ -1719,34 +1700,26 @@ class TestBuildActionToJailsMap:
def test_active_jail_maps_to_action(self) -> None:
from app.services.config_file_service import _build_action_to_jails_map
result = _build_action_to_jails_map(
{"sshd": {"action": "iptables-multiport"}}, {"sshd"}
)
result = _build_action_to_jails_map({"sshd": {"action": "iptables-multiport"}}, {"sshd"})
assert result == {"iptables-multiport": ["sshd"]}
def test_inactive_jail_not_included(self) -> None:
from app.services.config_file_service import _build_action_to_jails_map
result = _build_action_to_jails_map(
{"sshd": {"action": "iptables-multiport"}}, set()
)
result = _build_action_to_jails_map({"sshd": {"action": "iptables-multiport"}}, set())
assert result == {}
def test_multiple_actions_per_jail(self) -> None:
from app.services.config_file_service import _build_action_to_jails_map
result = _build_action_to_jails_map(
{"sshd": {"action": "iptables-multiport\niptables-ipset"}}, {"sshd"}
)
result = _build_action_to_jails_map({"sshd": {"action": "iptables-multiport\niptables-ipset"}}, {"sshd"})
assert "iptables-multiport" in result
assert "iptables-ipset" in result
def test_parameter_block_stripped(self) -> None:
from app.services.config_file_service import _build_action_to_jails_map
result = _build_action_to_jails_map(
{"sshd": {"action": "iptables[port=ssh, protocol=tcp]"}}, {"sshd"}
)
result = _build_action_to_jails_map({"sshd": {"action": "iptables[port=ssh, protocol=tcp]"}}, {"sshd"})
assert "iptables" in result
def test_multiple_jails_sharing_action(self) -> None:
@@ -2001,10 +1974,13 @@ class TestGetAction:
async def test_raises_for_unknown_action(self, tmp_path: Path) -> None:
from app.services.config_file_service import ActionNotFoundError, get_action
with patch(
"app.services.config_file_service._get_active_jail_names",
new=AsyncMock(return_value=set()),
), pytest.raises(ActionNotFoundError):
with (
patch(
"app.services.config_file_service._get_active_jail_names",
new=AsyncMock(return_value=set()),
),
pytest.raises(ActionNotFoundError),
):
await get_action(str(tmp_path), "/fake.sock", "nonexistent")
async def test_local_only_action_returned(self, tmp_path: Path) -> None:
@@ -2118,10 +2094,13 @@ class TestUpdateAction:
from app.models.config import ActionUpdateRequest
from app.services.config_file_service import ActionNotFoundError, update_action
with patch(
"app.services.config_file_service._get_active_jail_names",
new=AsyncMock(return_value=set()),
), pytest.raises(ActionNotFoundError):
with (
patch(
"app.services.config_file_service._get_active_jail_names",
new=AsyncMock(return_value=set()),
),
pytest.raises(ActionNotFoundError),
):
await update_action(
str(tmp_path),
"/fake.sock",
@@ -2587,9 +2566,7 @@ class TestRemoveActionFromJail:
"app.services.config_file_service._get_active_jail_names",
new=AsyncMock(return_value=set()),
):
await remove_action_from_jail(
str(tmp_path), "/fake.sock", "sshd", "iptables-multiport"
)
await remove_action_from_jail(str(tmp_path), "/fake.sock", "sshd", "iptables-multiport")
content = (jail_d / "sshd.local").read_text()
assert "iptables-multiport" not in content
@@ -2601,17 +2578,13 @@ class TestRemoveActionFromJail:
)
with pytest.raises(JailNotFoundInConfigError):
await remove_action_from_jail(
str(tmp_path), "/fake.sock", "nonexistent", "iptables"
)
await remove_action_from_jail(str(tmp_path), "/fake.sock", "nonexistent", "iptables")
async def test_raises_jail_name_error(self, tmp_path: Path) -> None:
from app.services.config_file_service import JailNameError, remove_action_from_jail
with pytest.raises(JailNameError):
await remove_action_from_jail(
str(tmp_path), "/fake.sock", "../evil", "iptables"
)
await remove_action_from_jail(str(tmp_path), "/fake.sock", "../evil", "iptables")
async def test_raises_action_name_error(self, tmp_path: Path) -> None:
from app.services.config_file_service import ActionNameError, remove_action_from_jail
@@ -2619,9 +2592,7 @@ class TestRemoveActionFromJail:
_write(tmp_path / "jail.conf", JAIL_CONF)
with pytest.raises(ActionNameError):
await remove_action_from_jail(
str(tmp_path), "/fake.sock", "sshd", "../evil"
)
await remove_action_from_jail(str(tmp_path), "/fake.sock", "sshd", "../evil")
async def test_triggers_reload_when_requested(self, tmp_path: Path) -> None:
from app.services.config_file_service import remove_action_from_jail
@@ -2640,9 +2611,7 @@ class TestRemoveActionFromJail:
new=AsyncMock(),
) as mock_reload,
):
await remove_action_from_jail(
str(tmp_path), "/fake.sock", "sshd", "iptables", do_reload=True
)
await remove_action_from_jail(str(tmp_path), "/fake.sock", "sshd", "iptables", do_reload=True)
mock_reload.assert_awaited_once()
@@ -2680,13 +2649,9 @@ class TestActivateJailReloadArgs:
mock_js.reload_all = AsyncMock()
await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req)
mock_js.reload_all.assert_awaited_once_with(
"/fake.sock", include_jails=["apache-auth"]
)
mock_js.reload_all.assert_awaited_once_with("/fake.sock", include_jails=["apache-auth"])
async def test_activate_returns_active_true_when_jail_starts(
self, tmp_path: Path
) -> None:
async def test_activate_returns_active_true_when_jail_starts(self, tmp_path: Path) -> None:
"""activate_jail returns active=True when the jail appears in post-reload names."""
_write(tmp_path / "jail.conf", JAIL_CONF)
from app.models.config import ActivateJailRequest, JailValidationResult
@@ -2708,16 +2673,12 @@ class TestActivateJailReloadArgs:
),
):
mock_js.reload_all = AsyncMock()
result = await activate_jail(
str(tmp_path), "/fake.sock", "apache-auth", req
)
result = await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req)
assert result.active is True
assert "activated" in result.message.lower()
async def test_activate_returns_active_false_when_jail_does_not_start(
self, tmp_path: Path
) -> None:
async def test_activate_returns_active_false_when_jail_does_not_start(self, tmp_path: Path) -> None:
"""activate_jail returns active=False when the jail is absent after reload.
This covers the Stage 3.1 requirement: if the jail config is invalid
@@ -2746,9 +2707,7 @@ class TestActivateJailReloadArgs:
),
):
mock_js.reload_all = AsyncMock()
result = await activate_jail(
str(tmp_path), "/fake.sock", "apache-auth", req
)
result = await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req)
assert result.active is False
assert "apache-auth" in result.name
@@ -2776,23 +2735,13 @@ class TestDeactivateJailReloadArgs:
mock_js.reload_all = AsyncMock()
await deactivate_jail(str(tmp_path), "/fake.sock", "sshd")
mock_js.reload_all.assert_awaited_once_with(
"/fake.sock", exclude_jails=["sshd"]
)
mock_js.reload_all.assert_awaited_once_with("/fake.sock", exclude_jails=["sshd"])
# ---------------------------------------------------------------------------
# _validate_jail_config_sync (Task 3)
# ---------------------------------------------------------------------------
from app.services.config_file_service import ( # noqa: E402 (added after block)
_validate_jail_config_sync,
_extract_filter_base_name,
_extract_action_base_name,
validate_jail_config,
rollback_jail,
)
class TestExtractFilterBaseName:
def test_plain_name(self) -> None:
@@ -2938,11 +2887,11 @@ class TestRollbackJail:
with (
patch(
"app.services.config_file_service._start_daemon",
"app.services.config_file_service.start_daemon",
new=AsyncMock(return_value=True),
),
patch(
"app.services.config_file_service._wait_for_fail2ban",
"app.services.config_file_service.wait_for_fail2ban",
new=AsyncMock(return_value=True),
),
patch(
@@ -2950,9 +2899,7 @@ class TestRollbackJail:
new=AsyncMock(return_value=set()),
),
):
result = await rollback_jail(
str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"]
)
result = await rollback_jail(str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"])
assert result.disabled is True
assert result.fail2ban_running is True
@@ -2968,26 +2915,22 @@ class TestRollbackJail:
with (
patch(
"app.services.config_file_service._start_daemon",
"app.services.config_file_service.start_daemon",
new=AsyncMock(return_value=False),
),
patch(
"app.services.config_file_service._wait_for_fail2ban",
"app.services.config_file_service.wait_for_fail2ban",
new=AsyncMock(return_value=False),
),
):
result = await rollback_jail(
str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"]
)
result = await rollback_jail(str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"])
assert result.fail2ban_running is False
assert result.disabled is True
async def test_rollback_rejects_unsafe_name(self, tmp_path: Path) -> None:
with pytest.raises(JailNameError):
await rollback_jail(
str(tmp_path), "/fake.sock", "../evil", ["fail2ban-client", "start"]
)
await rollback_jail(str(tmp_path), "/fake.sock", "../evil", ["fail2ban-client", "start"])
# ---------------------------------------------------------------------------
@@ -3096,9 +3039,7 @@ class TestActivateJailBlocking:
class TestActivateJailRollback:
"""Rollback logic in activate_jail restores the .local file and recovers."""
async def test_activate_jail_rollback_on_reload_failure(
self, tmp_path: Path
) -> None:
async def test_activate_jail_rollback_on_reload_failure(self, tmp_path: Path) -> None:
"""Rollback when reload_all raises on the activation reload.
Expects:
@@ -3135,23 +3076,17 @@ class TestActivateJailRollback:
),
patch(
"app.services.config_file_service._validate_jail_config_sync",
return_value=JailValidationResult(
jail_name="apache-auth", valid=True
),
return_value=JailValidationResult(jail_name="apache-auth", valid=True),
),
):
mock_js.reload_all = AsyncMock(side_effect=reload_side_effect)
result = await activate_jail(
str(tmp_path), "/fake.sock", "apache-auth", req
)
result = await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req)
assert result.active is False
assert result.recovered is True
assert local_path.read_text() == original_local
async def test_activate_jail_rollback_on_health_check_failure(
self, tmp_path: Path
) -> None:
async def test_activate_jail_rollback_on_health_check_failure(self, tmp_path: Path) -> None:
"""Rollback when fail2ban is unreachable after the activation reload.
Expects:
@@ -3190,15 +3125,11 @@ class TestActivateJailRollback:
),
patch(
"app.services.config_file_service._validate_jail_config_sync",
return_value=JailValidationResult(
jail_name="apache-auth", valid=True
),
return_value=JailValidationResult(jail_name="apache-auth", valid=True),
),
):
mock_js.reload_all = AsyncMock()
result = await activate_jail(
str(tmp_path), "/fake.sock", "apache-auth", req
)
result = await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req)
assert result.active is False
assert result.recovered is True
@@ -3232,25 +3163,17 @@ class TestActivateJailRollback:
),
patch(
"app.services.config_file_service._validate_jail_config_sync",
return_value=JailValidationResult(
jail_name="apache-auth", valid=True
),
return_value=JailValidationResult(jail_name="apache-auth", valid=True),
),
):
# Both the activation reload and the recovery reload fail.
mock_js.reload_all = AsyncMock(
side_effect=RuntimeError("fail2ban unavailable")
)
result = await activate_jail(
str(tmp_path), "/fake.sock", "apache-auth", req
)
mock_js.reload_all = AsyncMock(side_effect=RuntimeError("fail2ban unavailable"))
result = await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req)
assert result.active is False
assert result.recovered is False
async def test_activate_jail_rollback_on_jail_not_found_error(
self, tmp_path: Path
) -> None:
async def test_activate_jail_rollback_on_jail_not_found_error(self, tmp_path: Path) -> None:
"""Rollback when reload_all raises JailNotFoundError (invalid config).
When fail2ban cannot create a jail due to invalid configuration
@@ -3294,16 +3217,12 @@ class TestActivateJailRollback:
),
patch(
"app.services.config_file_service._validate_jail_config_sync",
return_value=JailValidationResult(
jail_name="apache-auth", valid=True
),
return_value=JailValidationResult(jail_name="apache-auth", valid=True),
),
):
mock_js.reload_all = AsyncMock(side_effect=reload_side_effect)
mock_js.JailNotFoundError = JailNotFoundError
result = await activate_jail(
str(tmp_path), "/fake.sock", "apache-auth", req
)
result = await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req)
assert result.active is False
assert result.recovered is True
@@ -3311,9 +3230,7 @@ class TestActivateJailRollback:
# Verify the error message mentions logpath issues.
assert "logpath" in result.message.lower() or "check that all logpath" in result.message.lower()
async def test_activate_jail_rollback_deletes_file_when_no_prior_local(
self, tmp_path: Path
) -> None:
async def test_activate_jail_rollback_deletes_file_when_no_prior_local(self, tmp_path: Path) -> None:
"""Rollback deletes the .local file when none existed before activation.
When a jail had no .local override before activation, activate_jail
@@ -3355,15 +3272,11 @@ class TestActivateJailRollback:
),
patch(
"app.services.config_file_service._validate_jail_config_sync",
return_value=JailValidationResult(
jail_name="apache-auth", valid=True
),
return_value=JailValidationResult(jail_name="apache-auth", valid=True),
),
):
mock_js.reload_all = AsyncMock(side_effect=reload_side_effect)
result = await activate_jail(
str(tmp_path), "/fake.sock", "apache-auth", req
)
result = await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req)
assert result.active is False
assert result.recovered is True
@@ -3376,7 +3289,7 @@ class TestActivateJailRollback:
@pytest.mark.asyncio
class TestRollbackJail:
class TestRollbackJailIntegration:
"""Integration tests for :func:`~app.services.config_file_service.rollback_jail`."""
async def test_local_file_written_enabled_false(self, tmp_path: Path) -> None:
@@ -3419,15 +3332,11 @@ class TestRollbackJail:
AsyncMock(return_value={"other"}),
),
):
await rollback_jail(
str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"]
)
await rollback_jail(str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"])
mock_start.assert_awaited_once_with(["fail2ban-client", "start"])
async def test_fail2ban_running_reflects_socket_probe_not_subprocess_exit(
self, tmp_path: Path
) -> None:
async def test_fail2ban_running_reflects_socket_probe_not_subprocess_exit(self, tmp_path: Path) -> None:
"""fail2ban_running in the response reflects the socket probe result.
Even when start_daemon returns True (subprocess exit 0), if the socket
@@ -3443,15 +3352,11 @@ class TestRollbackJail:
AsyncMock(return_value=False), # socket still unresponsive
),
):
result = await rollback_jail(
str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"]
)
result = await rollback_jail(str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"])
assert result.fail2ban_running is False
async def test_active_jails_zero_when_fail2ban_not_running(
self, tmp_path: Path
) -> None:
async def test_active_jails_zero_when_fail2ban_not_running(self, tmp_path: Path) -> None:
"""active_jails is 0 in the response when fail2ban_running is False."""
with (
patch(
@@ -3463,15 +3368,11 @@ class TestRollbackJail:
AsyncMock(return_value=False),
),
):
result = await rollback_jail(
str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"]
)
result = await rollback_jail(str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"])
assert result.active_jails == 0
async def test_active_jails_count_from_socket_when_running(
self, tmp_path: Path
) -> None:
async def test_active_jails_count_from_socket_when_running(self, tmp_path: Path) -> None:
"""active_jails reflects the actual jail count from the socket when fail2ban is up."""
with (
patch(
@@ -3487,15 +3388,11 @@ class TestRollbackJail:
AsyncMock(return_value={"sshd", "nginx", "apache-auth"}),
),
):
result = await rollback_jail(
str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"]
)
result = await rollback_jail(str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"])
assert result.active_jails == 3
async def test_fail2ban_down_at_start_still_succeeds_file_write(
self, tmp_path: Path
) -> None:
async def test_fail2ban_down_at_start_still_succeeds_file_write(self, tmp_path: Path) -> None:
"""rollback_jail writes the local file even when fail2ban is down at call time."""
# fail2ban is down: start_daemon fails and wait_for_fail2ban returns False.
with (
@@ -3508,12 +3405,9 @@ class TestRollbackJail:
AsyncMock(return_value=False),
),
):
result = await rollback_jail(
str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"]
)
result = await rollback_jail(str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"])
local = tmp_path / "jail.d" / "sshd.local"
assert local.is_file(), "local file must be written even when fail2ban is down"
assert result.disabled is True
assert result.fail2ban_running is False

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
from collections.abc import Mapping, Sequence
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@@ -44,7 +45,7 @@ def _make_session(response_json: dict[str, object], status: int = 200) -> MagicM
@pytest.fixture(autouse=True)
def clear_geo_cache() -> None: # type: ignore[misc]
def clear_geo_cache() -> None:
"""Flush the module-level geo cache before every test."""
geo_service.clear_cache()
@@ -68,7 +69,7 @@ class TestLookupSuccess:
"org": "AS3320 Deutsche Telekom AG",
}
)
result = await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
result = await geo_service.lookup("1.2.3.4", session)
assert result is not None
assert result.country_code == "DE"
@@ -84,7 +85,7 @@ class TestLookupSuccess:
"org": "Google LLC",
}
)
result = await geo_service.lookup("8.8.8.8", session) # type: ignore[arg-type]
result = await geo_service.lookup("8.8.8.8", session)
assert result is not None
assert result.country_name == "United States"
@@ -100,7 +101,7 @@ class TestLookupSuccess:
"org": "Deutsche Telekom",
}
)
result = await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
result = await geo_service.lookup("1.2.3.4", session)
assert result is not None
assert result.asn == "AS3320"
@@ -116,7 +117,7 @@ class TestLookupSuccess:
"org": "Google LLC",
}
)
result = await geo_service.lookup("8.8.8.8", session) # type: ignore[arg-type]
result = await geo_service.lookup("8.8.8.8", session)
assert result is not None
assert result.org == "Google LLC"
@@ -142,8 +143,8 @@ class TestLookupCaching:
}
)
await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
await geo_service.lookup("1.2.3.4", session)
await geo_service.lookup("1.2.3.4", session)
# The session.get() should only have been called once.
assert session.get.call_count == 1
@@ -160,9 +161,9 @@ class TestLookupCaching:
}
)
await geo_service.lookup("2.3.4.5", session) # type: ignore[arg-type]
await geo_service.lookup("2.3.4.5", session)
geo_service.clear_cache()
await geo_service.lookup("2.3.4.5", session) # type: ignore[arg-type]
await geo_service.lookup("2.3.4.5", session)
assert session.get.call_count == 2
@@ -172,8 +173,8 @@ class TestLookupCaching:
{"status": "fail", "message": "reserved range"}
)
await geo_service.lookup("192.168.1.1", session) # type: ignore[arg-type]
await geo_service.lookup("192.168.1.1", session) # type: ignore[arg-type]
await geo_service.lookup("192.168.1.1", session)
await geo_service.lookup("192.168.1.1", session)
# Second call is blocked by the negative cache — only one API hit.
assert session.get.call_count == 1
@@ -190,7 +191,7 @@ class TestLookupFailures:
async def test_non_200_response_returns_null_geo_info(self) -> None:
"""A 429 or 500 status returns GeoInfo with null fields (not None)."""
session = _make_session({}, status=429)
result = await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
result = await geo_service.lookup("1.2.3.4", session)
assert result is not None
assert isinstance(result, GeoInfo)
assert result.country_code is None
@@ -203,7 +204,7 @@ class TestLookupFailures:
mock_ctx.__aexit__ = AsyncMock(return_value=False)
session.get = MagicMock(return_value=mock_ctx)
result = await geo_service.lookup("10.0.0.1", session) # type: ignore[arg-type]
result = await geo_service.lookup("10.0.0.1", session)
assert result is not None
assert isinstance(result, GeoInfo)
assert result.country_code is None
@@ -211,7 +212,7 @@ class TestLookupFailures:
async def test_failed_status_returns_geo_info_with_nulls(self) -> None:
"""When ip-api returns ``status=fail`` a GeoInfo with null fields is returned (but not cached)."""
session = _make_session({"status": "fail", "message": "private range"})
result = await geo_service.lookup("10.0.0.1", session) # type: ignore[arg-type]
result = await geo_service.lookup("10.0.0.1", session)
assert result is not None
assert isinstance(result, GeoInfo)
@@ -231,8 +232,8 @@ class TestNegativeCache:
"""After a failed lookup the second call is served from the neg cache."""
session = _make_session({"status": "fail", "message": "private range"})
r1 = await geo_service.lookup("192.0.2.1", session) # type: ignore[arg-type]
r2 = await geo_service.lookup("192.0.2.1", session) # type: ignore[arg-type]
r1 = await geo_service.lookup("192.0.2.1", session)
r2 = await geo_service.lookup("192.0.2.1", session)
# Only one HTTP call should have been made; second served from neg cache.
assert session.get.call_count == 1
@@ -243,12 +244,12 @@ class TestNegativeCache:
"""When the neg-cache entry is older than the TTL a new API call is made."""
session = _make_session({"status": "fail", "message": "private range"})
await geo_service.lookup("192.0.2.2", session) # type: ignore[arg-type]
await geo_service.lookup("192.0.2.2", session)
# Manually expire the neg-cache entry.
geo_service._neg_cache["192.0.2.2"] -= geo_service._NEG_CACHE_TTL + 1 # type: ignore[attr-defined]
geo_service._neg_cache["192.0.2.2"] -= geo_service._NEG_CACHE_TTL + 1
await geo_service.lookup("192.0.2.2", session) # type: ignore[arg-type]
await geo_service.lookup("192.0.2.2", session)
# Both calls should have hit the API.
assert session.get.call_count == 2
@@ -257,9 +258,9 @@ class TestNegativeCache:
"""After clearing the neg cache the IP is eligible for a new API call."""
session = _make_session({"status": "fail", "message": "private range"})
await geo_service.lookup("192.0.2.3", session) # type: ignore[arg-type]
await geo_service.lookup("192.0.2.3", session)
geo_service.clear_neg_cache()
await geo_service.lookup("192.0.2.3", session) # type: ignore[arg-type]
await geo_service.lookup("192.0.2.3", session)
assert session.get.call_count == 2
@@ -275,9 +276,9 @@ class TestNegativeCache:
}
)
await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
await geo_service.lookup("1.2.3.4", session)
assert "1.2.3.4" not in geo_service._neg_cache # type: ignore[attr-defined]
assert "1.2.3.4" not in geo_service._neg_cache
# ---------------------------------------------------------------------------
@@ -307,7 +308,7 @@ class TestGeoipFallback:
mock_reader = self._make_geoip_reader("DE", "Germany")
with patch.object(geo_service, "_geoip_reader", mock_reader):
result = await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
result = await geo_service.lookup("1.2.3.4", session)
mock_reader.country.assert_called_once_with("1.2.3.4")
assert result is not None
@@ -320,12 +321,12 @@ class TestGeoipFallback:
mock_reader = self._make_geoip_reader("US", "United States")
with patch.object(geo_service, "_geoip_reader", mock_reader):
await geo_service.lookup("8.8.8.8", session) # type: ignore[arg-type]
await geo_service.lookup("8.8.8.8", session)
# Second call must be served from positive cache without hitting API.
await geo_service.lookup("8.8.8.8", session) # type: ignore[arg-type]
await geo_service.lookup("8.8.8.8", session)
assert session.get.call_count == 1
assert "8.8.8.8" in geo_service._cache # type: ignore[attr-defined]
assert "8.8.8.8" in geo_service._cache
async def test_geoip_fallback_not_called_on_api_success(self) -> None:
"""When ip-api succeeds, the geoip2 reader must not be consulted."""
@@ -341,7 +342,7 @@ class TestGeoipFallback:
mock_reader = self._make_geoip_reader("XX", "Nowhere")
with patch.object(geo_service, "_geoip_reader", mock_reader):
result = await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
result = await geo_service.lookup("1.2.3.4", session)
mock_reader.country.assert_not_called()
assert result is not None
@@ -352,7 +353,7 @@ class TestGeoipFallback:
session = _make_session({"status": "fail", "message": "private range"})
with patch.object(geo_service, "_geoip_reader", None):
result = await geo_service.lookup("10.0.0.1", session) # type: ignore[arg-type]
result = await geo_service.lookup("10.0.0.1", session)
assert result is not None
assert result.country_code is None
@@ -363,7 +364,7 @@ class TestGeoipFallback:
# ---------------------------------------------------------------------------
def _make_batch_session(batch_response: list[dict[str, object]]) -> MagicMock:
def _make_batch_session(batch_response: Sequence[Mapping[str, object]]) -> MagicMock:
"""Build a mock aiohttp.ClientSession for batch POST calls.
Args:
@@ -412,7 +413,7 @@ class TestLookupBatchSingleCommit:
session = _make_batch_session(batch_response)
db = _make_async_db()
await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type]
await geo_service.lookup_batch(ips, session, db=db)
db.commit.assert_awaited_once()
@@ -426,7 +427,7 @@ class TestLookupBatchSingleCommit:
session = _make_batch_session(batch_response)
db = _make_async_db()
await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type]
await geo_service.lookup_batch(ips, session, db=db)
db.commit.assert_awaited_once()
@@ -452,13 +453,13 @@ class TestLookupBatchSingleCommit:
async def test_no_commit_for_all_cached_ips(self) -> None:
"""When all IPs are already cached, no HTTP call and no commit occur."""
geo_service._cache["5.5.5.5"] = GeoInfo( # type: ignore[attr-defined]
geo_service._cache["5.5.5.5"] = GeoInfo(
country_code="FR", country_name="France", asn="AS1", org="ISP"
)
db = _make_async_db()
session = _make_batch_session([])
result = await geo_service.lookup_batch(["5.5.5.5"], session, db=db) # type: ignore[arg-type]
result = await geo_service.lookup_batch(["5.5.5.5"], session, db=db)
assert result["5.5.5.5"].country_code == "FR"
db.commit.assert_not_awaited()
@@ -476,26 +477,26 @@ class TestDirtySetTracking:
def test_successful_resolution_adds_to_dirty(self) -> None:
"""Storing a GeoInfo with a country_code adds the IP to _dirty."""
info = GeoInfo(country_code="DE", country_name="Germany", asn="AS1", org="ISP")
geo_service._store("1.2.3.4", info) # type: ignore[attr-defined]
geo_service._store("1.2.3.4", info)
assert "1.2.3.4" in geo_service._dirty # type: ignore[attr-defined]
assert "1.2.3.4" in geo_service._dirty
def test_null_country_does_not_add_to_dirty(self) -> None:
"""Storing a GeoInfo with country_code=None must not pollute _dirty."""
info = GeoInfo(country_code=None, country_name=None, asn=None, org=None)
geo_service._store("10.0.0.1", info) # type: ignore[attr-defined]
geo_service._store("10.0.0.1", info)
assert "10.0.0.1" not in geo_service._dirty # type: ignore[attr-defined]
assert "10.0.0.1" not in geo_service._dirty
def test_clear_cache_also_clears_dirty(self) -> None:
"""clear_cache() must discard any pending dirty entries."""
info = GeoInfo(country_code="US", country_name="United States", asn="AS1", org="ISP")
geo_service._store("8.8.8.8", info) # type: ignore[attr-defined]
assert geo_service._dirty # type: ignore[attr-defined]
geo_service._store("8.8.8.8", info)
assert geo_service._dirty
geo_service.clear_cache()
assert not geo_service._dirty # type: ignore[attr-defined]
assert not geo_service._dirty
async def test_lookup_batch_populates_dirty(self) -> None:
"""After lookup_batch() with db=None, resolved IPs appear in _dirty."""
@@ -509,7 +510,7 @@ class TestDirtySetTracking:
await geo_service.lookup_batch(ips, session, db=None)
for ip in ips:
assert ip in geo_service._dirty # type: ignore[attr-defined]
assert ip in geo_service._dirty
class TestFlushDirty:
@@ -518,8 +519,8 @@ class TestFlushDirty:
async def test_flush_writes_and_clears_dirty(self) -> None:
"""flush_dirty() inserts all dirty IPs and clears _dirty afterwards."""
info = GeoInfo(country_code="GB", country_name="United Kingdom", asn="AS2856", org="BT")
geo_service._store("100.0.0.1", info) # type: ignore[attr-defined]
assert "100.0.0.1" in geo_service._dirty # type: ignore[attr-defined]
geo_service._store("100.0.0.1", info)
assert "100.0.0.1" in geo_service._dirty
db = _make_async_db()
count = await geo_service.flush_dirty(db)
@@ -527,7 +528,7 @@ class TestFlushDirty:
assert count == 1
db.executemany.assert_awaited_once()
db.commit.assert_awaited_once()
assert "100.0.0.1" not in geo_service._dirty # type: ignore[attr-defined]
assert "100.0.0.1" not in geo_service._dirty
async def test_flush_returns_zero_when_nothing_dirty(self) -> None:
"""flush_dirty() returns 0 and makes no DB calls when _dirty is empty."""
@@ -541,7 +542,7 @@ class TestFlushDirty:
async def test_flush_re_adds_to_dirty_on_db_error(self) -> None:
"""When the DB write fails, entries are re-added to _dirty for retry."""
info = GeoInfo(country_code="AU", country_name="Australia", asn="AS1", org="ISP")
geo_service._store("200.0.0.1", info) # type: ignore[attr-defined]
geo_service._store("200.0.0.1", info)
db = _make_async_db()
db.executemany = AsyncMock(side_effect=OSError("disk full"))
@@ -549,7 +550,7 @@ class TestFlushDirty:
count = await geo_service.flush_dirty(db)
assert count == 0
assert "200.0.0.1" in geo_service._dirty # type: ignore[attr-defined]
assert "200.0.0.1" in geo_service._dirty
async def test_flush_batch_and_lookup_batch_integration(self) -> None:
"""lookup_batch() populates _dirty; flush_dirty() then persists them."""
@@ -562,14 +563,14 @@ class TestFlushDirty:
# Resolve without DB to populate only in-memory cache and _dirty.
await geo_service.lookup_batch(ips, session, db=None)
assert geo_service._dirty == set(ips) # type: ignore[attr-defined]
assert geo_service._dirty == set(ips)
# Now flush to the DB.
db = _make_async_db()
count = await geo_service.flush_dirty(db)
assert count == 2
assert not geo_service._dirty # type: ignore[attr-defined]
assert not geo_service._dirty
db.commit.assert_awaited_once()
@@ -585,7 +586,7 @@ class TestLookupBatchThrottling:
"""When more than _BATCH_SIZE IPs are sent, asyncio.sleep is called
between consecutive batch HTTP calls with at least _BATCH_DELAY."""
# Generate _BATCH_SIZE + 1 IPs so we get exactly 2 batch calls.
batch_size: int = geo_service._BATCH_SIZE # type: ignore[attr-defined]
batch_size: int = geo_service._BATCH_SIZE
ips = [f"10.0.{i // 256}.{i % 256}" for i in range(batch_size + 1)]
def _make_result(chunk: list[str], _session: object) -> dict[str, GeoInfo]:
@@ -608,7 +609,7 @@ class TestLookupBatchThrottling:
assert mock_batch.call_count == 2
mock_sleep.assert_awaited_once()
delay_arg: float = mock_sleep.call_args[0][0]
assert delay_arg >= geo_service._BATCH_DELAY # type: ignore[attr-defined]
assert delay_arg >= geo_service._BATCH_DELAY
async def test_lookup_batch_retries_on_full_chunk_failure(self) -> None:
"""When a chunk returns all-None on first try, it retries and succeeds."""
@@ -650,7 +651,7 @@ class TestLookupBatchThrottling:
_empty = GeoInfo(country_code=None, country_name=None, asn=None, org=None)
_failure: dict[str, GeoInfo] = dict.fromkeys(ips, _empty)
max_retries: int = geo_service._BATCH_MAX_RETRIES # type: ignore[attr-defined]
max_retries: int = geo_service._BATCH_MAX_RETRIES
with (
patch(
@@ -667,11 +668,11 @@ class TestLookupBatchThrottling:
# IP should have no country.
assert result["9.9.9.9"].country_code is None
# Negative cache should contain the IP.
assert "9.9.9.9" in geo_service._neg_cache # type: ignore[attr-defined]
assert "9.9.9.9" in geo_service._neg_cache
# Sleep called for each retry with exponential backoff.
assert mock_sleep.call_count == max_retries
backoff_values = [call.args[0] for call in mock_sleep.call_args_list]
batch_delay: float = geo_service._BATCH_DELAY # type: ignore[attr-defined]
batch_delay: float = geo_service._BATCH_DELAY
for i, val in enumerate(backoff_values):
expected = batch_delay * (2 ** (i + 1))
assert val == pytest.approx(expected)
@@ -709,7 +710,7 @@ class TestErrorLogging:
import structlog.testing
with structlog.testing.capture_logs() as captured:
result = await geo_service.lookup("197.221.98.153", session) # type: ignore[arg-type]
result = await geo_service.lookup("197.221.98.153", session)
assert result is not None
assert result.country_code is None
@@ -733,7 +734,7 @@ class TestErrorLogging:
import structlog.testing
with structlog.testing.capture_logs() as captured:
await geo_service.lookup("10.0.0.1", session) # type: ignore[arg-type]
await geo_service.lookup("10.0.0.1", session)
request_failed = [e for e in captured if e.get("event") == "geo_lookup_request_failed"]
assert len(request_failed) == 1
@@ -757,7 +758,7 @@ class TestErrorLogging:
import structlog.testing
with structlog.testing.capture_logs() as captured:
result = await geo_service._batch_api_call(["1.2.3.4"], session) # type: ignore[attr-defined]
result = await geo_service._batch_api_call(["1.2.3.4"], session)
assert result["1.2.3.4"].country_code is None
@@ -778,7 +779,7 @@ class TestLookupCachedOnly:
def test_returns_cached_ips(self) -> None:
"""IPs already in the cache are returned in the geo_map."""
geo_service._cache["1.1.1.1"] = GeoInfo( # type: ignore[attr-defined]
geo_service._cache["1.1.1.1"] = GeoInfo(
country_code="AU", country_name="Australia", asn="AS13335", org="Cloudflare"
)
geo_map, uncached = geo_service.lookup_cached_only(["1.1.1.1"])
@@ -798,7 +799,7 @@ class TestLookupCachedOnly:
"""IPs in the negative cache within TTL are not re-queued as uncached."""
import time
geo_service._neg_cache["10.0.0.1"] = time.monotonic() # type: ignore[attr-defined]
geo_service._neg_cache["10.0.0.1"] = time.monotonic()
geo_map, uncached = geo_service.lookup_cached_only(["10.0.0.1"])
@@ -807,7 +808,7 @@ class TestLookupCachedOnly:
def test_expired_neg_cache_requeued(self) -> None:
"""IPs whose neg-cache entry has expired are listed as uncached."""
geo_service._neg_cache["10.0.0.2"] = 0.0 # epoch 0 → expired # type: ignore[attr-defined]
geo_service._neg_cache["10.0.0.2"] = 0.0 # epoch 0 → expired
_geo_map, uncached = geo_service.lookup_cached_only(["10.0.0.2"])
@@ -815,12 +816,12 @@ class TestLookupCachedOnly:
def test_mixed_ips(self) -> None:
"""A mix of cached, neg-cached, and unknown IPs is split correctly."""
geo_service._cache["1.2.3.4"] = GeoInfo( # type: ignore[attr-defined]
geo_service._cache["1.2.3.4"] = GeoInfo(
country_code="DE", country_name="Germany", asn=None, org=None
)
import time
geo_service._neg_cache["5.5.5.5"] = time.monotonic() # type: ignore[attr-defined]
geo_service._neg_cache["5.5.5.5"] = time.monotonic()
geo_map, uncached = geo_service.lookup_cached_only(["1.2.3.4", "5.5.5.5", "9.9.9.9"])
@@ -829,7 +830,7 @@ class TestLookupCachedOnly:
def test_deduplication(self) -> None:
"""Duplicate IPs in the input appear at most once in the output."""
geo_service._cache["1.2.3.4"] = GeoInfo( # type: ignore[attr-defined]
geo_service._cache["1.2.3.4"] = GeoInfo(
country_code="US", country_name="United States", asn=None, org=None
)
@@ -866,7 +867,7 @@ class TestLookupBatchBulkWrites:
session = _make_batch_session(batch_response)
db = _make_async_db()
await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type]
await geo_service.lookup_batch(ips, session, db=db)
# One executemany for the positive rows.
assert db.executemany.await_count >= 1
@@ -883,7 +884,7 @@ class TestLookupBatchBulkWrites:
session = _make_batch_session(batch_response)
db = _make_async_db()
await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type]
await geo_service.lookup_batch(ips, session, db=db)
assert db.executemany.await_count >= 1
db.execute.assert_not_awaited()
@@ -905,7 +906,7 @@ class TestLookupBatchBulkWrites:
session = _make_batch_session(batch_response)
db = _make_async_db()
await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type]
await geo_service.lookup_batch(ips, session, db=db)
# One executemany for positives, one for negatives.
assert db.executemany.await_count == 2

View File

@@ -64,7 +64,7 @@ async def _create_f2b_db(path: str, rows: list[dict[str, Any]]) -> None:
@pytest.fixture
async def f2b_db_path(tmp_path: Path) -> str: # type: ignore[misc]
async def f2b_db_path(tmp_path: Path) -> str:
"""Return the path to a test fail2ban SQLite database."""
path = str(tmp_path / "fail2ban_test.sqlite3")
await _create_f2b_db(

View File

@@ -996,9 +996,6 @@ class TestGetJailBannedIps:
async def test_unknown_jail_raises_jail_not_found_error(self) -> None:
"""get_jail_banned_ips raises JailNotFoundError for unknown jail."""
responses = {
"status|ghost|short": (0, pytest.raises), # will be overridden
}
# Simulate fail2ban returning an "unknown jail" error.
class _FakeClient:
def __init__(self, **_kw: Any) -> None:

View File

@@ -270,7 +270,7 @@ class TestCrashDetection:
async def test_crash_within_window_creates_pending_recovery(self) -> None:
"""An online→offline transition within 60 s of activation must set pending_recovery."""
app = _make_app(prev_online=True)
now = datetime.datetime.now(tz=datetime.timezone.utc)
now = datetime.datetime.now(tz=datetime.UTC)
app.state.last_activation = {
"jail_name": "sshd",
"at": now - datetime.timedelta(seconds=10),
@@ -297,7 +297,7 @@ class TestCrashDetection:
app = _make_app(prev_online=True)
app.state.last_activation = {
"jail_name": "sshd",
"at": datetime.datetime.now(tz=datetime.timezone.utc)
"at": datetime.datetime.now(tz=datetime.UTC)
- datetime.timedelta(seconds=120),
}
app.state.pending_recovery = None
@@ -315,8 +315,8 @@ class TestCrashDetection:
async def test_came_online_marks_pending_recovery_resolved(self) -> None:
"""An offline→online transition must mark an existing pending_recovery as recovered."""
app = _make_app(prev_online=False)
activated_at = datetime.datetime.now(tz=datetime.timezone.utc) - datetime.timedelta(seconds=30)
detected_at = datetime.datetime.now(tz=datetime.timezone.utc)
activated_at = datetime.datetime.now(tz=datetime.UTC) - datetime.timedelta(seconds=30)
detected_at = datetime.datetime.now(tz=datetime.UTC)
app.state.pending_recovery = PendingRecovery(
jail_name="sshd",
activated_at=activated_at,