"""Unit tests for backend application startup and middleware configuration.""" import asyncio from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch import aiosqlite import pytest from httpx import ASGITransport, AsyncClient from app.config import Settings from app.db import init_db from app.exceptions import ConfigValidationError, ConfigWriteError, JailNotFoundError from app.main import CORSMiddleware, _lifespan, create_app from app.services import setup_service def test_create_app_configures_cors_from_settings() -> None: """The FastAPI app registers CORS middleware with the configured origins.""" settings = Settings( database_path="/tmp/test.db", fail2ban_socket="/tmp/fake_fail2ban.sock", fail2ban_config_dir="/tmp/fail2ban", session_secret="test-secret-key-do-not-use-in-production", session_duration_minutes=60, timezone="UTC", log_level="debug", cors_allowed_origins=["https://frontend.example.com"], ) app = create_app(settings=settings) cors_middleware = [ middleware for middleware in app.user_middleware if middleware.cls is CORSMiddleware ] assert len(cors_middleware) == 1 assert cors_middleware[0].kwargs["allow_origins"] == ["https://frontend.example.com"] assert cors_middleware[0].kwargs["allow_credentials"] is True assert cors_middleware[0].kwargs["allow_methods"] == ["*"] assert cors_middleware[0].kwargs["allow_headers"] == ["*"] def test_create_app_skips_cors_when_no_origins_are_configured() -> None: """The FastAPI app does not add CORS middleware when no origins are configured.""" settings = Settings( database_path="/tmp/test.db", fail2ban_socket="/tmp/fake_fail2ban.sock", fail2ban_config_dir="/tmp/fail2ban", session_secret="test-secret-key-do-not-use-in-production", session_duration_minutes=60, timezone="UTC", log_level="debug", cors_allowed_origins=[], ) app = create_app(settings=settings) cors_middleware = [ middleware for middleware in app.user_middleware if middleware.cls is CORSMiddleware ] assert cors_middleware == [] def test_create_app_initialises_runtime_state_manager() -> None: """The FastAPI app exposes a dedicated runtime state manager on app.state.""" settings = Settings( database_path="/tmp/test.db", fail2ban_socket="/tmp/fake_fail2ban.sock", fail2ban_config_dir="/tmp/fail2ban", session_secret="test-secret-key-do-not-use-in-production", session_duration_minutes=60, timezone="UTC", log_level="debug", ) app = create_app(settings=settings) runtime_state = app.state.runtime_state assert runtime_state.setup_complete_cached is False assert runtime_state.server_status.online is False assert runtime_state.pending_recovery is None assert runtime_state.last_activation is None assert app.state.server_status.online is False async def test_create_app_global_domain_exception_handlers() -> None: """Global exception handlers map domain exceptions to consistent HTTP responses.""" settings = Settings( database_path="/tmp/test.db", fail2ban_socket="/tmp/fake_fail2ban.sock", fail2ban_config_dir="/tmp/fail2ban", session_secret="test-secret-key-do-not-use-in-production", session_duration_minutes=60, timezone="UTC", log_level="debug", ) app = create_app(settings=settings) @app.get("/not-found") async def raise_not_found() -> None: raise JailNotFoundError("ssh") @app.get("/bad-request") async def raise_bad_request() -> None: raise ConfigValidationError("invalid payload") @app.get("/server-error") async def raise_server_error() -> None: raise ConfigWriteError("write failed") transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: response = await client.get("/not-found") assert response.status_code == 404 assert response.json() == {"detail": "Jail not found: 'ssh'"} response = await client.get("/bad-request") assert response.status_code == 400 assert response.json() == {"detail": "invalid payload"} response = await client.get("/server-error") assert response.status_code == 500 assert response.json() == {"detail": "write failed"} def test_create_app_disables_cors_by_default() -> None: """The FastAPI app does not add CORS middleware when no origins are configured by environment.""" settings = Settings( database_path="/tmp/test.db", fail2ban_socket="/tmp/fake_fail2ban.sock", fail2ban_config_dir="/tmp/fail2ban", session_secret="test-secret-key-do-not-use-in-production", session_duration_minutes=60, timezone="UTC", log_level="debug", ) app = create_app(settings=settings) cors_middleware = [ middleware for middleware in app.user_middleware if middleware.cls is CORSMiddleware ] assert cors_middleware == [] async def test_lifespan_initialises_and_cleans_up_shared_resources(tmp_path: Path) -> None: """The app lifespan creates and shuts down shared resources cleanly.""" settings = Settings( database_path=str(tmp_path / "bangui.db"), fail2ban_socket="/tmp/fake_fail2ban.sock", fail2ban_config_dir=str(tmp_path / "fail2ban"), session_secret="test-lifespan-secret", session_duration_minutes=60, timezone="UTC", log_level="debug", ) app = create_app(settings=settings) mock_scheduler = MagicMock() mock_scheduler.start = MagicMock() mock_scheduler.shutdown = MagicMock() mock_http_session = MagicMock() mock_http_session.close = AsyncMock() with ( patch("app.startup.ensure_jail_configs"), patch("app.startup.aiohttp.ClientSession", return_value=mock_http_session), patch("app.startup.AsyncIOScheduler", return_value=mock_scheduler), patch("app.startup.init_db", new=AsyncMock()), patch("app.services.geo_service.init_geoip"), patch("app.services.geo_service.load_cache_from_db", new=AsyncMock(return_value=None)), patch("app.services.geo_service.count_unresolved", new=AsyncMock(return_value=0)), patch("app.services.setup_service.is_setup_complete", new=AsyncMock(return_value=False)), patch("app.tasks.health_check.register"), patch("app.tasks.blocklist_import.register"), patch("app.tasks.geo_cache_flush.register"), patch("app.tasks.geo_re_resolve.register"), patch("app.tasks.history_sync.register"), ): async with _lifespan(app): assert app.state.http_session is mock_http_session assert app.state.scheduler is mock_scheduler assert app.state.settings is settings mock_http_session.close.assert_awaited_once() mock_scheduler.shutdown.assert_called_once_with(wait=False) async def test_lifespan_cleans_up_resources_when_startup_fails(tmp_path: Path) -> None: """The lifespan must close resources if shared startup registration fails.""" settings = Settings( database_path=str(tmp_path / "bangui.db"), fail2ban_socket="/tmp/fake_fail2ban.sock", fail2ban_config_dir=str(tmp_path / "fail2ban"), session_secret="test-lifespan-secret", session_duration_minutes=60, timezone="UTC", log_level="debug", ) app = create_app(settings=settings) mock_scheduler = MagicMock() mock_scheduler.start = MagicMock() mock_scheduler.shutdown = MagicMock() mock_http_session = MagicMock() mock_http_session.close = AsyncMock() with pytest.raises(RuntimeError, match="startup failed"), \ patch("app.startup.ensure_jail_configs"), \ patch("app.startup.aiohttp.ClientSession", return_value=mock_http_session), \ patch("app.startup.AsyncIOScheduler", return_value=mock_scheduler), \ patch("app.startup.init_db", new=AsyncMock()), \ patch("app.services.geo_service.init_geoip"), \ patch("app.services.geo_service.load_cache_from_db", new=AsyncMock(return_value=None)), \ patch("app.services.geo_service.count_unresolved", new=AsyncMock(return_value=0)), \ patch("app.services.setup_service.is_setup_complete", new=AsyncMock(return_value=False)), \ patch("app.tasks.health_check.register", side_effect=RuntimeError("startup failed")), \ patch("app.tasks.blocklist_import.register"), \ patch("app.tasks.geo_cache_flush.register"), \ patch("app.tasks.geo_re_resolve.register"), \ patch("app.tasks.history_sync.register"): async with _lifespan(app): pass mock_http_session.close.assert_awaited_once() mock_scheduler.shutdown.assert_called_once_with(wait=False) async def test_http_session_is_created_with_configured_timeouts_and_limits(tmp_path: Path) -> None: """The shared HTTP client session is created with the configured limits.""" settings = Settings( database_path=str(tmp_path / "bangui.db"), fail2ban_socket="/tmp/fake_fail2ban.sock", fail2ban_config_dir=str(tmp_path / "fail2ban"), session_secret="test-lifespan-secret", session_duration_minutes=60, timezone="UTC", log_level="debug", http_request_timeout_seconds=12.5, http_connect_timeout_seconds=1.5, http_max_connections=5, http_keepalive_timeout_seconds=8.0, ) app = create_app(settings=settings) mock_scheduler = MagicMock() mock_scheduler.start = MagicMock() mock_scheduler.shutdown = MagicMock() mock_http_session = MagicMock() mock_http_session.close = AsyncMock() with ( patch("app.startup.ensure_jail_configs"), patch("app.startup.aiohttp.ClientSession", return_value=mock_http_session) as mock_client_session, patch("app.startup.AsyncIOScheduler", return_value=mock_scheduler), patch("app.startup.init_db", new=AsyncMock()), patch("app.services.geo_service.init_geoip"), patch("app.services.geo_service.load_cache_from_db", new=AsyncMock(return_value=None)), patch("app.services.geo_service.count_unresolved", new=AsyncMock(return_value=0)), patch("app.services.setup_service.is_setup_complete", new=AsyncMock(return_value=False)), patch("app.tasks.health_check.register"), patch("app.tasks.blocklist_import.register"), patch("app.tasks.geo_cache_flush.register"), patch("app.tasks.geo_re_resolve.register"), patch("app.tasks.history_sync.register"), ): async with _lifespan(app): assert mock_client_session.call_count == 1 kwargs = mock_client_session.call_args.kwargs timeout = kwargs["timeout"] connector = kwargs["connector"] assert timeout.total == 12.5 assert timeout.connect == 1.5 assert timeout.sock_read == 12.5 assert connector.limit == 5 assert connector.limit_per_host == 5 async def test_startup_overrides_settings_from_persisted_setup(tmp_path: Path) -> None: """Startup should replace env defaults with values persisted by setup.""" env_settings = Settings( database_path=str(tmp_path / "pointer.db"), fail2ban_socket="/tmp/fake_fail2ban.sock", fail2ban_config_dir=str(tmp_path / "fail2ban"), session_secret="test-startup-secret", session_duration_minutes=60, timezone="UTC", log_level="debug", ) app = create_app(settings=env_settings) runtime_db_path = str(tmp_path / "runtime.db") db = await aiosqlite.connect(env_settings.database_path) db.row_factory = aiosqlite.Row await init_db(db) await setup_service.run_setup( db, master_password="Supersecret1!", database_path=runtime_db_path, fail2ban_socket="/tmp/persisted.sock", timezone="Europe/Berlin", session_duration_minutes=123, ) await db.close() mock_scheduler = MagicMock() mock_scheduler.start = MagicMock() mock_scheduler.shutdown = MagicMock() mock_http_session = MagicMock() mock_http_session.close = AsyncMock() with ( patch("app.startup.ensure_jail_configs"), patch("app.startup.aiohttp.ClientSession", return_value=mock_http_session), patch("app.startup.AsyncIOScheduler", return_value=mock_scheduler), patch("app.services.geo_service.init_geoip"), patch("app.services.geo_service.load_cache_from_db", new=AsyncMock(return_value=None)), patch("app.services.geo_service.count_unresolved", new=AsyncMock(return_value=0)), patch("app.tasks.health_check.register"), patch("app.tasks.blocklist_import.register"), patch("app.tasks.geo_cache_flush.register"), patch("app.tasks.geo_re_resolve.register"), patch("app.tasks.history_sync.register"), ): async with _lifespan(app): assert app.state.runtime_settings is not None assert app.state.runtime_settings.database_path == runtime_db_path assert app.state.runtime_settings.fail2ban_socket == "/tmp/persisted.sock" assert app.state.runtime_settings.timezone == "Europe/Berlin" assert app.state.runtime_settings.session_duration_minutes == 123 assert app.state.settings.database_path == str(tmp_path / "pointer.db") assert Path(runtime_db_path).exists() async def test_startup_loads_geo_cache_from_persisted_runtime_database(tmp_path: Path) -> None: """Startup must load geo cache from the resolved runtime database.""" env_settings = Settings( database_path=str(tmp_path / "pointer.db"), fail2ban_socket="/tmp/fake_fail2ban.sock", fail2ban_config_dir=str(tmp_path / "fail2ban"), session_secret="test-startup-secret", session_duration_minutes=60, timezone="UTC", log_level="debug", ) app = create_app(settings=env_settings) runtime_db_path = str(tmp_path / "runtime.db") opened_connections: list[tuple[str, aiosqlite.Connection]] = [] async def fake_open_db(path: str) -> aiosqlite.Connection: connection = await aiosqlite.connect(path) opened_connections.append((path, connection)) return connection mock_scheduler = MagicMock() mock_scheduler.start = MagicMock() mock_scheduler.shutdown = MagicMock() mock_http_session = MagicMock() mock_http_session.close = AsyncMock() load_cache = AsyncMock() with ( patch("app.startup.ensure_jail_configs"), patch("app.startup.aiohttp.ClientSession", return_value=mock_http_session), patch("app.startup.AsyncIOScheduler", return_value=mock_scheduler), patch("app.startup.open_db", new=AsyncMock(side_effect=fake_open_db)), patch("app.startup.init_db", new=AsyncMock()), patch("app.services.geo_service.init_geoip"), patch("app.services.geo_service.load_cache_from_db", new=load_cache), patch("app.services.geo_service.count_unresolved", new=AsyncMock(return_value=0)), patch("app.services.setup_service.is_setup_complete", new=AsyncMock(return_value=True)), patch("app.services.setup_service.get_runtime_database_path", new=AsyncMock(return_value=runtime_db_path)), patch( "app.services.setup_service.get_persisted_runtime_settings", new=AsyncMock( return_value={ "database_path": runtime_db_path, "fail2ban_socket": "/tmp/persisted.sock", "timezone": "Europe/Berlin", "session_duration_minutes": 123, } ), ), patch("app.tasks.health_check.register"), patch("app.tasks.blocklist_import.register"), patch("app.tasks.geo_cache_flush.register"), patch("app.tasks.geo_re_resolve.register"), patch("app.tasks.history_sync.register"), ): async with _lifespan(app): loaded_db = load_cache.call_args.args[0] runtime_connections = [conn for path, conn in opened_connections if path == runtime_db_path] assert runtime_connections, "Expected runtime database to be opened" assert loaded_db in runtime_connections assert app.state.runtime_settings is not None assert app.state.runtime_settings.database_path == runtime_db_path for _, connection in opened_connections: await connection.close() async def test_concurrent_requests_use_request_scoped_db_connections(tmp_path: Path) -> None: """Concurrent requests each open and close their own database connection.""" settings = Settings( database_path=str(tmp_path / "bangui.db"), fail2ban_socket="/tmp/fake_fail2ban.sock", fail2ban_config_dir=str(tmp_path / "fail2ban"), session_secret="test-concurrency-secret", session_duration_minutes=60, timezone="UTC", log_level="debug", ) app = create_app(settings=settings) connections: list[MagicMock] = [] async def fake_open_db(database_path: str) -> MagicMock: connection = MagicMock() connection.close = AsyncMock() connections.append(connection) return connection mock_scheduler = MagicMock() mock_scheduler.start = MagicMock() mock_scheduler.shutdown = MagicMock() mock_http_session = MagicMock() mock_http_session.close = AsyncMock() with ( patch("app.startup.open_db", new=AsyncMock(side_effect=fake_open_db)), patch("app.db.open_db", new=AsyncMock(side_effect=fake_open_db)), patch("app.startup.init_db", new=AsyncMock()), patch("app.startup.ensure_jail_configs"), patch("app.startup.aiohttp.ClientSession", return_value=mock_http_session), patch("app.startup.AsyncIOScheduler", return_value=mock_scheduler), patch("app.services.geo_service.init_geoip"), patch("app.services.geo_service.load_cache_from_db", new=AsyncMock(return_value=None)), patch("app.services.geo_service.count_unresolved", new=AsyncMock(return_value=0)), patch("app.services.setup_service.is_setup_complete", new=AsyncMock(return_value=False)), patch("app.tasks.health_check.register"), patch("app.tasks.blocklist_import.register"), patch("app.tasks.geo_cache_flush.register"), patch("app.tasks.geo_re_resolve.register"), patch("app.tasks.history_sync.register"), ): transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: app.state.setup_complete_cached = True responses = await asyncio.gather(*(client.post("/api/auth/logout") for _ in range(5))) assert len(connections) == 5 assert len({id(connection) for connection in connections}) == 5 assert all(response.status_code == 200 for response in responses) assert all(connection.close.await_count == 1 for connection in connections)