Introduce service/repository dependency protocols and tests
This commit is contained in:
@@ -48,6 +48,8 @@ Reference: `Docs/Refactoring.md` for full analysis of each issue.
|
|||||||
- Issue: Routers and higher-level services currently import concrete service modules directly, which prevents clean substitution and dependency override testing.
|
- Issue: Routers and higher-level services currently import concrete service modules directly, which prevents clean substitution and dependency override testing.
|
||||||
- Propose: Define protocols or abstract base classes for major services and repositories, then wire concrete implementations through FastAPI dependency providers.
|
- Propose: Define protocols or abstract base classes for major services and repositories, then wire concrete implementations through FastAPI dependency providers.
|
||||||
- Test: Add tests that override a service dependency with a fake implementation and verify the router behavior remains correct.
|
- Test: Add tests that override a service dependency with a fake implementation and verify the router behavior remains correct.
|
||||||
|
- Status: completed
|
||||||
|
- Completed: Added service/repository protocols, injected auth/jail services via FastAPI dependencies, and added router tests for dependency overrides.
|
||||||
|
|
||||||
7. Move operational orchestration out of routers and into service/task layer
|
7. Move operational orchestration out of routers and into service/task layer
|
||||||
- Goal: Keep routers thin and move operational control flow into service or task components.
|
- Goal: Keep routers thin and move operational control flow into service or task components.
|
||||||
|
|||||||
@@ -20,6 +20,8 @@ from app.config import Settings
|
|||||||
from app.models.auth import Session
|
from app.models.auth import Session
|
||||||
from app.models.config import PendingRecovery
|
from app.models.config import PendingRecovery
|
||||||
from app.models.server import ServerStatus
|
from app.models.server import ServerStatus
|
||||||
|
from app.repositories.protocols import SessionRepository
|
||||||
|
from app.services.protocols import AuthService, JailService
|
||||||
from app.utils.runtime_state import RuntimeState, get_effective_settings
|
from app.utils.runtime_state import RuntimeState, get_effective_settings
|
||||||
from app.utils.session_cache import SessionCache
|
from app.utils.session_cache import SessionCache
|
||||||
|
|
||||||
@@ -169,6 +171,27 @@ async def get_session_cache(request: Request) -> SessionCache:
|
|||||||
return session_cache
|
return session_cache
|
||||||
|
|
||||||
|
|
||||||
|
async def get_auth_service() -> AuthService:
|
||||||
|
"""Provide the concrete authentication service implementation."""
|
||||||
|
from app.services import auth_service # noqa: PLC0415
|
||||||
|
|
||||||
|
return cast(AuthService, auth_service)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_jail_service() -> JailService:
|
||||||
|
"""Provide the concrete jail service implementation."""
|
||||||
|
from app.services import jail_service # noqa: PLC0415
|
||||||
|
|
||||||
|
return cast(JailService, jail_service)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_session_repo() -> SessionRepository:
|
||||||
|
"""Provide the concrete session repository implementation."""
|
||||||
|
from app.repositories import session_repo # noqa: PLC0415
|
||||||
|
|
||||||
|
return session_repo
|
||||||
|
|
||||||
|
|
||||||
async def get_app_state(request: Request) -> AppState:
|
async def get_app_state(request: Request) -> AppState:
|
||||||
"""Provide the application state object for the current request."""
|
"""Provide the application state object for the current request."""
|
||||||
return cast("AppState", request.app.state)
|
return cast("AppState", request.app.state)
|
||||||
@@ -194,6 +217,8 @@ async def require_auth(
|
|||||||
db: Annotated[aiosqlite.Connection, Depends(get_db)],
|
db: Annotated[aiosqlite.Connection, Depends(get_db)],
|
||||||
settings: Annotated[Settings, Depends(get_settings)],
|
settings: Annotated[Settings, Depends(get_settings)],
|
||||||
session_cache: Annotated[SessionCache, Depends(get_session_cache)],
|
session_cache: Annotated[SessionCache, Depends(get_session_cache)],
|
||||||
|
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
||||||
|
session_repo: Annotated[SessionRepository, Depends(get_session_repo)],
|
||||||
) -> Session:
|
) -> Session:
|
||||||
"""Validate the session token and return the active session.
|
"""Validate the session token and return the active session.
|
||||||
|
|
||||||
@@ -218,7 +243,6 @@ async def require_auth(
|
|||||||
Raises:
|
Raises:
|
||||||
HTTPException: 401 if no valid session token is found.
|
HTTPException: 401 if no valid session token is found.
|
||||||
"""
|
"""
|
||||||
from app.services import auth_service # noqa: PLC0415
|
|
||||||
|
|
||||||
token: str | None = request.cookies.get(_COOKIE_NAME)
|
token: str | None = request.cookies.get(_COOKIE_NAME)
|
||||||
if not token:
|
if not token:
|
||||||
@@ -240,7 +264,12 @@ async def require_auth(
|
|||||||
return cached
|
return cached
|
||||||
|
|
||||||
try:
|
try:
|
||||||
session = await auth_service.validate_session(db, token, settings.session_secret)
|
session = await auth_service.validate_session(
|
||||||
|
db,
|
||||||
|
token,
|
||||||
|
settings.session_secret,
|
||||||
|
session_repo=session_repo,
|
||||||
|
)
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
@@ -264,6 +293,9 @@ Fail2BanStartCommandDep = Annotated[str, Depends(get_fail2ban_start_command)]
|
|||||||
ServerStatusDep = Annotated[ServerStatus, Depends(get_server_status)]
|
ServerStatusDep = Annotated[ServerStatus, Depends(get_server_status)]
|
||||||
PendingRecoveryDep = Annotated[PendingRecovery | None, Depends(get_pending_recovery)]
|
PendingRecoveryDep = Annotated[PendingRecovery | None, Depends(get_pending_recovery)]
|
||||||
SessionCacheDep = Annotated[SessionCache, Depends(get_session_cache)]
|
SessionCacheDep = Annotated[SessionCache, Depends(get_session_cache)]
|
||||||
|
AuthServiceDep = Annotated[AuthService, Depends(get_auth_service)]
|
||||||
|
JailServiceDep = Annotated[JailService, Depends(get_jail_service)]
|
||||||
|
SessionRepoDep = Annotated[SessionRepository, Depends(get_session_repo)]
|
||||||
AppStateDep = Annotated[AppState, Depends(get_app_state)]
|
AppStateDep = Annotated[AppState, Depends(get_app_state)]
|
||||||
AppDep = Annotated[FastAPI, Depends(get_app)]
|
AppDep = Annotated[FastAPI, Depends(get_app)]
|
||||||
AuthDep = Annotated[Session, Depends(require_auth)]
|
AuthDep = Annotated[Session, Depends(require_auth)]
|
||||||
|
|||||||
47
backend/app/repositories/protocols.py
Normal file
47
backend/app/repositories/protocols.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
"""Repository interface protocols for dependency injection.
|
||||||
|
|
||||||
|
Routers and services can depend on these abstractions instead of concrete
|
||||||
|
module implementations, making the backend easier to test and extend.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
import aiosqlite
|
||||||
|
|
||||||
|
from app.models.auth import Session
|
||||||
|
|
||||||
|
|
||||||
|
class SessionRepository(Protocol):
|
||||||
|
"""Protocol for session persistence operations."""
|
||||||
|
|
||||||
|
async def create_session(
|
||||||
|
self,
|
||||||
|
db: aiosqlite.Connection,
|
||||||
|
token: str,
|
||||||
|
created_at: str,
|
||||||
|
expires_at: str,
|
||||||
|
) -> Session:
|
||||||
|
...
|
||||||
|
|
||||||
|
async def get_session(
|
||||||
|
self,
|
||||||
|
db: aiosqlite.Connection,
|
||||||
|
token: str,
|
||||||
|
) -> Session | None:
|
||||||
|
...
|
||||||
|
|
||||||
|
async def delete_session(
|
||||||
|
self,
|
||||||
|
db: aiosqlite.Connection,
|
||||||
|
token: str,
|
||||||
|
) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
async def delete_expired_sessions(
|
||||||
|
self,
|
||||||
|
db: aiosqlite.Connection,
|
||||||
|
now_iso: str,
|
||||||
|
) -> int:
|
||||||
|
...
|
||||||
@@ -12,9 +12,15 @@ from __future__ import annotations
|
|||||||
import structlog
|
import structlog
|
||||||
from fastapi import APIRouter, HTTPException, Request, Response, status
|
from fastapi import APIRouter, HTTPException, Request, Response, status
|
||||||
|
|
||||||
from app.dependencies import DbDep, SessionCacheDep, SettingsDep
|
from app.dependencies import (
|
||||||
|
AuthServiceDep,
|
||||||
|
DbDep,
|
||||||
|
SessionCacheDep,
|
||||||
|
SettingsDep,
|
||||||
|
SessionRepoDep,
|
||||||
|
)
|
||||||
from app.models.auth import LoginRequest, LoginResponse, LogoutResponse
|
from app.models.auth import LoginRequest, LoginResponse, LogoutResponse
|
||||||
from app.services import auth_service
|
from app.services.auth_service import sign_session_token
|
||||||
|
|
||||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||||
|
|
||||||
@@ -33,6 +39,8 @@ async def login(
|
|||||||
response: Response,
|
response: Response,
|
||||||
db: DbDep,
|
db: DbDep,
|
||||||
settings: SettingsDep,
|
settings: SettingsDep,
|
||||||
|
auth_service: AuthServiceDep,
|
||||||
|
session_repo: SessionRepoDep,
|
||||||
) -> LoginResponse:
|
) -> LoginResponse:
|
||||||
"""Verify the master password and return a session token.
|
"""Verify the master password and return a session token.
|
||||||
|
|
||||||
@@ -56,6 +64,7 @@ async def login(
|
|||||||
db,
|
db,
|
||||||
password=body.password,
|
password=body.password,
|
||||||
session_duration_minutes=settings.session_duration_minutes,
|
session_duration_minutes=settings.session_duration_minutes,
|
||||||
|
session_repo=session_repo,
|
||||||
)
|
)
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@@ -63,7 +72,7 @@ async def login(
|
|||||||
detail=str(exc),
|
detail=str(exc),
|
||||||
) from exc
|
) from exc
|
||||||
|
|
||||||
signed_token = auth_service.sign_session_token(
|
signed_token = sign_session_token(
|
||||||
session.token,
|
session.token,
|
||||||
settings.session_secret,
|
settings.session_secret,
|
||||||
)
|
)
|
||||||
@@ -89,6 +98,8 @@ async def logout(
|
|||||||
db: DbDep,
|
db: DbDep,
|
||||||
settings: SettingsDep,
|
settings: SettingsDep,
|
||||||
session_cache: SessionCacheDep,
|
session_cache: SessionCacheDep,
|
||||||
|
auth_service: AuthServiceDep,
|
||||||
|
session_repo: SessionRepoDep,
|
||||||
) -> LogoutResponse:
|
) -> LogoutResponse:
|
||||||
"""Invalidate the active session.
|
"""Invalidate the active session.
|
||||||
|
|
||||||
@@ -107,7 +118,12 @@ async def logout(
|
|||||||
"""
|
"""
|
||||||
token = _extract_token(request)
|
token = _extract_token(request)
|
||||||
if token:
|
if token:
|
||||||
raw_token = await auth_service.logout(db, token, settings.session_secret)
|
raw_token = await auth_service.logout(
|
||||||
|
db,
|
||||||
|
token,
|
||||||
|
settings.session_secret,
|
||||||
|
session_repo=session_repo,
|
||||||
|
)
|
||||||
if raw_token:
|
if raw_token:
|
||||||
session_cache.invalidate(raw_token)
|
session_cache.invalidate(raw_token)
|
||||||
session_cache.invalidate(token)
|
session_cache.invalidate(token)
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from app.dependencies import (
|
|||||||
DbDep,
|
DbDep,
|
||||||
Fail2BanSocketDep,
|
Fail2BanSocketDep,
|
||||||
HttpSessionDep,
|
HttpSessionDep,
|
||||||
|
JailServiceDep,
|
||||||
)
|
)
|
||||||
from app.exceptions import JailNotFoundError, JailOperationError
|
from app.exceptions import JailNotFoundError, JailOperationError
|
||||||
from app.models.ban import JailBannedIpsResponse
|
from app.models.ban import JailBannedIpsResponse
|
||||||
@@ -37,7 +38,7 @@ from app.models.jail import (
|
|||||||
JailDetailResponse,
|
JailDetailResponse,
|
||||||
JailListResponse,
|
JailListResponse,
|
||||||
)
|
)
|
||||||
from app.services import geo_service, jail_service
|
from app.services import geo_service
|
||||||
from app.utils.fail2ban_client import Fail2BanConnectionError
|
from app.utils.fail2ban_client import Fail2BanConnectionError
|
||||||
|
|
||||||
router: APIRouter = APIRouter(prefix="/api/jails", tags=["Jails"])
|
router: APIRouter = APIRouter(prefix="/api/jails", tags=["Jails"])
|
||||||
@@ -107,6 +108,7 @@ def _conflict(message: str) -> HTTPException:
|
|||||||
async def get_jails(
|
async def get_jails(
|
||||||
_auth: AuthDep,
|
_auth: AuthDep,
|
||||||
socket_path: Fail2BanSocketDep,
|
socket_path: Fail2BanSocketDep,
|
||||||
|
jail_service: JailServiceDep,
|
||||||
) -> JailListResponse:
|
) -> JailListResponse:
|
||||||
"""Return a summary of every active fail2ban jail.
|
"""Return a summary of every active fail2ban jail.
|
||||||
|
|
||||||
@@ -135,6 +137,7 @@ async def get_jail(
|
|||||||
_auth: AuthDep,
|
_auth: AuthDep,
|
||||||
name: _NamePath,
|
name: _NamePath,
|
||||||
socket_path: Fail2BanSocketDep,
|
socket_path: Fail2BanSocketDep,
|
||||||
|
jail_service: JailServiceDep,
|
||||||
) -> JailDetailResponse:
|
) -> JailDetailResponse:
|
||||||
"""Return the complete configuration and runtime state for one jail.
|
"""Return the complete configuration and runtime state for one jail.
|
||||||
|
|
||||||
@@ -174,6 +177,7 @@ async def get_jail(
|
|||||||
async def reload_all_jails(
|
async def reload_all_jails(
|
||||||
_auth: AuthDep,
|
_auth: AuthDep,
|
||||||
socket_path: Fail2BanSocketDep,
|
socket_path: Fail2BanSocketDep,
|
||||||
|
jail_service: JailServiceDep,
|
||||||
) -> JailCommandResponse:
|
) -> JailCommandResponse:
|
||||||
"""Reload every fail2ban jail to apply configuration changes.
|
"""Reload every fail2ban jail to apply configuration changes.
|
||||||
|
|
||||||
@@ -208,6 +212,7 @@ async def start_jail(
|
|||||||
_auth: AuthDep,
|
_auth: AuthDep,
|
||||||
name: _NamePath,
|
name: _NamePath,
|
||||||
socket_path: Fail2BanSocketDep,
|
socket_path: Fail2BanSocketDep,
|
||||||
|
jail_service: JailServiceDep,
|
||||||
) -> JailCommandResponse:
|
) -> JailCommandResponse:
|
||||||
"""Start a fail2ban jail that is currently stopped.
|
"""Start a fail2ban jail that is currently stopped.
|
||||||
|
|
||||||
@@ -243,6 +248,7 @@ async def stop_jail(
|
|||||||
_auth: AuthDep,
|
_auth: AuthDep,
|
||||||
name: _NamePath,
|
name: _NamePath,
|
||||||
socket_path: Fail2BanSocketDep,
|
socket_path: Fail2BanSocketDep,
|
||||||
|
jail_service: JailServiceDep,
|
||||||
) -> JailCommandResponse:
|
) -> JailCommandResponse:
|
||||||
"""Stop a running fail2ban jail.
|
"""Stop a running fail2ban jail.
|
||||||
|
|
||||||
@@ -279,6 +285,7 @@ async def toggle_idle(
|
|||||||
_auth: AuthDep,
|
_auth: AuthDep,
|
||||||
name: _NamePath,
|
name: _NamePath,
|
||||||
socket_path: Fail2BanSocketDep,
|
socket_path: Fail2BanSocketDep,
|
||||||
|
jail_service: JailServiceDep,
|
||||||
on: bool = Body(..., description="``true`` to enable idle, ``false`` to disable."),
|
on: bool = Body(..., description="``true`` to enable idle, ``false`` to disable."),
|
||||||
) -> JailCommandResponse:
|
) -> JailCommandResponse:
|
||||||
"""Enable or disable idle mode for a fail2ban jail.
|
"""Enable or disable idle mode for a fail2ban jail.
|
||||||
@@ -323,6 +330,7 @@ async def reload_jail(
|
|||||||
_auth: AuthDep,
|
_auth: AuthDep,
|
||||||
name: _NamePath,
|
name: _NamePath,
|
||||||
socket_path: Fail2BanSocketDep,
|
socket_path: Fail2BanSocketDep,
|
||||||
|
jail_service: JailServiceDep,
|
||||||
) -> JailCommandResponse:
|
) -> JailCommandResponse:
|
||||||
"""Reload a single fail2ban jail to pick up configuration changes.
|
"""Reload a single fail2ban jail to pick up configuration changes.
|
||||||
|
|
||||||
@@ -371,6 +379,7 @@ async def get_ignore_list(
|
|||||||
_auth: AuthDep,
|
_auth: AuthDep,
|
||||||
name: _NamePath,
|
name: _NamePath,
|
||||||
socket_path: Fail2BanSocketDep,
|
socket_path: Fail2BanSocketDep,
|
||||||
|
jail_service: JailServiceDep,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
"""Return the current ignore list (IP whitelist) for a fail2ban jail.
|
"""Return the current ignore list (IP whitelist) for a fail2ban jail.
|
||||||
|
|
||||||
@@ -404,6 +413,7 @@ async def add_ignore_ip(
|
|||||||
name: _NamePath,
|
name: _NamePath,
|
||||||
body: IgnoreIpRequest,
|
body: IgnoreIpRequest,
|
||||||
socket_path: Fail2BanSocketDep,
|
socket_path: Fail2BanSocketDep,
|
||||||
|
jail_service: JailServiceDep,
|
||||||
) -> JailCommandResponse:
|
) -> JailCommandResponse:
|
||||||
"""Add an IP address or CIDR network to a jail's ignore list.
|
"""Add an IP address or CIDR network to a jail's ignore list.
|
||||||
|
|
||||||
@@ -453,6 +463,7 @@ async def del_ignore_ip(
|
|||||||
name: _NamePath,
|
name: _NamePath,
|
||||||
body: IgnoreIpRequest,
|
body: IgnoreIpRequest,
|
||||||
socket_path: Fail2BanSocketDep,
|
socket_path: Fail2BanSocketDep,
|
||||||
|
jail_service: JailServiceDep,
|
||||||
) -> JailCommandResponse:
|
) -> JailCommandResponse:
|
||||||
"""Remove an IP address or CIDR network from a jail's ignore list.
|
"""Remove an IP address or CIDR network from a jail's ignore list.
|
||||||
|
|
||||||
@@ -492,6 +503,7 @@ async def toggle_ignore_self(
|
|||||||
_auth: AuthDep,
|
_auth: AuthDep,
|
||||||
name: _NamePath,
|
name: _NamePath,
|
||||||
socket_path: Fail2BanSocketDep,
|
socket_path: Fail2BanSocketDep,
|
||||||
|
jail_service: JailServiceDep,
|
||||||
on: bool = Body(..., description="``true`` to enable ignoreself, ``false`` to disable."),
|
on: bool = Body(..., description="``true`` to enable ignoreself, ``false`` to disable."),
|
||||||
) -> JailCommandResponse:
|
) -> JailCommandResponse:
|
||||||
"""Toggle the ``ignoreself`` flag for a fail2ban jail.
|
"""Toggle the ``ignoreself`` flag for a fail2ban jail.
|
||||||
@@ -543,6 +555,7 @@ async def get_jail_banned_ips(
|
|||||||
name: _NamePath,
|
name: _NamePath,
|
||||||
socket_path: Fail2BanSocketDep,
|
socket_path: Fail2BanSocketDep,
|
||||||
http_session: HttpSessionDep,
|
http_session: HttpSessionDep,
|
||||||
|
jail_service: JailServiceDep,
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
page_size: int = 25,
|
page_size: int = 25,
|
||||||
search: str | None = None,
|
search: str | None = None,
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ if TYPE_CHECKING:
|
|||||||
import aiosqlite
|
import aiosqlite
|
||||||
|
|
||||||
from app.models.auth import Session
|
from app.models.auth import Session
|
||||||
|
from app.repositories.protocols import SessionRepository
|
||||||
|
|
||||||
from app.repositories import session_repo
|
from app.repositories import session_repo
|
||||||
from app.utils.constants import SESSION_TOKEN_BYTES, SESSION_TOKEN_SIGNATURE_SEPARATOR
|
from app.utils.constants import SESSION_TOKEN_BYTES, SESSION_TOKEN_SIGNATURE_SEPARATOR
|
||||||
@@ -80,6 +81,7 @@ async def login(
|
|||||||
db: aiosqlite.Connection,
|
db: aiosqlite.Connection,
|
||||||
password: str,
|
password: str,
|
||||||
session_duration_minutes: int,
|
session_duration_minutes: int,
|
||||||
|
session_repository: SessionRepository = session_repo,
|
||||||
) -> Session:
|
) -> Session:
|
||||||
"""Verify *password* and create a new session on success.
|
"""Verify *password* and create a new session on success.
|
||||||
|
|
||||||
@@ -108,7 +110,7 @@ async def login(
|
|||||||
created_iso = now.isoformat()
|
created_iso = now.isoformat()
|
||||||
expires_iso = add_minutes(now, session_duration_minutes).isoformat()
|
expires_iso = add_minutes(now, session_duration_minutes).isoformat()
|
||||||
|
|
||||||
session = await session_repo.create_session(
|
session = await session_repository.create_session(
|
||||||
db, token=token, created_at=created_iso, expires_at=expires_iso
|
db, token=token, created_at=created_iso, expires_at=expires_iso
|
||||||
)
|
)
|
||||||
log.info("bangui_login_success", token_prefix=token[:8])
|
log.info("bangui_login_success", token_prefix=token[:8])
|
||||||
@@ -119,6 +121,7 @@ async def validate_session(
|
|||||||
db: aiosqlite.Connection,
|
db: aiosqlite.Connection,
|
||||||
token: str,
|
token: str,
|
||||||
session_secret: str | None = None,
|
session_secret: str | None = None,
|
||||||
|
session_repository: SessionRepository = session_repo,
|
||||||
) -> Session:
|
) -> Session:
|
||||||
"""Return the session for *token* if it is valid and not expired.
|
"""Return the session for *token* if it is valid and not expired.
|
||||||
|
|
||||||
@@ -139,13 +142,13 @@ async def validate_session(
|
|||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
raise ValueError("Session token is invalid.") from exc
|
raise ValueError("Session token is invalid.") from exc
|
||||||
|
|
||||||
session = await session_repo.get_session(db, token)
|
session = await session_repository.get_session(db, token)
|
||||||
if session is None:
|
if session is None:
|
||||||
raise ValueError("Session not found.")
|
raise ValueError("Session not found.")
|
||||||
|
|
||||||
now_iso = utc_now().isoformat()
|
now_iso = utc_now().isoformat()
|
||||||
if session.expires_at <= now_iso:
|
if session.expires_at <= now_iso:
|
||||||
await session_repo.delete_session(db, token)
|
await session_repository.delete_session(db, token)
|
||||||
raise ValueError("Session has expired.")
|
raise ValueError("Session has expired.")
|
||||||
|
|
||||||
return session
|
return session
|
||||||
@@ -155,6 +158,7 @@ async def logout(
|
|||||||
db: aiosqlite.Connection,
|
db: aiosqlite.Connection,
|
||||||
token: str,
|
token: str,
|
||||||
session_secret: str | None = None,
|
session_secret: str | None = None,
|
||||||
|
session_repository: SessionRepository = session_repo,
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
"""Invalidate the session identified by *token*.
|
"""Invalidate the session identified by *token*.
|
||||||
|
|
||||||
@@ -173,6 +177,6 @@ async def logout(
|
|||||||
log.warning("bangui_logout_invalid_token", token_prefix=token[:8])
|
log.warning("bangui_logout_invalid_token", token_prefix=token[:8])
|
||||||
return None
|
return None
|
||||||
|
|
||||||
await session_repo.delete_session(db, token)
|
await session_repository.delete_session(db, token)
|
||||||
log.info("bangui_logout", token_prefix=token[:8])
|
log.info("bangui_logout", token_prefix=token[:8])
|
||||||
return token
|
return token
|
||||||
|
|||||||
105
backend/app/services/protocols.py
Normal file
105
backend/app/services/protocols.py
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
"""Service interface protocols for dependency injection.
|
||||||
|
|
||||||
|
These structural protocols define the public contract that routers and higher
|
||||||
|
layers depend on, without binding them to concrete module implementations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
import aiosqlite
|
||||||
|
|
||||||
|
from app.models.auth import Session
|
||||||
|
from app.models.ban import JailBannedIpsResponse
|
||||||
|
from app.models.jail import JailDetailResponse, JailListResponse
|
||||||
|
|
||||||
|
|
||||||
|
class AuthService(Protocol):
|
||||||
|
"""Protocol for authentication service operations."""
|
||||||
|
|
||||||
|
async def login(
|
||||||
|
self,
|
||||||
|
db: aiosqlite.Connection,
|
||||||
|
password: str,
|
||||||
|
session_duration_minutes: int,
|
||||||
|
session_repo: object | None = None,
|
||||||
|
) -> Session:
|
||||||
|
...
|
||||||
|
|
||||||
|
async def validate_session(
|
||||||
|
self,
|
||||||
|
db: aiosqlite.Connection,
|
||||||
|
token: str,
|
||||||
|
session_secret: str | None = None,
|
||||||
|
session_repo: object | None = None,
|
||||||
|
) -> Session:
|
||||||
|
...
|
||||||
|
|
||||||
|
async def logout(
|
||||||
|
self,
|
||||||
|
db: aiosqlite.Connection,
|
||||||
|
token: str,
|
||||||
|
session_secret: str | None = None,
|
||||||
|
session_repo: object | None = None,
|
||||||
|
) -> str | None:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class JailService(Protocol):
|
||||||
|
"""Protocol for jail management service operations."""
|
||||||
|
|
||||||
|
async def list_jails(self, socket_path: str) -> JailListResponse:
|
||||||
|
...
|
||||||
|
|
||||||
|
async def get_jail(self, socket_path: str, name: str) -> JailDetailResponse:
|
||||||
|
...
|
||||||
|
|
||||||
|
async def reload_all(self, socket_path: str) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
async def start_jail(self, socket_path: str, name: str) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
async def stop_jail(self, socket_path: str, name: str) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
async def set_idle(self, socket_path: str, name: str, *, on: bool) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
async def reload_jail(self, socket_path: str, name: str) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
async def get_ignore_list(self, socket_path: str, name: str) -> list[str]:
|
||||||
|
...
|
||||||
|
|
||||||
|
async def add_ignore_ip(self, socket_path: str, name: str, ip: str) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
async def del_ignore_ip(self, socket_path: str, name: str, ip: str) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
async def set_ignore_self(self, socket_path: str, name: str, *, on: bool) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
async def get_jail_banned_ips(
|
||||||
|
self,
|
||||||
|
socket_path: str,
|
||||||
|
jail_name: str,
|
||||||
|
page: int,
|
||||||
|
page_size: int,
|
||||||
|
search: str | None = None,
|
||||||
|
*,
|
||||||
|
geo_batch_lookup: object,
|
||||||
|
http_session: object,
|
||||||
|
app_db: aiosqlite.Connection,
|
||||||
|
) -> JailBannedIpsResponse:
|
||||||
|
...
|
||||||
|
|
||||||
|
async def lookup_ip(
|
||||||
|
self,
|
||||||
|
socket_path: str,
|
||||||
|
ip: str,
|
||||||
|
geo_enricher: object,
|
||||||
|
) -> object:
|
||||||
|
...
|
||||||
189
backend/tests/test_routers/test_dependency_injection.py
Normal file
189
backend/tests/test_routers/test_dependency_injection.py
Normal file
@@ -0,0 +1,189 @@
|
|||||||
|
"""Router dependency injection tests.
|
||||||
|
|
||||||
|
These tests verify that routers can consume service abstractions via FastAPI
|
||||||
|
dependencies and that those dependencies can be overridden cleanly for unit
|
||||||
|
testing without touching concrete implementations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import aiosqlite
|
||||||
|
from httpx import ASGITransport, AsyncClient
|
||||||
|
|
||||||
|
from app.config import Settings
|
||||||
|
from app.db import init_db
|
||||||
|
from app.dependencies import get_auth_service, get_jail_service
|
||||||
|
from app.main import create_app
|
||||||
|
from app.models.auth import Session
|
||||||
|
from app.models.jail import JailListResponse
|
||||||
|
from app.utils.setup_state import set_setup_complete_cache
|
||||||
|
|
||||||
|
|
||||||
|
class FakeAuthService:
|
||||||
|
async def login(
|
||||||
|
self,
|
||||||
|
_db: aiosqlite.Connection,
|
||||||
|
password: str,
|
||||||
|
session_duration_minutes: int,
|
||||||
|
session_repo: object | None = None,
|
||||||
|
) -> Session:
|
||||||
|
return Session(
|
||||||
|
id=1,
|
||||||
|
token="fake-token",
|
||||||
|
created_at="2025-01-01T00:00:00Z",
|
||||||
|
expires_at="2099-01-01T00:00:00Z",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def validate_session(
|
||||||
|
self,
|
||||||
|
_db: aiosqlite.Connection,
|
||||||
|
token: str,
|
||||||
|
session_secret: str | None = None,
|
||||||
|
session_repo: object | None = None,
|
||||||
|
) -> Session:
|
||||||
|
return Session(
|
||||||
|
id=1,
|
||||||
|
token=token,
|
||||||
|
created_at="2025-01-01T00:00:00Z",
|
||||||
|
expires_at="2099-01-01T00:00:00Z",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def logout(
|
||||||
|
self,
|
||||||
|
_db: aiosqlite.Connection,
|
||||||
|
token: str,
|
||||||
|
session_secret: str | None = None,
|
||||||
|
session_repo: object | None = None,
|
||||||
|
) -> str | None:
|
||||||
|
return token
|
||||||
|
|
||||||
|
|
||||||
|
class FakeJailService:
|
||||||
|
async def list_jails(self, _socket_path: str) -> JailListResponse:
|
||||||
|
return JailListResponse(jails=[], total=0)
|
||||||
|
|
||||||
|
async def get_jail(self, _socket_path: str, _name: str) -> JailListResponse:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def reload_all(self, _socket_path: str) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def start_jail(self, _socket_path: str, _name: str) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def stop_jail(self, _socket_path: str, _name: str) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def set_idle(self, socket_path: str, name: str, *, on: bool) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def reload_jail(self, socket_path: str, name: str) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_ignore_list(self, socket_path: str, name: str) -> list[str]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def add_ignore_ip(self, socket_path: str, name: str, ip: str) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def del_ignore_ip(self, socket_path: str, name: str, ip: str) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def set_ignore_self(self, socket_path: str, name: str, *, on: bool) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_jail_banned_ips(
|
||||||
|
self,
|
||||||
|
socket_path: str,
|
||||||
|
jail_name: str,
|
||||||
|
page: int,
|
||||||
|
page_size: int,
|
||||||
|
search: str | None = None,
|
||||||
|
*,
|
||||||
|
geo_batch_lookup: object,
|
||||||
|
http_session: object,
|
||||||
|
app_db: aiosqlite.Connection,
|
||||||
|
) -> JailListResponse:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
async def _build_app(settings: Settings):
|
||||||
|
app = create_app(settings=settings)
|
||||||
|
set_setup_complete_cache(app, True)
|
||||||
|
db = await aiosqlite.connect(settings.database_path)
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
await init_db(db)
|
||||||
|
return app, db
|
||||||
|
|
||||||
|
|
||||||
|
async def test_auth_login_uses_injected_auth_service(tmp_path: Path) -> None:
|
||||||
|
settings = Settings(
|
||||||
|
database_path=str(tmp_path / "test_bangui.db"),
|
||||||
|
fail2ban_socket="/tmp/fake_fail2ban.sock",
|
||||||
|
fail2ban_config_dir=str(tmp_path / "fail2ban"),
|
||||||
|
session_secret="test-secret-key-do-not-use-in-production",
|
||||||
|
session_duration_minutes=60,
|
||||||
|
timezone="UTC",
|
||||||
|
log_level="debug",
|
||||||
|
)
|
||||||
|
|
||||||
|
app, db = await _build_app(settings)
|
||||||
|
def _fake_auth_service() -> FakeAuthService:
|
||||||
|
return FakeAuthService()
|
||||||
|
|
||||||
|
app.dependency_overrides[get_auth_service] = _fake_auth_service
|
||||||
|
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=transport,
|
||||||
|
base_url="http://test",
|
||||||
|
) as client:
|
||||||
|
response = await client.post(
|
||||||
|
"/api/auth/login",
|
||||||
|
json={"password": "ignored"},
|
||||||
|
)
|
||||||
|
|
||||||
|
await db.close()
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["token"].startswith("fake-token")
|
||||||
|
assert response.cookies.get("bangui_session") is not None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_jail_list_uses_injected_jail_service_and_auth(tmp_path: Path) -> None:
|
||||||
|
settings = Settings(
|
||||||
|
database_path=str(tmp_path / "test_bangui.db"),
|
||||||
|
fail2ban_socket="/tmp/fake_fail2ban.sock",
|
||||||
|
fail2ban_config_dir=str(tmp_path / "fail2ban"),
|
||||||
|
session_secret="test-secret-key-do-not-use-in-production",
|
||||||
|
session_duration_minutes=60,
|
||||||
|
timezone="UTC",
|
||||||
|
log_level="debug",
|
||||||
|
)
|
||||||
|
|
||||||
|
app, db = await _build_app(settings)
|
||||||
|
def _fake_auth_service() -> FakeAuthService:
|
||||||
|
return FakeAuthService()
|
||||||
|
|
||||||
|
def _fake_jail_service() -> FakeJailService:
|
||||||
|
return FakeJailService()
|
||||||
|
|
||||||
|
app.dependency_overrides[get_auth_service] = _fake_auth_service
|
||||||
|
app.dependency_overrides[get_jail_service] = _fake_jail_service
|
||||||
|
|
||||||
|
transport = ASGITransport(app=app)
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=transport,
|
||||||
|
base_url="http://test",
|
||||||
|
) as client:
|
||||||
|
response = await client.get(
|
||||||
|
"/api/jails",
|
||||||
|
headers={"Cookie": "bangui_session=fake-token"},
|
||||||
|
)
|
||||||
|
|
||||||
|
await db.close()
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"jails": [], "total": 0}
|
||||||
Reference in New Issue
Block a user