From 3371ff83248cb87d124fc87a57ebd4d9c9bd8001 Mon Sep 17 00:00:00 2001 From: Lukas Date: Fri, 10 Apr 2026 19:51:19 +0200 Subject: [PATCH] Introduce service/repository dependency protocols and tests --- Docs/Tasks.md | 2 + backend/app/dependencies.py | 36 +++- backend/app/repositories/protocols.py | 47 +++++ backend/app/routers/auth.py | 24 ++- backend/app/routers/jails.py | 15 +- backend/app/services/auth_service.py | 12 +- backend/app/services/protocols.py | 105 ++++++++++ .../test_routers/test_dependency_injection.py | 189 ++++++++++++++++++ 8 files changed, 419 insertions(+), 11 deletions(-) create mode 100644 backend/app/repositories/protocols.py create mode 100644 backend/app/services/protocols.py create mode 100644 backend/tests/test_routers/test_dependency_injection.py diff --git a/Docs/Tasks.md b/Docs/Tasks.md index 7730807..3c856e8 100644 --- a/Docs/Tasks.md +++ b/Docs/Tasks.md @@ -48,6 +48,8 @@ Reference: `Docs/Refactoring.md` for full analysis of each issue. - Issue: Routers and higher-level services currently import concrete service modules directly, which prevents clean substitution and dependency override testing. - Propose: Define protocols or abstract base classes for major services and repositories, then wire concrete implementations through FastAPI dependency providers. - Test: Add tests that override a service dependency with a fake implementation and verify the router behavior remains correct. + - Status: completed + - Completed: Added service/repository protocols, injected auth/jail services via FastAPI dependencies, and added router tests for dependency overrides. 7. Move operational orchestration out of routers and into service/task layer - Goal: Keep routers thin and move operational control flow into service or task components. diff --git a/backend/app/dependencies.py b/backend/app/dependencies.py index ef6bd7c..6dbce34 100644 --- a/backend/app/dependencies.py +++ b/backend/app/dependencies.py @@ -20,6 +20,8 @@ from app.config import Settings from app.models.auth import Session from app.models.config import PendingRecovery from app.models.server import ServerStatus +from app.repositories.protocols import SessionRepository +from app.services.protocols import AuthService, JailService from app.utils.runtime_state import RuntimeState, get_effective_settings from app.utils.session_cache import SessionCache @@ -169,6 +171,27 @@ async def get_session_cache(request: Request) -> SessionCache: return session_cache +async def get_auth_service() -> AuthService: + """Provide the concrete authentication service implementation.""" + from app.services import auth_service # noqa: PLC0415 + + return cast(AuthService, auth_service) + + +async def get_jail_service() -> JailService: + """Provide the concrete jail service implementation.""" + from app.services import jail_service # noqa: PLC0415 + + return cast(JailService, jail_service) + + +async def get_session_repo() -> SessionRepository: + """Provide the concrete session repository implementation.""" + from app.repositories import session_repo # noqa: PLC0415 + + return session_repo + + async def get_app_state(request: Request) -> AppState: """Provide the application state object for the current request.""" return cast("AppState", request.app.state) @@ -194,6 +217,8 @@ async def require_auth( db: Annotated[aiosqlite.Connection, Depends(get_db)], settings: Annotated[Settings, Depends(get_settings)], session_cache: Annotated[SessionCache, Depends(get_session_cache)], + auth_service: Annotated[AuthService, Depends(get_auth_service)], + session_repo: Annotated[SessionRepository, Depends(get_session_repo)], ) -> Session: """Validate the session token and return the active session. @@ -218,7 +243,6 @@ async def require_auth( Raises: HTTPException: 401 if no valid session token is found. """ - from app.services import auth_service # noqa: PLC0415 token: str | None = request.cookies.get(_COOKIE_NAME) if not token: @@ -240,7 +264,12 @@ async def require_auth( return cached try: - session = await auth_service.validate_session(db, token, settings.session_secret) + session = await auth_service.validate_session( + db, + token, + settings.session_secret, + session_repo=session_repo, + ) except ValueError as exc: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -264,6 +293,9 @@ Fail2BanStartCommandDep = Annotated[str, Depends(get_fail2ban_start_command)] ServerStatusDep = Annotated[ServerStatus, Depends(get_server_status)] PendingRecoveryDep = Annotated[PendingRecovery | None, Depends(get_pending_recovery)] SessionCacheDep = Annotated[SessionCache, Depends(get_session_cache)] +AuthServiceDep = Annotated[AuthService, Depends(get_auth_service)] +JailServiceDep = Annotated[JailService, Depends(get_jail_service)] +SessionRepoDep = Annotated[SessionRepository, Depends(get_session_repo)] AppStateDep = Annotated[AppState, Depends(get_app_state)] AppDep = Annotated[FastAPI, Depends(get_app)] AuthDep = Annotated[Session, Depends(require_auth)] diff --git a/backend/app/repositories/protocols.py b/backend/app/repositories/protocols.py new file mode 100644 index 0000000..120a086 --- /dev/null +++ b/backend/app/repositories/protocols.py @@ -0,0 +1,47 @@ +"""Repository interface protocols for dependency injection. + +Routers and services can depend on these abstractions instead of concrete +module implementations, making the backend easier to test and extend. +""" + +from __future__ import annotations + +from typing import Protocol + +import aiosqlite + +from app.models.auth import Session + + +class SessionRepository(Protocol): + """Protocol for session persistence operations.""" + + async def create_session( + self, + db: aiosqlite.Connection, + token: str, + created_at: str, + expires_at: str, + ) -> Session: + ... + + async def get_session( + self, + db: aiosqlite.Connection, + token: str, + ) -> Session | None: + ... + + async def delete_session( + self, + db: aiosqlite.Connection, + token: str, + ) -> None: + ... + + async def delete_expired_sessions( + self, + db: aiosqlite.Connection, + now_iso: str, + ) -> int: + ... diff --git a/backend/app/routers/auth.py b/backend/app/routers/auth.py index 33237e1..755fa3a 100644 --- a/backend/app/routers/auth.py +++ b/backend/app/routers/auth.py @@ -12,9 +12,15 @@ from __future__ import annotations import structlog from fastapi import APIRouter, HTTPException, Request, Response, status -from app.dependencies import DbDep, SessionCacheDep, SettingsDep +from app.dependencies import ( + AuthServiceDep, + DbDep, + SessionCacheDep, + SettingsDep, + SessionRepoDep, +) from app.models.auth import LoginRequest, LoginResponse, LogoutResponse -from app.services import auth_service +from app.services.auth_service import sign_session_token log: structlog.stdlib.BoundLogger = structlog.get_logger() @@ -33,6 +39,8 @@ async def login( response: Response, db: DbDep, settings: SettingsDep, + auth_service: AuthServiceDep, + session_repo: SessionRepoDep, ) -> LoginResponse: """Verify the master password and return a session token. @@ -56,6 +64,7 @@ async def login( db, password=body.password, session_duration_minutes=settings.session_duration_minutes, + session_repo=session_repo, ) except ValueError as exc: raise HTTPException( @@ -63,7 +72,7 @@ async def login( detail=str(exc), ) from exc - signed_token = auth_service.sign_session_token( + signed_token = sign_session_token( session.token, settings.session_secret, ) @@ -89,6 +98,8 @@ async def logout( db: DbDep, settings: SettingsDep, session_cache: SessionCacheDep, + auth_service: AuthServiceDep, + session_repo: SessionRepoDep, ) -> LogoutResponse: """Invalidate the active session. @@ -107,7 +118,12 @@ async def logout( """ token = _extract_token(request) if token: - raw_token = await auth_service.logout(db, token, settings.session_secret) + raw_token = await auth_service.logout( + db, + token, + settings.session_secret, + session_repo=session_repo, + ) if raw_token: session_cache.invalidate(raw_token) session_cache.invalidate(token) diff --git a/backend/app/routers/jails.py b/backend/app/routers/jails.py index c59d4fb..85f3e9b 100644 --- a/backend/app/routers/jails.py +++ b/backend/app/routers/jails.py @@ -28,6 +28,7 @@ from app.dependencies import ( DbDep, Fail2BanSocketDep, HttpSessionDep, + JailServiceDep, ) from app.exceptions import JailNotFoundError, JailOperationError from app.models.ban import JailBannedIpsResponse @@ -37,7 +38,7 @@ from app.models.jail import ( JailDetailResponse, JailListResponse, ) -from app.services import geo_service, jail_service +from app.services import geo_service from app.utils.fail2ban_client import Fail2BanConnectionError router: APIRouter = APIRouter(prefix="/api/jails", tags=["Jails"]) @@ -107,6 +108,7 @@ def _conflict(message: str) -> HTTPException: async def get_jails( _auth: AuthDep, socket_path: Fail2BanSocketDep, + jail_service: JailServiceDep, ) -> JailListResponse: """Return a summary of every active fail2ban jail. @@ -135,6 +137,7 @@ async def get_jail( _auth: AuthDep, name: _NamePath, socket_path: Fail2BanSocketDep, + jail_service: JailServiceDep, ) -> JailDetailResponse: """Return the complete configuration and runtime state for one jail. @@ -174,6 +177,7 @@ async def get_jail( async def reload_all_jails( _auth: AuthDep, socket_path: Fail2BanSocketDep, + jail_service: JailServiceDep, ) -> JailCommandResponse: """Reload every fail2ban jail to apply configuration changes. @@ -208,6 +212,7 @@ async def start_jail( _auth: AuthDep, name: _NamePath, socket_path: Fail2BanSocketDep, + jail_service: JailServiceDep, ) -> JailCommandResponse: """Start a fail2ban jail that is currently stopped. @@ -243,6 +248,7 @@ async def stop_jail( _auth: AuthDep, name: _NamePath, socket_path: Fail2BanSocketDep, + jail_service: JailServiceDep, ) -> JailCommandResponse: """Stop a running fail2ban jail. @@ -279,6 +285,7 @@ async def toggle_idle( _auth: AuthDep, name: _NamePath, socket_path: Fail2BanSocketDep, + jail_service: JailServiceDep, on: bool = Body(..., description="``true`` to enable idle, ``false`` to disable."), ) -> JailCommandResponse: """Enable or disable idle mode for a fail2ban jail. @@ -323,6 +330,7 @@ async def reload_jail( _auth: AuthDep, name: _NamePath, socket_path: Fail2BanSocketDep, + jail_service: JailServiceDep, ) -> JailCommandResponse: """Reload a single fail2ban jail to pick up configuration changes. @@ -371,6 +379,7 @@ async def get_ignore_list( _auth: AuthDep, name: _NamePath, socket_path: Fail2BanSocketDep, + jail_service: JailServiceDep, ) -> list[str]: """Return the current ignore list (IP whitelist) for a fail2ban jail. @@ -404,6 +413,7 @@ async def add_ignore_ip( name: _NamePath, body: IgnoreIpRequest, socket_path: Fail2BanSocketDep, + jail_service: JailServiceDep, ) -> JailCommandResponse: """Add an IP address or CIDR network to a jail's ignore list. @@ -453,6 +463,7 @@ async def del_ignore_ip( name: _NamePath, body: IgnoreIpRequest, socket_path: Fail2BanSocketDep, + jail_service: JailServiceDep, ) -> JailCommandResponse: """Remove an IP address or CIDR network from a jail's ignore list. @@ -492,6 +503,7 @@ async def toggle_ignore_self( _auth: AuthDep, name: _NamePath, socket_path: Fail2BanSocketDep, + jail_service: JailServiceDep, on: bool = Body(..., description="``true`` to enable ignoreself, ``false`` to disable."), ) -> JailCommandResponse: """Toggle the ``ignoreself`` flag for a fail2ban jail. @@ -543,6 +555,7 @@ async def get_jail_banned_ips( name: _NamePath, socket_path: Fail2BanSocketDep, http_session: HttpSessionDep, + jail_service: JailServiceDep, page: int = 1, page_size: int = 25, search: str | None = None, diff --git a/backend/app/services/auth_service.py b/backend/app/services/auth_service.py index 3d15d06..dc4ae47 100644 --- a/backend/app/services/auth_service.py +++ b/backend/app/services/auth_service.py @@ -20,6 +20,7 @@ if TYPE_CHECKING: import aiosqlite from app.models.auth import Session + from app.repositories.protocols import SessionRepository from app.repositories import session_repo from app.utils.constants import SESSION_TOKEN_BYTES, SESSION_TOKEN_SIGNATURE_SEPARATOR @@ -80,6 +81,7 @@ async def login( db: aiosqlite.Connection, password: str, session_duration_minutes: int, + session_repository: SessionRepository = session_repo, ) -> Session: """Verify *password* and create a new session on success. @@ -108,7 +110,7 @@ async def login( created_iso = now.isoformat() expires_iso = add_minutes(now, session_duration_minutes).isoformat() - session = await session_repo.create_session( + session = await session_repository.create_session( db, token=token, created_at=created_iso, expires_at=expires_iso ) log.info("bangui_login_success", token_prefix=token[:8]) @@ -119,6 +121,7 @@ async def validate_session( db: aiosqlite.Connection, token: str, session_secret: str | None = None, + session_repository: SessionRepository = session_repo, ) -> Session: """Return the session for *token* if it is valid and not expired. @@ -139,13 +142,13 @@ async def validate_session( except ValueError as exc: raise ValueError("Session token is invalid.") from exc - session = await session_repo.get_session(db, token) + session = await session_repository.get_session(db, token) if session is None: raise ValueError("Session not found.") now_iso = utc_now().isoformat() if session.expires_at <= now_iso: - await session_repo.delete_session(db, token) + await session_repository.delete_session(db, token) raise ValueError("Session has expired.") return session @@ -155,6 +158,7 @@ async def logout( db: aiosqlite.Connection, token: str, session_secret: str | None = None, + session_repository: SessionRepository = session_repo, ) -> str | None: """Invalidate the session identified by *token*. @@ -173,6 +177,6 @@ async def logout( log.warning("bangui_logout_invalid_token", token_prefix=token[:8]) return None - await session_repo.delete_session(db, token) + await session_repository.delete_session(db, token) log.info("bangui_logout", token_prefix=token[:8]) return token diff --git a/backend/app/services/protocols.py b/backend/app/services/protocols.py new file mode 100644 index 0000000..f6d4dce --- /dev/null +++ b/backend/app/services/protocols.py @@ -0,0 +1,105 @@ +"""Service interface protocols for dependency injection. + +These structural protocols define the public contract that routers and higher +layers depend on, without binding them to concrete module implementations. +""" + +from __future__ import annotations + +from typing import Protocol + +import aiosqlite + +from app.models.auth import Session +from app.models.ban import JailBannedIpsResponse +from app.models.jail import JailDetailResponse, JailListResponse + + +class AuthService(Protocol): + """Protocol for authentication service operations.""" + + async def login( + self, + db: aiosqlite.Connection, + password: str, + session_duration_minutes: int, + session_repo: object | None = None, + ) -> Session: + ... + + async def validate_session( + self, + db: aiosqlite.Connection, + token: str, + session_secret: str | None = None, + session_repo: object | None = None, + ) -> Session: + ... + + async def logout( + self, + db: aiosqlite.Connection, + token: str, + session_secret: str | None = None, + session_repo: object | None = None, + ) -> str | None: + ... + + +class JailService(Protocol): + """Protocol for jail management service operations.""" + + async def list_jails(self, socket_path: str) -> JailListResponse: + ... + + async def get_jail(self, socket_path: str, name: str) -> JailDetailResponse: + ... + + async def reload_all(self, socket_path: str) -> None: + ... + + async def start_jail(self, socket_path: str, name: str) -> None: + ... + + async def stop_jail(self, socket_path: str, name: str) -> None: + ... + + async def set_idle(self, socket_path: str, name: str, *, on: bool) -> None: + ... + + async def reload_jail(self, socket_path: str, name: str) -> None: + ... + + async def get_ignore_list(self, socket_path: str, name: str) -> list[str]: + ... + + async def add_ignore_ip(self, socket_path: str, name: str, ip: str) -> None: + ... + + async def del_ignore_ip(self, socket_path: str, name: str, ip: str) -> None: + ... + + async def set_ignore_self(self, socket_path: str, name: str, *, on: bool) -> None: + ... + + async def get_jail_banned_ips( + self, + socket_path: str, + jail_name: str, + page: int, + page_size: int, + search: str | None = None, + *, + geo_batch_lookup: object, + http_session: object, + app_db: aiosqlite.Connection, + ) -> JailBannedIpsResponse: + ... + + async def lookup_ip( + self, + socket_path: str, + ip: str, + geo_enricher: object, + ) -> object: + ... diff --git a/backend/tests/test_routers/test_dependency_injection.py b/backend/tests/test_routers/test_dependency_injection.py new file mode 100644 index 0000000..57e1405 --- /dev/null +++ b/backend/tests/test_routers/test_dependency_injection.py @@ -0,0 +1,189 @@ +"""Router dependency injection tests. + +These tests verify that routers can consume service abstractions via FastAPI +dependencies and that those dependencies can be overridden cleanly for unit +testing without touching concrete implementations. +""" + +from __future__ import annotations + +from pathlib import Path + +import aiosqlite +from httpx import ASGITransport, AsyncClient + +from app.config import Settings +from app.db import init_db +from app.dependencies import get_auth_service, get_jail_service +from app.main import create_app +from app.models.auth import Session +from app.models.jail import JailListResponse +from app.utils.setup_state import set_setup_complete_cache + + +class FakeAuthService: + async def login( + self, + _db: aiosqlite.Connection, + password: str, + session_duration_minutes: int, + session_repo: object | None = None, + ) -> Session: + return Session( + id=1, + token="fake-token", + created_at="2025-01-01T00:00:00Z", + expires_at="2099-01-01T00:00:00Z", + ) + + async def validate_session( + self, + _db: aiosqlite.Connection, + token: str, + session_secret: str | None = None, + session_repo: object | None = None, + ) -> Session: + return Session( + id=1, + token=token, + created_at="2025-01-01T00:00:00Z", + expires_at="2099-01-01T00:00:00Z", + ) + + async def logout( + self, + _db: aiosqlite.Connection, + token: str, + session_secret: str | None = None, + session_repo: object | None = None, + ) -> str | None: + return token + + +class FakeJailService: + async def list_jails(self, _socket_path: str) -> JailListResponse: + return JailListResponse(jails=[], total=0) + + async def get_jail(self, _socket_path: str, _name: str) -> JailListResponse: + raise NotImplementedError + + async def reload_all(self, _socket_path: str) -> None: + raise NotImplementedError + + async def start_jail(self, _socket_path: str, _name: str) -> None: + raise NotImplementedError + + async def stop_jail(self, _socket_path: str, _name: str) -> None: + raise NotImplementedError + + async def set_idle(self, socket_path: str, name: str, *, on: bool) -> None: + raise NotImplementedError + + async def reload_jail(self, socket_path: str, name: str) -> None: + raise NotImplementedError + + async def get_ignore_list(self, socket_path: str, name: str) -> list[str]: + raise NotImplementedError + + async def add_ignore_ip(self, socket_path: str, name: str, ip: str) -> None: + raise NotImplementedError + + async def del_ignore_ip(self, socket_path: str, name: str, ip: str) -> None: + raise NotImplementedError + + async def set_ignore_self(self, socket_path: str, name: str, *, on: bool) -> None: + raise NotImplementedError + + async def get_jail_banned_ips( + self, + socket_path: str, + jail_name: str, + page: int, + page_size: int, + search: str | None = None, + *, + geo_batch_lookup: object, + http_session: object, + app_db: aiosqlite.Connection, + ) -> JailListResponse: + raise NotImplementedError + + +async def _build_app(settings: Settings): + app = create_app(settings=settings) + set_setup_complete_cache(app, True) + db = await aiosqlite.connect(settings.database_path) + db.row_factory = aiosqlite.Row + await init_db(db) + return app, db + + +async def test_auth_login_uses_injected_auth_service(tmp_path: Path) -> None: + settings = Settings( + database_path=str(tmp_path / "test_bangui.db"), + fail2ban_socket="/tmp/fake_fail2ban.sock", + fail2ban_config_dir=str(tmp_path / "fail2ban"), + session_secret="test-secret-key-do-not-use-in-production", + session_duration_minutes=60, + timezone="UTC", + log_level="debug", + ) + + app, db = await _build_app(settings) + def _fake_auth_service() -> FakeAuthService: + return FakeAuthService() + + app.dependency_overrides[get_auth_service] = _fake_auth_service + + transport = ASGITransport(app=app) + async with AsyncClient( + transport=transport, + base_url="http://test", + ) as client: + response = await client.post( + "/api/auth/login", + json={"password": "ignored"}, + ) + + await db.close() + + assert response.status_code == 200 + assert response.json()["token"].startswith("fake-token") + assert response.cookies.get("bangui_session") is not None + + +async def test_jail_list_uses_injected_jail_service_and_auth(tmp_path: Path) -> None: + settings = Settings( + database_path=str(tmp_path / "test_bangui.db"), + fail2ban_socket="/tmp/fake_fail2ban.sock", + fail2ban_config_dir=str(tmp_path / "fail2ban"), + session_secret="test-secret-key-do-not-use-in-production", + session_duration_minutes=60, + timezone="UTC", + log_level="debug", + ) + + app, db = await _build_app(settings) + def _fake_auth_service() -> FakeAuthService: + return FakeAuthService() + + def _fake_jail_service() -> FakeJailService: + return FakeJailService() + + app.dependency_overrides[get_auth_service] = _fake_auth_service + app.dependency_overrides[get_jail_service] = _fake_jail_service + + transport = ASGITransport(app=app) + async with AsyncClient( + transport=transport, + base_url="http://test", + ) as client: + response = await client.get( + "/api/jails", + headers={"Cookie": "bangui_session=fake-token"}, + ) + + await db.close() + + assert response.status_code == 200 + assert response.json() == {"jails": [], "total": 0}