Files
BanGUI/backend/tests/test_main.py
Lukas 2db635ae19 Fix exception handler overlap issue - add DomainError catch-all handler
**Problem:** Broad exception handlers created fragility where adding a new
DomainError subclass without explicit registration would silently fall through
to the generic exception handler, losing the specific error_code and metadata.

**Solution:**
1. Import DomainError in main.py for explicit handler registration
2. Fix type hints in exception handlers from 'Exception' to specific types
   - NotFoundError handler now typed as 'NotFoundError'
   - BadRequestError handler now typed as 'BadRequestError'
   - ConflictError handler now typed as 'ConflictError'
   - DomainError handler now typed as 'DomainError'
   - ServiceUnavailableError handler now typed as 'ServiceUnavailableError'
3. Add DomainError as an explicit catch-all handler in the registration chain
   - Positioned after specific handlers, before HTTPException
   - Any unregistered DomainError subclass now gets correct error_code + metadata
4. Document the exception handler hierarchy with detailed comments
5. Update Backend-Development.md with handler hierarchy documentation
6. Update Architekture.md section 2.2 with exception handler details
7. Fix test expectations in test_main.py to verify ErrorResponse format

**Impact:** Any new DomainError subclass now automatically gets correct HTTP 500
status, error_code, and metadata - even if developer forgets explicit handler.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-04-30 19:44:43 +02:00

513 lines
21 KiB
Python

