Introduce explicit ApplicationContext and remove raw request.app.state usage
This commit is contained in:
@@ -8,6 +8,7 @@ directly — to keep coupling explicit and testable.
|
||||
|
||||
import datetime
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated, Protocol, cast
|
||||
|
||||
import aiohttp
|
||||
@@ -22,7 +23,7 @@ from app.models.config import PendingRecovery
|
||||
from app.models.server import ServerStatus
|
||||
from app.repositories.protocols import SessionRepository
|
||||
from app.services.protocols import AuthService, JailService
|
||||
from app.utils.runtime_state import RuntimeState, get_effective_settings
|
||||
from app.utils.runtime_state import RuntimeState
|
||||
from app.utils.session_cache import SessionCache
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
@@ -42,6 +43,21 @@ class AppState(Protocol):
|
||||
session_cache: SessionCache
|
||||
|
||||
|
||||
@dataclass
|
||||
class ApplicationContext:
|
||||
"""A typed wrapper around shared application lifecycle resources."""
|
||||
|
||||
settings: Settings
|
||||
http_session: aiohttp.ClientSession | None
|
||||
scheduler: AsyncIOScheduler | None
|
||||
server_status: ServerStatus
|
||||
pending_recovery: PendingRecovery | None
|
||||
last_activation: dict[str, datetime.datetime] | None
|
||||
runtime_settings: Settings | None
|
||||
runtime_state: RuntimeState
|
||||
session_cache: SessionCache | None
|
||||
|
||||
|
||||
_COOKIE_NAME = "bangui_session"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -61,7 +77,27 @@ def _session_cache_enabled(settings: Settings) -> bool:
|
||||
return settings.session_cache_enabled and settings.session_cache_ttl_seconds > 0.0
|
||||
|
||||
|
||||
async def get_db(request: Request) -> AsyncGenerator[aiosqlite.Connection, None]:
|
||||
def _build_app_context(request: Request) -> ApplicationContext:
|
||||
state = cast("AppState", request.app.state)
|
||||
return ApplicationContext(
|
||||
settings=state.settings,
|
||||
http_session=getattr(state, "http_session", None),
|
||||
scheduler=getattr(state, "scheduler", None),
|
||||
server_status=getattr(state, "server_status", ServerStatus(online=False)),
|
||||
pending_recovery=getattr(state, "pending_recovery", None),
|
||||
last_activation=getattr(state, "last_activation", None),
|
||||
runtime_settings=getattr(state, "runtime_settings", None),
|
||||
runtime_state=state.runtime_state,
|
||||
session_cache=getattr(state, "session_cache", None),
|
||||
)
|
||||
|
||||
|
||||
async def get_app_context(request: Request) -> ApplicationContext:
|
||||
"""Provide the typed application context for the current request."""
|
||||
return _build_app_context(request)
|
||||
|
||||
|
||||
async def get_db(app_context: Annotated[ApplicationContext, Depends(get_app_context)]) -> AsyncGenerator[aiosqlite.Connection, None]:
|
||||
"""Provide a request-scoped :class:`aiosqlite.Connection` for the current request.
|
||||
|
||||
Opens a fresh connection for every request and closes it when the request
|
||||
@@ -69,14 +105,14 @@ async def get_db(request: Request) -> AsyncGenerator[aiosqlite.Connection, None]
|
||||
SQLite connection across concurrent requests.
|
||||
|
||||
Args:
|
||||
request: The current FastAPI request (injected automatically).
|
||||
app_context: The injected shared application context.
|
||||
|
||||
Yields:
|
||||
An open :class:`aiosqlite.Connection` for the request.
|
||||
"""
|
||||
from app.db import open_db # noqa: PLC0415
|
||||
|
||||
settings = cast("AppState", request.app.state).settings
|
||||
settings = app_context.settings
|
||||
try:
|
||||
db = await open_db(settings.database_path)
|
||||
except Exception as exc:
|
||||
@@ -92,16 +128,16 @@ async def get_db(request: Request) -> AsyncGenerator[aiosqlite.Connection, None]
|
||||
await db.close()
|
||||
|
||||
|
||||
async def get_settings(request: Request) -> Settings:
|
||||
async def get_settings(app_context: Annotated[ApplicationContext, Depends(get_app_context)]) -> Settings:
|
||||
"""Provide the effective application settings for the current request."""
|
||||
return get_effective_settings(request.app)
|
||||
return app_context.runtime_settings if app_context.runtime_settings is not None else app_context.settings
|
||||
|
||||
|
||||
async def get_http_session(request: Request) -> aiohttp.ClientSession:
|
||||
"""Provide the shared HTTP client session from application state.
|
||||
async def get_http_session(app_context: Annotated[ApplicationContext, Depends(get_app_context)]) -> aiohttp.ClientSession:
|
||||
"""Provide the shared HTTP client session from application context.
|
||||
|
||||
Args:
|
||||
request: The current FastAPI request.
|
||||
app_context: The injected shared application context.
|
||||
|
||||
Returns:
|
||||
A shared :class:`aiohttp.ClientSession` managed by the lifespan.
|
||||
@@ -109,22 +145,20 @@ async def get_http_session(request: Request) -> aiohttp.ClientSession:
|
||||
Raises:
|
||||
HTTPException: If the session is unavailable.
|
||||
"""
|
||||
state = cast("AppState", request.app.state)
|
||||
http_session = getattr(state, "http_session", None)
|
||||
if http_session is None:
|
||||
if app_context.http_session is None:
|
||||
log.error("http_session_unavailable")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="HTTP session is not available.",
|
||||
)
|
||||
return http_session
|
||||
return app_context.http_session
|
||||
|
||||
|
||||
async def get_scheduler(request: Request) -> AsyncIOScheduler:
|
||||
"""Provide the shared scheduler from application state.
|
||||
async def get_scheduler(app_context: Annotated[ApplicationContext, Depends(get_app_context)]) -> AsyncIOScheduler:
|
||||
"""Provide the shared scheduler from application context.
|
||||
|
||||
Args:
|
||||
request: The current FastAPI request.
|
||||
app_context: The injected shared application context.
|
||||
|
||||
Returns:
|
||||
The :class:`apscheduler.schedulers.asyncio.AsyncIOScheduler` instance.
|
||||
@@ -132,15 +166,13 @@ async def get_scheduler(request: Request) -> AsyncIOScheduler:
|
||||
Raises:
|
||||
HTTPException: If the scheduler is unavailable.
|
||||
"""
|
||||
state = cast("AppState", request.app.state)
|
||||
scheduler = getattr(state, "scheduler", None)
|
||||
if scheduler is None:
|
||||
if app_context.scheduler is None:
|
||||
log.error("scheduler_unavailable")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Scheduler is not available.",
|
||||
)
|
||||
return scheduler
|
||||
return app_context.scheduler
|
||||
|
||||
|
||||
async def get_fail2ban_socket(settings: Settings = Depends(get_settings)) -> str:
|
||||
@@ -158,17 +190,15 @@ async def get_fail2ban_start_command(settings: Settings = Depends(get_settings))
|
||||
return settings.fail2ban_start_command
|
||||
|
||||
|
||||
async def get_session_cache(request: Request) -> SessionCache:
|
||||
"""Provide the configured session cache backend from application state."""
|
||||
state = cast("AppState", request.app.state)
|
||||
session_cache = getattr(state, "session_cache", None)
|
||||
if session_cache is None:
|
||||
async def get_session_cache(app_context: Annotated[ApplicationContext, Depends(get_app_context)]) -> SessionCache:
|
||||
"""Provide the configured session cache backend from application context."""
|
||||
if app_context.session_cache is None:
|
||||
log.error("session_cache_unavailable")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Session cache is not available.",
|
||||
)
|
||||
return session_cache
|
||||
return app_context.session_cache
|
||||
|
||||
|
||||
async def get_auth_service() -> AuthService:
|
||||
@@ -192,9 +222,9 @@ async def get_session_repo() -> SessionRepository:
|
||||
return session_repo
|
||||
|
||||
|
||||
async def get_app_state(request: Request) -> AppState:
|
||||
async def get_app_state(app_context: Annotated[ApplicationContext, Depends(get_app_context)]) -> ApplicationContext:
|
||||
"""Provide the application state object for the current request."""
|
||||
return cast("AppState", request.app.state)
|
||||
return app_context
|
||||
|
||||
|
||||
async def get_app(request: Request) -> FastAPI:
|
||||
@@ -202,15 +232,14 @@ async def get_app(request: Request) -> FastAPI:
|
||||
return request.app
|
||||
|
||||
|
||||
async def get_server_status(request: Request) -> ServerStatus:
|
||||
"""Return the cached fail2ban server status snapshot from app state."""
|
||||
state = cast("AppState", request.app.state)
|
||||
return getattr(state, "server_status", ServerStatus(online=False))
|
||||
async def get_server_status(app_context: Annotated[ApplicationContext, Depends(get_app_context)]) -> ServerStatus:
|
||||
"""Return the cached fail2ban server status snapshot from application context."""
|
||||
return app_context.server_status
|
||||
|
||||
async def get_pending_recovery(request: Request) -> PendingRecovery | None:
|
||||
"""Return the current pending recovery record from app state."""
|
||||
state = cast("AppState", request.app.state)
|
||||
return getattr(state, "pending_recovery", None)
|
||||
|
||||
async def get_pending_recovery(app_context: Annotated[ApplicationContext, Depends(get_app_context)]) -> PendingRecovery | None:
|
||||
"""Return the current pending recovery record from application context."""
|
||||
return app_context.pending_recovery
|
||||
|
||||
async def require_auth(
|
||||
request: Request,
|
||||
|
||||
78
backend/tests/test_dependencies.py
Normal file
78
backend/tests/test_dependencies.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from starlette.requests import Request
|
||||
|
||||
from app.config import Settings
|
||||
from app.dependencies import (
|
||||
ApplicationContext,
|
||||
get_app_context,
|
||||
get_http_session,
|
||||
get_scheduler,
|
||||
get_settings,
|
||||
get_session_cache,
|
||||
)
|
||||
from app.main import create_app
|
||||
from app.models.server import ServerStatus
|
||||
|
||||
|
||||
def _make_test_request(app: FastAPI) -> Request:
|
||||
scope = {
|
||||
"type": "http",
|
||||
"method": "GET",
|
||||
"path": "/",
|
||||
"headers": [],
|
||||
"query_string": b"",
|
||||
"client": ("test", 0),
|
||||
"server": ("test", 0),
|
||||
"scheme": "http",
|
||||
"app": app,
|
||||
}
|
||||
return Request(scope)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_app_context_dependency_exposes_shared_resources(test_settings: Settings) -> None:
|
||||
app = create_app(settings=test_settings)
|
||||
session = aiohttp.ClientSession()
|
||||
scheduler = MagicMock()
|
||||
app.state.http_session = session
|
||||
app.state.scheduler = scheduler
|
||||
app.state.server_status = ServerStatus(online=False)
|
||||
app.state.pending_recovery = None
|
||||
app.state.last_activation = None
|
||||
|
||||
request = _make_test_request(app)
|
||||
app_context = await get_app_context(request)
|
||||
|
||||
assert isinstance(app_context, ApplicationContext)
|
||||
assert app_context.settings is test_settings
|
||||
assert app_context.http_session is session
|
||||
assert app_context.scheduler is scheduler
|
||||
assert app_context.session_cache is app.state.session_cache
|
||||
assert app_context.runtime_state is app.state.runtime_state
|
||||
assert await get_settings(app_context) is test_settings
|
||||
assert await get_http_session(app_context) is session
|
||||
assert await get_scheduler(app_context) is scheduler
|
||||
assert await get_session_cache(app_context) is app.state.session_cache
|
||||
|
||||
await session.close()
|
||||
|
||||
|
||||
def test_request_app_state_access_is_only_allowed_in_dependencies() -> None:
|
||||
app_root = Path(__file__).resolve().parents[1] / "app"
|
||||
bad_modules: list[str] = []
|
||||
|
||||
for path in sorted(app_root.rglob("*.py")):
|
||||
if path.name == "dependencies.py":
|
||||
continue
|
||||
text = path.read_text()
|
||||
if "request.app.state" in text:
|
||||
bad_modules.append(str(path))
|
||||
|
||||
assert not bad_modules, f"Direct request.app.state access found in: {bad_modules}"
|
||||
Reference in New Issue
Block a user