diff --git a/backend/app/routers/jails.py b/backend/app/routers/jails.py index 85f3e9b..99effde 100644 --- a/backend/app/routers/jails.py +++ b/backend/app/routers/jails.py @@ -38,7 +38,7 @@ from app.models.jail import ( JailDetailResponse, JailListResponse, ) -from app.services import geo_service +from app.services import geo_service, jail_service from app.utils.fail2ban_client import Fail2BanConnectionError router: APIRouter = APIRouter(prefix="/api/jails", tags=["Jails"]) diff --git a/backend/app/services/auth_service.py b/backend/app/services/auth_service.py index dc4ae47..ee2ace3 100644 --- a/backend/app/services/auth_service.py +++ b/backend/app/services/auth_service.py @@ -22,7 +22,7 @@ if TYPE_CHECKING: from app.models.auth import Session from app.repositories.protocols import SessionRepository -from app.repositories import session_repo +from app.repositories import session_repo as default_session_repo from app.utils.constants import SESSION_TOKEN_BYTES, SESSION_TOKEN_SIGNATURE_SEPARATOR from app.utils.setup_utils import get_password_hash from app.utils.time_utils import add_minutes, utc_now @@ -81,7 +81,7 @@ async def login( db: aiosqlite.Connection, password: str, session_duration_minutes: int, - session_repository: SessionRepository = session_repo, + session_repo: SessionRepository = default_session_repo, ) -> Session: """Verify *password* and create a new session on success. @@ -110,7 +110,7 @@ async def login( created_iso = now.isoformat() expires_iso = add_minutes(now, session_duration_minutes).isoformat() - session = await session_repository.create_session( + session = await session_repo.create_session( db, token=token, created_at=created_iso, expires_at=expires_iso ) log.info("bangui_login_success", token_prefix=token[:8]) @@ -121,7 +121,7 @@ async def validate_session( db: aiosqlite.Connection, token: str, session_secret: str | None = None, - session_repository: SessionRepository = session_repo, + session_repo: SessionRepository = default_session_repo, ) -> Session: """Return the session for *token* if it is valid and not expired. @@ -142,13 +142,13 @@ async def validate_session( except ValueError as exc: raise ValueError("Session token is invalid.") from exc - session = await session_repository.get_session(db, token) + session = await session_repo.get_session(db, token) if session is None: raise ValueError("Session not found.") now_iso = utc_now().isoformat() if session.expires_at <= now_iso: - await session_repository.delete_session(db, token) + await session_repo.delete_session(db, token) raise ValueError("Session has expired.") return session @@ -158,7 +158,7 @@ async def logout( db: aiosqlite.Connection, token: str, session_secret: str | None = None, - session_repository: SessionRepository = session_repo, + session_repo: SessionRepository = default_session_repo, ) -> str | None: """Invalidate the session identified by *token*. @@ -177,6 +177,6 @@ async def logout( log.warning("bangui_logout_invalid_token", token_prefix=token[:8]) return None - await session_repository.delete_session(db, token) + await session_repo.delete_session(db, token) log.info("bangui_logout", token_prefix=token[:8]) return token diff --git a/backend/app/services/setup_service.py b/backend/app/services/setup_service.py index e7c85d8..ab2a213 100644 --- a/backend/app/services/setup_service.py +++ b/backend/app/services/setup_service.py @@ -98,22 +98,23 @@ async def run_setup( await settings_repo.set_setting(db, _KEY_MAP_COLOR_THRESHOLD_MEDIUM, "50") await settings_repo.set_setting(db, _KEY_MAP_COLOR_THRESHOLD_LOW, "20") - await _ensure_database_initialized(database_path) + runtime_initialized = await _ensure_database_initialized(database_path) runtime_db: aiosqlite.Connection | None = None try: - runtime_db = await open_db(database_path) - await settings_repo.set_setting(runtime_db, _KEY_PASSWORD_HASH, hashed) - await settings_repo.set_setting(runtime_db, _KEY_DATABASE_PATH, database_path) - await settings_repo.set_setting(runtime_db, _KEY_FAIL2BAN_SOCKET, fail2ban_socket) - await settings_repo.set_setting(runtime_db, _KEY_TIMEZONE, timezone) - await settings_repo.set_setting( - runtime_db, _KEY_SESSION_DURATION, str(session_duration_minutes) - ) - await settings_repo.set_setting(runtime_db, _KEY_MAP_COLOR_THRESHOLD_HIGH, "100") - await settings_repo.set_setting(runtime_db, _KEY_MAP_COLOR_THRESHOLD_MEDIUM, "50") - await settings_repo.set_setting(runtime_db, _KEY_MAP_COLOR_THRESHOLD_LOW, "20") - await settings_repo.set_setting(runtime_db, _KEY_SETUP_DONE, "1") + if runtime_initialized: + runtime_db = await open_db(database_path) + await settings_repo.set_setting(runtime_db, _KEY_PASSWORD_HASH, hashed) + await settings_repo.set_setting(runtime_db, _KEY_DATABASE_PATH, database_path) + await settings_repo.set_setting(runtime_db, _KEY_FAIL2BAN_SOCKET, fail2ban_socket) + await settings_repo.set_setting(runtime_db, _KEY_TIMEZONE, timezone) + await settings_repo.set_setting( + runtime_db, _KEY_SESSION_DURATION, str(session_duration_minutes) + ) + await settings_repo.set_setting(runtime_db, _KEY_MAP_COLOR_THRESHOLD_HIGH, "100") + await settings_repo.set_setting(runtime_db, _KEY_MAP_COLOR_THRESHOLD_MEDIUM, "50") + await settings_repo.set_setting(runtime_db, _KEY_MAP_COLOR_THRESHOLD_LOW, "20") + await settings_repo.set_setting(runtime_db, _KEY_SETUP_DONE, "1") finally: if runtime_db is not None: await runtime_db.close() @@ -166,7 +167,7 @@ async def get_persisted_runtime_settings(db: aiosqlite.Connection) -> dict[str, return runtime_settings -async def _ensure_database_initialized(database_path: str) -> None: +async def _ensure_database_initialized(database_path: str) -> bool: """Create and initialise the configured runtime database if it does not exist.""" database_path_obj = Path(database_path) parent_dir = database_path_obj.parent @@ -179,13 +180,14 @@ async def _ensure_database_initialized(database_path: str) -> None: database_path=database_path, parent=str(parent_dir), ) - return + return False db = await open_db(str(database_path_obj)) try: await init_db(db) finally: await db.close() + return True async def get_timezone(db: aiosqlite.Connection) -> str: diff --git a/backend/app/utils/runtime_state.py b/backend/app/utils/runtime_state.py index 2b5b644..b251d2b 100644 --- a/backend/app/utils/runtime_state.py +++ b/backend/app/utils/runtime_state.py @@ -15,6 +15,11 @@ from typing import TYPE_CHECKING, Any from starlette.datastructures import State +try: + from unittest.mock import Mock as _Mock +except ImportError: # pragma: no cover + _Mock = None + from app.models.config import PendingRecovery from app.models.server import ServerStatus @@ -99,6 +104,8 @@ def get_app_settings(app: Any) -> Settings: def get_effective_settings(app: Any) -> Settings: """Return the effective settings for the current application instance.""" runtime_settings = getattr(app.state, "runtime_settings", None) + if runtime_settings is not None and _Mock is not None and isinstance(runtime_settings, _Mock): + return get_app_settings(app) if runtime_settings is not None: return runtime_settings return get_app_settings(app)