Introduce service/repository dependency protocols and tests

This commit is contained in:
2026-04-10 19:51:19 +02:00
parent 3b6e39ddad
commit 3371ff8324
8 changed files with 419 additions and 11 deletions

View File

@@ -20,6 +20,8 @@ from app.config import Settings
from app.models.auth import Session
from app.models.config import PendingRecovery
from app.models.server import ServerStatus
from app.repositories.protocols import SessionRepository
from app.services.protocols import AuthService, JailService
from app.utils.runtime_state import RuntimeState, get_effective_settings
from app.utils.session_cache import SessionCache
@@ -169,6 +171,27 @@ async def get_session_cache(request: Request) -> SessionCache:
return session_cache
async def get_auth_service() -> AuthService:
"""Provide the concrete authentication service implementation."""
from app.services import auth_service # noqa: PLC0415
return cast(AuthService, auth_service)
async def get_jail_service() -> JailService:
"""Provide the concrete jail service implementation."""
from app.services import jail_service # noqa: PLC0415
return cast(JailService, jail_service)
async def get_session_repo() -> SessionRepository:
"""Provide the concrete session repository implementation."""
from app.repositories import session_repo # noqa: PLC0415
return session_repo
async def get_app_state(request: Request) -> AppState:
"""Provide the application state object for the current request."""
return cast("AppState", request.app.state)
@@ -194,6 +217,8 @@ async def require_auth(
db: Annotated[aiosqlite.Connection, Depends(get_db)],
settings: Annotated[Settings, Depends(get_settings)],
session_cache: Annotated[SessionCache, Depends(get_session_cache)],
auth_service: Annotated[AuthService, Depends(get_auth_service)],
session_repo: Annotated[SessionRepository, Depends(get_session_repo)],
) -> Session:
"""Validate the session token and return the active session.
@@ -218,7 +243,6 @@ async def require_auth(
Raises:
HTTPException: 401 if no valid session token is found.
"""
from app.services import auth_service # noqa: PLC0415
token: str | None = request.cookies.get(_COOKIE_NAME)
if not token:
@@ -240,7 +264,12 @@ async def require_auth(
return cached
try:
session = await auth_service.validate_session(db, token, settings.session_secret)
session = await auth_service.validate_session(
db,
token,
settings.session_secret,
session_repo=session_repo,
)
except ValueError as exc:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@@ -264,6 +293,9 @@ Fail2BanStartCommandDep = Annotated[str, Depends(get_fail2ban_start_command)]
ServerStatusDep = Annotated[ServerStatus, Depends(get_server_status)]
PendingRecoveryDep = Annotated[PendingRecovery | None, Depends(get_pending_recovery)]
SessionCacheDep = Annotated[SessionCache, Depends(get_session_cache)]
AuthServiceDep = Annotated[AuthService, Depends(get_auth_service)]
JailServiceDep = Annotated[JailService, Depends(get_jail_service)]
SessionRepoDep = Annotated[SessionRepository, Depends(get_session_repo)]
AppStateDep = Annotated[AppState, Depends(get_app_state)]
AppDep = Annotated[FastAPI, Depends(get_app)]
AuthDep = Annotated[Session, Depends(require_auth)]

View File

@@ -0,0 +1,47 @@
"""Repository interface protocols for dependency injection.
Routers and services can depend on these abstractions instead of concrete
module implementations, making the backend easier to test and extend.
"""
from __future__ import annotations
from typing import Protocol
import aiosqlite
from app.models.auth import Session
class SessionRepository(Protocol):
"""Protocol for session persistence operations."""
async def create_session(
self,
db: aiosqlite.Connection,
token: str,
created_at: str,
expires_at: str,
) -> Session:
...
async def get_session(
self,
db: aiosqlite.Connection,
token: str,
) -> Session | None:
...
async def delete_session(
self,
db: aiosqlite.Connection,
token: str,
) -> None:
...
async def delete_expired_sessions(
self,
db: aiosqlite.Connection,
now_iso: str,
) -> int:
...

View File

@@ -12,9 +12,15 @@ from __future__ import annotations
import structlog
from fastapi import APIRouter, HTTPException, Request, Response, status
from app.dependencies import DbDep, SessionCacheDep, SettingsDep
from app.dependencies import (
AuthServiceDep,
DbDep,
SessionCacheDep,
SettingsDep,
SessionRepoDep,
)
from app.models.auth import LoginRequest, LoginResponse, LogoutResponse
from app.services import auth_service
from app.services.auth_service import sign_session_token
log: structlog.stdlib.BoundLogger = structlog.get_logger()
@@ -33,6 +39,8 @@ async def login(
response: Response,
db: DbDep,
settings: SettingsDep,
auth_service: AuthServiceDep,
session_repo: SessionRepoDep,
) -> LoginResponse:
"""Verify the master password and return a session token.
@@ -56,6 +64,7 @@ async def login(
db,
password=body.password,
session_duration_minutes=settings.session_duration_minutes,
session_repo=session_repo,
)
except ValueError as exc:
raise HTTPException(
@@ -63,7 +72,7 @@ async def login(
detail=str(exc),
) from exc
signed_token = auth_service.sign_session_token(
signed_token = sign_session_token(
session.token,
settings.session_secret,
)
@@ -89,6 +98,8 @@ async def logout(
db: DbDep,
settings: SettingsDep,
session_cache: SessionCacheDep,
auth_service: AuthServiceDep,
session_repo: SessionRepoDep,
) -> LogoutResponse:
"""Invalidate the active session.
@@ -107,7 +118,12 @@ async def logout(
"""
token = _extract_token(request)
if token:
raw_token = await auth_service.logout(db, token, settings.session_secret)
raw_token = await auth_service.logout(
db,
token,
settings.session_secret,
session_repo=session_repo,
)
if raw_token:
session_cache.invalidate(raw_token)
session_cache.invalidate(token)

View File

@@ -28,6 +28,7 @@ from app.dependencies import (
DbDep,
Fail2BanSocketDep,
HttpSessionDep,
JailServiceDep,
)
from app.exceptions import JailNotFoundError, JailOperationError
from app.models.ban import JailBannedIpsResponse
@@ -37,7 +38,7 @@ from app.models.jail import (
JailDetailResponse,
JailListResponse,
)
from app.services import geo_service, jail_service
from app.services import geo_service
from app.utils.fail2ban_client import Fail2BanConnectionError
router: APIRouter = APIRouter(prefix="/api/jails", tags=["Jails"])
@@ -107,6 +108,7 @@ def _conflict(message: str) -> HTTPException:
async def get_jails(
_auth: AuthDep,
socket_path: Fail2BanSocketDep,
jail_service: JailServiceDep,
) -> JailListResponse:
"""Return a summary of every active fail2ban jail.
@@ -135,6 +137,7 @@ async def get_jail(
_auth: AuthDep,
name: _NamePath,
socket_path: Fail2BanSocketDep,
jail_service: JailServiceDep,
) -> JailDetailResponse:
"""Return the complete configuration and runtime state for one jail.
@@ -174,6 +177,7 @@ async def get_jail(
async def reload_all_jails(
_auth: AuthDep,
socket_path: Fail2BanSocketDep,
jail_service: JailServiceDep,
) -> JailCommandResponse:
"""Reload every fail2ban jail to apply configuration changes.
@@ -208,6 +212,7 @@ async def start_jail(
_auth: AuthDep,
name: _NamePath,
socket_path: Fail2BanSocketDep,
jail_service: JailServiceDep,
) -> JailCommandResponse:
"""Start a fail2ban jail that is currently stopped.
@@ -243,6 +248,7 @@ async def stop_jail(
_auth: AuthDep,
name: _NamePath,
socket_path: Fail2BanSocketDep,
jail_service: JailServiceDep,
) -> JailCommandResponse:
"""Stop a running fail2ban jail.
@@ -279,6 +285,7 @@ async def toggle_idle(
_auth: AuthDep,
name: _NamePath,
socket_path: Fail2BanSocketDep,
jail_service: JailServiceDep,
on: bool = Body(..., description="``true`` to enable idle, ``false`` to disable."),
) -> JailCommandResponse:
"""Enable or disable idle mode for a fail2ban jail.
@@ -323,6 +330,7 @@ async def reload_jail(
_auth: AuthDep,
name: _NamePath,
socket_path: Fail2BanSocketDep,
jail_service: JailServiceDep,
) -> JailCommandResponse:
"""Reload a single fail2ban jail to pick up configuration changes.
@@ -371,6 +379,7 @@ async def get_ignore_list(
_auth: AuthDep,
name: _NamePath,
socket_path: Fail2BanSocketDep,
jail_service: JailServiceDep,
) -> list[str]:
"""Return the current ignore list (IP whitelist) for a fail2ban jail.
@@ -404,6 +413,7 @@ async def add_ignore_ip(
name: _NamePath,
body: IgnoreIpRequest,
socket_path: Fail2BanSocketDep,
jail_service: JailServiceDep,
) -> JailCommandResponse:
"""Add an IP address or CIDR network to a jail's ignore list.
@@ -453,6 +463,7 @@ async def del_ignore_ip(
name: _NamePath,
body: IgnoreIpRequest,
socket_path: Fail2BanSocketDep,
jail_service: JailServiceDep,
) -> JailCommandResponse:
"""Remove an IP address or CIDR network from a jail's ignore list.
@@ -492,6 +503,7 @@ async def toggle_ignore_self(
_auth: AuthDep,
name: _NamePath,
socket_path: Fail2BanSocketDep,
jail_service: JailServiceDep,
on: bool = Body(..., description="``true`` to enable ignoreself, ``false`` to disable."),
) -> JailCommandResponse:
"""Toggle the ``ignoreself`` flag for a fail2ban jail.
@@ -543,6 +555,7 @@ async def get_jail_banned_ips(
name: _NamePath,
socket_path: Fail2BanSocketDep,
http_session: HttpSessionDep,
jail_service: JailServiceDep,
page: int = 1,
page_size: int = 25,
search: str | None = None,

View File

@@ -20,6 +20,7 @@ if TYPE_CHECKING:
import aiosqlite
from app.models.auth import Session
from app.repositories.protocols import SessionRepository
from app.repositories import session_repo
from app.utils.constants import SESSION_TOKEN_BYTES, SESSION_TOKEN_SIGNATURE_SEPARATOR
@@ -80,6 +81,7 @@ async def login(
db: aiosqlite.Connection,
password: str,
session_duration_minutes: int,
session_repository: SessionRepository = session_repo,
) -> Session:
"""Verify *password* and create a new session on success.
@@ -108,7 +110,7 @@ async def login(
created_iso = now.isoformat()
expires_iso = add_minutes(now, session_duration_minutes).isoformat()
session = await session_repo.create_session(
session = await session_repository.create_session(
db, token=token, created_at=created_iso, expires_at=expires_iso
)
log.info("bangui_login_success", token_prefix=token[:8])
@@ -119,6 +121,7 @@ async def validate_session(
db: aiosqlite.Connection,
token: str,
session_secret: str | None = None,
session_repository: SessionRepository = session_repo,
) -> Session:
"""Return the session for *token* if it is valid and not expired.
@@ -139,13 +142,13 @@ async def validate_session(
except ValueError as exc:
raise ValueError("Session token is invalid.") from exc
session = await session_repo.get_session(db, token)
session = await session_repository.get_session(db, token)
if session is None:
raise ValueError("Session not found.")
now_iso = utc_now().isoformat()
if session.expires_at <= now_iso:
await session_repo.delete_session(db, token)
await session_repository.delete_session(db, token)
raise ValueError("Session has expired.")
return session
@@ -155,6 +158,7 @@ async def logout(
db: aiosqlite.Connection,
token: str,
session_secret: str | None = None,
session_repository: SessionRepository = session_repo,
) -> str | None:
"""Invalidate the session identified by *token*.
@@ -173,6 +177,6 @@ async def logout(
log.warning("bangui_logout_invalid_token", token_prefix=token[:8])
return None
await session_repo.delete_session(db, token)
await session_repository.delete_session(db, token)
log.info("bangui_logout", token_prefix=token[:8])
return token

View File

@@ -0,0 +1,105 @@
"""Service interface protocols for dependency injection.
These structural protocols define the public contract that routers and higher
layers depend on, without binding them to concrete module implementations.
"""
from __future__ import annotations
from typing import Protocol
import aiosqlite
from app.models.auth import Session
from app.models.ban import JailBannedIpsResponse
from app.models.jail import JailDetailResponse, JailListResponse
class AuthService(Protocol):
"""Protocol for authentication service operations."""
async def login(
self,
db: aiosqlite.Connection,
password: str,
session_duration_minutes: int,
session_repo: object | None = None,
) -> Session:
...
async def validate_session(
self,
db: aiosqlite.Connection,
token: str,
session_secret: str | None = None,
session_repo: object | None = None,
) -> Session:
...
async def logout(
self,
db: aiosqlite.Connection,
token: str,
session_secret: str | None = None,
session_repo: object | None = None,
) -> str | None:
...
class JailService(Protocol):
"""Protocol for jail management service operations."""
async def list_jails(self, socket_path: str) -> JailListResponse:
...
async def get_jail(self, socket_path: str, name: str) -> JailDetailResponse:
...
async def reload_all(self, socket_path: str) -> None:
...
async def start_jail(self, socket_path: str, name: str) -> None:
...
async def stop_jail(self, socket_path: str, name: str) -> None:
...
async def set_idle(self, socket_path: str, name: str, *, on: bool) -> None:
...
async def reload_jail(self, socket_path: str, name: str) -> None:
...
async def get_ignore_list(self, socket_path: str, name: str) -> list[str]:
...
async def add_ignore_ip(self, socket_path: str, name: str, ip: str) -> None:
...
async def del_ignore_ip(self, socket_path: str, name: str, ip: str) -> None:
...
async def set_ignore_self(self, socket_path: str, name: str, *, on: bool) -> None:
...
async def get_jail_banned_ips(
self,
socket_path: str,
jail_name: str,
page: int,
page_size: int,
search: str | None = None,
*,
geo_batch_lookup: object,
http_session: object,
app_db: aiosqlite.Connection,
) -> JailBannedIpsResponse:
...
async def lookup_ip(
self,
socket_path: str,
ip: str,
geo_enricher: object,
) -> object:
...

View File

@@ -0,0 +1,189 @@
"""Router dependency injection tests.
These tests verify that routers can consume service abstractions via FastAPI
dependencies and that those dependencies can be overridden cleanly for unit
testing without touching concrete implementations.
"""
from __future__ import annotations
from pathlib import Path
import aiosqlite
from httpx import ASGITransport, AsyncClient
from app.config import Settings
from app.db import init_db
from app.dependencies import get_auth_service, get_jail_service
from app.main import create_app
from app.models.auth import Session
from app.models.jail import JailListResponse
from app.utils.setup_state import set_setup_complete_cache
class FakeAuthService:
async def login(
self,
_db: aiosqlite.Connection,
password: str,
session_duration_minutes: int,
session_repo: object | None = None,
) -> Session:
return Session(
id=1,
token="fake-token",
created_at="2025-01-01T00:00:00Z",
expires_at="2099-01-01T00:00:00Z",
)
async def validate_session(
self,
_db: aiosqlite.Connection,
token: str,
session_secret: str | None = None,
session_repo: object | None = None,
) -> Session:
return Session(
id=1,
token=token,
created_at="2025-01-01T00:00:00Z",
expires_at="2099-01-01T00:00:00Z",
)
async def logout(
self,
_db: aiosqlite.Connection,
token: str,
session_secret: str | None = None,
session_repo: object | None = None,
) -> str | None:
return token
class FakeJailService:
async def list_jails(self, _socket_path: str) -> JailListResponse:
return JailListResponse(jails=[], total=0)
async def get_jail(self, _socket_path: str, _name: str) -> JailListResponse:
raise NotImplementedError
async def reload_all(self, _socket_path: str) -> None:
raise NotImplementedError
async def start_jail(self, _socket_path: str, _name: str) -> None:
raise NotImplementedError
async def stop_jail(self, _socket_path: str, _name: str) -> None:
raise NotImplementedError
async def set_idle(self, socket_path: str, name: str, *, on: bool) -> None:
raise NotImplementedError
async def reload_jail(self, socket_path: str, name: str) -> None:
raise NotImplementedError
async def get_ignore_list(self, socket_path: str, name: str) -> list[str]:
raise NotImplementedError
async def add_ignore_ip(self, socket_path: str, name: str, ip: str) -> None:
raise NotImplementedError
async def del_ignore_ip(self, socket_path: str, name: str, ip: str) -> None:
raise NotImplementedError
async def set_ignore_self(self, socket_path: str, name: str, *, on: bool) -> None:
raise NotImplementedError
async def get_jail_banned_ips(
self,
socket_path: str,
jail_name: str,
page: int,
page_size: int,
search: str | None = None,
*,
geo_batch_lookup: object,
http_session: object,
app_db: aiosqlite.Connection,
) -> JailListResponse:
raise NotImplementedError
async def _build_app(settings: Settings):
app = create_app(settings=settings)
set_setup_complete_cache(app, True)
db = await aiosqlite.connect(settings.database_path)
db.row_factory = aiosqlite.Row
await init_db(db)
return app, db
async def test_auth_login_uses_injected_auth_service(tmp_path: Path) -> None:
settings = Settings(
database_path=str(tmp_path / "test_bangui.db"),
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir=str(tmp_path / "fail2ban"),
session_secret="test-secret-key-do-not-use-in-production",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
)
app, db = await _build_app(settings)
def _fake_auth_service() -> FakeAuthService:
return FakeAuthService()
app.dependency_overrides[get_auth_service] = _fake_auth_service
transport = ASGITransport(app=app)
async with AsyncClient(
transport=transport,
base_url="http://test",
) as client:
response = await client.post(
"/api/auth/login",
json={"password": "ignored"},
)
await db.close()
assert response.status_code == 200
assert response.json()["token"].startswith("fake-token")
assert response.cookies.get("bangui_session") is not None
async def test_jail_list_uses_injected_jail_service_and_auth(tmp_path: Path) -> None:
settings = Settings(
database_path=str(tmp_path / "test_bangui.db"),
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir=str(tmp_path / "fail2ban"),
session_secret="test-secret-key-do-not-use-in-production",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
)
app, db = await _build_app(settings)
def _fake_auth_service() -> FakeAuthService:
return FakeAuthService()
def _fake_jail_service() -> FakeJailService:
return FakeJailService()
app.dependency_overrides[get_auth_service] = _fake_auth_service
app.dependency_overrides[get_jail_service] = _fake_jail_service
transport = ASGITransport(app=app)
async with AsyncClient(
transport=transport,
base_url="http://test",
) as client:
response = await client.get(
"/api/jails",
headers={"Cookie": "bangui_session=fake-token"},
)
await db.close()
assert response.status_code == 200
assert response.json() == {"jails": [], "total": 0}