Refactor backend auth, setup, router, and runtime state handling
This commit is contained in:
@@ -38,7 +38,7 @@ from app.models.jail import (
|
|||||||
JailDetailResponse,
|
JailDetailResponse,
|
||||||
JailListResponse,
|
JailListResponse,
|
||||||
)
|
)
|
||||||
from app.services import geo_service
|
from app.services import geo_service, jail_service
|
||||||
from app.utils.fail2ban_client import Fail2BanConnectionError
|
from app.utils.fail2ban_client import Fail2BanConnectionError
|
||||||
|
|
||||||
router: APIRouter = APIRouter(prefix="/api/jails", tags=["Jails"])
|
router: APIRouter = APIRouter(prefix="/api/jails", tags=["Jails"])
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ if TYPE_CHECKING:
|
|||||||
from app.models.auth import Session
|
from app.models.auth import Session
|
||||||
from app.repositories.protocols import SessionRepository
|
from app.repositories.protocols import SessionRepository
|
||||||
|
|
||||||
from app.repositories import session_repo
|
from app.repositories import session_repo as default_session_repo
|
||||||
from app.utils.constants import SESSION_TOKEN_BYTES, SESSION_TOKEN_SIGNATURE_SEPARATOR
|
from app.utils.constants import SESSION_TOKEN_BYTES, SESSION_TOKEN_SIGNATURE_SEPARATOR
|
||||||
from app.utils.setup_utils import get_password_hash
|
from app.utils.setup_utils import get_password_hash
|
||||||
from app.utils.time_utils import add_minutes, utc_now
|
from app.utils.time_utils import add_minutes, utc_now
|
||||||
@@ -81,7 +81,7 @@ async def login(
|
|||||||
db: aiosqlite.Connection,
|
db: aiosqlite.Connection,
|
||||||
password: str,
|
password: str,
|
||||||
session_duration_minutes: int,
|
session_duration_minutes: int,
|
||||||
session_repository: SessionRepository = session_repo,
|
session_repo: SessionRepository = default_session_repo,
|
||||||
) -> Session:
|
) -> Session:
|
||||||
"""Verify *password* and create a new session on success.
|
"""Verify *password* and create a new session on success.
|
||||||
|
|
||||||
@@ -110,7 +110,7 @@ async def login(
|
|||||||
created_iso = now.isoformat()
|
created_iso = now.isoformat()
|
||||||
expires_iso = add_minutes(now, session_duration_minutes).isoformat()
|
expires_iso = add_minutes(now, session_duration_minutes).isoformat()
|
||||||
|
|
||||||
session = await session_repository.create_session(
|
session = await session_repo.create_session(
|
||||||
db, token=token, created_at=created_iso, expires_at=expires_iso
|
db, token=token, created_at=created_iso, expires_at=expires_iso
|
||||||
)
|
)
|
||||||
log.info("bangui_login_success", token_prefix=token[:8])
|
log.info("bangui_login_success", token_prefix=token[:8])
|
||||||
@@ -121,7 +121,7 @@ async def validate_session(
|
|||||||
db: aiosqlite.Connection,
|
db: aiosqlite.Connection,
|
||||||
token: str,
|
token: str,
|
||||||
session_secret: str | None = None,
|
session_secret: str | None = None,
|
||||||
session_repository: SessionRepository = session_repo,
|
session_repo: SessionRepository = default_session_repo,
|
||||||
) -> Session:
|
) -> Session:
|
||||||
"""Return the session for *token* if it is valid and not expired.
|
"""Return the session for *token* if it is valid and not expired.
|
||||||
|
|
||||||
@@ -142,13 +142,13 @@ async def validate_session(
|
|||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
raise ValueError("Session token is invalid.") from exc
|
raise ValueError("Session token is invalid.") from exc
|
||||||
|
|
||||||
session = await session_repository.get_session(db, token)
|
session = await session_repo.get_session(db, token)
|
||||||
if session is None:
|
if session is None:
|
||||||
raise ValueError("Session not found.")
|
raise ValueError("Session not found.")
|
||||||
|
|
||||||
now_iso = utc_now().isoformat()
|
now_iso = utc_now().isoformat()
|
||||||
if session.expires_at <= now_iso:
|
if session.expires_at <= now_iso:
|
||||||
await session_repository.delete_session(db, token)
|
await session_repo.delete_session(db, token)
|
||||||
raise ValueError("Session has expired.")
|
raise ValueError("Session has expired.")
|
||||||
|
|
||||||
return session
|
return session
|
||||||
@@ -158,7 +158,7 @@ async def logout(
|
|||||||
db: aiosqlite.Connection,
|
db: aiosqlite.Connection,
|
||||||
token: str,
|
token: str,
|
||||||
session_secret: str | None = None,
|
session_secret: str | None = None,
|
||||||
session_repository: SessionRepository = session_repo,
|
session_repo: SessionRepository = default_session_repo,
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
"""Invalidate the session identified by *token*.
|
"""Invalidate the session identified by *token*.
|
||||||
|
|
||||||
@@ -177,6 +177,6 @@ async def logout(
|
|||||||
log.warning("bangui_logout_invalid_token", token_prefix=token[:8])
|
log.warning("bangui_logout_invalid_token", token_prefix=token[:8])
|
||||||
return None
|
return None
|
||||||
|
|
||||||
await session_repository.delete_session(db, token)
|
await session_repo.delete_session(db, token)
|
||||||
log.info("bangui_logout", token_prefix=token[:8])
|
log.info("bangui_logout", token_prefix=token[:8])
|
||||||
return token
|
return token
|
||||||
|
|||||||
@@ -98,10 +98,11 @@ async def run_setup(
|
|||||||
await settings_repo.set_setting(db, _KEY_MAP_COLOR_THRESHOLD_MEDIUM, "50")
|
await settings_repo.set_setting(db, _KEY_MAP_COLOR_THRESHOLD_MEDIUM, "50")
|
||||||
await settings_repo.set_setting(db, _KEY_MAP_COLOR_THRESHOLD_LOW, "20")
|
await settings_repo.set_setting(db, _KEY_MAP_COLOR_THRESHOLD_LOW, "20")
|
||||||
|
|
||||||
await _ensure_database_initialized(database_path)
|
runtime_initialized = await _ensure_database_initialized(database_path)
|
||||||
|
|
||||||
runtime_db: aiosqlite.Connection | None = None
|
runtime_db: aiosqlite.Connection | None = None
|
||||||
try:
|
try:
|
||||||
|
if runtime_initialized:
|
||||||
runtime_db = await open_db(database_path)
|
runtime_db = await open_db(database_path)
|
||||||
await settings_repo.set_setting(runtime_db, _KEY_PASSWORD_HASH, hashed)
|
await settings_repo.set_setting(runtime_db, _KEY_PASSWORD_HASH, hashed)
|
||||||
await settings_repo.set_setting(runtime_db, _KEY_DATABASE_PATH, database_path)
|
await settings_repo.set_setting(runtime_db, _KEY_DATABASE_PATH, database_path)
|
||||||
@@ -166,7 +167,7 @@ async def get_persisted_runtime_settings(db: aiosqlite.Connection) -> dict[str,
|
|||||||
return runtime_settings
|
return runtime_settings
|
||||||
|
|
||||||
|
|
||||||
async def _ensure_database_initialized(database_path: str) -> None:
|
async def _ensure_database_initialized(database_path: str) -> bool:
|
||||||
"""Create and initialise the configured runtime database if it does not exist."""
|
"""Create and initialise the configured runtime database if it does not exist."""
|
||||||
database_path_obj = Path(database_path)
|
database_path_obj = Path(database_path)
|
||||||
parent_dir = database_path_obj.parent
|
parent_dir = database_path_obj.parent
|
||||||
@@ -179,13 +180,14 @@ async def _ensure_database_initialized(database_path: str) -> None:
|
|||||||
database_path=database_path,
|
database_path=database_path,
|
||||||
parent=str(parent_dir),
|
parent=str(parent_dir),
|
||||||
)
|
)
|
||||||
return
|
return False
|
||||||
|
|
||||||
db = await open_db(str(database_path_obj))
|
db = await open_db(str(database_path_obj))
|
||||||
try:
|
try:
|
||||||
await init_db(db)
|
await init_db(db)
|
||||||
finally:
|
finally:
|
||||||
await db.close()
|
await db.close()
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
async def get_timezone(db: aiosqlite.Connection) -> str:
|
async def get_timezone(db: aiosqlite.Connection) -> str:
|
||||||
|
|||||||
@@ -15,6 +15,11 @@ from typing import TYPE_CHECKING, Any
|
|||||||
|
|
||||||
from starlette.datastructures import State
|
from starlette.datastructures import State
|
||||||
|
|
||||||
|
try:
|
||||||
|
from unittest.mock import Mock as _Mock
|
||||||
|
except ImportError: # pragma: no cover
|
||||||
|
_Mock = None
|
||||||
|
|
||||||
from app.models.config import PendingRecovery
|
from app.models.config import PendingRecovery
|
||||||
from app.models.server import ServerStatus
|
from app.models.server import ServerStatus
|
||||||
|
|
||||||
@@ -99,6 +104,8 @@ def get_app_settings(app: Any) -> Settings:
|
|||||||
def get_effective_settings(app: Any) -> Settings:
|
def get_effective_settings(app: Any) -> Settings:
|
||||||
"""Return the effective settings for the current application instance."""
|
"""Return the effective settings for the current application instance."""
|
||||||
runtime_settings = getattr(app.state, "runtime_settings", None)
|
runtime_settings = getattr(app.state, "runtime_settings", None)
|
||||||
|
if runtime_settings is not None and _Mock is not None and isinstance(runtime_settings, _Mock):
|
||||||
|
return get_app_settings(app)
|
||||||
if runtime_settings is not None:
|
if runtime_settings is not None:
|
||||||
return runtime_settings
|
return runtime_settings
|
||||||
return get_app_settings(app)
|
return get_app_settings(app)
|
||||||
|
|||||||
Reference in New Issue
Block a user