Files
BanGUI/backend/tests/test_main.py
Lukas 654dbdb000 T-04: Encapsulate geo_service module-level mutable state in GeoCache class
Create GeoCache class with all mutable state as instance attributes:
- _cache, _neg_cache, _dirty, _geoip_reader, _geoip_initialized, _cache_lock
- All public methods: lookup(), lookup_batch(), lookup_cached_only(), flush_dirty(), load_from_db(), clear(), etc.

Initialization & Dependency Injection:
- Instantiate GeoCache in startup.py and store on app.state.geo_cache
- Add get_geo_cache() dependency function in dependencies.py
- Inject into routes and tasks via FastAPI's dependency system

Backward Compatibility:
- Maintain module-level functions in geo_service.py as deprecated wrappers
- All old callers continue to work through _default_geo_cache instance
- Remove test-escape-hatch functions (clear_cache, clear_neg_cache moved to methods)

Background Tasks:
- Update geo_cache_flush.py and geo_re_resolve.py to receive GeoCache instance
- Tasks now operate on injected instance rather than module globals

Tests:
- Refactor test_geo_service.py with geo_cache fixture providing fresh instances
- Update patch paths to target GeoCache methods correctly
- Fix internal state assertions to access instance attributes

Documentation:
- Update Architekture.md to document GeoCache as managed stateful service
- Describe cache lifecycle (load on startup, flush periodically, re-resolve stale)
- Note process-local limitations for multi-worker deployments

Fixes violation of Single Responsibility Principle: module no longer owns both
lookup logic and cache lifecycle management. Cache is now a first-class
injectable service with transparent lifecycle.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-04-23 16:18:09 +02:00

463 lines
19 KiB
Python

