diff --git a/Docs/Tasks.md b/Docs/Tasks.md index 5a37717..4b1cbfa 100644 --- a/Docs/Tasks.md +++ b/Docs/Tasks.md @@ -8,71 +8,123 @@ Reference: `Docs/Refactoring.md` for full analysis of each issue. ## Open Issues -### Backend Architecture +--- -- **Replace the single shared SQLite connection.** - - Current startup code opens one `aiosqlite.Connection` and reuses it for every request. - - This should be replaced with either a connection pool or request-scoped connections to avoid concurrency and locking issues. - - Update request dependencies, application lifecycle, and tests to use the new pattern. +### TASK-001 — WorldMap: filter companion table by selected country (server-side) -- **Refactor dependency wiring and shared resource management.** - - Remove hidden module-level import coupling between routers, services, and shared utilities. - - Introduce explicit factories or providers for shared resources such as DB, HTTP client session, scheduler, and settings. - - Ensure routers depend on injected providers rather than global state or dynamic imports. +**Status:** Done +**Priority:** Medium +**Domain:** Full-stack (backend + frontend) +**References:** `Docs/Features.md §4`, `Docs/Web-Development.md` -- **Harden fail2ban integration.** - - Remove the `sys.path` hack that locates `fail2ban-master` at runtime. - - Replace it with a deterministic packaging or configuration model so the backend does not depend on repository layout. - - Refactor `Fail2BanClient` so concurrency control is instance-based and not backed by hidden module globals. +#### Background -- **Improve startup / setup guard behavior.** - - Convert `SetupRedirectMiddleware` from an on-demand DB check into a startup/initialisation guard where possible. - - Cache setup completion in a safe way and provide an explicit invalidation path if the application state changes. - - Reduce middleware responsibility and avoid DB access during normal request dispatch. +The `GET /api/dashboard/bans/by-country` endpoint always returns the **200 most recent** ban rows in `bans` (constant `_MAX_COMPANION_BANS = 200` in `backend/app/services/ban_service.py`). `MapPage.tsx` stores a `selectedCountry` state and filters the returned rows client-side via `visibleBans`. This means the companion table can only show the fraction of a country's bans that fall within the global top-200 window. If the selected time range has, say, 1 500 bans and 300 are from China, but China's bans are not all in the top 200 overall, the table will silently display fewer than 300 rows. -- **Make deployment configuration explicit.** - - Move hard-coded environment assumptions such as CORS origins into settings. - - Ensure `fail2ban_socket`, `fail2ban_config_dir`, and startup commands are fully configurable via `Settings`. - - Document production-ready defaults separately from development defaults. +When a country is selected the companion table **must** return the complete set of bans for that country so the user sees an accurate picture. -### Reliability and Resilience +#### Desired behaviour -- **Add backend lifecycle tests for resource cleanup.** - - Verify startup opens and initialises DB, HTTP session, scheduler, and geo cache correctly. - - Verify shutdown closes those resources cleanly. +- No country selected → companion table shows the 200 most recent bans across all countries (existing behaviour, no change). +- Country selected → the server returns **all** ban entries for that country in the selected time window; no client-side row-count cap applies. +- Deselecting a country (clicking the same country again, or the "Clear filter" button) reverts to the default 200-row unfiltered view. +- The existing `visibleBans` client-side filter in `MapPage.tsx` can remain as a defensive guard but must not be the only filter. -- **Add concurrency/regression coverage for DB and fail2ban socket use.** - - Add tests that simulate multiple concurrent requests using the same DB dependency. - - Add tests around fail2ban socket retries, protocol errors, and rate limiting. +#### Implementation steps -- **Improve state caching and invalidation.** - - Add tests for session cache invalidation on logout. - - Add tests for setup completion caching so stale state is never served. +1. **Backend — router** (`backend/app/routers/dashboard.py`) + - Add `country_code: str | None = Query(default=None, description="ISO alpha-2 country code to filter companion rows.")` to `get_bans_by_country`. + - Pass it to `ban_service.bans_by_country(..., country_code=country_code)`. -### Backend Feature Work +2. **Backend — service** (`backend/app/services/ban_service.py`) + - Add `country_code: str | None = None` keyword argument to `bans_by_country`. + - After `geo_map` is built (existing geo-resolution step), collect IPs whose resolved country matches `country_code`. + - For the **fail2ban source**: call `fail2ban_db_repo.get_currently_banned` with `ip_filter=matched_ips` and no `limit` (remove the `_MAX_COMPANION_BANS` cap for filtered queries). + - For the **archive source**: filter `all_rows` to those whose IP is in `matched_ips` and return all of them (skip the `page_size=_MAX_COMPANION_BANS` call). + - When `country_code` is `None`, behaviour is identical to today. -- **Document and implement backend-safe environment-driven CORS.** - - Add support for production and local development origins through configuration. - - Avoid a hardcoded Vite origin in the core app factory. +3. **Backend — repository** (`backend/app/repositories/fail2ban_db_repo.py`) + - Add `ip_filter: list[str] | None = None` to `get_currently_banned`. + - When provided and non-empty, append `AND ip IN ({placeholders})` to the SQL `WHERE` clause, parameterised safely (never interpolated as a string). -- **Centralise scheduler job registration.** - - Refactor APScheduler registration so background tasks are registered through a common lifecycle helper. - - Ensure jobs can be discovered, replaced, and tested without requiring implicit `app.state` side effects. +4. **Backend — repository (archive)** (`backend/app/repositories/history_archive_repo.py`) + - Similarly add optional `ip_filter` to the archive companion-rows query used from `bans_by_country`. -- **Strengthen fail2ban error handling and reporting.** - - Standardise `502` responses for connection/protocol failures across all endpoints. - - Add structured logging for retries and fatal socket failures. - - Ensure the UI can distinguish offline fail2ban from internal backend failures. +5. **Frontend — API client** (`frontend/src/api/map.ts`) + - Add optional `countryCode?: string` parameter to `fetchBansByCountry`. + - When set, append `country_code=` to the query string. -- **Improve documentation of backend responsibilities.** - - Keep `Docs/Tasks.md` aligned with the backend architecture review. - - Add references to the backend modules, resource lifecycle, and dependency model in the documentation. +6. **Frontend — hook** (`frontend/src/hooks/useMapData.ts`) + - Add `countryCode?: string` to the function signature. + - Include it in the `useCallback` dependency array and pass it to `fetchBansByCountry`. + +7. **Frontend — page** (`frontend/src/pages/MapPage.tsx`) + - Pass `selectedCountry ?? undefined` as `countryCode` to `useMapData`. + - The hook's effect will re-fetch automatically when `selectedCountry` changes; the existing `useEffect` that resets `page` to 1 already covers this. + +#### Testing guidance + +- Select a country that has > 200 bans in the chosen time window; confirm the companion table shows more than the previous cap would allow. +- With no country selected, confirm only 200 rows are returned (no regression). +- Deselect the country; confirm the unfiltered 200-row view is restored. +- Test with the archive source as well as the fail2ban live source. +- Verify the `ip_filter` SQL clause is parameterised and cannot be injected. + +--- + +### TASK-002 — WorldMap: sticky table header and sticky pagination bar + +**Priority:** Low +**Domain:** Frontend only +**References:** `Docs/Features.md §4`, `Docs/Web-Design.md`, `Docs/Web-Development.md` + +#### Background + +The companion ban table in `MapPage.tsx` is wrapped in `tableWrapper` (CSS `overflow: auto; maxHeight: 420px`). Both the Fluent UI `TableHeader` row and the `.pagination` div inside `tableWrapper` scroll with the content. Once the user scrolls more than a few rows, the column header labels disappear and the pagination controls become unreachable without scrolling back to the top or bottom. + +#### Desired behaviour + +- The column header row (`TableHeader →TableRow → TableHeaderCell × 6`) must remain fixed at the **top** of the scrollable container at all times. +- The pagination / page-size bar (`.pagination` div at the bottom of `tableWrapper`) must remain fixed at the **bottom** of the scrollable container at all times. +- Rows in `TableBody` scroll normally between the two fixed ends. +- No changes to the container height, overall layout, or other pages. + +#### Implementation steps + +All changes are in `frontend/src/pages/MapPage.tsx`. + +1. **Sticky table header cells** + - In `useStyles` (`makeStyles`), add a new class: + ```ts + stickyHeaderCell: { + position: "sticky", + top: 0, + zIndex: 1, + backgroundColor: tokens.colorNeutralBackground1, + boxShadow: `0 1px 0 ${tokens.colorNeutralStroke2}`, + }, + ``` + - Apply `className={styles.stickyHeaderCell}` to **each** `TableHeaderCell` in the header row. + - Note: `position: sticky` on `` elements is unreliable across browsers for table layouts; apply it to each `` (`TableHeaderCell`) instead. + +2. **Sticky pagination bar** + - In the existing `pagination` entry in `useStyles`, add: + ```ts + position: "sticky", + bottom: 0, + zIndex: 1, + ``` + - The existing `backgroundColor: tokens.colorNeutralBackground2` already prevents table rows from bleeding through. + +3. **No other changes** — do not alter `tableWrapper`, its height, or anything outside `MapPage.tsx`. + +#### Testing guidance + +- Load the Map page with a time range that produces > 25 bans (enough to overflow the `420px` container). +- Scroll down through the table and confirm the column headers remain visible at the top. +- Scroll down and confirm the pagination bar remains visible at the bottom. +- Verify no visual artefacts (table body rows must not overlap or bleed through the sticky elements). +- Run `tsc --noEmit` — zero type errors expected. +- Run existing frontend tests: `vitest run` — no regressions. -### Priority Execution Plan -1. Fix the global SQLite connection pattern and tests. -2. Refactor dependency injection / explicit shared resources. -3. Harden fail2ban client concurrency and packaging. -4. Convert setup guard to a safer startup-driven model. -5. Add deployment-safe configuration and production-ready CORS. -6. Add lifecycle and concurrency regression tests. diff --git a/backend/app/repositories/fail2ban_db_repo.py b/backend/app/repositories/fail2ban_db_repo.py index 06b8419..ce6c10c 100644 --- a/backend/app/repositories/fail2ban_db_repo.py +++ b/backend/app/repositories/fail2ban_db_repo.py @@ -126,6 +126,7 @@ async def get_currently_banned( since: int, origin: BanOrigin | None = None, *, + ip_filter: list[str] | None = None, limit: int | None = None, offset: int | None = None, ) -> tuple[list[BanRecord], int]: @@ -135,6 +136,7 @@ async def get_currently_banned( db_path: File path to the fail2ban SQLite database. since: Unix timestamp to filter bans newer than or equal to. origin: Optional origin filter. + ip_filter: Optional list of IP addresses to restrict the result to. limit: Optional maximum number of rows to return. offset: Optional offset for pagination. @@ -142,14 +144,21 @@ async def get_currently_banned( A ``(records, total)`` tuple. """ + if ip_filter is not None and len(ip_filter) == 0: + return [], 0 + origin_clause, origin_params = _origin_sql_filter(origin) + ip_filter_clause = "" + if ip_filter is not None: + placeholder = ", ".join("?" for _ in ip_filter) + ip_filter_clause = f" AND ip IN ({placeholder})" async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db: db.row_factory = aiosqlite.Row async with db.execute( - "SELECT COUNT(*) FROM bans WHERE timeofban >= ?" + origin_clause, - (since, *origin_params), + "SELECT COUNT(*) FROM bans WHERE timeofban >= ?" + origin_clause + ip_filter_clause, + (since, *origin_params, *(ip_filter or [])), ) as cur: count_row = await cur.fetchone() total: int = int(count_row[0]) if count_row else 0 @@ -157,9 +166,9 @@ async def get_currently_banned( query = ( "SELECT jail, ip, timeofban, bancount, data " "FROM bans " - "WHERE timeofban >= ?" + origin_clause + " ORDER BY timeofban DESC" + "WHERE timeofban >= ?" + origin_clause + ip_filter_clause + " ORDER BY timeofban DESC" ) - params: list[object] = [since, *origin_params] + params: list[object] = [since, *origin_params, *(ip_filter or [])] if limit is not None: query += " LIMIT ?" params.append(limit) diff --git a/backend/app/repositories/history_archive_repo.py b/backend/app/repositories/history_archive_repo.py index 8f1b599..184e7d4 100644 --- a/backend/app/repositories/history_archive_repo.py +++ b/backend/app/repositories/history_archive_repo.py @@ -40,13 +40,16 @@ async def get_archived_history( db: aiosqlite.Connection, since: int | None = None, jail: str | None = None, - ip_filter: str | None = None, + ip_filter: str | list[str] | None = None, origin: BanOrigin | None = None, action: str | None = None, page: int = 1, page_size: int = 100, ) -> tuple[list[dict], int]: """Return a paginated archived history result set.""" + if isinstance(ip_filter, list) and len(ip_filter) == 0: + return [], 0 + wheres: list[str] = [] params: list[object] = [] @@ -59,8 +62,13 @@ async def get_archived_history( params.append(jail) if ip_filter is not None: - wheres.append("ip LIKE ?") - params.append(f"{ip_filter}%") + if isinstance(ip_filter, list): + placeholder = ", ".join("?" for _ in ip_filter) + wheres.append(f"ip IN ({placeholder})") + params.extend(ip_filter) + else: + wheres.append("ip LIKE ?") + params.append(f"{ip_filter}%") if origin == "blocklist": wheres.append("jail = ?") @@ -108,7 +116,7 @@ async def get_all_archived_history( db: aiosqlite.Connection, since: int | None = None, jail: str | None = None, - ip_filter: str | None = None, + ip_filter: str | list[str] | None = None, origin: BanOrigin | None = None, action: str | None = None, ) -> list[dict]: diff --git a/backend/app/routers/dashboard.py b/backend/app/routers/dashboard.py index af90287..644fb61 100644 --- a/backend/app/routers/dashboard.py +++ b/backend/app/routers/dashboard.py @@ -83,7 +83,10 @@ async def get_dashboard_bans( request: Request, _auth: AuthDep, range: TimeRange = Query(default=_DEFAULT_RANGE, description="Time-range preset."), - source: Literal["fail2ban", "archive"] = Query(default="fail2ban", description="Data source: 'fail2ban' or 'archive'."), + source: Literal["fail2ban", "archive"] = Query( + default="fail2ban", + description="Data source: 'fail2ban' or 'archive'.", + ), page: int = Query(default=1, ge=1, description="1-based page number."), page_size: int = Query(default=_DEFAULT_PAGE_SIZE, ge=1, le=500, description="Items per page."), origin: BanOrigin | None = Query( @@ -137,11 +140,18 @@ async def get_bans_by_country( request: Request, _auth: AuthDep, range: TimeRange = Query(default=_DEFAULT_RANGE, description="Time-range preset."), - source: Literal["fail2ban", "archive"] = Query(default="fail2ban", description="Data source: 'fail2ban' or 'archive'."), + source: Literal["fail2ban", "archive"] = Query( + default="fail2ban", + description="Data source: 'fail2ban' or 'archive'.", + ), origin: BanOrigin | None = Query( default=None, description="Filter by ban origin: 'blocklist' or 'selfblock'. Omit for all.", ), + country_code: str | None = Query( + default=None, + description="ISO alpha-2 country code to filter companion rows.", + ), ) -> BansByCountryResponse: """Return ban counts aggregated by ISO country code. @@ -173,6 +183,7 @@ async def get_bans_by_country( geo_batch_lookup=geo_service.lookup_batch, app_db=request.app.state.db, origin=origin, + country_code=country_code, ) @@ -185,7 +196,10 @@ async def get_ban_trend( request: Request, _auth: AuthDep, range: TimeRange = Query(default=_DEFAULT_RANGE, description="Time-range preset."), - source: Literal["fail2ban", "archive"] = Query(default="fail2ban", description="Data source: 'fail2ban' or 'archive'."), + source: Literal["fail2ban", "archive"] = Query( + default="fail2ban", + description="Data source: 'fail2ban' or 'archive'.", + ), origin: BanOrigin | None = Query( default=None, description="Filter by ban origin: 'blocklist' or 'selfblock'. Omit for all.", @@ -235,7 +249,10 @@ async def get_bans_by_jail( request: Request, _auth: AuthDep, range: TimeRange = Query(default=_DEFAULT_RANGE, description="Time-range preset."), - source: Literal["fail2ban", "archive"] = Query(default="fail2ban", description="Data source: 'fail2ban' or 'archive'."), + source: Literal["fail2ban", "archive"] = Query( + default="fail2ban", + description="Data source: 'fail2ban' or 'archive'.", + ), origin: BanOrigin | None = Query( default=None, description="Filter by ban origin: 'blocklist' or 'selfblock'. Omit for all.", diff --git a/backend/app/services/ban_service.py b/backend/app/services/ban_service.py index 07ae406..57a1e74 100644 --- a/backend/app/services/ban_service.py +++ b/backend/app/services/ban_service.py @@ -290,6 +290,7 @@ async def bans_by_country( geo_enricher: GeoEnricher | None = None, app_db: aiosqlite.Connection | None = None, origin: BanOrigin | None = None, + country_code: str | None = None, ) -> BansByCountryResponse: """Aggregate ban counts per country for the selected time window. @@ -350,16 +351,6 @@ async def bans_by_country( total = len(all_rows) - # companion rows for the table should be most recent - companion_rows, _ = await get_archived_history( - db=app_db, - since=since, - origin=origin, - action="ban", - page=1, - page_size=_MAX_COMPANION_BANS, - ) - agg_rows = {} for row in all_rows: ip = str(row["ip"]) @@ -393,14 +384,6 @@ async def bans_by_country( origin=origin, ) - companion_rows, _ = await fail2ban_db_repo.get_currently_banned( - db_path=db_path, - since=since, - origin=origin, - limit=_MAX_COMPANION_BANS, - offset=0, - ) - unique_ips = [r.ip for r in agg_rows] geo_map: dict[str, GeoInfo] = {} @@ -434,6 +417,54 @@ async def bans_by_country( results = await asyncio.gather(*(_safe_lookup(ip) for ip in unique_ips)) geo_map = {ip: geo for ip, geo in results if geo is not None} + companion_rows: list[dict[str, object] | fail2ban_db_repo.BanRecord] + if country_code is None: + if source == "archive": + companion_rows, _ = await get_archived_history( + db=app_db, + since=since, + origin=origin, + action="ban", + page=1, + page_size=_MAX_COMPANION_BANS, + ) + else: + companion_rows, _ = await fail2ban_db_repo.get_currently_banned( + db_path=db_path, + since=since, + origin=origin, + limit=_MAX_COMPANION_BANS, + offset=0, + ) + else: + matched_ips = [ + ip + for ip, geo in geo_map.items() + if geo is not None and geo.country_code == country_code + ] + + if source == "archive": + if matched_ips: + companion_rows = await get_all_archived_history( + db=app_db, + since=since, + origin=origin, + action="ban", + ip_filter=matched_ips, + ) + else: + companion_rows = [] + else: + if matched_ips: + companion_rows, _ = await fail2ban_db_repo.get_currently_banned( + db_path=db_path, + since=since, + origin=origin, + ip_filter=matched_ips, + ) + else: + companion_rows = [] + # Build country aggregation from the SQL-grouped rows. countries: dict[str, int] = {} country_names: dict[str, str] = {} diff --git a/backend/tests/test_repositories/test_fail2ban_db_repo.py b/backend/tests/test_repositories/test_fail2ban_db_repo.py index 98146bc..5f0c429 100644 --- a/backend/tests/test_repositories/test_fail2ban_db_repo.py +++ b/backend/tests/test_repositories/test_fail2ban_db_repo.py @@ -80,6 +80,32 @@ async def test_get_currently_banned_filters_and_pagination(tmp_path: Path) -> No assert records[0].ip == "3.3.3.3" +@pytest.mark.asyncio +async def test_get_currently_banned_filters_by_ip_list(tmp_path: Path) -> None: + db_path = str(tmp_path / "fail2ban.db") + async with aiosqlite.connect(db_path) as db: + await _create_bans_table(db) + await db.executemany( + "INSERT INTO bans (jail, ip, timeofban, bancount, data) VALUES (?, ?, ?, ?, ?)", + [ + ("jail1", "1.1.1.1", 10, 1, "{}"), + ("jail1", "2.2.2.2", 20, 1, "{}"), + ("jail1", "3.3.3.3", 30, 1, "{}"), + ], + ) + await db.commit() + + records, total = await fail2ban_db_repo.get_currently_banned( + db_path=db_path, + since=0, + ip_filter=["2.2.2.2", "3.3.3.3"], + ) + + assert total == 2 + assert len(records) == 2 + assert {record.ip for record in records} == {"2.2.2.2", "3.3.3.3"} + + @pytest.mark.asyncio async def test_get_ban_counts_by_bucket_ignores_out_of_range_buckets(tmp_path: Path) -> None: db_path = str(tmp_path / "fail2ban.db") diff --git a/backend/tests/test_repositories/test_history_archive_repo.py b/backend/tests/test_repositories/test_history_archive_repo.py index c10997f..f9b5373 100644 --- a/backend/tests/test_repositories/test_history_archive_repo.py +++ b/backend/tests/test_repositories/test_history_archive_repo.py @@ -47,6 +47,10 @@ async def test_get_archived_history_filtering_and_pagination(app_db: str) -> Non assert total == 2 assert len(rows) == 1 + rows, total = await get_archived_history(db, ip_filter=["2.2.2.2"]) + assert total == 1 + assert rows[0]["ip"] == "2.2.2.2" + @pytest.mark.asyncio async def test_purge_archived_history(app_db: str) -> None: diff --git a/backend/tests/test_routers/test_dashboard.py b/backend/tests/test_routers/test_dashboard.py index 30a8c89..80e74ab 100644 --- a/backend/tests/test_routers/test_dashboard.py +++ b/backend/tests/test_routers/test_dashboard.py @@ -522,6 +522,19 @@ class TestDashboardBansOriginField: assert mock_fn.call_args[1]["source"] == "archive" + async def test_bans_by_country_country_code_forwarded( + self, dashboard_client: AsyncClient + ) -> None: + """The ``country_code`` query parameter is forwarded to bans_by_country.""" + mock_fn = AsyncMock(return_value=_make_bans_by_country_response()) + with patch("app.routers.dashboard.ban_service.bans_by_country", new=mock_fn): + await dashboard_client.get( + "/api/dashboard/bans/by-country?country_code=DE" + ) + + _, kwargs = mock_fn.call_args + assert kwargs.get("country_code") == "DE" + async def test_blocklist_origin_serialised_correctly( self, dashboard_client: AsyncClient ) -> None: diff --git a/backend/tests/test_services/test_ban_service.py b/backend/tests/test_services/test_ban_service.py index bf5cefd..87be876 100644 --- a/backend/tests/test_services/test_ban_service.py +++ b/backend/tests/test_services/test_ban_service.py @@ -654,6 +654,54 @@ class TestOriginFilter: assert result.total == 3 + async def test_bans_by_country_country_code_returns_all_matched_rows( + self, tmp_path: Path + ) -> None: + """``bans_by_country`` returns all companion rows for the selected country.""" + path = str(tmp_path / "fail2ban_country_filter.sqlite3") + rows = [ + { + "jail": "sshd", + "ip": "10.0.0.1", + "timeofban": _ONE_HOUR_AGO - i, + "bantime": 3600, + "bancount": 1, + "data": {"matches": ["failed login"]}, + } + for i in range(205) + ] + await _create_f2b_db(path, rows) + + from app.services import geo_service + + geo_service._cache["10.0.0.1"] = geo_service.GeoInfo( + country_code="DE", + country_name="Germany", + asn=None, + org=None, + ) + + with patch( + "app.services.ban_service.get_fail2ban_db_path", + new=AsyncMock(return_value=path), + ), patch( + "app.services.ban_service.asyncio.create_task" + ) as mock_create_task: + result = await ban_service.bans_by_country( + "/fake/sock", + "24h", + country_code="DE", + http_session=AsyncMock(), + geo_cache_lookup=geo_service.lookup_cached_only, + ) + + mock_create_task.assert_not_called() + assert result.total == 205 + assert len(result.bans) == 205 + assert all(b.country_code == "DE" for b in result.bans) + + geo_service.clear_cache() + async def test_bans_by_country_source_archive_reads_archive( self, app_db_with_archive: aiosqlite.Connection ) -> None: diff --git a/frontend/src/api/map.test.ts b/frontend/src/api/map.test.ts new file mode 100644 index 0000000..f561b71 --- /dev/null +++ b/frontend/src/api/map.test.ts @@ -0,0 +1,34 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; +import type { Mock } from "vitest"; +import { ENDPOINTS } from "./endpoints"; +import { fetchBansByCountry } from "./map"; +import { get } from "./client"; + +vi.mock("./client", () => ({ + get: vi.fn(), +})); + +const mockedGet = get as Mock; + +describe("fetchBansByCountry", () => { + beforeEach(() => { + mockedGet.mockReset(); + mockedGet.mockResolvedValue({ countries: {}, country_names: {}, bans: [], total: 0 }); + }); + + it("appends country_code when provided", async () => { + await fetchBansByCountry("24h", "all", "fail2ban", "US"); + + expect(get).toHaveBeenCalledWith( + `${ENDPOINTS.dashboardBansByCountry}?range=24h&country_code=US` + ); + }); + + it("does not append country_code when undefined", async () => { + await fetchBansByCountry("24h", "all", "fail2ban"); + + expect(get).toHaveBeenCalledWith( + `${ENDPOINTS.dashboardBansByCountry}?range=24h` + ); + }); +}); diff --git a/frontend/src/api/map.ts b/frontend/src/api/map.ts index 5405995..086217d 100644 --- a/frontend/src/api/map.ts +++ b/frontend/src/api/map.ts @@ -18,6 +18,7 @@ export async function fetchBansByCountry( range: TimeRange = "24h", origin: BanOriginFilter = "all", source: "fail2ban" | "archive" = "fail2ban", + countryCode?: string, ): Promise { const params = new URLSearchParams({ range }); if (origin !== "all") { @@ -26,5 +27,8 @@ export async function fetchBansByCountry( if (source !== "fail2ban") { params.set("source", source); } + if (countryCode) { + params.set("country_code", countryCode); + } return get(`${ENDPOINTS.dashboardBansByCountry}?${params.toString()}`); } diff --git a/frontend/src/hooks/useMapData.ts b/frontend/src/hooks/useMapData.ts index f1d2f76..0597cfb 100644 --- a/frontend/src/hooks/useMapData.ts +++ b/frontend/src/hooks/useMapData.ts @@ -44,6 +44,7 @@ export function useMapData( range: TimeRange = "24h", origin: BanOriginFilter = "all", source: "fail2ban" | "archive" = "fail2ban", + countryCode?: string, ): UseMapDataResult { const [data, setData] = useState(null); const [loading, setLoading] = useState(true); @@ -65,7 +66,7 @@ export function useMapData( abortRef.current?.abort(); abortRef.current = new AbortController(); - fetchBansByCountry(range, origin, source) + fetchBansByCountry(range, origin, source, countryCode) .then((resp) => { setData(resp); }) @@ -76,7 +77,7 @@ export function useMapData( setLoading(false); }); }, DEBOUNCE_MS); - }, [range, origin, source]); + }, [range, origin, source, countryCode]); useEffect((): (() => void) => { load(); diff --git a/frontend/src/pages/MapPage.tsx b/frontend/src/pages/MapPage.tsx index 98a82a1..2b796ff 100644 --- a/frontend/src/pages/MapPage.tsx +++ b/frontend/src/pages/MapPage.tsx @@ -101,7 +101,7 @@ export function MapPage(): React.JSX.Element { const source = range === "24h" ? "fail2ban" : "archive"; const { countries, countryNames, bans, total, loading, error, refresh } = - useMapData(range, originFilter, source); + useMapData(range, originFilter, source, selectedCountry ?? undefined); const { thresholds: mapThresholds,