Files
BanGUI/backend/tests/test_main.py
Lukas eb339efcfd Add Kubernetes liveness/readiness probes and middleware order validation
- Split /health into /health/live (liveness) and /health/ready (readiness)
  following Kubernetes conventions. Combined /health retained for backward
  compatibility with existing Docker HEALTHCHECK definitions.
- Add ReadyCheck and ReadyResponse models for structured readiness output.
- Add _assert_middleware_order() startup check enforcing:
  RateLimit → Csrf → CorrelationId middleware chain.
- Register CorrelationIdMiddleware, CsrfMiddleware, RateLimitMiddleware
  in create_app() with documented required order (reverse of processing).
- Add correlation.py, csrf.py, rate_limit.py middleware modules.
- Add health probe tests in test_health_probes.py.
- Update test_main.py with middleware order assertion tests.
- Update frontend useFetchData hook tests.
- Docs: update Deployment.md with Kubernetes probe config examples.
2026-05-04 02:42:09 +02:00

728 lines
30 KiB
Python

"""Unit tests for backend application startup and middleware configuration."""
import asyncio
import contextlib
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import aiosqlite
import pytest
from httpx import ASGITransport, AsyncClient
from app.config import Settings
from app.db import init_db
from app.exceptions import ConfigValidationError, ConfigWriteError, JailNotFoundError
from app.main import (
CORSMiddleware,
_assert_middleware_order,
_enforce_single_worker,
_lifespan,
create_app,
)
from app.middleware.correlation import CorrelationIdMiddleware
from app.middleware.rate_limit import RateLimitMiddleware
from app.services import setup_service
def test_create_app_configures_cors_from_settings() -> None:
"""The FastAPI app registers CORS middleware with the configured origins."""
settings = Settings(
database_path="/tmp/test.db",
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir="/tmp/fail2ban",
session_secret="test-secret-key-do-not-use-in-production",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
cors_allowed_origins=["https://frontend.example.com"],
)
app = create_app(settings=settings)
cors_middleware = [
middleware for middleware in app.user_middleware if middleware.cls is CORSMiddleware
]
assert len(cors_middleware) == 1
assert cors_middleware[0].kwargs["allow_origins"] == ["https://frontend.example.com"]
assert cors_middleware[0].kwargs["allow_credentials"] is True
assert cors_middleware[0].kwargs["allow_methods"] == ["*"]
assert cors_middleware[0].kwargs["allow_headers"] == ["*"]
def test_create_app_skips_cors_when_no_origins_are_configured() -> None:
"""The FastAPI app does not add CORS middleware when no origins are configured."""
settings = Settings(
database_path="/tmp/test.db",
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir="/tmp/fail2ban",
session_secret="test-secret-key-do-not-use-in-production",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
cors_allowed_origins=[],
)
app = create_app(settings=settings)
cors_middleware = [
middleware for middleware in app.user_middleware if middleware.cls is CORSMiddleware
]
assert cors_middleware == []
def test_create_app_initialises_runtime_state_manager() -> None:
"""The FastAPI app exposes a dedicated runtime state manager on app.state."""
settings = Settings(
database_path="/tmp/test.db",
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir="/tmp/fail2ban",
session_secret="test-secret-key-do-not-use-in-production",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
)
app = create_app(settings=settings)
runtime_state = app.state.runtime_state
assert runtime_state.setup_complete_cached is False
assert runtime_state.server_status.online is False
assert runtime_state.pending_recovery is None
assert runtime_state.last_activation is None
assert app.state.server_status.online is False
async def test_create_app_global_domain_exception_handlers() -> None:
"""Global exception handlers map domain exceptions to consistent HTTP responses."""
settings = Settings(
database_path="/tmp/test.db",
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir="/tmp/fail2ban",
session_secret="test-secret-key-do-not-use-in-production",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
)
app = create_app(settings=settings)
@app.get("/not-found")
async def raise_not_found() -> None:
raise JailNotFoundError("ssh")
@app.get("/bad-request")
async def raise_bad_request() -> None:
raise ConfigValidationError("invalid payload")
@app.get("/server-error")
async def raise_server_error() -> None:
raise ConfigWriteError("write failed")
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
response = await client.get("/not-found")
assert response.status_code == 404
data = response.json()
assert data["code"] == "jail_not_found"
assert data["detail"] == "Jail not found: 'ssh'"
assert data["metadata"] == {"jail_name": "ssh"}
assert "correlation_id" in data
response = await client.get("/bad-request")
assert response.status_code == 400
data = response.json()
assert data["code"] == "config_validation_failed"
assert data["detail"] == "invalid payload"
assert "correlation_id" in data
response = await client.get("/server-error")
assert response.status_code == 500
data = response.json()
assert data["code"] == "config_write_failed"
assert data["detail"] == "write failed"
assert "correlation_id" in data
def test_create_app_disables_cors_by_default() -> None:
"""The FastAPI app does not add CORS middleware when no origins are configured by environment."""
settings = Settings(
database_path="/tmp/test.db",
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir="/tmp/fail2ban",
session_secret="test-secret-key-do-not-use-in-production",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
cors_allowed_origins=[],
)
app = create_app(settings=settings)
cors_middleware = [
middleware for middleware in app.user_middleware if middleware.cls is CORSMiddleware
]
assert cors_middleware == []
def test_create_app_disables_api_docs_by_default() -> None:
"""API documentation endpoints are disabled when enable_docs is false."""
settings = Settings(
database_path="/tmp/test.db",
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir="/tmp/fail2ban",
session_secret="test-secret-key-do-not-use-in-production",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
enable_docs=False,
)
app = create_app(settings=settings)
assert app.docs_url is None
assert app.redoc_url is None
assert app.openapi_url is None
def test_create_app_enables_api_docs_when_configured() -> None:
"""API documentation endpoints are enabled at /api/* when enable_docs is true."""
settings = Settings(
database_path="/tmp/test.db",
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir="/tmp/fail2ban",
session_secret="test-secret-key-do-not-use-in-production",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
enable_docs=True,
)
app = create_app(settings=settings)
assert app.docs_url == "/api/docs"
assert app.redoc_url == "/api/redoc"
assert app.openapi_url == "/api/openapi.json"
async def test_lifespan_initialises_and_cleans_up_shared_resources(tmp_path: Path) -> None:
"""The app lifespan creates and shuts down shared resources cleanly."""
fail2ban_config_dir = tmp_path / "fail2ban"
fail2ban_config_dir.mkdir()
settings = Settings(
database_path=str(tmp_path / "bangui.db"),
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir=str(fail2ban_config_dir),
session_secret="test-lifespan-secret-that-is-long-enough!!",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
)
app = create_app(settings=settings)
mock_scheduler = MagicMock()
mock_scheduler.start = MagicMock()
mock_scheduler.shutdown = MagicMock()
mock_http_session = MagicMock()
mock_http_session.close = AsyncMock()
with (
patch("app.startup.ensure_jail_configs"),
patch("app.startup.aiohttp.ClientSession", return_value=mock_http_session),
patch("app.startup.AsyncIOScheduler", return_value=mock_scheduler),
patch("app.startup.init_db", new=AsyncMock()),
patch("app.startup.acquire_scheduler_lock", new=AsyncMock(return_value=True)),
patch("app.services.geo_cache.GeoCache.init_geoip"),
patch("app.services.geo_cache.GeoCache.load_cache_from_db", new=AsyncMock(return_value=None)),
patch("app.services.geo_cache.GeoCache.count_unresolved", new=AsyncMock(return_value=0)),
patch("app.services.setup_service.is_setup_complete", new=AsyncMock(return_value=False)),
patch("app.tasks.health_check.register"),
patch("app.tasks.blocklist_import.register"),
patch("app.tasks.geo_cache_flush.register"),
patch("app.tasks.geo_re_resolve.register"),
patch("app.tasks.history_sync.register"),
patch("app.tasks.session_cleanup.register"),
patch("app.tasks.rate_limiter_cleanup.register"),
):
async with _lifespan(app):
assert app.state.http_session is mock_http_session
assert app.state.scheduler is mock_scheduler
assert app.state.settings is settings
mock_http_session.close.assert_awaited_once()
mock_scheduler.shutdown.assert_called_once_with(wait=False)
async def test_lifespan_cleans_up_resources_when_startup_fails(tmp_path: Path) -> None:
"""The lifespan must close resources if shared startup registration fails."""
fail2ban_config_dir = tmp_path / "fail2ban"
fail2ban_config_dir.mkdir()
settings = Settings(
database_path=str(tmp_path / "bangui.db"),
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir=str(fail2ban_config_dir),
session_secret="test-lifespan-secret-that-is-long-enough!!",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
)
app = create_app(settings=settings)
mock_scheduler = MagicMock()
mock_scheduler.start = MagicMock()
mock_scheduler.shutdown = MagicMock()
mock_http_session = MagicMock()
mock_http_session.close = AsyncMock()
with pytest.raises(RuntimeError, match="startup failed"), \
patch("app.startup.ensure_jail_configs"), \
patch("app.startup.aiohttp.ClientSession", return_value=mock_http_session), \
patch("app.startup.AsyncIOScheduler", return_value=mock_scheduler), \
patch("app.startup.init_db", new=AsyncMock()), \
patch("app.startup.acquire_scheduler_lock", new=AsyncMock(return_value=True)), \
patch("app.services.geo_cache.GeoCache.init_geoip"), \
patch("app.services.geo_cache.GeoCache.load_cache_from_db", new=AsyncMock(return_value=None)), \
patch("app.services.geo_cache.GeoCache.count_unresolved", new=AsyncMock(return_value=0)), \
patch("app.services.setup_service.is_setup_complete", new=AsyncMock(return_value=False)), \
patch("app.tasks.health_check.register", side_effect=RuntimeError("startup failed")), \
patch("app.tasks.blocklist_import.register"), \
patch("app.tasks.geo_cache_flush.register"), \
patch("app.tasks.geo_re_resolve.register"), \
patch("app.tasks.history_sync.register"), \
patch("app.tasks.session_cleanup.register"), \
patch("app.tasks.rate_limiter_cleanup.register"):
async with _lifespan(app):
pass
mock_http_session.close.assert_awaited_once()
mock_scheduler.shutdown.assert_called_once_with(wait=False)
async def test_http_session_is_created_with_configured_timeouts_and_limits(tmp_path: Path) -> None:
"""The shared HTTP client session is created with the configured limits."""
fail2ban_config_dir = tmp_path / "fail2ban"
fail2ban_config_dir.mkdir()
settings = Settings(
database_path=str(tmp_path / "bangui.db"),
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir=str(fail2ban_config_dir),
session_secret="test-lifespan-secret-that-is-long-enough!!",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
http_request_timeout_seconds=12.5,
http_connect_timeout_seconds=1.5,
http_max_connections=5,
http_keepalive_timeout_seconds=8.0,
)
app = create_app(settings=settings)
mock_scheduler = MagicMock()
mock_scheduler.start = MagicMock()
mock_scheduler.shutdown = MagicMock()
mock_http_session = MagicMock()
mock_http_session.close = AsyncMock()
with (
patch("app.startup.ensure_jail_configs"),
patch("app.startup.aiohttp.ClientSession", return_value=mock_http_session) as mock_client_session,
patch("app.startup.AsyncIOScheduler", return_value=mock_scheduler),
patch("app.startup.init_db", new=AsyncMock()),
patch("app.startup.acquire_scheduler_lock", new=AsyncMock(return_value=True)),
patch("app.services.geo_cache.GeoCache.init_geoip"),
patch("app.services.geo_cache.GeoCache.load_cache_from_db", new=AsyncMock(return_value=None)),
patch("app.services.geo_cache.GeoCache.count_unresolved", new=AsyncMock(return_value=0)),
patch("app.services.setup_service.is_setup_complete", new=AsyncMock(return_value=False)),
patch("app.tasks.health_check.register"),
patch("app.tasks.blocklist_import.register"),
patch("app.tasks.geo_cache_flush.register"),
patch("app.tasks.geo_re_resolve.register"),
patch("app.tasks.history_sync.register"),
patch("app.tasks.session_cleanup.register"),
patch("app.tasks.rate_limiter_cleanup.register"),
):
async with _lifespan(app):
assert mock_client_session.call_count == 1
kwargs = mock_client_session.call_args.kwargs
timeout = kwargs["timeout"]
connector = kwargs["connector"]
assert timeout.total == 12.5
assert timeout.connect == 1.5
assert timeout.sock_read == 12.5
assert connector.limit == 5
assert connector.limit_per_host == 5
async def test_startup_overrides_settings_from_persisted_setup(tmp_path: Path) -> None:
"""Startup should replace env defaults with values persisted by setup."""
fail2ban_config_dir = tmp_path / "fail2ban"
fail2ban_config_dir.mkdir()
env_settings = Settings(
database_path=str(tmp_path / "pointer.db"),
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir=str(fail2ban_config_dir),
session_secret="test-startup-secret-that-is-long-enough!!!",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
)
app = create_app(settings=env_settings)
runtime_db_path = str(tmp_path / "runtime.db")
db = await aiosqlite.connect(env_settings.database_path)
db.row_factory = aiosqlite.Row
await init_db(db)
await setup_service.run_setup(
db,
master_password="Supersecret1!",
database_path=runtime_db_path,
fail2ban_socket="/tmp/persisted.sock",
timezone="Europe/Berlin",
session_duration_minutes=123,
)
await db.close()
mock_scheduler = MagicMock()
mock_scheduler.start = MagicMock()
mock_scheduler.shutdown = MagicMock()
mock_http_session = MagicMock()
mock_http_session.close = AsyncMock()
with (
patch("app.startup.ensure_jail_configs"),
patch("app.startup.aiohttp.ClientSession", return_value=mock_http_session),
patch("app.startup.AsyncIOScheduler", return_value=mock_scheduler),
patch("app.startup.acquire_scheduler_lock", new=AsyncMock(return_value=True)),
patch("app.services.geo_cache.GeoCache.init_geoip"),
patch("app.services.geo_cache.GeoCache.load_cache_from_db", new=AsyncMock(return_value=None)),
patch("app.services.geo_cache.GeoCache.count_unresolved", new=AsyncMock(return_value=0)),
patch("app.tasks.health_check.register"),
patch("app.tasks.blocklist_import.register"),
patch("app.tasks.geo_cache_flush.register"),
patch("app.tasks.geo_re_resolve.register"),
patch("app.tasks.history_sync.register"),
patch("app.tasks.session_cleanup.register"),
patch("app.tasks.rate_limiter_cleanup.register"),
):
async with _lifespan(app):
assert app.state.runtime_settings is not None
assert app.state.runtime_settings.database_path == runtime_db_path
assert app.state.runtime_settings.fail2ban_socket == "/tmp/persisted.sock"
assert app.state.runtime_settings.timezone == "Europe/Berlin"
assert app.state.runtime_settings.session_duration_minutes == 123
assert app.state.settings.database_path == str(tmp_path / "pointer.db")
assert Path(runtime_db_path).exists()
async def test_startup_loads_geo_cache_from_persisted_runtime_database(tmp_path: Path) -> None:
"""Startup must load geo cache from the resolved runtime database."""
fail2ban_config_dir = tmp_path / "fail2ban"
fail2ban_config_dir.mkdir()
env_settings = Settings(
database_path=str(tmp_path / "pointer.db"),
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir=str(fail2ban_config_dir),
session_secret="test-startup-secret-that-is-long-enough!!!",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
)
app = create_app(settings=env_settings)
runtime_db_path = str(tmp_path / "runtime.db")
opened_connections: list[tuple[str, aiosqlite.Connection]] = []
async def fake_open_db(path: str) -> aiosqlite.Connection:
connection = await aiosqlite.connect(path)
opened_connections.append((path, connection))
return connection
mock_scheduler = MagicMock()
mock_scheduler.start = MagicMock()
mock_scheduler.shutdown = MagicMock()
mock_http_session = MagicMock()
mock_http_session.close = AsyncMock()
load_cache = AsyncMock()
exit_stack = contextlib.ExitStack()
exit_stack.enter_context(patch("app.startup.ensure_jail_configs"))
exit_stack.enter_context(patch("app.startup.aiohttp.ClientSession", return_value=mock_http_session))
exit_stack.enter_context(patch("app.startup.AsyncIOScheduler", return_value=mock_scheduler))
exit_stack.enter_context(patch("app.startup.open_db", new=AsyncMock(side_effect=fake_open_db)))
exit_stack.enter_context(patch("app.startup.init_db", new=AsyncMock()))
exit_stack.enter_context(patch("app.startup.acquire_scheduler_lock", new=AsyncMock(return_value=True)))
exit_stack.enter_context(patch("app.services.geo_cache.GeoCache.init_geoip"))
exit_stack.enter_context(patch("app.services.geo_cache.GeoCache.load_cache_from_db", new=load_cache))
exit_stack.enter_context(patch("app.services.geo_cache.GeoCache.count_unresolved", new=AsyncMock(return_value=0)))
exit_stack.enter_context(patch("app.services.setup_service.is_setup_complete", new=AsyncMock(return_value=True)))
exit_stack.enter_context(patch(
"app.services.setup_service.get_runtime_database_path",
new=AsyncMock(return_value=runtime_db_path),
))
exit_stack.enter_context(patch(
"app.services.setup_service.get_persisted_runtime_settings",
new=AsyncMock(return_value={
"database_path": runtime_db_path,
"fail2ban_socket": "/tmp/persisted.sock",
"timezone": "Europe/Berlin",
"session_duration_minutes": 123,
}),
))
exit_stack.enter_context(patch(
"app.services.setup_service.get_fail2ban_db_path",
new=AsyncMock(return_value="/tmp/fail2ban/banned.tar.bz2"),
))
exit_stack.enter_context(patch("app.tasks.health_check.register"))
exit_stack.enter_context(patch("app.tasks.blocklist_import.register"))
exit_stack.enter_context(patch("app.tasks.geo_cache_flush.register"))
exit_stack.enter_context(patch("app.tasks.geo_re_resolve.register"))
exit_stack.enter_context(patch("app.tasks.history_sync.register"))
with exit_stack:
async with _lifespan(app):
runtime_connections = [
conn for path, conn in opened_connections if path == runtime_db_path
]
assert runtime_connections, "Expected runtime database to be opened"
assert app.state.runtime_settings is not None
assert app.state.runtime_settings.database_path == runtime_db_path
for _, connection in opened_connections:
await connection.close()
async def test_concurrent_requests_use_request_scoped_db_connections(tmp_path: Path) -> None:
"""Concurrent requests each open and close their own database connection."""
fail2ban_config_dir = tmp_path / "fail2ban"
fail2ban_config_dir.mkdir()
settings = Settings(
database_path=str(tmp_path / "bangui.db"),
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir=str(fail2ban_config_dir),
session_secret="test-concurrency-secret-that-is-long-enough!!!",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
)
app = create_app(settings=settings)
connections: list[MagicMock] = []
async def fake_open_db(database_path: str) -> MagicMock:
connection = MagicMock()
connection.close = AsyncMock()
connections.append(connection)
return connection
mock_scheduler = MagicMock()
mock_scheduler.start = MagicMock()
mock_scheduler.shutdown = MagicMock()
mock_http_session = MagicMock()
mock_http_session.close = AsyncMock()
with (
patch("app.startup.open_db", new=AsyncMock(side_effect=fake_open_db)),
patch("app.db.open_db", new=AsyncMock(side_effect=fake_open_db)),
patch("app.startup.init_db", new=AsyncMock()),
patch("app.startup.ensure_jail_configs"),
patch("app.startup.aiohttp.ClientSession", return_value=mock_http_session),
patch("app.startup.AsyncIOScheduler", return_value=mock_scheduler),
patch("app.startup.acquire_scheduler_lock", new=AsyncMock(return_value=True)),
patch("app.services.geo_cache.GeoCache.init_geoip"),
patch("app.services.geo_cache.GeoCache.load_cache_from_db", new=AsyncMock(return_value=None)),
patch("app.services.geo_cache.GeoCache.count_unresolved", new=AsyncMock(return_value=0)),
patch("app.services.setup_service.is_setup_complete", new=AsyncMock(return_value=False)),
patch("app.tasks.health_check.register"),
patch("app.tasks.blocklist_import.register"),
patch("app.tasks.geo_cache_flush.register"),
patch("app.tasks.geo_re_resolve.register"),
patch("app.tasks.history_sync.register"),
patch("app.tasks.session_cleanup.register"),
patch("app.tasks.rate_limiter_cleanup.register"),
):
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
app.state.setup_complete_cached = True
responses = await asyncio.gather(*(client.post("/api/v1/auth/logout") for _ in range(5)))
assert len(connections) == 5
assert len({id(connection) for connection in connections}) == 5
assert all(response.status_code == 200 for response in responses)
assert all(connection.close.await_count == 1 for connection in connections)
# ---------------------------------------------------------------------------
# Middleware order validation
# ---------------------------------------------------------------------------
def _make_settings(tmp_path: Path) -> Settings:
"""Return a minimal Settings object with a temporary fail2ban config dir."""
fail2ban_config_dir = tmp_path / "fail2ban"
fail2ban_config_dir.mkdir()
return Settings(
database_path=str(tmp_path / "bangui.db"),
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir=str(fail2ban_config_dir),
session_secret="test-secret-key-do-not-use-in-production",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
)
def test_create_app_raises_on_incorrect_middleware_order(
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""_assert_middleware_order() raises AssertionError when middleware order is wrong.
The security-critical chain requires:
RateLimitMiddleware → CsrfMiddleware → CorrelationIdMiddleware
in user_middleware (processing order: outermost → innermost).
"""
monkeypatch.setenv("TESTING", "1")
settings = _make_settings(tmp_path)
app = create_app(settings=settings)
# Swap CorrelationIdMiddleware and RateLimitMiddleware to break the order.
user_mw = app.user_middleware
corr_idx = next(i for i, m in enumerate(user_mw) if m.cls.__name__ == "CorrelationIdMiddleware")
rate_idx = next(i for i, m in enumerate(user_mw) if m.cls.__name__ == "RateLimitMiddleware")
user_mw[corr_idx], user_mw[rate_idx] = user_mw[rate_idx], user_mw[corr_idx]
with pytest.raises(AssertionError, match="must be registered before"):
_assert_middleware_order(app)
def test_middleware_order_validation_passes_for_correct_order(
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""_assert_middleware_order() does not raise when middleware order is correct."""
monkeypatch.setenv("TESTING", "1")
settings = _make_settings(tmp_path)
app = create_app(settings=settings)
_assert_middleware_order(app) # Should not raise
def test_create_app_validates_middleware_order_at_startup(
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""create_app() raises immediately if middleware registration order is incorrect.
This test verifies the integration: _assert_middleware_order is called at the
end of create_app, so a fresh app with deliberately wrong middleware order
(simulated by patching add_middleware during creation) raises AssertionError.
"""
monkeypatch.setenv("TESTING", "1")
settings = _make_settings(tmp_path)
from starlette.applications import Starlette
original_add = Starlette.add_middleware
def swapping_add(self, middleware_cls: type, **kwargs: object) -> None:
"""Patched add_middleware that swaps CorrelationId and RateLimit."""
if middleware_cls is CorrelationIdMiddleware:
pass # Skip CorrelationId
elif middleware_cls is RateLimitMiddleware:
original_add(self, RateLimitMiddleware, **kwargs)
original_add(self, CorrelationIdMiddleware)
else:
original_add(self, middleware_cls, **kwargs)
with patch.object(Starlette, "add_middleware", swapping_add), \
pytest.raises(AssertionError, match="must be registered before"):
create_app(settings=settings)
# ---------------------------------------------------------------------------
# Single-worker enforcement
# ---------------------------------------------------------------------------
def test_enforce_single_worker_allows_no_env_vars(monkeypatch: pytest.MonkeyPatch) -> None:
"""No error raised when WEB_CONCURRENCY and BANGUI_WORKERS are not set."""
monkeypatch.delenv("WEB_CONCURRENCY", raising=False)
monkeypatch.delenv("BANGUI_WORKERS", raising=False)
# Should not raise
_enforce_single_worker()
def test_enforce_single_worker_allows_workers_1(monkeypatch: pytest.MonkeyPatch) -> None:
"""WEB_CONCURRENCY=1 and BANGUI_WORKERS=1 are both allowed."""
monkeypatch.setenv("WEB_CONCURRENCY", "1")
monkeypatch.setenv("BANGUI_WORKERS", "1")
_enforce_single_worker() # Should not raise
def test_enforce_single_worker_rejects_web_concurrency_2(monkeypatch: pytest.MonkeyPatch) -> None:
"""WEB_CONCURRENCY=2 raises RuntimeError."""
monkeypatch.setenv("WEB_CONCURRENCY", "2")
monkeypatch.delenv("BANGUI_WORKERS", raising=False)
with pytest.raises(RuntimeError, match="WEB_CONCURRENCY"):
_enforce_single_worker()
def test_enforce_single_worker_rejects_web_concurrency_4(monkeypatch: pytest.MonkeyPatch) -> None:
"""WEB_CONCURRENCY=4 raises RuntimeError."""
monkeypatch.setenv("WEB_CONCURRENCY", "4")
monkeypatch.delenv("BANGUI_WORKERS", raising=False)
with pytest.raises(RuntimeError, match="WEB_CONCURRENCY"):
_enforce_single_worker()
def test_enforce_single_worker_rejects_bangui_workers_2(monkeypatch: pytest.MonkeyPatch) -> None:
"""BANGUI_WORKERS=2 raises RuntimeError."""
monkeypatch.setenv("WEB_CONCURRENCY", "1") # WEB_CONCURRENCY=1 should pass
monkeypatch.setenv("BANGUI_WORKERS", "2") # but BANGUI_WORKERS=2 should fail
with pytest.raises(RuntimeError, match="BANGUI_WORKERS"):
_enforce_single_worker()
def test_enforce_single_worker_rejects_invalid_web_concurrency(monkeypatch: pytest.MonkeyPatch) -> None:
"""Non-integer WEB_CONCURRENCY raises RuntimeError."""
monkeypatch.setenv("WEB_CONCURRENCY", "not-a-number")
with pytest.raises(RuntimeError, match="WEB_CONCURRENCY must be an integer"):
_enforce_single_worker()
def test_enforce_single_worker_error_message_mentions_docs(monkeypatch: pytest.MonkeyPatch) -> None:
"""Error message references Docs/Deployment.md."""
monkeypatch.setenv("WEB_CONCURRENCY", "2")
with pytest.raises(RuntimeError) as exc_info:
_enforce_single_worker()
assert "Deployment.md" in str(exc_info.value)
def test_create_app_raises_when_web_concurrency_gt_1(monkeypatch: pytest.MonkeyPatch) -> None:
"""create_app() raises RuntimeError when WEB_CONCURRENCY > 1 (no TESTING set)."""
monkeypatch.setenv("WEB_CONCURRENCY", "2")
with pytest.raises(RuntimeError, match="WEB_CONCURRENCY"):
create_app(settings=None) # settings=None triggers get_settings() which loads env vars
def test_create_app_skips_enforcement_when_testing_set(monkeypatch: pytest.MonkeyPatch) -> None:
"""create_app() does NOT raise when TESTING env var is set, even with workers > 1."""
monkeypatch.setenv("WEB_CONCURRENCY", "4")
monkeypatch.setenv("BANGUI_WORKERS", "4")
monkeypatch.setenv("TESTING", "1")
# Pass explicit settings to bypass get_settings() env loading.
settings = Settings(
database_path="/tmp/test.db",
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir="/tmp/fail2ban",
session_secret="test-secret-key-do-not-use-in-production",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
)
# Should not raise
app = create_app(settings=settings)
assert app is not None