"""Unit tests for backend application startup and middleware configuration.""" import asyncio import contextlib import io import json import logging from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch import aiosqlite import pytest from httpx import ASGITransport, AsyncClient from app.config import Settings from app.db import init_db from app.exceptions import ConfigValidationError, ConfigWriteError, JailNotFoundError from app.main import ( CORSMiddleware, _assert_middleware_order, _enforce_single_worker, _lifespan, create_app, ) from app.middleware.correlation import CorrelationIdMiddleware from app.middleware.rate_limit import RateLimitMiddleware from app.services import setup_service from app.utils.json_formatter import JSONFormatter def test_create_app_configures_cors_from_settings() -> None: """The FastAPI app registers CORS middleware with the configured origins.""" settings = Settings( database_path="/tmp/test.db", fail2ban_socket="/tmp/fake_fail2ban.sock", fail2ban_config_dir="/tmp/fail2ban", session_secret="test-secret-key-do-not-use-in-production", session_duration_minutes=60, timezone="UTC", log_level="debug", cors_allowed_origins=["https://frontend.example.com"], ) app = create_app(settings=settings) cors_middleware = [ middleware for middleware in app.user_middleware if middleware.cls is CORSMiddleware ] assert len(cors_middleware) == 1 assert cors_middleware[0].kwargs["allow_origins"] == ["https://frontend.example.com"] assert cors_middleware[0].kwargs["allow_credentials"] is True assert cors_middleware[0].kwargs["allow_methods"] == ["*"] assert cors_middleware[0].kwargs["allow_headers"] == ["*"] def test_create_app_skips_cors_when_no_origins_are_configured() -> None: """The FastAPI app does not add CORS middleware when no origins are configured.""" settings = Settings( database_path="/tmp/test.db", fail2ban_socket="/tmp/fake_fail2ban.sock", fail2ban_config_dir="/tmp/fail2ban", session_secret="test-secret-key-do-not-use-in-production", session_duration_minutes=60, timezone="UTC", log_level="debug", cors_allowed_origins=[], ) app = create_app(settings=settings) cors_middleware = [ middleware for middleware in app.user_middleware if middleware.cls is CORSMiddleware ] assert cors_middleware == [] def test_create_app_initialises_runtime_state_manager() -> None: """The FastAPI app exposes a dedicated runtime state manager on app.state.""" settings = Settings( database_path="/tmp/test.db", fail2ban_socket="/tmp/fake_fail2ban.sock", fail2ban_config_dir="/tmp/fail2ban", session_secret="test-secret-key-do-not-use-in-production", session_duration_minutes=60, timezone="UTC", log_level="debug", ) app = create_app(settings=settings) runtime_state = app.state.runtime_state assert runtime_state.setup_complete_cached is False assert runtime_state.server_status.online is False assert runtime_state.pending_recovery is None assert runtime_state.last_activation is None assert app.state.server_status.online is False async def test_create_app_global_domain_exception_handlers() -> None: """Global exception handlers map domain exceptions to consistent HTTP responses.""" settings = Settings( database_path="/tmp/test.db", fail2ban_socket="/tmp/fake_fail2ban.sock", fail2ban_config_dir="/tmp/fail2ban", session_secret="test-secret-key-do-not-use-in-production", session_duration_minutes=60, timezone="UTC", log_level="debug", ) app = create_app(settings=settings) @app.get("/not-found") async def raise_not_found() -> None: raise JailNotFoundError("ssh") @app.get("/bad-request") async def raise_bad_request() -> None: raise ConfigValidationError("invalid payload") @app.get("/server-error") async def raise_server_error() -> None: raise ConfigWriteError("write failed") transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: response = await client.get("/not-found") assert response.status_code == 404 data = response.json() assert data["code"] == "jail_not_found" assert data["detail"] == "Jail not found: 'ssh'" assert data["metadata"] == {"jail_name": "ssh"} assert "correlation_id" in data response = await client.get("/bad-request") assert response.status_code == 400 data = response.json() assert data["code"] == "config_validation_failed" assert data["detail"] == "invalid payload" assert "correlation_id" in data response = await client.get("/server-error") assert response.status_code == 500 data = response.json() assert data["code"] == "config_write_failed" assert data["detail"] == "write failed" assert "correlation_id" in data def test_create_app_disables_cors_by_default() -> None: """The FastAPI app does not add CORS middleware when no origins are configured by environment.""" settings = Settings( database_path="/tmp/test.db", fail2ban_socket="/tmp/fake_fail2ban.sock", fail2ban_config_dir="/tmp/fail2ban", session_secret="test-secret-key-do-not-use-in-production", session_duration_minutes=60, timezone="UTC", log_level="debug", cors_allowed_origins=[], ) app = create_app(settings=settings) cors_middleware = [ middleware for middleware in app.user_middleware if middleware.cls is CORSMiddleware ] assert cors_middleware == [] def test_create_app_disables_api_docs_by_default() -> None: """API documentation endpoints are disabled when enable_docs is false.""" settings = Settings( database_path="/tmp/test.db", fail2ban_socket="/tmp/fake_fail2ban.sock", fail2ban_config_dir="/tmp/fail2ban", session_secret="test-secret-key-do-not-use-in-production", session_duration_minutes=60, timezone="UTC", log_level="debug", enable_docs=False, ) app = create_app(settings=settings) assert app.docs_url is None assert app.redoc_url is None assert app.openapi_url is None def test_create_app_enables_api_docs_when_configured() -> None: """API documentation endpoints are enabled at /api/* when enable_docs is true.""" settings = Settings( database_path="/tmp/test.db", fail2ban_socket="/tmp/fake_fail2ban.sock", fail2ban_config_dir="/tmp/fail2ban", session_secret="test-secret-key-do-not-use-in-production", session_duration_minutes=60, timezone="UTC", log_level="debug", enable_docs=True, ) app = create_app(settings=settings) assert app.docs_url == "/api/docs" assert app.redoc_url == "/api/redoc" assert app.openapi_url == "/api/openapi.json" async def test_lifespan_initialises_and_cleans_up_shared_resources(tmp_path: Path) -> None: """The app lifespan creates and shuts down shared resources cleanly.""" fail2ban_config_dir = tmp_path / "fail2ban" fail2ban_config_dir.mkdir() settings = Settings( database_path=str(tmp_path / "bangui.db"), fail2ban_socket="/tmp/fake_fail2ban.sock", fail2ban_config_dir=str(fail2ban_config_dir), session_secret="test-lifespan-secret-that-is-long-enough!!", session_duration_minutes=60, timezone="UTC", log_level="debug", ) app = create_app(settings=settings) mock_scheduler = MagicMock() mock_scheduler.start = MagicMock() mock_scheduler.shutdown = MagicMock() mock_http_session = MagicMock() mock_http_session.close = AsyncMock() with ( patch("app.startup.ensure_jail_configs"), patch("app.startup.aiohttp.ClientSession", return_value=mock_http_session), patch("app.startup.AsyncIOScheduler", return_value=mock_scheduler), patch("app.startup.init_db", new=AsyncMock()), patch("app.startup.acquire_scheduler_lock", new=AsyncMock(return_value=True)), patch("app.services.geo_cache.GeoCache.init_geoip"), patch("app.services.geo_cache.GeoCache.load_cache_from_db", new=AsyncMock(return_value=None)), patch("app.services.geo_cache.GeoCache.count_unresolved", new=AsyncMock(return_value=0)), patch("app.services.setup_service.is_setup_complete", new=AsyncMock(return_value=False)), patch("app.tasks.health_check.register"), patch("app.tasks.blocklist_import.register"), patch("app.tasks.geo_cache_flush.register"), patch("app.tasks.geo_re_resolve.register"), patch("app.tasks.history_sync.register"), patch("app.tasks.session_cleanup.register"), patch("app.tasks.rate_limiter_cleanup.register"), ): async with _lifespan(app): assert app.state.http_session is mock_http_session assert app.state.scheduler is mock_scheduler assert app.state.settings is settings mock_http_session.close.assert_awaited_once() mock_scheduler.shutdown.assert_called_once_with(wait=False) async def test_lifespan_cleans_up_resources_when_startup_fails(tmp_path: Path) -> None: """The lifespan must close resources if shared startup registration fails.""" fail2ban_config_dir = tmp_path / "fail2ban" fail2ban_config_dir.mkdir() settings = Settings( database_path=str(tmp_path / "bangui.db"), fail2ban_socket="/tmp/fake_fail2ban.sock", fail2ban_config_dir=str(fail2ban_config_dir), session_secret="test-lifespan-secret-that-is-long-enough!!", session_duration_minutes=60, timezone="UTC", log_level="debug", ) app = create_app(settings=settings) mock_scheduler = MagicMock() mock_scheduler.start = MagicMock() mock_scheduler.shutdown = MagicMock() mock_http_session = MagicMock() mock_http_session.close = AsyncMock() with pytest.raises(RuntimeError, match="startup failed"), \ patch("app.startup.ensure_jail_configs"), \ patch("app.startup.aiohttp.ClientSession", return_value=mock_http_session), \ patch("app.startup.AsyncIOScheduler", return_value=mock_scheduler), \ patch("app.startup.init_db", new=AsyncMock()), \ patch("app.startup.acquire_scheduler_lock", new=AsyncMock(return_value=True)), \ patch("app.services.geo_cache.GeoCache.init_geoip"), \ patch("app.services.geo_cache.GeoCache.load_cache_from_db", new=AsyncMock(return_value=None)), \ patch("app.services.geo_cache.GeoCache.count_unresolved", new=AsyncMock(return_value=0)), \ patch("app.services.setup_service.is_setup_complete", new=AsyncMock(return_value=False)), \ patch("app.tasks.health_check.register", side_effect=RuntimeError("startup failed")), \ patch("app.tasks.blocklist_import.register"), \ patch("app.tasks.geo_cache_flush.register"), \ patch("app.tasks.geo_re_resolve.register"), \ patch("app.tasks.history_sync.register"), \ patch("app.tasks.session_cleanup.register"), \ patch("app.tasks.rate_limiter_cleanup.register"): async with _lifespan(app): pass mock_http_session.close.assert_awaited_once() mock_scheduler.shutdown.assert_called_once_with(wait=False) async def test_http_session_is_created_with_configured_timeouts_and_limits(tmp_path: Path) -> None: """The shared HTTP client session is created with the configured limits.""" fail2ban_config_dir = tmp_path / "fail2ban" fail2ban_config_dir.mkdir() settings = Settings( database_path=str(tmp_path / "bangui.db"), fail2ban_socket="/tmp/fake_fail2ban.sock", fail2ban_config_dir=str(fail2ban_config_dir), session_secret="test-lifespan-secret-that-is-long-enough!!", session_duration_minutes=60, timezone="UTC", log_level="debug", http_request_timeout_seconds=12.5, http_connect_timeout_seconds=1.5, http_max_connections=5, http_keepalive_timeout_seconds=8.0, ) app = create_app(settings=settings) mock_scheduler = MagicMock() mock_scheduler.start = MagicMock() mock_scheduler.shutdown = MagicMock() mock_http_session = MagicMock() mock_http_session.close = AsyncMock() with ( patch("app.startup.ensure_jail_configs"), patch("app.startup.aiohttp.ClientSession", return_value=mock_http_session) as mock_client_session, patch("app.startup.AsyncIOScheduler", return_value=mock_scheduler), patch("app.startup.init_db", new=AsyncMock()), patch("app.startup.acquire_scheduler_lock", new=AsyncMock(return_value=True)), patch("app.services.geo_cache.GeoCache.init_geoip"), patch("app.services.geo_cache.GeoCache.load_cache_from_db", new=AsyncMock(return_value=None)), patch("app.services.geo_cache.GeoCache.count_unresolved", new=AsyncMock(return_value=0)), patch("app.services.setup_service.is_setup_complete", new=AsyncMock(return_value=False)), patch("app.tasks.health_check.register"), patch("app.tasks.blocklist_import.register"), patch("app.tasks.geo_cache_flush.register"), patch("app.tasks.geo_re_resolve.register"), patch("app.tasks.history_sync.register"), patch("app.tasks.session_cleanup.register"), patch("app.tasks.rate_limiter_cleanup.register"), ): async with _lifespan(app): assert mock_client_session.call_count == 1 kwargs = mock_client_session.call_args.kwargs timeout = kwargs["timeout"] connector = kwargs["connector"] assert timeout.total == 12.5 assert timeout.connect == 1.5 assert timeout.sock_read == 12.5 assert connector.limit == 5 assert connector.limit_per_host == 5 async def test_startup_overrides_settings_from_persisted_setup(tmp_path: Path) -> None: """Startup should replace env defaults with values persisted by setup.""" fail2ban_config_dir = tmp_path / "fail2ban" fail2ban_config_dir.mkdir() env_settings = Settings( database_path=str(tmp_path / "pointer.db"), fail2ban_socket="/tmp/fake_fail2ban.sock", fail2ban_config_dir=str(fail2ban_config_dir), session_secret="test-startup-secret-that-is-long-enough!!!", session_duration_minutes=60, timezone="UTC", log_level="debug", ) app = create_app(settings=env_settings) runtime_db_path = str(tmp_path / "runtime.db") db = await aiosqlite.connect(env_settings.database_path) db.row_factory = aiosqlite.Row await init_db(db) await setup_service.run_setup( db, master_password="Supersecret1!", database_path=runtime_db_path, fail2ban_socket="/tmp/persisted.sock", timezone="Europe/Berlin", session_duration_minutes=123, ) await db.close() mock_scheduler = MagicMock() mock_scheduler.start = MagicMock() mock_scheduler.shutdown = MagicMock() mock_http_session = MagicMock() mock_http_session.close = AsyncMock() with ( patch("app.startup.ensure_jail_configs"), patch("app.startup.aiohttp.ClientSession", return_value=mock_http_session), patch("app.startup.AsyncIOScheduler", return_value=mock_scheduler), patch("app.startup.acquire_scheduler_lock", new=AsyncMock(return_value=True)), patch("app.services.geo_cache.GeoCache.init_geoip"), patch("app.services.geo_cache.GeoCache.load_cache_from_db", new=AsyncMock(return_value=None)), patch("app.services.geo_cache.GeoCache.count_unresolved", new=AsyncMock(return_value=0)), patch("app.tasks.health_check.register"), patch("app.tasks.blocklist_import.register"), patch("app.tasks.geo_cache_flush.register"), patch("app.tasks.geo_re_resolve.register"), patch("app.tasks.history_sync.register"), patch("app.tasks.session_cleanup.register"), patch("app.tasks.rate_limiter_cleanup.register"), ): async with _lifespan(app): assert app.state.runtime_settings is not None assert app.state.runtime_settings.database_path == runtime_db_path assert app.state.runtime_settings.fail2ban_socket == "/tmp/persisted.sock" assert app.state.runtime_settings.timezone == "Europe/Berlin" assert app.state.runtime_settings.session_duration_minutes == 123 assert app.state.settings.database_path == str(tmp_path / "pointer.db") assert Path(runtime_db_path).exists() async def test_startup_loads_geo_cache_from_persisted_runtime_database(tmp_path: Path) -> None: """Startup must load geo cache from the resolved runtime database.""" fail2ban_config_dir = tmp_path / "fail2ban" fail2ban_config_dir.mkdir() env_settings = Settings( database_path=str(tmp_path / "pointer.db"), fail2ban_socket="/tmp/fake_fail2ban.sock", fail2ban_config_dir=str(fail2ban_config_dir), session_secret="test-startup-secret-that-is-long-enough!!!", session_duration_minutes=60, timezone="UTC", log_level="debug", ) app = create_app(settings=env_settings) runtime_db_path = str(tmp_path / "runtime.db") opened_connections: list[tuple[str, aiosqlite.Connection]] = [] async def fake_open_db(path: str) -> aiosqlite.Connection: connection = await aiosqlite.connect(path) opened_connections.append((path, connection)) return connection mock_scheduler = MagicMock() mock_scheduler.start = MagicMock() mock_scheduler.shutdown = MagicMock() mock_http_session = MagicMock() mock_http_session.close = AsyncMock() load_cache = AsyncMock() exit_stack = contextlib.ExitStack() exit_stack.enter_context(patch("app.startup.ensure_jail_configs")) exit_stack.enter_context(patch("app.startup.aiohttp.ClientSession", return_value=mock_http_session)) exit_stack.enter_context(patch("app.startup.AsyncIOScheduler", return_value=mock_scheduler)) exit_stack.enter_context(patch("app.startup.open_db", new=AsyncMock(side_effect=fake_open_db))) exit_stack.enter_context(patch("app.startup.init_db", new=AsyncMock())) exit_stack.enter_context(patch("app.startup.acquire_scheduler_lock", new=AsyncMock(return_value=True))) exit_stack.enter_context(patch("app.services.geo_cache.GeoCache.init_geoip")) exit_stack.enter_context(patch("app.services.geo_cache.GeoCache.load_cache_from_db", new=load_cache)) exit_stack.enter_context(patch("app.services.geo_cache.GeoCache.count_unresolved", new=AsyncMock(return_value=0))) exit_stack.enter_context(patch("app.services.setup_service.is_setup_complete", new=AsyncMock(return_value=True))) exit_stack.enter_context(patch( "app.services.setup_service.get_runtime_database_path", new=AsyncMock(return_value=runtime_db_path), )) exit_stack.enter_context(patch( "app.services.setup_service.get_persisted_runtime_settings", new=AsyncMock(return_value={ "database_path": runtime_db_path, "fail2ban_socket": "/tmp/persisted.sock", "timezone": "Europe/Berlin", "session_duration_minutes": 123, }), )) exit_stack.enter_context(patch( "app.services.setup_service.get_fail2ban_db_path", new=AsyncMock(return_value="/tmp/fail2ban/banned.tar.bz2"), )) exit_stack.enter_context(patch("app.tasks.health_check.register")) exit_stack.enter_context(patch("app.tasks.blocklist_import.register")) exit_stack.enter_context(patch("app.tasks.geo_cache_flush.register")) exit_stack.enter_context(patch("app.tasks.geo_re_resolve.register")) exit_stack.enter_context(patch("app.tasks.history_sync.register")) with exit_stack: async with _lifespan(app): runtime_connections = [ conn for path, conn in opened_connections if path == runtime_db_path ] assert runtime_connections, "Expected runtime database to be opened" assert app.state.runtime_settings is not None assert app.state.runtime_settings.database_path == runtime_db_path for _, connection in opened_connections: await connection.close() async def test_concurrent_requests_use_request_scoped_db_connections(tmp_path: Path) -> None: """Concurrent requests each open and close their own database connection.""" fail2ban_config_dir = tmp_path / "fail2ban" fail2ban_config_dir.mkdir() settings = Settings( database_path=str(tmp_path / "bangui.db"), fail2ban_socket="/tmp/fake_fail2ban.sock", fail2ban_config_dir=str(fail2ban_config_dir), session_secret="test-concurrency-secret-that-is-long-enough!!!", session_duration_minutes=60, timezone="UTC", log_level="debug", ) app = create_app(settings=settings) connections: list[MagicMock] = [] async def fake_open_db(database_path: str) -> MagicMock: connection = MagicMock() connection.close = AsyncMock() connections.append(connection) return connection mock_scheduler = MagicMock() mock_scheduler.start = MagicMock() mock_scheduler.shutdown = MagicMock() mock_http_session = MagicMock() mock_http_session.close = AsyncMock() with ( patch("app.startup.open_db", new=AsyncMock(side_effect=fake_open_db)), patch("app.db.open_db", new=AsyncMock(side_effect=fake_open_db)), patch("app.startup.init_db", new=AsyncMock()), patch("app.startup.ensure_jail_configs"), patch("app.startup.aiohttp.ClientSession", return_value=mock_http_session), patch("app.startup.AsyncIOScheduler", return_value=mock_scheduler), patch("app.startup.acquire_scheduler_lock", new=AsyncMock(return_value=True)), patch("app.services.geo_cache.GeoCache.init_geoip"), patch("app.services.geo_cache.GeoCache.load_cache_from_db", new=AsyncMock(return_value=None)), patch("app.services.geo_cache.GeoCache.count_unresolved", new=AsyncMock(return_value=0)), patch("app.services.setup_service.is_setup_complete", new=AsyncMock(return_value=False)), patch("app.tasks.health_check.register"), patch("app.tasks.blocklist_import.register"), patch("app.tasks.geo_cache_flush.register"), patch("app.tasks.geo_re_resolve.register"), patch("app.tasks.history_sync.register"), patch("app.tasks.session_cleanup.register"), patch("app.tasks.rate_limiter_cleanup.register"), ): transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: app.state.setup_complete_cached = True responses = await asyncio.gather(*(client.post("/api/v1/auth/logout") for _ in range(5))) assert len(connections) == 5 assert len({id(connection) for connection in connections}) == 5 assert all(response.status_code == 200 for response in responses) assert all(connection.close.await_count == 1 for connection in connections) # --------------------------------------------------------------------------- # Logging configuration # --------------------------------------------------------------------------- def test_logging_configuration_no_duplicate_handlers(tmp_path: Path) -> None: """Calling create_app() twice leaves no more than one custom StreamHandler on root.""" fail2ban_config_dir = tmp_path / "fail2ban" fail2ban_config_dir.mkdir() settings1 = Settings( database_path=str(tmp_path / "test1.db"), fail2ban_socket="/tmp/fake_fail2ban.sock", fail2ban_config_dir=str(fail2ban_config_dir), session_secret="test-secret-key-do-not-use-in-production", session_duration_minutes=60, timezone="UTC", log_level="debug", ) create_app(settings=settings1) settings2 = Settings( database_path=str(tmp_path / "test2.db"), fail2ban_socket="/tmp/fake_fail2ban.sock", fail2ban_config_dir=str(fail2ban_config_dir), session_secret="test-secret-key-do-not-use-in-production-2", session_duration_minutes=60, timezone="UTC", log_level="debug", ) create_app(settings=settings2) # _configure_logging uses basicConfig which replaces handlers on the root logger. # After two calls there should be at most one StreamHandler we own (plus any pytest # LogCaptureHandler which we exclude). root_stream_handlers = [ h for h in logging.getLogger().handlers if isinstance(h, logging.StreamHandler) and not type(h).__name__.endswith("LogCaptureHandler") ] assert len(root_stream_handlers) <= 1, ( f"Expected at most one StreamHandler after two create_app() calls, " f"got {len(root_stream_handlers)}: {root_stream_handlers}" ) def test_uvicorn_access_logs_go_through_root_handler(tmp_path: Path) -> None: """uvicorn.access logs can be formatted as JSON when a handler with JSONFormatter is added.""" fail2ban_config_dir = tmp_path / "fail2ban" fail2ban_config_dir.mkdir() settings = Settings( database_path=str(tmp_path / "test.db"), fail2ban_socket="/tmp/fake_fail2ban.sock", fail2ban_config_dir=str(fail2ban_config_dir), session_secret="test-secret-key-do-not-use-in-production", session_duration_minutes=60, timezone="UTC", log_level="debug", ) create_app(settings=settings) # uvicorn.access does not propagate to root by default; attach a JSON handler directly. uvicorn_access = logging.getLogger("uvicorn.access") output = io.StringIO() handler = logging.StreamHandler(stream=output) handler.setFormatter(JSONFormatter()) uvicorn_access.addHandler(handler) try: uvicorn_access.setLevel(logging.DEBUG) uvicorn_access.info("GET /api/v1/health 200") line = output.getvalue().strip() assert line, "Expected non-empty log output from uvicorn.access" parsed = json.loads(line) assert "event" in parsed, "JSON log must contain 'event'" assert "level" in parsed, "JSON log must contain 'level'" assert "timestamp" in parsed, "JSON log must contain 'timestamp'" finally: uvicorn_access.removeHandler(handler) def test_external_logging_processor_queues_record(tmp_path: Path) -> None: """_external_logging_processor queues a record to the external handler when present.""" from app.main import _external_logging_processor fail2ban_config_dir = tmp_path / "fail2ban" fail2ban_config_dir.mkdir() settings = Settings( database_path=str(tmp_path / "test.db"), fail2ban_socket="/tmp/fake_fail2ban.sock", fail2ban_config_dir=str(fail2ban_config_dir), session_secret="test-secret-key-do-not-use-in-production", session_duration_minutes=60, timezone="UTC", log_level="debug", ) create_app(settings=settings) from app.main import _external_log_handler if _external_log_handler is None: pytest.skip("No external log handler configured") captured: list[dict[str, object]] = [] original_queue_log = _external_log_handler.queue_log def mock_queue_log(record: dict[str, object]) -> None: captured.append(record) _external_log_handler.queue_log = mock_queue_log try: record = logging.makeLogRecord({"msg": "test event", "levelname": "INFO", "name": "test.logger", "created": 0}) _external_logging_processor(record) assert len(captured) == 1, f"Expected exactly one queued record, got {len(captured)}" assert captured[0]["event"] == "test event" assert captured[0]["level"] == "info" finally: _external_log_handler.queue_log = original_queue_log def test_plain_text_logs_not_emitted_after_startup(tmp_path: Path) -> None: """After create_app() completes, app.db logger output is JSON, not plain text.""" fail2ban_config_dir = tmp_path / "fail2ban" fail2ban_config_dir.mkdir() settings = Settings( database_path=str(tmp_path / "test.db"), fail2ban_socket="/tmp/fake_fail2ban.sock", fail2ban_config_dir=str(fail2ban_config_dir), session_secret="test-secret-key-do-not-use-in-production", session_duration_minutes=60, timezone="UTC", log_level="debug", ) create_app(settings=settings) output = io.StringIO() handler = logging.StreamHandler(stream=output) handler.setFormatter(JSONFormatter()) db_logger = logging.getLogger("app.db") db_logger.addHandler(handler) db_logger.setLevel(logging.DEBUG) try: db_logger.info("test_db_log") line = output.getvalue().strip() assert line, "Expected non-empty log output" assert not line.startswith("test_db_log "), "Log must not be plain text" parsed = json.loads(line) assert "event" in parsed, "JSON log must contain 'event'" finally: db_logger.removeHandler(handler) try: db_logger.info("test_db_log") line = output.getvalue().strip() assert line, "Expected non-empty log output" assert not line.startswith("test_db_log "), "Log must not be plain text" parsed = json.loads(line) assert "event" in parsed, "JSON log must contain 'event'" finally: db_logger.removeHandler(handler) # --------------------------------------------------------------------------- # Middleware order validation # --------------------------------------------------------------------------- def _make_settings(tmp_path: Path) -> Settings: """Return a minimal Settings object with a temporary fail2ban config dir.""" fail2ban_config_dir = tmp_path / "fail2ban" fail2ban_config_dir.mkdir() return Settings( database_path=str(tmp_path / "bangui.db"), fail2ban_socket="/tmp/fake_fail2ban.sock", fail2ban_config_dir=str(fail2ban_config_dir), session_secret="test-secret-key-do-not-use-in-production", session_duration_minutes=60, timezone="UTC", log_level="debug", ) def test_create_app_raises_on_incorrect_middleware_order( tmp_path: Path, monkeypatch: pytest.MonkeyPatch, ) -> None: """_assert_middleware_order() raises AssertionError when middleware order is wrong. The security-critical chain requires: RateLimitMiddleware → CsrfMiddleware → CorrelationIdMiddleware in user_middleware (processing order: outermost → innermost). """ monkeypatch.setenv("TESTING", "1") settings = _make_settings(tmp_path) app = create_app(settings=settings) # Swap CorrelationIdMiddleware and RateLimitMiddleware to break the order. user_mw = app.user_middleware corr_idx = next(i for i, m in enumerate(user_mw) if m.cls.__name__ == "CorrelationIdMiddleware") rate_idx = next(i for i, m in enumerate(user_mw) if m.cls.__name__ == "RateLimitMiddleware") user_mw[corr_idx], user_mw[rate_idx] = user_mw[rate_idx], user_mw[corr_idx] with pytest.raises(AssertionError, match="must be registered before"): _assert_middleware_order(app) def test_middleware_order_validation_passes_for_correct_order( tmp_path: Path, monkeypatch: pytest.MonkeyPatch, ) -> None: """_assert_middleware_order() does not raise when middleware order is correct.""" monkeypatch.setenv("TESTING", "1") settings = _make_settings(tmp_path) app = create_app(settings=settings) _assert_middleware_order(app) # Should not raise def test_create_app_validates_middleware_order_at_startup( tmp_path: Path, monkeypatch: pytest.MonkeyPatch, ) -> None: """create_app() raises immediately if middleware registration order is incorrect. This test verifies the integration: _assert_middleware_order is called at the end of create_app, so a fresh app with deliberately wrong middleware order (simulated by patching add_middleware during creation) raises AssertionError. """ monkeypatch.setenv("TESTING", "1") settings = _make_settings(tmp_path) from starlette.applications import Starlette original_add = Starlette.add_middleware def swapping_add(self, middleware_cls: type, **kwargs: object) -> None: """Patched add_middleware that swaps CorrelationId and RateLimit.""" if middleware_cls is CorrelationIdMiddleware: pass # Skip CorrelationId elif middleware_cls is RateLimitMiddleware: original_add(self, RateLimitMiddleware, **kwargs) original_add(self, CorrelationIdMiddleware) else: original_add(self, middleware_cls, **kwargs) with patch.object(Starlette, "add_middleware", swapping_add), \ pytest.raises(AssertionError, match="must be registered before"): create_app(settings=settings) # --------------------------------------------------------------------------- # Single-worker enforcement # --------------------------------------------------------------------------- def test_enforce_single_worker_allows_no_env_vars(monkeypatch: pytest.MonkeyPatch) -> None: """No error raised when WEB_CONCURRENCY and BANGUI_WORKERS are not set.""" monkeypatch.delenv("WEB_CONCURRENCY", raising=False) monkeypatch.delenv("BANGUI_WORKERS", raising=False) # Should not raise _enforce_single_worker() def test_enforce_single_worker_allows_workers_1(monkeypatch: pytest.MonkeyPatch) -> None: """WEB_CONCURRENCY=1 and BANGUI_WORKERS=1 are both allowed.""" monkeypatch.setenv("WEB_CONCURRENCY", "1") monkeypatch.setenv("BANGUI_WORKERS", "1") _enforce_single_worker() # Should not raise def test_enforce_single_worker_rejects_web_concurrency_2(monkeypatch: pytest.MonkeyPatch) -> None: """WEB_CONCURRENCY=2 raises RuntimeError.""" monkeypatch.setenv("WEB_CONCURRENCY", "2") monkeypatch.delenv("BANGUI_WORKERS", raising=False) with pytest.raises(RuntimeError, match="WEB_CONCURRENCY"): _enforce_single_worker() def test_enforce_single_worker_rejects_web_concurrency_4(monkeypatch: pytest.MonkeyPatch) -> None: """WEB_CONCURRENCY=4 raises RuntimeError.""" monkeypatch.setenv("WEB_CONCURRENCY", "4") monkeypatch.delenv("BANGUI_WORKERS", raising=False) with pytest.raises(RuntimeError, match="WEB_CONCURRENCY"): _enforce_single_worker() def test_enforce_single_worker_rejects_bangui_workers_2(monkeypatch: pytest.MonkeyPatch) -> None: """BANGUI_WORKERS=2 raises RuntimeError.""" monkeypatch.setenv("WEB_CONCURRENCY", "1") # WEB_CONCURRENCY=1 should pass monkeypatch.setenv("BANGUI_WORKERS", "2") # but BANGUI_WORKERS=2 should fail with pytest.raises(RuntimeError, match="BANGUI_WORKERS"): _enforce_single_worker() def test_enforce_single_worker_rejects_invalid_web_concurrency(monkeypatch: pytest.MonkeyPatch) -> None: """Non-integer WEB_CONCURRENCY raises RuntimeError.""" monkeypatch.setenv("WEB_CONCURRENCY", "not-a-number") with pytest.raises(RuntimeError, match="WEB_CONCURRENCY must be an integer"): _enforce_single_worker() def test_enforce_single_worker_error_message_mentions_docs(monkeypatch: pytest.MonkeyPatch) -> None: """Error message references Docs/Deployment.md.""" monkeypatch.setenv("WEB_CONCURRENCY", "2") with pytest.raises(RuntimeError) as exc_info: _enforce_single_worker() assert "Deployment.md" in str(exc_info.value) def test_create_app_raises_when_web_concurrency_gt_1(monkeypatch: pytest.MonkeyPatch) -> None: """create_app() raises RuntimeError when WEB_CONCURRENCY > 1 (no TESTING set).""" monkeypatch.setenv("WEB_CONCURRENCY", "2") with pytest.raises(RuntimeError, match="WEB_CONCURRENCY"): create_app(settings=None) # settings=None triggers get_settings() which loads env vars def test_create_app_skips_enforcement_when_testing_set(monkeypatch: pytest.MonkeyPatch) -> None: """create_app() does NOT raise when TESTING env var is set, even with workers > 1.""" monkeypatch.setenv("WEB_CONCURRENCY", "4") monkeypatch.setenv("BANGUI_WORKERS", "4") monkeypatch.setenv("TESTING", "1") # Pass explicit settings to bypass get_settings() env loading. settings = Settings( database_path="/tmp/test.db", fail2ban_socket="/tmp/fake_fail2ban.sock", fail2ban_config_dir="/tmp/fail2ban", session_secret="test-secret-key-do-not-use-in-production", session_duration_minutes=60, timezone="UTC", log_level="debug", ) # Should not raise app = create_app(settings=settings) assert app is not None