Task 8: Standardize modeling style (TypedDict vs Pydantic)

Convert inconsistent modeling style to standardized Pydantic models for all
external-facing data structures while maintaining TypedDict compatibility where
appropriate for internal layer-private structures.

Changes:
- Converted IpLookupResult TypedDict to use IpLookupResponse Pydantic model
  in jail_service.lookup_ip() for consistency with routers
- Added GeoCacheEntry Pydantic model for geo cache repository rows
- Converted GeoCacheRow TypedDict to use GeoCacheEntry alias
- Converted ImportLogRow TypedDict to use ImportLogEntry alias
- Updated routers and services to work with Pydantic models
- Updated all tests to use Pydantic model field access (attributes)
  instead of dict subscripting

Documentation:
- Added 'Model Type Usage by Layer' section to Backend-Development.md
- Defines when TypedDict is allowed (internal structures) vs Pydantic
  (external-facing, cross-boundary data)
- Provides clear guidance on modeling conventions per layer

Benefits:
- Consistent validation and serialization behavior
- Better IDE support and type checking
- Clearer separation of concerns by layer
- Reduced maintenance cost from mixed validation approaches

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
2026-04-28 07:53:30 +02:00
parent 3888c5eb3f
commit 52a4d04d92
10 changed files with 127 additions and 85 deletions

View File

@@ -401,6 +401,69 @@ async def delete_log_path(
- **Never use string prefix matching** for path validation (e.g., `path.startswith("/var/log")`). The helper uses `Path.relative_to()` to prevent bypasses like `/var/log_evil/file.log`. - **Never use string prefix matching** for path validation (e.g., `path.startswith("/var/log")`). The helper uses `Path.relative_to()` to prevent bypasses like `/var/log_evil/file.log`.
- Symlinks are resolved before validating to prevent symlink-based escapes. - Symlinks are resolved before validating to prevent symlink-based escapes.
### Model Type Usage by Layer
**Pydantic models** are mandatory for all **external-facing** data structures — anything that crosses layer boundaries or is serialized to HTTP responses. **TypedDict** may be used **only** for internal, layer-private data structures where they provide precise typing without runtime overhead.
**Rules:**
1. **Routers (HTTP boundary):** All request and response types **must be Pydantic models**. FastAPI uses these for validation, serialization, and OpenAPI documentation.
- Use Pydantic request models for request bodies and query parameters.
- Use Pydantic response models in the `response_model` parameter.
```python
# Good — Pydantic models for router layer
class JailStatsRequest(BaseModel):
jail_name: str
class JailStatsResponse(BaseModel):
jail_name: str
active_bans: int
@router.post("/stats", response_model=JailStatsResponse)
async def get_stats(req: JailStatsRequest) -> JailStatsResponse:
...
```
2. **Services (business logic):** Return types should be **Pydantic models** if the result is:
- Returned to a router (likely — they become API responses).
- Used across multiple services (shared interfaces).
- Exposed to external consumers (even indirectly).
If a service returns a purely internal intermediate result used by a single caller, TypedDict is acceptable but should be rare.
```python
# Good — service returns Pydantic (may be used by multiple routers)
async def get_jail_details(name: str) -> JailDetailResponse:
...
# Acceptable — purely internal utility result
def _parse_fail2ban_response(raw: str) -> ParsedResponse:
"""Internal helper—used only by this service."""
...
```
3. **Repositories (data access):** Return types may use **TypedDict** because they represent **raw database rows** that:
- Are layer-private (only called by their own service).
- Do not cross HTTP boundaries directly.
- Benefit from lightweight typing without runtime validation.
```python
# Good — TypedDict for raw repository rows
class GeoRow(TypedDict):
ip: str
country_code: str | None
async def load_all(db: aiosqlite.Connection) -> list[GeoRow]:
...
```
If a repository result becomes part of a service's public interface (returned to routers or other services), convert it to a Pydantic model.
4. **Utilities and helpers:** Internal helper results may use TypedDict if they are not part of a public module interface.
**Migration path:** Existing internal TypedDicts (e.g., `GeoCacheRow`, `ImportLogRow`) may remain as TypedDicts so long as they stay within their layer. If a type needs to cross layer boundaries (repo → service → router), convert it to a Pydantic model incrementally as you refactor that data flow.
--- ---
## 6. Async Rules ## 6. Async Rules

View File

@@ -1,22 +1,3 @@
## 7) Service layer coupled to response/presentation models
- Where found:
- [backend/app/services/ban_service.py](backend/app/services/ban_service.py)
- Why this is needed:
- Domain logic becomes tied to API shape and slows model evolution.
- Goal:
- Keep service returns domain-centered; map to API models at router boundary.
- What to do:
- Introduce service DTO/domain objects.
- Map to response models in routers or dedicated mappers.
- Possible traps and issues:
- Temporary duplicate model definitions during migration.
- Docs changes needed:
- Add layer responsibilities and mapping policy.
- Doc references:
- [Docs/Architekture.md](Docs/Architekture.md)
---
## 8) Inconsistent modeling style (TypedDict vs Pydantic) ## 8) Inconsistent modeling style (TypedDict vs Pydantic)
- Where found: - Where found:
- [backend/app/services/jail_service.py](backend/app/services/jail_service.py) - [backend/app/services/jail_service.py](backend/app/services/jail_service.py)

View File

@@ -42,6 +42,33 @@ class GeoDetail(BaseModel):
) )
class GeoCacheEntry(BaseModel):
"""A single cached geolocation entry for an IP address.
Represents a row from the ``geo_cache`` table in the application database.
"""
model_config = ConfigDict(strict=True)
ip: str = Field(..., description="IP address (IPv4 or IPv6).")
country_code: str | None = Field(
default=None,
description="ISO 3166-1 alpha-2 country code.",
)
country_name: str | None = Field(
default=None,
description="Human-readable country name.",
)
asn: str | None = Field(
default=None,
description="Autonomous System Number (e.g. ``'AS3320'``).",
)
org: str | None = Field(
default=None,
description="Organisation associated with the ASN.",
)
class GeoCacheStatsResponse(BaseModel): class GeoCacheStatsResponse(BaseModel):
"""Response for ``GET /api/geo/stats``. """Response for ``GET /api/geo/stats``.

