"""Unit tests for backend application startup and middleware configuration.""" import asyncio from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch import pytest import aiosqlite from httpx import ASGITransport, AsyncClient from app.config import Settings from app.db import init_db from app.main import CORSMiddleware, _lifespan, create_app from app.services import setup_service def test_create_app_configures_cors_from_settings() -> None: """The FastAPI app registers CORS middleware with the configured origins.""" settings = Settings( database_path="/tmp/test.db", fail2ban_socket="/tmp/fake_fail2ban.sock", fail2ban_config_dir="/tmp/fail2ban", session_secret="test-secret-key-do-not-use-in-production", session_duration_minutes=60, timezone="UTC", log_level="debug", cors_allowed_origins=["https://frontend.example.com"], ) app = create_app(settings=settings) cors_middleware = [ middleware for middleware in app.user_middleware if middleware.cls is CORSMiddleware ] assert len(cors_middleware) == 1 assert cors_middleware[0].kwargs["allow_origins"] == ["https://frontend.example.com"] assert cors_middleware[0].kwargs["allow_credentials"] is True assert cors_middleware[0].kwargs["allow_methods"] == ["*"] assert cors_middleware[0].kwargs["allow_headers"] == ["*"] def test_create_app_skips_cors_when_no_origins_are_configured() -> None: """The FastAPI app does not add CORS middleware when no origins are configured.""" settings = Settings( database_path="/tmp/test.db", fail2ban_socket="/tmp/fake_fail2ban.sock", fail2ban_config_dir="/tmp/fail2ban", session_secret="test-secret-key-do-not-use-in-production", session_duration_minutes=60, timezone="UTC", log_level="debug", cors_allowed_origins=[], ) app = create_app(settings=settings) cors_middleware = [ middleware for middleware in app.user_middleware if middleware.cls is CORSMiddleware ] assert cors_middleware == [] def test_create_app_initialises_runtime_state_manager() -> None: """The FastAPI app exposes a dedicated runtime state manager on app.state.""" settings = Settings( database_path="/tmp/test.db", fail2ban_socket="/tmp/fake_fail2ban.sock", fail2ban_config_dir="/tmp/fail2ban", session_secret="test-secret-key-do-not-use-in-production", session_duration_minutes=60, timezone="UTC", log_level="debug", ) app = create_app(settings=settings) runtime_state = app.state.runtime_state assert runtime_state.setup_complete_cached is False assert runtime_state.server_status.online is False assert runtime_state.pending_recovery is None assert runtime_state.last_activation is None assert app.state.server_status.online is False def test_create_app_disables_cors_by_default() -> None: """The FastAPI app does not add CORS middleware when no origins are configured by environment.""" settings = Settings( database_path="/tmp/test.db", fail2ban_socket="/tmp/fake_fail2ban.sock", fail2ban_config_dir="/tmp/fail2ban", session_secret="test-secret-key-do-not-use-in-production", session_duration_minutes=60, timezone="UTC", log_level="debug", ) app = create_app(settings=settings) cors_middleware = [ middleware for middleware in app.user_middleware if middleware.cls is CORSMiddleware ] assert cors_middleware == [] async def test_lifespan_initialises_and_cleans_up_shared_resources(tmp_path: Path) -> None: """The app lifespan creates and shuts down shared resources cleanly.""" settings = Settings( database_path=str(tmp_path / "bangui.db"), fail2ban_socket="/tmp/fake_fail2ban.sock", fail2ban_config_dir=str(tmp_path / "fail2ban"), session_secret="test-lifespan-secret", session_duration_minutes=60, timezone="UTC", log_level="debug", ) app = create_app(settings=settings) mock_scheduler = MagicMock() mock_scheduler.start = MagicMock() mock_scheduler.shutdown = MagicMock() mock_http_session = MagicMock() mock_http_session.close = AsyncMock() with ( patch("app.startup.ensure_jail_configs"), patch("app.startup.aiohttp.ClientSession", return_value=mock_http_session), patch("app.startup.AsyncIOScheduler", return_value=mock_scheduler), patch("app.startup.init_db", new=AsyncMock()), patch("app.services.geo_service.init_geoip"), patch("app.services.geo_service.load_cache_from_db", new=AsyncMock(return_value=None)), patch("app.services.geo_service.count_unresolved", new=AsyncMock(return_value=0)), patch("app.services.setup_service.is_setup_complete", new=AsyncMock(return_value=False)), patch("app.tasks.health_check.register"), patch("app.tasks.blocklist_import.register"), patch("app.tasks.geo_cache_flush.register"), patch("app.tasks.geo_re_resolve.register"), patch("app.tasks.history_sync.register"), ): async with _lifespan(app): assert app.state.http_session is mock_http_session assert app.state.scheduler is mock_scheduler assert app.state.settings is settings mock_http_session.close.assert_awaited_once() mock_scheduler.shutdown.assert_called_once_with(wait=False) async def test_lifespan_cleans_up_resources_when_startup_fails(tmp_path: Path) -> None: """The lifespan must close resources if shared startup registration fails.""" settings = Settings( database_path=str(tmp_path / "bangui.db"), fail2ban_socket="/tmp/fake_fail2ban.sock", fail2ban_config_dir=str(tmp_path / "fail2ban"), session_secret="test-lifespan-secret", session_duration_minutes=60, timezone="UTC", log_level="debug", ) app = create_app(settings=settings) mock_scheduler = MagicMock() mock_scheduler.start = MagicMock() mock_scheduler.shutdown = MagicMock() mock_http_session = MagicMock() mock_http_session.close = AsyncMock() with ( patch("app.startup.ensure_jail_configs"), patch("app.startup.aiohttp.ClientSession", return_value=mock_http_session), patch("app.startup.AsyncIOScheduler", return_value=mock_scheduler), patch("app.startup.init_db", new=AsyncMock()), patch("app.services.geo_service.init_geoip"), patch("app.services.geo_service.load_cache_from_db", new=AsyncMock(return_value=None)), patch("app.services.geo_service.count_unresolved", new=AsyncMock(return_value=0)), patch("app.services.setup_service.is_setup_complete", new=AsyncMock(return_value=False)), patch("app.tasks.health_check.register", side_effect=RuntimeError("startup failed")), patch("app.tasks.blocklist_import.register"), patch("app.tasks.geo_cache_flush.register"), patch("app.tasks.geo_re_resolve.register"), patch("app.tasks.history_sync.register"), ): with pytest.raises(RuntimeError, match="startup failed"): async with _lifespan(app): pass mock_http_session.close.assert_awaited_once() mock_scheduler.shutdown.assert_called_once_with(wait=False) async def test_http_session_is_created_with_configured_timeouts_and_limits(tmp_path: Path) -> None: """The shared HTTP client session is created with the configured limits.""" settings = Settings( database_path=str(tmp_path / "bangui.db"), fail2ban_socket="/tmp/fake_fail2ban.sock", fail2ban_config_dir=str(tmp_path / "fail2ban"), session_secret="test-lifespan-secret", session_duration_minutes=60, timezone="UTC", log_level="debug", http_request_timeout_seconds=12.5, http_connect_timeout_seconds=1.5, http_max_connections=5, http_keepalive_timeout_seconds=8.0, ) app = create_app(settings=settings) mock_scheduler = MagicMock() mock_scheduler.start = MagicMock() mock_scheduler.shutdown = MagicMock() mock_http_session = MagicMock() mock_http_session.close = AsyncMock() with ( patch("app.startup.ensure_jail_configs"), patch("app.startup.aiohttp.ClientSession", return_value=mock_http_session) as mock_client_session, patch("app.startup.AsyncIOScheduler", return_value=mock_scheduler), patch("app.startup.init_db", new=AsyncMock()), patch("app.services.geo_service.init_geoip"), patch("app.services.geo_service.load_cache_from_db", new=AsyncMock(return_value=None)), patch("app.services.geo_service.count_unresolved", new=AsyncMock(return_value=0)), patch("app.services.setup_service.is_setup_complete", new=AsyncMock(return_value=False)), patch("app.tasks.health_check.register"), patch("app.tasks.blocklist_import.register"), patch("app.tasks.geo_cache_flush.register"), patch("app.tasks.geo_re_resolve.register"), patch("app.tasks.history_sync.register"), ): async with _lifespan(app): assert mock_client_session.call_count == 1 kwargs = mock_client_session.call_args.kwargs timeout = kwargs["timeout"] connector = kwargs["connector"] assert timeout.total == 12.5 assert timeout.connect == 1.5 assert timeout.sock_read == 12.5 assert connector.limit == 5 assert connector.limit_per_host == 5 async def test_startup_overrides_settings_from_persisted_setup(tmp_path: Path) -> None: """Startup should replace env defaults with values persisted by setup.""" env_settings = Settings( database_path=str(tmp_path / "pointer.db"), fail2ban_socket="/tmp/fake_fail2ban.sock", fail2ban_config_dir=str(tmp_path / "fail2ban"), session_secret="test-startup-secret", session_duration_minutes=60, timezone="UTC", log_level="debug", ) app = create_app(settings=env_settings) runtime_db_path = str(tmp_path / "runtime.db") db = await aiosqlite.connect(env_settings.database_path) db.row_factory = aiosqlite.Row await init_db(db) await setup_service.run_setup( db, master_password="supersecret123", database_path=runtime_db_path, fail2ban_socket="/tmp/persisted.sock", timezone="Europe/Berlin", session_duration_minutes=123, ) await db.close() mock_scheduler = MagicMock() mock_scheduler.start = MagicMock() mock_scheduler.shutdown = MagicMock() mock_http_session = MagicMock() mock_http_session.close = AsyncMock() with ( patch("app.startup.ensure_jail_configs"), patch("app.startup.aiohttp.ClientSession", return_value=mock_http_session), patch("app.startup.AsyncIOScheduler", return_value=mock_scheduler), patch("app.services.geo_service.init_geoip"), patch("app.services.geo_service.load_cache_from_db", new=AsyncMock(return_value=None)), patch("app.services.geo_service.count_unresolved", new=AsyncMock(return_value=0)), patch("app.tasks.health_check.register"), patch("app.tasks.blocklist_import.register"), patch("app.tasks.geo_cache_flush.register"), patch("app.tasks.geo_re_resolve.register"), patch("app.tasks.history_sync.register"), ): async with _lifespan(app): assert app.state.settings.database_path == runtime_db_path assert app.state.settings.fail2ban_socket == "/tmp/persisted.sock" assert app.state.settings.timezone == "Europe/Berlin" assert app.state.settings.session_duration_minutes == 123 assert Path(runtime_db_path).exists() async def test_concurrent_requests_use_request_scoped_db_connections(tmp_path: Path) -> None: """Concurrent requests each open and close their own database connection.""" settings = Settings( database_path=str(tmp_path / "bangui.db"), fail2ban_socket="/tmp/fake_fail2ban.sock", fail2ban_config_dir=str(tmp_path / "fail2ban"), session_secret="test-concurrency-secret", session_duration_minutes=60, timezone="UTC", log_level="debug", ) app = create_app(settings=settings) connections: list[MagicMock] = [] async def fake_open_db(database_path: str) -> MagicMock: connection = MagicMock() connection.close = AsyncMock() connections.append(connection) return connection mock_scheduler = MagicMock() mock_scheduler.start = MagicMock() mock_scheduler.shutdown = MagicMock() mock_http_session = MagicMock() mock_http_session.close = AsyncMock() with ( patch("app.startup.open_db", new=AsyncMock(side_effect=fake_open_db)), patch("app.db.open_db", new=AsyncMock(side_effect=fake_open_db)), patch("app.startup.init_db", new=AsyncMock()), patch("app.startup.ensure_jail_configs"), patch("app.startup.aiohttp.ClientSession", return_value=mock_http_session), patch("app.startup.AsyncIOScheduler", return_value=mock_scheduler), patch("app.services.geo_service.init_geoip"), patch("app.services.geo_service.load_cache_from_db", new=AsyncMock(return_value=None)), patch("app.services.geo_service.count_unresolved", new=AsyncMock(return_value=0)), patch("app.services.setup_service.is_setup_complete", new=AsyncMock(return_value=False)), patch("app.tasks.health_check.register"), patch("app.tasks.blocklist_import.register"), patch("app.tasks.geo_cache_flush.register"), patch("app.tasks.geo_re_resolve.register"), patch("app.tasks.history_sync.register"), ): transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: app.state.setup_complete_cached = True responses = await asyncio.gather(*(client.post("/api/auth/logout") for _ in range(5))) assert len(connections) == 5 assert len({id(connection) for connection in connections}) == 5 assert all(response.status_code == 200 for response in responses) assert all(connection.close.await_count == 1 for connection in connections)