79 lines
2.3 KiB
Python
79 lines
2.3 KiB
Python
from __future__ import annotations
|
|
|
|
from pathlib import Path
|
|
from unittest.mock import MagicMock
|
|
|
|
import aiohttp
|
|
import pytest
|
|
from fastapi import FastAPI
|
|
from starlette.requests import Request
|
|
|
|
from app.config import Settings
|
|
from app.dependencies import (
|
|
ApplicationContext,
|
|
get_app_context,
|
|
get_http_session,
|
|
get_scheduler,
|
|
get_settings,
|
|
get_session_cache,
|
|
)
|
|
from app.main import create_app
|
|
from app.models.server import ServerStatus
|
|
|
|
|
|
def _make_test_request(app: FastAPI) -> Request:
|
|
scope = {
|
|
"type": "http",
|
|
"method": "GET",
|
|
"path": "/",
|
|
"headers": [],
|
|
"query_string": b"",
|
|
"client": ("test", 0),
|
|
"server": ("test", 0),
|
|
"scheme": "http",
|
|
"app": app,
|
|
}
|
|
return Request(scope)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_app_context_dependency_exposes_shared_resources(test_settings: Settings) -> None:
|
|
app = create_app(settings=test_settings)
|
|
session = aiohttp.ClientSession()
|
|
scheduler = MagicMock()
|
|
app.state.http_session = session
|
|
app.state.scheduler = scheduler
|
|
app.state.server_status = ServerStatus(online=False)
|
|
app.state.pending_recovery = None
|
|
app.state.last_activation = None
|
|
|
|
request = _make_test_request(app)
|
|
app_context = await get_app_context(request)
|
|
|
|
assert isinstance(app_context, ApplicationContext)
|
|
assert app_context.settings is test_settings
|
|
assert app_context.http_session is session
|
|
assert app_context.scheduler is scheduler
|
|
assert app_context.session_cache is app.state.session_cache
|
|
assert app_context.runtime_state is app.state.runtime_state
|
|
assert await get_settings(app_context) is test_settings
|
|
assert await get_http_session(app_context) is session
|
|
assert await get_scheduler(app_context) is scheduler
|
|
assert await get_session_cache(app_context) is app.state.session_cache
|
|
|
|
await session.close()
|
|
|
|
|
|
def test_request_app_state_access_is_only_allowed_in_dependencies() -> None:
|
|
app_root = Path(__file__).resolve().parents[1] / "app"
|
|
bad_modules: list[str] = []
|
|
|
|
for path in sorted(app_root.rglob("*.py")):
|
|
if path.name == "dependencies.py":
|
|
continue
|
|
text = path.read_text()
|
|
if "request.app.state" in text:
|
|
bad_modules.append(str(path))
|
|
|
|
assert not bad_modules, f"Direct request.app.state access found in: {bad_modules}"
|