View File

@@ -9,22 +9,18 @@ connection lifetimes.
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, TypedDict from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Sequence from collections.abc import Sequence
import aiosqlite import aiosqlite
from app.models.geo import GeoCacheEntry
class GeoCacheRow(TypedDict):
"""A single row from the ``geo_cache`` table."""
ip: str # Alias for backward compatibility with protocols
country_code: str | None GeoCacheRow = GeoCacheEntry
country_name: str | None
asn: str | None
org: str | None
async def load_all(db: aiosqlite.Connection) -> list[GeoCacheRow]: async def load_all(db: aiosqlite.Connection) -> list[GeoCacheRow]:

View File

@@ -8,26 +8,18 @@ table. All methods are plain async functions that accept a
from __future__ import annotations from __future__ import annotations
import math import math
from typing import TYPE_CHECKING, TypedDict, cast from typing import TYPE_CHECKING, cast
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Mapping from collections.abc import Mapping
import aiosqlite import aiosqlite
from app.models.blocklist import ImportLogEntry
class ImportLogRow(TypedDict):
"""Row shape returned by queries on the import_log table."""
id: int
source_id: int | None
source_url: str
timestamp: str
ips_imported: int
ips_skipped: int
errors: str | None
# Alias for backward compatibility with protocols
ImportLogRow = ImportLogEntry
async def add_log( async def add_log(
db: aiosqlite.Connection, db: aiosqlite.Connection,
*, *,
@@ -158,13 +150,13 @@ def compute_total_pages(total: int, page_size: int) -> int:
def _row_to_dict(row: object) -> ImportLogRow: def _row_to_dict(row: object) -> ImportLogRow:
"""Convert an aiosqlite row to a plain Python dict. """Convert an aiosqlite row to an ImportLogEntry Pydantic model.
Args: Args:
row: An :class:`aiosqlite.Row` or similar mapping returned by a cursor. row: An :class:`aiosqlite.Row` or similar mapping returned by a cursor.
Returns: Returns:
Dict mapping column names to Python values. ImportLogEntry Pydantic model instance.
""" """
mapping = cast("Mapping[str, object]", row) mapping = cast("Mapping[str, object]", row)
return cast("ImportLogRow", dict(mapping)) return ImportLogEntry(**mapping)

View File

@@ -8,10 +8,7 @@ Provides the IP enrichment endpoints:
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Annotated from typing import Annotated
if TYPE_CHECKING:
from app.services.jail_service import IpLookupResult
from fastapi import APIRouter, Path from fastapi import APIRouter, Path
@@ -57,14 +54,12 @@ async def lookup_ip(
HTTPException: 400 when *ip* is not a valid IP address. HTTPException: 400 when *ip* is not a valid IP address.
HTTPException: 502 when fail2ban is unreachable. HTTPException: 502 when fail2ban is unreachable.
""" """
result: IpLookupResult = await jail_service.lookup_ip( return await jail_service.lookup_ip(
socket_path, socket_path,
ip, ip,
http_session=http_session, http_session=http_session,
) )
return IpLookupResponse(**result)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# POST /api/geo/re-resolve # POST /api/geo/re-resolve

View File

@@ -18,14 +18,14 @@ from __future__ import annotations
import asyncio import asyncio
import contextlib import contextlib
import ipaddress import ipaddress
from typing import TYPE_CHECKING, TypedDict, cast from typing import TYPE_CHECKING, cast
import structlog import structlog
from app.exceptions import JailNotFoundError, JailOperationError from app.exceptions import JailNotFoundError, JailOperationError
from app.models.ban import ActiveBan, JailBannedIpsResponse from app.models.ban import ActiveBan, JailBannedIpsResponse
from app.models.config import BantimeEscalation from app.models.config import BantimeEscalation
from app.models.geo import GeoDetail from app.models.geo import GeoDetail, IpLookupResponse
from app.models.jail import ( from app.models.jail import (
Jail, Jail,
JailDetailResponse, JailDetailResponse,
@@ -63,18 +63,6 @@ log: structlog.stdlib.BoundLogger = structlog.get_logger()
__all__ = ["reload_all"] __all__ = ["reload_all"]
class IpLookupResult(TypedDict):
"""Result returned by :func:`lookup_ip`.
This is intentionally a :class:`TypedDict` to provide precise typing for
callers (e.g. routers) while keeping the implementation flexible.
"""
ip: str
currently_banned_in: list[str]
geo: GeoDetail | None
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Constants # Constants
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -998,7 +986,7 @@ async def lookup_ip(
ip: str, ip: str,
geo_enricher: GeoEnricher | None = None, geo_enricher: GeoEnricher | None = None,
http_session: aiohttp.ClientSession | None = None, http_session: aiohttp.ClientSession | None = None,
) -> IpLookupResult: ) -> IpLookupResponse:
"""Return ban status and history for a single IP address. """Return ban status and history for a single IP address.
Checks every running jail for whether the IP is currently banned. Checks every running jail for whether the IP is currently banned.
@@ -1075,11 +1063,11 @@ async def lookup_ip(
log.info("ip_lookup_completed", ip=ip, banned_in_jails=currently_banned_in) log.info("ip_lookup_completed", ip=ip, banned_in_jails=currently_banned_in)
return { return IpLookupResponse(
"ip": ip, ip=ip,
"currently_banned_in": currently_banned_in, currently_banned_in=currently_banned_in,
"geo": geo, geo=geo,
} )
async def unban_all_ips(socket_path: str) -> int: async def unban_all_ips(socket_path: str) -> int:

View File

@@ -173,7 +173,7 @@ class TestImportLogRepo:
) )
items, total = await import_log_repo.list_logs(db, source_id=source_id) items, total = await import_log_repo.list_logs(db, source_id=source_id)
assert total == 1 assert total == 1
assert items[0]["source_url"] == "https://s.test/" assert items[0].source_url == "https://s.test/"
async def test_get_last_log_empty(self, db: aiosqlite.Connection) -> None: async def test_get_last_log_empty(self, db: aiosqlite.Connection) -> None:
"""get_last_log returns None when no logs exist.""" """get_last_log returns None when no logs exist."""
@@ -200,7 +200,7 @@ class TestImportLogRepo:
) )
last = await import_log_repo.get_last_log(db) last = await import_log_repo.get_last_log(db)
assert last is not None assert last is not None
assert last["source_url"] == "https://last.test/" assert last.source_url == "https://last.test/"
async def test_compute_total_pages(self) -> None: async def test_compute_total_pages(self) -> None:
"""compute_total_pages returns correct page count.""" """compute_total_pages returns correct page count."""

View File

@@ -82,7 +82,7 @@ async def test_load_all_and_count_unresolved(tmp_path: Path) -> None:
unresolved = await geo_cache_repo.count_unresolved(db) unresolved = await geo_cache_repo.count_unresolved(db)
assert unresolved == 1 assert unresolved == 1
assert any(row["ip"] == "6.6.6.6" and row["country_code"] == "FR" for row in rows) assert any(row.ip == "6.6.6.6" and row.country_code == "FR" for row in rows)
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@@ -847,8 +847,8 @@ class TestLookupIp:
with _patch_client(responses): with _patch_client(responses):
result = await jail_service.lookup_ip(_SOCKET, "1.2.3.4") result = await jail_service.lookup_ip(_SOCKET, "1.2.3.4")
assert result["ip"] == "1.2.3.4" assert result.ip == "1.2.3.4"
assert "sshd" in result["currently_banned_in"] assert "sshd" in result.currently_banned_in
async def test_geo_enricher_returns_geo_detail(self) -> None: async def test_geo_enricher_returns_geo_detail(self) -> None:
"""lookup_ip converts GeoInfo from the enricher into GeoDetail.""" """lookup_ip converts GeoInfo from the enricher into GeoDetail."""
@@ -868,11 +868,11 @@ class TestLookupIp:
geo_enricher=_enricher, geo_enricher=_enricher,
) )
assert isinstance(result["geo"], GeoDetail) assert isinstance(result.geo, GeoDetail)
assert result["geo"].country_code == "DE" assert result.geo.country_code == "DE"
assert result["geo"].country_name == "Germany" assert result.geo.country_name == "Germany"
assert result["geo"].asn == "AS123" assert result.geo.asn == "AS123"
assert result["geo"].org == "Acme" assert result.geo.org == "Acme"
async def test_http_session_uses_geo_service_lookup(self) -> None: async def test_http_session_uses_geo_service_lookup(self) -> None:
"""lookup_ip uses geo_service.lookup when http_session is provided.""" """lookup_ip uses geo_service.lookup when http_session is provided."""
@@ -896,11 +896,11 @@ class TestLookupIp:
) )
mock_lookup.assert_awaited_once_with("1.2.3.4", mock_session) mock_lookup.assert_awaited_once_with("1.2.3.4", mock_session)
assert isinstance(result["geo"], GeoDetail) assert isinstance(result.geo, GeoDetail)
assert result["geo"].country_code == "JP" assert result.geo.country_code == "JP"
assert result["geo"].country_name == "Japan" assert result.geo.country_name == "Japan"
assert result["geo"].asn is None assert result.geo.asn is None
assert result["geo"].org is None assert result.geo.org is None
async def test_invalid_ip_raises(self) -> None: async def test_invalid_ip_raises(self) -> None:
"""lookup_ip raises ValueError for invalid IP.""" """lookup_ip raises ValueError for invalid IP."""
@@ -917,7 +917,7 @@ class TestLookupIp:
with _patch_client(responses): with _patch_client(responses):
result = await jail_service.lookup_ip(_SOCKET, "9.9.9.9") result = await jail_service.lookup_ip(_SOCKET, "9.9.9.9")
assert result["currently_banned_in"] == [] assert result.currently_banned_in == []
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------