Introduce service/repository dependency protocols and tests
This commit is contained in:
@@ -20,6 +20,8 @@ from app.config import Settings
|
||||
from app.models.auth import Session
|
||||
from app.models.config import PendingRecovery
|
||||
from app.models.server import ServerStatus
|
||||
from app.repositories.protocols import SessionRepository
|
||||
from app.services.protocols import AuthService, JailService
|
||||
from app.utils.runtime_state import RuntimeState, get_effective_settings
|
||||
from app.utils.session_cache import SessionCache
|
||||
|
||||
@@ -169,6 +171,27 @@ async def get_session_cache(request: Request) -> SessionCache:
|
||||
return session_cache
|
||||
|
||||
|
||||
async def get_auth_service() -> AuthService:
|
||||
"""Provide the concrete authentication service implementation."""
|
||||
from app.services import auth_service # noqa: PLC0415
|
||||
|
||||
return cast(AuthService, auth_service)
|
||||
|
||||
|
||||
async def get_jail_service() -> JailService:
|
||||
"""Provide the concrete jail service implementation."""
|
||||
from app.services import jail_service # noqa: PLC0415
|
||||
|
||||
return cast(JailService, jail_service)
|
||||
|
||||
|
||||
async def get_session_repo() -> SessionRepository:
|
||||
"""Provide the concrete session repository implementation."""
|
||||
from app.repositories import session_repo # noqa: PLC0415
|
||||
|
||||
return session_repo
|
||||
|
||||
|
||||
async def get_app_state(request: Request) -> AppState:
|
||||
"""Provide the application state object for the current request."""
|
||||
return cast("AppState", request.app.state)
|
||||
@@ -194,6 +217,8 @@ async def require_auth(
|
||||
db: Annotated[aiosqlite.Connection, Depends(get_db)],
|
||||
settings: Annotated[Settings, Depends(get_settings)],
|
||||
session_cache: Annotated[SessionCache, Depends(get_session_cache)],
|
||||
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
||||
session_repo: Annotated[SessionRepository, Depends(get_session_repo)],
|
||||
) -> Session:
|
||||
"""Validate the session token and return the active session.
|
||||
|
||||
@@ -218,7 +243,6 @@ async def require_auth(
|
||||
Raises:
|
||||
HTTPException: 401 if no valid session token is found.
|
||||
"""
|
||||
from app.services import auth_service # noqa: PLC0415
|
||||
|
||||
token: str | None = request.cookies.get(_COOKIE_NAME)
|
||||
if not token:
|
||||
@@ -240,7 +264,12 @@ async def require_auth(
|
||||
return cached
|
||||
|
||||
try:
|
||||
session = await auth_service.validate_session(db, token, settings.session_secret)
|
||||
session = await auth_service.validate_session(
|
||||
db,
|
||||
token,
|
||||
settings.session_secret,
|
||||
session_repo=session_repo,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
@@ -264,6 +293,9 @@ Fail2BanStartCommandDep = Annotated[str, Depends(get_fail2ban_start_command)]
|
||||
ServerStatusDep = Annotated[ServerStatus, Depends(get_server_status)]
|
||||
PendingRecoveryDep = Annotated[PendingRecovery | None, Depends(get_pending_recovery)]
|
||||
SessionCacheDep = Annotated[SessionCache, Depends(get_session_cache)]
|
||||
AuthServiceDep = Annotated[AuthService, Depends(get_auth_service)]
|
||||
JailServiceDep = Annotated[JailService, Depends(get_jail_service)]
|
||||
SessionRepoDep = Annotated[SessionRepository, Depends(get_session_repo)]
|
||||
AppStateDep = Annotated[AppState, Depends(get_app_state)]
|
||||
AppDep = Annotated[FastAPI, Depends(get_app)]
|
||||
AuthDep = Annotated[Session, Depends(require_auth)]
|
||||
|
||||
47
backend/app/repositories/protocols.py
Normal file
47
backend/app/repositories/protocols.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""Repository interface protocols for dependency injection.
|
||||
|
||||
Routers and services can depend on these abstractions instead of concrete
|
||||
module implementations, making the backend easier to test and extend.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
import aiosqlite
|
||||
|
||||
from app.models.auth import Session
|
||||
|
||||
|
||||
class SessionRepository(Protocol):
|
||||
"""Protocol for session persistence operations."""
|
||||
|
||||
async def create_session(
|
||||
self,
|
||||
db: aiosqlite.Connection,
|
||||
token: str,
|
||||
created_at: str,
|
||||
expires_at: str,
|
||||
) -> Session:
|
||||
...
|
||||
|
||||
async def get_session(
|
||||
self,
|
||||
db: aiosqlite.Connection,
|
||||
token: str,
|
||||
) -> Session | None:
|
||||
...
|
||||
|
||||
async def delete_session(
|
||||
self,
|
||||
db: aiosqlite.Connection,
|
||||
token: str,
|
||||
) -> None:
|
||||
...
|
||||
|
||||
async def delete_expired_sessions(
|
||||
self,
|
||||
db: aiosqlite.Connection,
|
||||
now_iso: str,
|
||||
) -> int:
|
||||
...
|
||||
@@ -12,9 +12,15 @@ from __future__ import annotations
|
||||
import structlog
|
||||
from fastapi import APIRouter, HTTPException, Request, Response, status
|
||||
|
||||
from app.dependencies import DbDep, SessionCacheDep, SettingsDep
|
||||
from app.dependencies import (
|
||||
AuthServiceDep,
|
||||
DbDep,
|
||||
SessionCacheDep,
|
||||
SettingsDep,
|
||||
SessionRepoDep,
|
||||
)
|
||||
from app.models.auth import LoginRequest, LoginResponse, LogoutResponse
|
||||
from app.services import auth_service
|
||||
from app.services.auth_service import sign_session_token
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
@@ -33,6 +39,8 @@ async def login(
|
||||
response: Response,
|
||||
db: DbDep,
|
||||
settings: SettingsDep,
|
||||
auth_service: AuthServiceDep,
|
||||
session_repo: SessionRepoDep,
|
||||
) -> LoginResponse:
|
||||
"""Verify the master password and return a session token.
|
||||
|
||||
@@ -56,6 +64,7 @@ async def login(
|
||||
db,
|
||||
password=body.password,
|
||||
session_duration_minutes=settings.session_duration_minutes,
|
||||
session_repo=session_repo,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(
|
||||
@@ -63,7 +72,7 @@ async def login(
|
||||
detail=str(exc),
|
||||
) from exc
|
||||
|
||||
signed_token = auth_service.sign_session_token(
|
||||
signed_token = sign_session_token(
|
||||
session.token,
|
||||
settings.session_secret,
|
||||
)
|
||||
@@ -89,6 +98,8 @@ async def logout(
|
||||
db: DbDep,
|
||||
settings: SettingsDep,
|
||||
session_cache: SessionCacheDep,
|
||||
auth_service: AuthServiceDep,
|
||||
session_repo: SessionRepoDep,
|
||||
) -> LogoutResponse:
|
||||
"""Invalidate the active session.
|
||||
|
||||
@@ -107,7 +118,12 @@ async def logout(
|
||||
"""
|
||||
token = _extract_token(request)
|
||||
if token:
|
||||
raw_token = await auth_service.logout(db, token, settings.session_secret)
|
||||
raw_token = await auth_service.logout(
|
||||
db,
|
||||
token,
|
||||
settings.session_secret,
|
||||
session_repo=session_repo,
|
||||
)
|
||||
if raw_token:
|
||||
session_cache.invalidate(raw_token)
|
||||
session_cache.invalidate(token)
|
||||
|
||||
@@ -28,6 +28,7 @@ from app.dependencies import (
|
||||
DbDep,
|
||||
Fail2BanSocketDep,
|
||||
HttpSessionDep,
|
||||
JailServiceDep,
|
||||
)
|
||||
from app.exceptions import JailNotFoundError, JailOperationError
|
||||
from app.models.ban import JailBannedIpsResponse
|
||||
@@ -37,7 +38,7 @@ from app.models.jail import (
|
||||
JailDetailResponse,
|
||||
JailListResponse,
|
||||
)
|
||||
from app.services import geo_service, jail_service
|
||||
from app.services import geo_service
|
||||
from app.utils.fail2ban_client import Fail2BanConnectionError
|
||||
|
||||
router: APIRouter = APIRouter(prefix="/api/jails", tags=["Jails"])
|
||||
@@ -107,6 +108,7 @@ def _conflict(message: str) -> HTTPException:
|
||||
async def get_jails(
|
||||
_auth: AuthDep,
|
||||
socket_path: Fail2BanSocketDep,
|
||||
jail_service: JailServiceDep,
|
||||
) -> JailListResponse:
|
||||
"""Return a summary of every active fail2ban jail.
|
||||
|
||||
@@ -135,6 +137,7 @@ async def get_jail(
|
||||
_auth: AuthDep,
|
||||
name: _NamePath,
|
||||
socket_path: Fail2BanSocketDep,
|
||||
jail_service: JailServiceDep,
|
||||
) -> JailDetailResponse:
|
||||
"""Return the complete configuration and runtime state for one jail.
|
||||
|
||||
@@ -174,6 +177,7 @@ async def get_jail(
|
||||
async def reload_all_jails(
|
||||
_auth: AuthDep,
|
||||
socket_path: Fail2BanSocketDep,
|
||||
jail_service: JailServiceDep,
|
||||
) -> JailCommandResponse:
|
||||
"""Reload every fail2ban jail to apply configuration changes.
|
||||
|
||||
@@ -208,6 +212,7 @@ async def start_jail(
|
||||
_auth: AuthDep,
|
||||
name: _NamePath,
|
||||
socket_path: Fail2BanSocketDep,
|
||||
jail_service: JailServiceDep,
|
||||
) -> JailCommandResponse:
|
||||
"""Start a fail2ban jail that is currently stopped.
|
||||
|
||||
@@ -243,6 +248,7 @@ async def stop_jail(
|
||||
_auth: AuthDep,
|
||||
name: _NamePath,
|
||||
socket_path: Fail2BanSocketDep,
|
||||
jail_service: JailServiceDep,
|
||||
) -> JailCommandResponse:
|
||||
"""Stop a running fail2ban jail.
|
||||
|
||||
@@ -279,6 +285,7 @@ async def toggle_idle(
|
||||
_auth: AuthDep,
|
||||
name: _NamePath,
|
||||
socket_path: Fail2BanSocketDep,
|
||||
jail_service: JailServiceDep,
|
||||
on: bool = Body(..., description="``true`` to enable idle, ``false`` to disable."),
|
||||
) -> JailCommandResponse:
|
||||
"""Enable or disable idle mode for a fail2ban jail.
|
||||
@@ -323,6 +330,7 @@ async def reload_jail(
|
||||
_auth: AuthDep,
|
||||
name: _NamePath,
|
||||
socket_path: Fail2BanSocketDep,
|
||||
jail_service: JailServiceDep,
|
||||
) -> JailCommandResponse:
|
||||
"""Reload a single fail2ban jail to pick up configuration changes.
|
||||
|
||||
@@ -371,6 +379,7 @@ async def get_ignore_list(
|
||||
_auth: AuthDep,
|
||||
name: _NamePath,
|
||||
socket_path: Fail2BanSocketDep,
|
||||
jail_service: JailServiceDep,
|
||||
) -> list[str]:
|
||||
"""Return the current ignore list (IP whitelist) for a fail2ban jail.
|
||||
|
||||
@@ -404,6 +413,7 @@ async def add_ignore_ip(
|
||||
name: _NamePath,
|
||||
body: IgnoreIpRequest,
|
||||
socket_path: Fail2BanSocketDep,
|
||||
jail_service: JailServiceDep,
|
||||
) -> JailCommandResponse:
|
||||
"""Add an IP address or CIDR network to a jail's ignore list.
|
||||
|
||||
@@ -453,6 +463,7 @@ async def del_ignore_ip(
|
||||
name: _NamePath,
|
||||
body: IgnoreIpRequest,
|
||||
socket_path: Fail2BanSocketDep,
|
||||
jail_service: JailServiceDep,
|
||||
) -> JailCommandResponse:
|
||||
"""Remove an IP address or CIDR network from a jail's ignore list.
|
||||
|
||||
@@ -492,6 +503,7 @@ async def toggle_ignore_self(
|
||||
_auth: AuthDep,
|
||||
name: _NamePath,
|
||||
socket_path: Fail2BanSocketDep,
|
||||
jail_service: JailServiceDep,
|
||||
on: bool = Body(..., description="``true`` to enable ignoreself, ``false`` to disable."),
|
||||
) -> JailCommandResponse:
|
||||
"""Toggle the ``ignoreself`` flag for a fail2ban jail.
|
||||
@@ -543,6 +555,7 @@ async def get_jail_banned_ips(
|
||||
name: _NamePath,
|
||||
socket_path: Fail2BanSocketDep,
|
||||
http_session: HttpSessionDep,
|
||||
jail_service: JailServiceDep,
|
||||
page: int = 1,
|
||||
page_size: int = 25,
|
||||
search: str | None = None,
|
||||
|
||||
@@ -20,6 +20,7 @@ if TYPE_CHECKING:
|
||||
import aiosqlite
|
||||
|
||||
from app.models.auth import Session
|
||||
from app.repositories.protocols import SessionRepository
|
||||
|
||||
from app.repositories import session_repo
|
||||
from app.utils.constants import SESSION_TOKEN_BYTES, SESSION_TOKEN_SIGNATURE_SEPARATOR
|
||||
@@ -80,6 +81,7 @@ async def login(
|
||||
db: aiosqlite.Connection,
|
||||
password: str,
|
||||
session_duration_minutes: int,
|
||||
session_repository: SessionRepository = session_repo,
|
||||
) -> Session:
|
||||
"""Verify *password* and create a new session on success.
|
||||
|
||||
@@ -108,7 +110,7 @@ async def login(
|
||||
created_iso = now.isoformat()
|
||||
expires_iso = add_minutes(now, session_duration_minutes).isoformat()
|
||||
|
||||
session = await session_repo.create_session(
|
||||
session = await session_repository.create_session(
|
||||
db, token=token, created_at=created_iso, expires_at=expires_iso
|
||||
)
|
||||
log.info("bangui_login_success", token_prefix=token[:8])
|
||||
@@ -119,6 +121,7 @@ async def validate_session(
|
||||
db: aiosqlite.Connection,
|
||||
token: str,
|
||||
session_secret: str | None = None,
|
||||
session_repository: SessionRepository = session_repo,
|
||||
) -> Session:
|
||||
"""Return the session for *token* if it is valid and not expired.
|
||||
|
||||
@@ -139,13 +142,13 @@ async def validate_session(
|
||||
except ValueError as exc:
|
||||
raise ValueError("Session token is invalid.") from exc
|
||||
|
||||
session = await session_repo.get_session(db, token)
|
||||
session = await session_repository.get_session(db, token)
|
||||
if session is None:
|
||||
raise ValueError("Session not found.")
|
||||
|
||||
now_iso = utc_now().isoformat()
|
||||
if session.expires_at <= now_iso:
|
||||
await session_repo.delete_session(db, token)
|
||||
await session_repository.delete_session(db, token)
|
||||
raise ValueError("Session has expired.")
|
||||
|
||||
return session
|
||||
@@ -155,6 +158,7 @@ async def logout(
|
||||
db: aiosqlite.Connection,
|
||||
token: str,
|
||||
session_secret: str | None = None,
|
||||
session_repository: SessionRepository = session_repo,
|
||||
) -> str | None:
|
||||
"""Invalidate the session identified by *token*.
|
||||
|
||||
@@ -173,6 +177,6 @@ async def logout(
|
||||
log.warning("bangui_logout_invalid_token", token_prefix=token[:8])
|
||||
return None
|
||||
|
||||
await session_repo.delete_session(db, token)
|
||||
await session_repository.delete_session(db, token)
|
||||
log.info("bangui_logout", token_prefix=token[:8])
|
||||
return token
|
||||
|
||||
105
backend/app/services/protocols.py
Normal file
105
backend/app/services/protocols.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""Service interface protocols for dependency injection.
|
||||
|
||||
These structural protocols define the public contract that routers and higher
|
||||
layers depend on, without binding them to concrete module implementations.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
import aiosqlite
|
||||
|
||||
from app.models.auth import Session
|
||||
from app.models.ban import JailBannedIpsResponse
|
||||
from app.models.jail import JailDetailResponse, JailListResponse
|
||||
|
||||
|
||||
class AuthService(Protocol):
|
||||
"""Protocol for authentication service operations."""
|
||||
|
||||
async def login(
|
||||
self,
|
||||
db: aiosqlite.Connection,
|
||||
password: str,
|
||||
session_duration_minutes: int,
|
||||
session_repo: object | None = None,
|
||||
) -> Session:
|
||||
...
|
||||
|
||||
async def validate_session(
|
||||
self,
|
||||
db: aiosqlite.Connection,
|
||||
token: str,
|
||||
session_secret: str | None = None,
|
||||
session_repo: object | None = None,
|
||||
) -> Session:
|
||||
...
|
||||
|
||||
async def logout(
|
||||
self,
|
||||
db: aiosqlite.Connection,
|
||||
token: str,
|
||||
session_secret: str | None = None,
|
||||
session_repo: object | None = None,
|
||||
) -> str | None:
|
||||
...
|
||||
|
||||
|
||||
class JailService(Protocol):
|
||||
"""Protocol for jail management service operations."""
|
||||
|
||||
async def list_jails(self, socket_path: str) -> JailListResponse:
|
||||
...
|
||||
|
||||
async def get_jail(self, socket_path: str, name: str) -> JailDetailResponse:
|
||||
...
|
||||
|
||||
async def reload_all(self, socket_path: str) -> None:
|
||||
...
|
||||
|
||||
async def start_jail(self, socket_path: str, name: str) -> None:
|
||||
...
|
||||
|
||||
async def stop_jail(self, socket_path: str, name: str) -> None:
|
||||
...
|
||||
|
||||
async def set_idle(self, socket_path: str, name: str, *, on: bool) -> None:
|
||||
...
|
||||
|
||||
async def reload_jail(self, socket_path: str, name: str) -> None:
|
||||
...
|
||||
|
||||
async def get_ignore_list(self, socket_path: str, name: str) -> list[str]:
|
||||
...
|
||||
|
||||
async def add_ignore_ip(self, socket_path: str, name: str, ip: str) -> None:
|
||||
...
|
||||
|
||||
async def del_ignore_ip(self, socket_path: str, name: str, ip: str) -> None:
|
||||
...
|
||||
|
||||
async def set_ignore_self(self, socket_path: str, name: str, *, on: bool) -> None:
|
||||
...
|
||||
|
||||
async def get_jail_banned_ips(
|
||||
self,
|
||||
socket_path: str,
|
||||
jail_name: str,
|
||||
page: int,
|
||||
page_size: int,
|
||||
search: str | None = None,
|
||||
*,
|
||||
geo_batch_lookup: object,
|
||||
http_session: object,
|
||||
app_db: aiosqlite.Connection,
|
||||
) -> JailBannedIpsResponse:
|
||||
...
|
||||
|
||||
async def lookup_ip(
|
||||
self,
|
||||
socket_path: str,
|
||||
ip: str,
|
||||
geo_enricher: object,
|
||||
) -> object:
|
||||
...
|
||||
Reference in New Issue
Block a user