"""Unit tests for backend application startup and middleware configuration."""
import asyncio
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import aiosqlite
import pytest
from httpx import ASGITransport, AsyncClient
from app.config import Settings
from app.db import init_db
from app.exceptions import ConfigValidationError, ConfigWriteError, JailNotFoundError
from app.main import CORSMiddleware, _lifespan, create_app
from app.services import setup_service
def test_create_app_configures_cors_from_settings() -> None:
"""The FastAPI app registers CORS middleware with the configured origins."""
settings = Settings(
database_path="/tmp/test.db",
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir="/tmp/fail2ban",
session_secret="test-secret-key-do-not-use-in-production",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
cors_allowed_origins=["https://frontend.example.com"],
)
app = create_app(settings=settings)
cors_middleware = [
middleware for middleware in app.user_middleware if middleware.cls is CORSMiddleware
]
assert len(cors_middleware) == 1
assert cors_middleware[0].kwargs["allow_origins"] == ["https://frontend.example.com"]
assert cors_middleware[0].kwargs["allow_credentials"] is True
assert cors_middleware[0].kwargs["allow_methods"] == ["*"]
assert cors_middleware[0].kwargs["allow_headers"] == ["*"]
def test_create_app_skips_cors_when_no_origins_are_configured() -> None:
"""The FastAPI app does not add CORS middleware when no origins are configured."""
settings = Settings(
database_path="/tmp/test.db",
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir="/tmp/fail2ban",
session_secret="test-secret-key-do-not-use-in-production",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
cors_allowed_origins=[],
)
app = create_app(settings=settings)
cors_middleware = [
middleware for middleware in app.user_middleware if middleware.cls is CORSMiddleware
]
assert cors_middleware == []
def test_create_app_initialises_runtime_state_manager() -> None:
"""The FastAPI app exposes a dedicated runtime state manager on app.state."""
settings = Settings(
database_path="/tmp/test.db",
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir="/tmp/fail2ban",
session_secret="test-secret-key-do-not-use-in-production",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
)
app = create_app(settings=settings)
runtime_state = app.state.runtime_state
assert runtime_state.setup_complete_cached is False
assert runtime_state.server_status.online is False
assert runtime_state.pending_recovery is None
assert runtime_state.last_activation is None
assert app.state.server_status.online is False
async def test_create_app_global_domain_exception_handlers() -> None:
"""Global exception handlers map domain exceptions to consistent HTTP responses."""
settings = Settings(
database_path="/tmp/test.db",
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir="/tmp/fail2ban",
session_secret="test-secret-key-do-not-use-in-production",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
)
app = create_app(settings=settings)
@app.get("/not-found")
async def raise_not_found() -> None:
raise JailNotFoundError("ssh")
@app.get("/bad-request")
async def raise_bad_request() -> None:
raise ConfigValidationError("invalid payload")
@app.get("/server-error")
async def raise_server_error() -> None:
raise ConfigWriteError("write failed")
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
response = await client.get("/not-found")
assert response.status_code == 404
assert response.json() == {"detail": "Jail not found: 'ssh'"}
response = await client.get("/bad-request")
assert response.status_code == 400
assert response.json() == {"detail": "invalid payload"}
response = await client.get("/server-error")
assert response.status_code == 500
assert response.json() == {"detail": "write failed"}
def test_create_app_disables_cors_by_default() -> None:
"""The FastAPI app does not add CORS middleware when no origins are configured by environment."""
settings = Settings(
database_path="/tmp/test.db",
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir="/tmp/fail2ban",
session_secret="test-secret-key-do-not-use-in-production",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
)
app = create_app(settings=settings)
cors_middleware = [
middleware for middleware in app.user_middleware if middleware.cls is CORSMiddleware
]
assert cors_middleware == []
async def test_lifespan_initialises_and_cleans_up_shared_resources(tmp_path: Path) -> None:
"""The app lifespan creates and shuts down shared resources cleanly."""
settings = Settings(
database_path=str(tmp_path / "bangui.db"),
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir=str(tmp_path / "fail2ban"),
session_secret="test-lifespan-secret",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
)
app = create_app(settings=settings)
mock_scheduler = MagicMock()
mock_scheduler.start = MagicMock()
mock_scheduler.shutdown = MagicMock()
mock_http_session = MagicMock()
mock_http_session.close = AsyncMock()
with (
patch("app.startup.ensure_jail_configs"),
patch("app.startup.aiohttp.ClientSession", return_value=mock_http_session),
patch("app.startup.AsyncIOScheduler", return_value=mock_scheduler),
patch("app.startup.init_db", new=AsyncMock()),
patch("app.services.geo_cache.GeoCache.init_geoip"),
patch("app.services.geo_cache.GeoCache.load_cache_from_db", new=AsyncMock(return_value=None)),
patch("app.services.geo_cache.GeoCache.count_unresolved", new=AsyncMock(return_value=0)),
patch("app.services.setup_service.is_setup_complete", new=AsyncMock(return_value=False)),
patch("app.tasks.health_check.register"),
patch("app.tasks.blocklist_import.register"),
patch("app.tasks.geo_cache_flush.register"),
patch("app.tasks.geo_re_resolve.register"),
patch("app.tasks.history_sync.register"),
):
async with _lifespan(app):
assert app.state.http_session is mock_http_session
assert app.state.scheduler is mock_scheduler
assert app.state.settings is settings
mock_http_session.close.assert_awaited_once()
mock_scheduler.shutdown.assert_called_once_with(wait=False)
async def test_lifespan_cleans_up_resources_when_startup_fails(tmp_path: Path) -> None:
"""The lifespan must close resources if shared startup registration fails."""
settings = Settings(
database_path=str(tmp_path / "bangui.db"),
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir=str(tmp_path / "fail2ban"),
session_secret="test-lifespan-secret",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
)
app = create_app(settings=settings)
mock_scheduler = MagicMock()
mock_scheduler.start = MagicMock()
mock_scheduler.shutdown = MagicMock()
mock_http_session = MagicMock()
mock_http_session.close = AsyncMock()
with pytest.raises(RuntimeError, match="startup failed"), \
patch("app.startup.ensure_jail_configs"), \
patch("app.startup.aiohttp.ClientSession", return_value=mock_http_session), \
patch("app.startup.AsyncIOScheduler", return_value=mock_scheduler), \
patch("app.startup.init_db", new=AsyncMock()), \
patch("app.services.geo_service.init_geoip"), \
patch("app.services.geo_service.load_cache_from_db", new=AsyncMock(return_value=None)), \
patch("app.services.geo_service.count_unresolved", new=AsyncMock(return_value=0)), \
patch("app.services.setup_service.is_setup_complete", new=AsyncMock(return_value=False)), \
patch("app.tasks.health_check.register", side_effect=RuntimeError("startup failed")), \
patch("app.tasks.blocklist_import.register"), \
patch("app.tasks.geo_cache_flush.register"), \
patch("app.tasks.geo_re_resolve.register"), \
patch("app.tasks.history_sync.register"):
async with _lifespan(app):
pass
mock_http_session.close.assert_awaited_once()
mock_scheduler.shutdown.assert_called_once_with(wait=False)
async def test_http_session_is_created_with_configured_timeouts_and_limits(tmp_path: Path) -> None:
"""The shared HTTP client session is created with the configured limits."""
settings = Settings(
database_path=str(tmp_path / "bangui.db"),
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir=str(tmp_path / "fail2ban"),
session_secret="test-lifespan-secret",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
http_request_timeout_seconds=12.5,
http_connect_timeout_seconds=1.5,
http_max_connections=5,
http_keepalive_timeout_seconds=8.0,
)
app = create_app(settings=settings)
mock_scheduler = MagicMock()
mock_scheduler.start = MagicMock()
mock_scheduler.shutdown = MagicMock()
mock_http_session = MagicMock()
mock_http_session.close = AsyncMock()
with (
patch("app.startup.ensure_jail_configs"),
patch("app.startup.aiohttp.ClientSession", return_value=mock_http_session) as mock_client_session,
patch("app.startup.AsyncIOScheduler", return_value=mock_scheduler),
patch("app.startup.init_db", new=AsyncMock()),
patch("app.services.geo_service.init_geoip"),
patch("app.services.geo_service.load_cache_from_db", new=AsyncMock(return_value=None)),
patch("app.services.geo_service.count_unresolved", new=AsyncMock(return_value=0)),
patch("app.services.setup_service.is_setup_complete", new=AsyncMock(return_value=False)),
patch("app.tasks.health_check.register"),
patch("app.tasks.blocklist_import.register"),
patch("app.tasks.geo_cache_flush.register"),
patch("app.tasks.geo_re_resolve.register"),
patch("app.tasks.history_sync.register"),
):
async with _lifespan(app):
assert mock_client_session.call_count == 1
kwargs = mock_client_session.call_args.kwargs
timeout = kwargs["timeout"]
connector = kwargs["connector"]
assert timeout.total == 12.5
assert timeout.connect == 1.5
assert timeout.sock_read == 12.5
assert connector.limit == 5
assert connector.limit_per_host == 5
async def test_startup_overrides_settings_from_persisted_setup(tmp_path: Path) -> None:
"""Startup should replace env defaults with values persisted by setup."""
env_settings = Settings(
database_path=str(tmp_path / "pointer.db"),
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir=str(tmp_path / "fail2ban"),
session_secret="test-startup-secret",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
)
app = create_app(settings=env_settings)
runtime_db_path = str(tmp_path / "runtime.db")
db = await aiosqlite.connect(env_settings.database_path)
db.row_factory = aiosqlite.Row
await init_db(db)
await setup_service.run_setup(
db,
master_password="Supersecret1!",
database_path=runtime_db_path,
fail2ban_socket="/tmp/persisted.sock",
timezone="Europe/Berlin",
session_duration_minutes=123,
)
await db.close()
mock_scheduler = MagicMock()
mock_scheduler.start = MagicMock()
mock_scheduler.shutdown = MagicMock()
mock_http_session = MagicMock()
mock_http_session.close = AsyncMock()
with (
patch("app.startup.ensure_jail_configs"),
patch("app.startup.aiohttp.ClientSession", return_value=mock_http_session),
patch("app.startup.AsyncIOScheduler", return_value=mock_scheduler),
patch("app.services.geo_service.init_geoip"),
patch("app.services.geo_service.load_cache_from_db", new=AsyncMock(return_value=None)),
patch("app.services.geo_service.count_unresolved", new=AsyncMock(return_value=0)),
patch("app.tasks.health_check.register"),
patch("app.tasks.blocklist_import.register"),
patch("app.tasks.geo_cache_flush.register"),
patch("app.tasks.geo_re_resolve.register"),
patch("app.tasks.history_sync.register"),
):
async with _lifespan(app):
assert app.state.runtime_settings is not None
assert app.state.runtime_settings.database_path == runtime_db_path
assert app.state.runtime_settings.fail2ban_socket == "/tmp/persisted.sock"
assert app.state.runtime_settings.timezone == "Europe/Berlin"
assert app.state.runtime_settings.session_duration_minutes == 123
assert app.state.settings.database_path == str(tmp_path / "pointer.db")
assert Path(runtime_db_path).exists()
async def test_startup_loads_geo_cache_from_persisted_runtime_database(tmp_path: Path) -> None:
"""Startup must load geo cache from the resolved runtime database."""
env_settings = Settings(
database_path=str(tmp_path / "pointer.db"),
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir=str(tmp_path / "fail2ban"),
session_secret="test-startup-secret",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
)
app = create_app(settings=env_settings)
runtime_db_path = str(tmp_path / "runtime.db")
opened_connections: list[tuple[str, aiosqlite.Connection]] = []
async def fake_open_db(path: str) -> aiosqlite.Connection:
connection = await aiosqlite.connect(path)
opened_connections.append((path, connection))
return connection
mock_scheduler = MagicMock()
mock_scheduler.start = MagicMock()
mock_scheduler.shutdown = MagicMock()
mock_http_session = MagicMock()
mock_http_session.close = AsyncMock()
load_cache = AsyncMock()
with (
patch("app.startup.ensure_jail_configs"),
patch("app.startup.aiohttp.ClientSession", return_value=mock_http_session),
patch("app.startup.AsyncIOScheduler", return_value=mock_scheduler),
patch("app.startup.open_db", new=AsyncMock(side_effect=fake_open_db)),
patch("app.startup.init_db", new=AsyncMock()),
patch("app.services.geo_service.init_geoip"),
patch("app.services.geo_service.load_cache_from_db", new=load_cache),
patch("app.services.geo_service.count_unresolved", new=AsyncMock(return_value=0)),
patch("app.services.setup_service.is_setup_complete", new=AsyncMock(return_value=True)),
patch("app.services.setup_service.get_runtime_database_path", new=AsyncMock(return_value=runtime_db_path)),
patch(
"app.services.setup_service.get_persisted_runtime_settings",
new=AsyncMock(
return_value={
"database_path": runtime_db_path,
"fail2ban_socket": "/tmp/persisted.sock",
"timezone": "Europe/Berlin",
"session_duration_minutes": 123,
}
),
),
patch("app.tasks.health_check.register"),
patch("app.tasks.blocklist_import.register"),
patch("app.tasks.geo_cache_flush.register"),
patch("app.tasks.geo_re_resolve.register"),
patch("app.tasks.history_sync.register"),
):
async with _lifespan(app):
loaded_db = load_cache.call_args.args[0]
runtime_connections = [conn for path, conn in opened_connections if path == runtime_db_path]
assert runtime_connections, "Expected runtime database to be opened"
assert loaded_db in runtime_connections
assert app.state.runtime_settings is not None
assert app.state.runtime_settings.database_path == runtime_db_path
for _, connection in opened_connections:
await connection.close()
async def test_concurrent_requests_use_request_scoped_db_connections(tmp_path: Path) -> None:
"""Concurrent requests each open and close their own database connection."""
settings = Settings(
database_path=str(tmp_path / "bangui.db"),
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir=str(tmp_path / "fail2ban"),
session_secret="test-concurrency-secret",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
)
app = create_app(settings=settings)
connections: list[MagicMock] = []
async def fake_open_db(database_path: str) -> MagicMock:
connection = MagicMock()
connection.close = AsyncMock()
connections.append(connection)
return connection
mock_scheduler = MagicMock()
mock_scheduler.start = MagicMock()
mock_scheduler.shutdown = MagicMock()
mock_http_session = MagicMock()
mock_http_session.close = AsyncMock()
with (
patch("app.startup.open_db", new=AsyncMock(side_effect=fake_open_db)),
patch("app.db.open_db", new=AsyncMock(side_effect=fake_open_db)),
patch("app.startup.init_db", new=AsyncMock()),
patch("app.startup.ensure_jail_configs"),
patch("app.startup.aiohttp.ClientSession", return_value=mock_http_session),
patch("app.startup.AsyncIOScheduler", return_value=mock_scheduler),
patch("app.services.geo_service.init_geoip"),
patch("app.services.geo_service.load_cache_from_db", new=AsyncMock(return_value=None)),
patch("app.services.geo_service.count_unresolved", new=AsyncMock(return_value=0)),
patch("app.services.setup_service.is_setup_complete", new=AsyncMock(return_value=False)),
patch("app.tasks.health_check.register"),
patch("app.tasks.blocklist_import.register"),
patch("app.tasks.geo_cache_flush.register"),
patch("app.tasks.geo_re_resolve.register"),
patch("app.tasks.history_sync.register"),
):
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
app.state.setup_complete_cached = True
responses = await asyncio.gather(*(client.post("/api/auth/logout") for _ in range(5)))
assert len(connections) == 5
assert len({id(connection) for connection in connections}) == 5
assert all(response.status_code == 200 for response in responses)
assert all(connection.close.await_count == 1 for connection in connections)