"""Unit tests for backend application startup and middleware configuration."""
import asyncio
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import aiosqlite
import pytest
from httpx import ASGITransport, AsyncClient
from app.config import Settings
from app.db import init_db
from app.exceptions import ConfigValidationError, ConfigWriteError, JailNotFoundError
from app.main import CORSMiddleware, _lifespan, create_app
from app.services import setup_service
def test_create_app_configures_cors_from_settings() -> None:
"""The FastAPI app registers CORS middleware with the configured origins."""
settings = Settings(
database_path="/tmp/test.db",
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir="/tmp/fail2ban",
session_secret="test-secret-key-do-not-use-in-production",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
cors_allowed_origins=["https://frontend.example.com"],
)
app = create_app(settings=settings)
cors_middleware = [
middleware for middleware in app.user_middleware if middleware.cls is CORSMiddleware
]
assert len(cors_middleware) == 1
assert cors_middleware[0].kwargs["allow_origins"] == ["https://frontend.example.com"]
assert cors_middleware[0].kwargs["allow_credentials"] is True
assert cors_middleware[0].kwargs["allow_methods"] == ["*"]
assert cors_middleware[0].kwargs["allow_headers"] == ["*"]
def test_create_app_skips_cors_when_no_origins_are_configured() -> None:
"""The FastAPI app does not add CORS middleware when no origins are configured."""
settings = Settings(
database_path="/tmp/test.db",
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir="/tmp/fail2ban",
session_secret="test-secret-key-do-not-use-in-production",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
cors_allowed_origins=[],
)
app = create_app(settings=settings)
cors_middleware = [
middleware for middleware in app.user_middleware if middleware.cls is CORSMiddleware
]
assert cors_middleware == []
def test_create_app_initialises_runtime_state_manager() -> None:
"""The FastAPI app exposes a dedicated runtime state manager on app.state."""
settings = Settings(
database_path="/tmp/test.db",
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir="/tmp/fail2ban",
session_secret="test-secret-key-do-not-use-in-production",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
)
app = create_app(settings=settings)
runtime_state = app.state.runtime_state
assert runtime_state.setup_complete_cached is False
assert runtime_state.server_status.online is False
assert runtime_state.pending_recovery is None
assert runtime_state.last_activation is None
assert app.state.server_status.online is False
async def test_create_app_global_domain_exception_handlers() -> None:
"""Global exception handlers map domain exceptions to consistent HTTP responses."""
settings = Settings(
database_path="/tmp/test.db",
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir="/tmp/fail2ban",
session_secret="test-secret-key-do-not-use-in-production",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
)
app = create_app(settings=settings)
@app.get("/not-found")
async def raise_not_found() -> None:
raise JailNotFoundError("ssh")
@app.get("/bad-request")
async def raise_bad_request() -> None:
raise ConfigValidationError("invalid payload")
@app.get("/server-error")
async def raise_server_error() -> None:
raise ConfigWriteError("write failed")
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
response = await client.get("/not-found")
assert response.status_code == 404
data = response.json()
assert data["code"] == "jail_not_found"
assert data["detail"] == "Jail not found: 'ssh'"
assert data["metadata"] == {"jail_name": "ssh"}
assert "correlation_id" in data
response = await client.get("/bad-request")
assert response.status_code == 400
data = response.json()
assert data["code"] == "config_validation_failed"
assert data["detail"] == "invalid payload"
assert "correlation_id" in data
response = await client.get("/server-error")
assert response.status_code == 500
data = response.json()
assert data["code"] == "config_write_failed"
assert data["detail"] == "write failed"
assert "correlation_id" in data
def test_create_app_disables_cors_by_default() -> None:
"""The FastAPI app does not add CORS middleware when no origins are configured by environment."""
settings = Settings(
database_path="/tmp/test.db",
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir="/tmp/fail2ban",
session_secret="test-secret-key-do-not-use-in-production",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
)
app = create_app(settings=settings)
cors_middleware = [
middleware for middleware in app.user_middleware if middleware.cls is CORSMiddleware
]
assert cors_middleware == []
def test_create_app_disables_api_docs_by_default() -> None:
"""API documentation endpoints are disabled when enable_docs is false."""
settings = Settings(
database_path="/tmp/test.db",
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir="/tmp/fail2ban",
session_secret="test-secret-key-do-not-use-in-production",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
enable_docs=False,
)
app = create_app(settings=settings)
assert app.docs_url is None
assert app.redoc_url is None
assert app.openapi_url is None
def test_create_app_enables_api_docs_when_configured() -> None:
"""API documentation endpoints are enabled at /api/* when enable_docs is true."""
settings = Settings(
database_path="/tmp/test.db",
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir="/tmp/fail2ban",
session_secret="test-secret-key-do-not-use-in-production",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
enable_docs=True,
)
app = create_app(settings=settings)
assert app.docs_url == "/api/docs"
assert app.redoc_url == "/api/redoc"
assert app.openapi_url == "/api/openapi.json"
async def test_lifespan_initialises_and_cleans_up_shared_resources(tmp_path: Path) -> None:
"""The app lifespan creates and shuts down shared resources cleanly."""
settings = Settings(
database_path=str(tmp_path / "bangui.db"),
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir=str(tmp_path / "fail2ban"),
session_secret="test-lifespan-secret",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
)
app = create_app(settings=settings)
mock_scheduler = MagicMock()
mock_scheduler.start = MagicMock()
mock_scheduler.shutdown = MagicMock()
mock_http_session = MagicMock()
mock_http_session.close = AsyncMock()
with (
patch("app.startup.ensure_jail_configs"),
patch("app.startup.aiohttp.ClientSession", return_value=mock_http_session),
patch("app.startup.AsyncIOScheduler", return_value=mock_scheduler),
patch("app.startup.init_db", new=AsyncMock()),
patch("app.services.geo_cache.GeoCache.init_geoip"),
patch("app.services.geo_cache.GeoCache.load_cache_from_db", new=AsyncMock(return_value=None)),
patch("app.services.geo_cache.GeoCache.count_unresolved", new=AsyncMock(return_value=0)),
patch("app.services.setup_service.is_setup_complete", new=AsyncMock(return_value=False)),
patch("app.tasks.health_check.register"),
patch("app.tasks.blocklist_import.register"),
patch("app.tasks.geo_cache_flush.register"),
patch("app.tasks.geo_re_resolve.register"),
patch("app.tasks.history_sync.register"),
):
async with _lifespan(app):
assert app.state.http_session is mock_http_session
assert app.state.scheduler is mock_scheduler
assert app.state.settings is settings
mock_http_session.close.assert_awaited_once()
mock_scheduler.shutdown.assert_called_once_with(wait=False)
async def test_lifespan_cleans_up_resources_when_startup_fails(tmp_path: Path) -> None:
"""The lifespan must close resources if shared startup registration fails."""
settings = Settings(
database_path=str(tmp_path / "bangui.db"),
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir=str(tmp_path / "fail2ban"),
session_secret="test-lifespan-secret",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
)
app = create_app(settings=settings)
mock_scheduler = MagicMock()
mock_scheduler.start = MagicMock()
mock_scheduler.shutdown = MagicMock()
mock_http_session = MagicMock()
mock_http_session.close = AsyncMock()
with pytest.raises(RuntimeError, match="startup failed"), \
patch("app.startup.ensure_jail_configs"), \
patch("app.startup.aiohttp.ClientSession", return_value=mock_http_session), \
patch("app.startup.AsyncIOScheduler", return_value=mock_scheduler), \
patch("app.startup.init_db", new=AsyncMock()), \
patch("app.services.geo_service.init_geoip"), \
patch("app.services.geo_service.load_cache_from_db", new=AsyncMock(return_value=None)), \
patch("app.services.geo_service.count_unresolved", new=AsyncMock(return_value=0)), \
patch("app.services.setup_service.is_setup_complete", new=AsyncMock(return_value=False)), \
patch("app.tasks.health_check.register", side_effect=RuntimeError("startup failed")), \
patch("app.tasks.blocklist_import.register"), \
patch("app.tasks.geo_cache_flush.register"), \
patch("app.tasks.geo_re_resolve.register"), \
patch("app.tasks.history_sync.register"):
async with _lifespan(app):
pass
mock_http_session.close.assert_awaited_once()
mock_scheduler.shutdown.assert_called_once_with(wait=False)
async def test_http_session_is_created_with_configured_timeouts_and_limits(tmp_path: Path) -> None:
"""The shared HTTP client session is created with the configured limits."""
settings = Settings(
database_path=str(tmp_path / "bangui.db"),
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir=str(tmp_path / "fail2ban"),
session_secret="test-lifespan-secret",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
http_request_timeout_seconds=12.5,
http_connect_timeout_seconds=1.5,
http_max_connections=5,
http_keepalive_timeout_seconds=8.0,
)
app = create_app(settings=settings)
mock_scheduler = MagicMock()
mock_scheduler.start = MagicMock()
mock_scheduler.shutdown = MagicMock()
mock_http_session = MagicMock()
mock_http_session.close = AsyncMock()
with (
patch("app.startup.ensure_jail_configs"),
patch("app.startup.aiohttp.ClientSession", return_value=mock_http_session) as mock_client_session,
patch("app.startup.AsyncIOScheduler", return_value=mock_scheduler),
patch("app.startup.init_db", new=AsyncMock()),
patch("app.services.geo_service.init_geoip"),
patch("app.services.geo_service.load_cache_from_db", new=AsyncMock(return_value=None)),
patch("app.services.geo_service.count_unresolved", new=AsyncMock(return_value=0)),
patch("app.services.setup_service.is_setup_complete", new=AsyncMock(return_value=False)),
patch("app.tasks.health_check.register"),
patch("app.tasks.blocklist_import.register"),
patch("app.tasks.geo_cache_flush.register"),
patch("app.tasks.geo_re_resolve.register"),
patch("app.tasks.history_sync.register"),
):
async with _lifespan(app):
assert mock_client_session.call_count == 1
kwargs = mock_client_session.call_args.kwargs
timeout = kwargs["timeout"]
connector = kwargs["connector"]
assert timeout.total == 12.5
assert timeout.connect == 1.5
assert timeout.sock_read == 12.5
assert connector.limit == 5
assert connector.limit_per_host == 5
async def test_startup_overrides_settings_from_persisted_setup(tmp_path: Path) -> None:
"""Startup should replace env defaults with values persisted by setup."""
env_settings = Settings(
database_path=str(tmp_path / "pointer.db"),
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir=str(tmp_path / "fail2ban"),
session_secret="test-startup-secret",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
)
app = create_app(settings=env_settings)
runtime_db_path = str(tmp_path / "runtime.db")
db = await aiosqlite.connect(env_settings.database_path)
db.row_factory = aiosqlite.Row
await init_db(db)
await setup_service.run_setup(
db,
master_password="Supersecret1!",
database_path=runtime_db_path,
fail2ban_socket="/tmp/persisted.sock",
timezone="Europe/Berlin",
session_duration_minutes=123,
)
await db.close()
mock_scheduler = MagicMock()
mock_scheduler.start = MagicMock()
mock_scheduler.shutdown = MagicMock()
mock_http_session = MagicMock()
mock_http_session.close = AsyncMock()
with (
patch("app.startup.ensure_jail_configs"),
patch("app.startup.aiohttp.ClientSession", return_value=mock_http_session),
patch("app.startup.AsyncIOScheduler", return_value=mock_scheduler),
patch("app.services.geo_service.init_geoip"),
patch("app.services.geo_service.load_cache_from_db", new=AsyncMock(return_value=None)),
patch("app.services.geo_service.count_unresolved", new=AsyncMock(return_value=0)),
patch("app.tasks.health_check.register"),
patch("app.tasks.blocklist_import.register"),
patch("app.tasks.geo_cache_flush.register"),
patch("app.tasks.geo_re_resolve.register"),
patch("app.tasks.history_sync.register"),
):
async with _lifespan(app):
assert app.state.runtime_settings is not None
assert app.state.runtime_settings.database_path == runtime_db_path
assert app.state.runtime_settings.fail2ban_socket == "/tmp/persisted.sock"
assert app.state.runtime_settings.timezone == "Europe/Berlin"
assert app.state.runtime_settings.session_duration_minutes == 123
assert app.state.settings.database_path == str(tmp_path / "pointer.db")
assert Path(runtime_db_path).exists()
async def test_startup_loads_geo_cache_from_persisted_runtime_database(tmp_path: Path) -> None:
"""Startup must load geo cache from the resolved runtime database."""
env_settings = Settings(
database_path=str(tmp_path / "pointer.db"),
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir=str(tmp_path / "fail2ban"),
session_secret="test-startup-secret",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
)
app = create_app(settings=env_settings)
runtime_db_path = str(tmp_path / "runtime.db")
opened_connections: list[tuple[str, aiosqlite.Connection]] = []
async def fake_open_db(path: str) -> aiosqlite.Connection:
connection = await aiosqlite.connect(path)
opened_connections.append((path, connection))
return connection
mock_scheduler = MagicMock()
mock_scheduler.start = MagicMock()
mock_scheduler.shutdown = MagicMock()
mock_http_session = MagicMock()
mock_http_session.close = AsyncMock()
load_cache = AsyncMock()
with (
patch("app.startup.ensure_jail_configs"),
patch("app.startup.aiohttp.ClientSession", return_value=mock_http_session),
patch("app.startup.AsyncIOScheduler", return_value=mock_scheduler),
patch("app.startup.open_db", new=AsyncMock(side_effect=fake_open_db)),
patch("app.startup.init_db", new=AsyncMock()),
patch("app.services.geo_service.init_geoip"),
patch("app.services.geo_service.load_cache_from_db", new=load_cache),
patch("app.services.geo_service.count_unresolved", new=AsyncMock(return_value=0)),
patch("app.services.setup_service.is_setup_complete", new=AsyncMock(return_value=True)),
patch("app.services.setup_service.get_runtime_database_path", new=AsyncMock(return_value=runtime_db_path)),
patch(
"app.services.setup_service.get_persisted_runtime_settings",
new=AsyncMock(
return_value={
"database_path": runtime_db_path,
"fail2ban_socket": "/tmp/persisted.sock",
"timezone": "Europe/Berlin",
"session_duration_minutes": 123,
}
),
),
patch("app.tasks.health_check.register"),
patch("app.tasks.blocklist_import.register"),
patch("app.tasks.geo_cache_flush.register"),
patch("app.tasks.geo_re_resolve.register"),
patch("app.tasks.history_sync.register"),
):
async with _lifespan(app):
loaded_db = load_cache.call_args.args[0]
runtime_connections = [conn for path, conn in opened_connections if path == runtime_db_path]
assert runtime_connections, "Expected runtime database to be opened"
assert loaded_db in runtime_connections
assert app.state.runtime_settings is not None
assert app.state.runtime_settings.database_path == runtime_db_path
for _, connection in opened_connections:
await connection.close()
async def test_concurrent_requests_use_request_scoped_db_connections(tmp_path: Path) -> None:
"""Concurrent requests each open and close their own database connection."""
settings = Settings(
database_path=str(tmp_path / "bangui.db"),
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir=str(tmp_path / "fail2ban"),
session_secret="test-concurrency-secret",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
)
app = create_app(settings=settings)
connections: list[MagicMock] = []
async def fake_open_db(database_path: str) -> MagicMock:
connection = MagicMock()
connection.close = AsyncMock()
connections.append(connection)
return connection
mock_scheduler = MagicMock()
mock_scheduler.start = MagicMock()
mock_scheduler.shutdown = MagicMock()
mock_http_session = MagicMock()
mock_http_session.close = AsyncMock()
with (
patch("app.startup.open_db", new=AsyncMock(side_effect=fake_open_db)),
patch("app.db.open_db", new=AsyncMock(side_effect=fake_open_db)),
patch("app.startup.init_db", new=AsyncMock()),
patch("app.startup.ensure_jail_configs"),
patch("app.startup.aiohttp.ClientSession", return_value=mock_http_session),
patch("app.startup.AsyncIOScheduler", return_value=mock_scheduler),
patch("app.services.geo_service.init_geoip"),
patch("app.services.geo_service.load_cache_from_db", new=AsyncMock(return_value=None)),
patch("app.services.geo_service.count_unresolved", new=AsyncMock(return_value=0)),
patch("app.services.setup_service.is_setup_complete", new=AsyncMock(return_value=False)),
patch("app.tasks.health_check.register"),
patch("app.tasks.blocklist_import.register"),
patch("app.tasks.geo_cache_flush.register"),
patch("app.tasks.geo_re_resolve.register"),
patch("app.tasks.history_sync.register"),
):
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
app.state.setup_complete_cached = True
responses = await asyncio.gather(*(client.post("/api/auth/logout") for _ in range(5)))
assert len(connections) == 5
assert len({id(connection) for connection in connections}) == 5
assert all(response.status_code == 200 for response in responses)
assert all(connection.close.await_count == 1 for connection in connections)