Compare commits
15 Commits
v0.9.19-rc
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 3af8f0571b | |||
| d5a78a251a | |||
| 904db63fa2 | |||
| d737a1c319 | |||
| 9e765c6cb7 | |||
| ecb8542496 | |||
| 97f4df4a61 | |||
| 44542b93c0 | |||
| 01a4215f60 | |||
| bc49b7cd5b | |||
| fa4fe4bbdf | |||
| ee0fe9c695 | |||
| 551db0bb9c | |||
| 4a649e7347 | |||
| 025c82a982 |
@@ -48,6 +48,7 @@ services:
|
||||
target: runtime
|
||||
container_name: bangui-backend
|
||||
restart: unless-stopped
|
||||
stop_grace_period: 30s # Give lifespan 30s to complete before SIGKILL
|
||||
depends_on:
|
||||
fail2ban:
|
||||
condition: service_healthy
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
## Task: Investigate Orphaned SQLite Shared Memory Files on Startup
|
||||
|
||||
### Issue in Detail
|
||||
|
||||
The log shows repeated warnings:
|
||||
```
|
||||
event=orphaned_sqlite_file_removed path=/data/bangui.db-shm
|
||||
```
|
||||
|
||||
This occurs at `19:39:48` and again at `19:49:39` (after restart). The `-shm` file is SQLite's shared memory file for WAL mode. Its presence indicates **unclean shutdowns** (crashes or SIGKILL instead of graceful SIGTERM).
|
||||
|
||||
### Why This Happens
|
||||
|
||||
1. **Docker stop timeout:** Docker sends SIGTERM, waits `stop_grace_period` (default 10s), then sends SIGKILL. The backend allows 25s for graceful shutdown, but if the container's `stop_grace_period` is shorter, the process is killed before cleanup completes.
|
||||
2. **Missing connection close:** If the application crashes or is killed, SQLite connections are not closed, leaving `.wal` and `.shm` files behind.
|
||||
3. **`_cleanup_wal_files()` is a workaround, not a fix:** It removes stale files on the *next* startup, but the underlying cause (unclean shutdown) remains.
|
||||
|
||||
### How to Fix It
|
||||
|
||||
1. **Verify Docker Compose `stop_grace_period`:** In `Docker/compose.prod.yml`, ensure the backend service has `stop_grace_period: 30s` (matching the 25s internal timeout + margin).
|
||||
2. **Improve shutdown logging:** Add explicit logs when the database connection is closed during lifespan shutdown.
|
||||
3. **Consider `PRAGMA journal_mode = DELETE` for single-process setups:** WAL mode is beneficial for concurrent readers, but if BanGUI runs with a single worker and single process, DELETE mode eliminates `.wal`/`.shm` files entirely. Evaluate the tradeoff.
|
||||
|
||||
### Issues and Trapfalls
|
||||
|
||||
1. **WAL mode is required for concurrent reads:** If you switch to DELETE mode, readers block writers. This may degrade API performance under load.
|
||||
2. **The `_cleanup_wal_files()` 10-second threshold:** Files modified within 10 seconds are skipped. If the container restarts rapidly (e.g., health check failure → restart), the files may not be cleaned up.
|
||||
|
||||
### Documentation References
|
||||
|
||||
- **`Docs/Deployment.md`:** Docker deployment configuration and graceful shutdown behavior.
|
||||
- **`Docs/Architekture.md`:** Deployment constraints and process-local state.
|
||||
|
||||
### Tests to Write
|
||||
|
||||
#### 1. `test_cleanup_wal_files_removes_stale_files`
|
||||
- **Setup:** Create fake `.wal` and `.shm` files with mtime > 10s ago.
|
||||
- **Action:** Call `_cleanup_wal_files()`.
|
||||
- **Assert:** Files are removed.
|
||||
|
||||
#### 2. `test_cleanup_wal_files_skips_recent_files`
|
||||
- **Setup:** Create fake `.wal` and `.shm` files with mtime < 10s ago.
|
||||
- **Action:** Call `_cleanup_wal_files()`.
|
||||
- **Assert:** Files are NOT removed.
|
||||
@@ -274,7 +274,18 @@ CREATE INDEX IF NOT EXISTS idx_import_log_source_id_desc
|
||||
|
||||
|
||||
async def _configure_connection(db: aiosqlite.Connection) -> None:
|
||||
"""Apply hardening pragmas to a newly-opened SQLite connection."""
|
||||
"""Apply hardening pragmas to a newly-opened SQLite connection.
|
||||
|
||||
WAL mode is intentionally kept despite the risk of orphaned ``.wal``/``.shm``
|
||||
files after unclean shutdowns. The benefits for concurrent readers
|
||||
(readers do not block writers) outweigh the cleanup overhead, especially
|
||||
under load. BanGUI runs as a single worker, but multiple concurrent HTTP
|
||||
requests can still issue overlapping reads; DELETE mode would serialize
|
||||
those reads behind any write, degrading API performance.
|
||||
|
||||
Orphaned files are handled by :func:`_cleanup_wal_files`, which is called
|
||||
during startup before the database is opened.
|
||||
"""
|
||||
await db.execute("PRAGMA journal_mode=WAL;")
|
||||
await db.execute("PRAGMA foreign_keys=ON;")
|
||||
await db.execute("PRAGMA busy_timeout=5000;")
|
||||
@@ -475,14 +486,75 @@ async def init_db(db: aiosqlite.Connection) -> None:
|
||||
async def open_db(database_path: str) -> aiosqlite.Connection:
|
||||
"""Open a new application SQLite connection with the standard settings.
|
||||
|
||||
Creates the parent directory if it does not exist.
|
||||
|
||||
Args:
|
||||
database_path: Path to the BanGUI SQLite database.
|
||||
|
||||
Returns:
|
||||
A configured :class:`aiosqlite.Connection` instance.
|
||||
|
||||
Raises:
|
||||
DatabasePathInvalidError: If the directory cannot be created or is inaccessible.
|
||||
DatabasePermissionDeniedError: If aiosqlite.connect raises PermissionError.
|
||||
DatabaseCorruptedError: If the database file is corrupted.
|
||||
DatabaseUnavailableError: For any other unexpected error.
|
||||
"""
|
||||
await _cleanup_wal_files(database_path)
|
||||
db = await aiosqlite.connect(database_path)
|
||||
from app.exceptions import (
|
||||
DatabaseCorruptedError,
|
||||
DatabasePathInvalidError,
|
||||
DatabasePermissionDeniedError,
|
||||
DatabaseUnavailableError,
|
||||
)
|
||||
|
||||
db_dir = Path(database_path).parent
|
||||
if not db_dir.exists():
|
||||
try:
|
||||
db_dir.mkdir(parents=True, exist_ok=True)
|
||||
except PermissionError as exc:
|
||||
log.error("database_open_failed", error=str(exc), database_path=database_path)
|
||||
raise DatabasePathInvalidError(database_path) from exc
|
||||
except OSError as exc:
|
||||
log.error("database_open_failed", error=str(exc), database_path=database_path)
|
||||
raise DatabaseUnavailableError(database_path, str(exc)) from exc
|
||||
|
||||
try:
|
||||
db = await aiosqlite.connect(database_path)
|
||||
except PermissionError as exc:
|
||||
log.error("database_open_failed", error=str(exc), database_path=database_path)
|
||||
raise DatabasePermissionDeniedError(database_path) from exc
|
||||
except aiosqlite.OperationalError as exc:
|
||||
error_msg = str(exc).lower()
|
||||
sqlite_code = getattr(exc, "sqlite_errorcode", None)
|
||||
log.error(
|
||||
"database_open_failed",
|
||||
error=str(exc),
|
||||
sqlite_errorcode=sqlite_code,
|
||||
database_path=database_path,
|
||||
)
|
||||
if "database is locked" in error_msg or "busy" in error_msg:
|
||||
raise DatabaseUnavailableError(database_path, str(exc)) from exc
|
||||
if "unable to open database file" in error_msg:
|
||||
raise DatabasePathInvalidError(database_path) from exc
|
||||
raise DatabaseUnavailableError(database_path, str(exc)) from exc
|
||||
except aiosqlite.DatabaseError as exc:
|
||||
log.error(
|
||||
"database_open_failed",
|
||||
error=str(exc),
|
||||
database_path=database_path,
|
||||
)
|
||||
raise DatabaseCorruptedError(database_path) from exc
|
||||
except OSError as exc:
|
||||
log.error("database_open_failed", error=str(exc), database_path=database_path)
|
||||
raise DatabaseUnavailableError(database_path, str(exc)) from exc
|
||||
except Exception as exc:
|
||||
log.error("database_open_failed", error=str(exc), database_path=database_path)
|
||||
raise DatabaseUnavailableError(database_path, str(exc)) from exc
|
||||
|
||||
db.row_factory = aiosqlite.Row
|
||||
await _configure_connection(db)
|
||||
try:
|
||||
await _configure_connection(db)
|
||||
except Exception:
|
||||
await db.close()
|
||||
raise
|
||||
return db
|
||||
|
||||
@@ -165,22 +165,61 @@ async def get_db(
|
||||
|
||||
Yields:
|
||||
An open :class:`aiosqlite.Connection` for the request.
|
||||
|
||||
Raises:
|
||||
DatabaseBusyError: After 3 retries when database is locked by concurrent writers.
|
||||
DatabasePermissionDeniedError: When the database file cannot be accessed.
|
||||
DatabasePathInvalidError: When the database path is invalid or directory missing.
|
||||
DatabaseCorruptedError: When the database file is corrupted.
|
||||
DatabaseUnavailableError: For any other unexpected database error.
|
||||
"""
|
||||
from app.db import open_db # noqa: PLC0415
|
||||
from app.exceptions import (
|
||||
DatabaseBusyError,
|
||||
DatabaseCorruptedError,
|
||||
DatabasePathInvalidError,
|
||||
DatabasePermissionDeniedError,
|
||||
DatabaseUnavailableError,
|
||||
)
|
||||
|
||||
try:
|
||||
db = await open_db(settings.database_path)
|
||||
except Exception as exc:
|
||||
log.error("database_open_failed", error=str(exc))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Database is not available.",
|
||||
) from exc
|
||||
db = None
|
||||
retries = 3
|
||||
retry_delay = 0.1
|
||||
last_exc = None
|
||||
|
||||
for attempt in range(1, retries + 1):
|
||||
try:
|
||||
db = await open_db(settings.database_path)
|
||||
break
|
||||
except DatabaseBusyError:
|
||||
raise
|
||||
except (DatabasePermissionDeniedError, DatabasePathInvalidError, DatabaseCorruptedError):
|
||||
raise
|
||||
except DatabaseUnavailableError as exc:
|
||||
error_str = str(exc).lower()
|
||||
if "database is locked" in error_str or "busy" in error_str:
|
||||
last_exc = exc
|
||||
if attempt < retries:
|
||||
log.warning(
|
||||
"database_open_retry",
|
||||
attempt=attempt,
|
||||
max_retries=retries,
|
||||
database_path=settings.database_path,
|
||||
)
|
||||
import asyncio
|
||||
await asyncio.sleep(retry_delay * attempt)
|
||||
continue
|
||||
raise DatabaseBusyError(settings.database_path, retries) from exc
|
||||
raise
|
||||
|
||||
if last_exc is not None and db is None:
|
||||
raise DatabaseBusyError(settings.database_path, retries)
|
||||
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
await db.close()
|
||||
if db is not None:
|
||||
await db.close()
|
||||
|
||||
|
||||
async def get_http_session(
|
||||
|
||||
@@ -473,6 +473,75 @@ class SetupAlreadyCompleteError(ConflictError):
|
||||
super().__init__("Setup has already been completed.")
|
||||
|
||||
|
||||
class DatabaseBusyError(ServiceUnavailableError):
|
||||
"""Raised when the SQLite database is locked or busy after all retries."""
|
||||
|
||||
error_code: str = "database_busy"
|
||||
|
||||
def __init__(self, database_path: str, retries: int) -> None:
|
||||
self.database_path = database_path
|
||||
self.retries = retries
|
||||
super().__init__(
|
||||
f"Database is temporarily busy after {retries} retries."
|
||||
)
|
||||
|
||||
def get_error_metadata(self) -> ErrorMetadata:
|
||||
return {"database_path": self.database_path, "retries": self.retries}
|
||||
|
||||
|
||||
class DatabasePermissionDeniedError(ServiceUnavailableError):
|
||||
"""Raised when the database file cannot be accessed due to insufficient permissions."""
|
||||
|
||||
error_code: str = "database_permission_denied"
|
||||
|
||||
def __init__(self, database_path: str) -> None:
|
||||
self.database_path = database_path
|
||||
super().__init__("Insufficient permissions to access the database file.")
|
||||
|
||||
def get_error_metadata(self) -> ErrorMetadata:
|
||||
return {"database_path": self.database_path}
|
||||
|
||||
|
||||
class DatabasePathInvalidError(ServiceUnavailableError):
|
||||
"""Raised when the database directory does not exist or the path is invalid."""
|
||||
|
||||
error_code: str = "database_path_invalid"
|
||||
|
||||
def __init__(self, database_path: str) -> None:
|
||||
self.database_path = database_path
|
||||
super().__init__("Database directory does not exist or path is invalid.")
|
||||
|
||||
def get_error_metadata(self) -> ErrorMetadata:
|
||||
return {"database_path": self.database_path}
|
||||
|
||||
|
||||
class DatabaseCorruptedError(ServiceUnavailableError):
|
||||
"""Raised when the database file is corrupted."""
|
||||
|
||||
error_code: str = "database_corrupted"
|
||||
|
||||
def __init__(self, database_path: str) -> None:
|
||||
self.database_path = database_path
|
||||
super().__init__("Database file is corrupted.")
|
||||
|
||||
def get_error_metadata(self) -> ErrorMetadata:
|
||||
return {"database_path": self.database_path}
|
||||
|
||||
|
||||
class DatabaseUnavailableError(ServiceUnavailableError):
|
||||
"""Raised for any other unexpected database error."""
|
||||
|
||||
error_code: str = "database_unavailable"
|
||||
|
||||
def __init__(self, database_path: str, error: str) -> None:
|
||||
self.database_path = database_path
|
||||
self.error = error
|
||||
super().__init__(f"Database is not available: {error}")
|
||||
|
||||
def get_error_metadata(self) -> ErrorMetadata:
|
||||
return {"database_path": self.database_path, "error": self.error}
|
||||
|
||||
|
||||
class BlocklistSourceNotFoundError(NotFoundError):
|
||||
"""Raised when a blocklist source is not found."""
|
||||
|
||||
|
||||
@@ -318,7 +318,12 @@ async def _lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
log.error("scheduler_lock_release_failed", error=str(e))
|
||||
|
||||
# 6. Close the database connection.
|
||||
await startup_db.close()
|
||||
try:
|
||||
await startup_db.close()
|
||||
log.debug("database_connection_closed")
|
||||
except Exception as exc:
|
||||
log.error("database_connection_close_failed", error=str(exc))
|
||||
|
||||
log.info("bangui_shut_down")
|
||||
|
||||
|
||||
|
||||
@@ -26,10 +26,9 @@ from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import aiohttp
|
||||
from app.utils.logging_compat import get_logger
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler # type: ignore[import-untyped]
|
||||
|
||||
from app.db import init_db, open_db
|
||||
from app.db import _cleanup_wal_files, init_db, open_db
|
||||
from app.services import setup_service
|
||||
from app.services.dns_validated_connector import create_dns_validated_socket_factory
|
||||
from app.services.geo_cache import GeoCache
|
||||
@@ -48,6 +47,7 @@ from app.tasks import (
|
||||
from app.utils.async_utils import run_blocking
|
||||
from app.utils.fail2ban_db_utils import ensure_fail2ban_indexes
|
||||
from app.utils.jail_config import ensure_jail_configs
|
||||
from app.utils.logging_compat import get_logger
|
||||
from app.utils.runtime_state import set_runtime_settings
|
||||
from app.utils.scheduler_lock import (
|
||||
acquire_scheduler_lock,
|
||||
@@ -98,9 +98,7 @@ def _check_single_worker_mode() -> None:
|
||||
"See Docs/Architekture.md § Deployment Constraints for details."
|
||||
)
|
||||
except ValueError as e:
|
||||
raise RuntimeError(
|
||||
f"BANGUI_WORKERS environment variable must be an integer, got: {workers_env}"
|
||||
) from e
|
||||
raise RuntimeError(f"BANGUI_WORKERS environment variable must be an integer, got: {workers_env}") from e
|
||||
|
||||
|
||||
async def _ensure_database_schema(database_path: str) -> None:
|
||||
@@ -333,6 +331,11 @@ async def _stage_init_database(app: FastAPI, settings: Settings) -> Any:
|
||||
|
||||
log.debug("database_directory_ensured", directory=str(db_path.parent))
|
||||
|
||||
# Clean up orphaned WAL files from previous unclean shutdowns before
|
||||
# opening the database. This prevents stale .wal/.shm files from
|
||||
# interfering with startup or triggering misleading warnings.
|
||||
await _cleanup_wal_files(settings.database_path)
|
||||
|
||||
original_db_path = db_path.resolve()
|
||||
startup_db = await open_db(settings.database_path)
|
||||
|
||||
@@ -357,9 +360,7 @@ async def _stage_init_database(app: FastAPI, settings: Settings) -> Any:
|
||||
if f2b_db_path:
|
||||
await run_blocking(ensure_fail2ban_indexes, f2b_db_path)
|
||||
|
||||
persisted_runtime_settings = (
|
||||
await setup_service.get_persisted_runtime_settings(runtime_db)
|
||||
)
|
||||
persisted_runtime_settings = await setup_service.get_persisted_runtime_settings(runtime_db)
|
||||
finally:
|
||||
await runtime_db.close()
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "bangui-backend"
|
||||
version = "0.9.19-rc.4"
|
||||
version = "0.9.19-rc.5"
|
||||
description = "BanGUI backend — fail2ban web management interface"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
|
||||
@@ -252,6 +252,30 @@ async def test_cleanup_wal_files_removes_orphaned_files(tmp_path: Path) -> None:
|
||||
assert not shm_path.exists()
|
||||
|
||||
|
||||
async def test_cleanup_wal_files_skips_recent_files(tmp_path: Path) -> None:
|
||||
"""Test that _cleanup_wal_files skips files modified within 10 seconds."""
|
||||
db_path = str(tmp_path / "test_wal_recent.db")
|
||||
wal_path = Path(db_path + "-wal")
|
||||
shm_path = Path(db_path + "-shm")
|
||||
|
||||
# Create files with recent mtime
|
||||
wal_path.write_text("recent")
|
||||
shm_path.write_text("recent")
|
||||
recent_mtime = time.time() - 5
|
||||
os.utime(wal_path, (recent_mtime, recent_mtime))
|
||||
os.utime(shm_path, (recent_mtime, recent_mtime))
|
||||
|
||||
assert wal_path.exists()
|
||||
assert shm_path.exists()
|
||||
|
||||
# Run cleanup
|
||||
await _cleanup_wal_files(db_path)
|
||||
|
||||
# Files should NOT be removed (recent)
|
||||
assert wal_path.exists()
|
||||
assert shm_path.exists()
|
||||
|
||||
|
||||
async def test_cleanup_wal_files_handles_missing_files(tmp_path: Path) -> None:
|
||||
"""Test that _cleanup_wal_files handles non-existent files gracefully."""
|
||||
db_path = str(tmp_path / "nonexistent.db")
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import aiohttp
|
||||
import aiosqlite
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from starlette.requests import Request
|
||||
@@ -19,6 +20,13 @@ from app.dependencies import (
|
||||
get_settings,
|
||||
get_settings_repo,
|
||||
)
|
||||
from app.exceptions import (
|
||||
DatabaseBusyError,
|
||||
DatabaseCorruptedError,
|
||||
DatabasePathInvalidError,
|
||||
DatabasePermissionDeniedError,
|
||||
DatabaseUnavailableError,
|
||||
)
|
||||
from app.main import create_app
|
||||
from app.models.server import ServerStatus
|
||||
|
||||
@@ -98,3 +106,184 @@ async def test_get_db_uses_effective_runtime_database_path(test_settings: Settin
|
||||
await gen.aclose()
|
||||
|
||||
mock_open_db.assert_awaited_once_with("/tmp/runtime.db")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Database error handling tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_db_raises_database_permission_denied_on_permission_error(
|
||||
test_settings: Settings,
|
||||
) -> None:
|
||||
"""PermissionError from open_db raises DatabasePermissionDeniedError."""
|
||||
with patch(
|
||||
"app.db.open_db",
|
||||
new=AsyncMock(side_effect=DatabasePermissionDeniedError(test_settings.database_path)),
|
||||
):
|
||||
gen = get_db(settings=test_settings)
|
||||
with pytest.raises(DatabasePermissionDeniedError) as exc_info:
|
||||
await gen.__anext__()
|
||||
await gen.aclose()
|
||||
|
||||
assert exc_info.value.error_code == "database_permission_denied"
|
||||
assert exc_info.value.database_path == test_settings.database_path
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_db_raises_database_path_invalid_on_missing_directory(
|
||||
test_settings: Settings,
|
||||
) -> None:
|
||||
"""sqlite3.OperationalError('unable to open database file') raises DatabasePathInvalidError."""
|
||||
with patch(
|
||||
"app.db.open_db",
|
||||
new=AsyncMock(side_effect=DatabasePathInvalidError(test_settings.database_path)),
|
||||
):
|
||||
gen = get_db(settings=test_settings)
|
||||
with pytest.raises(DatabasePathInvalidError) as exc_info:
|
||||
await gen.__anext__()
|
||||
await gen.aclose()
|
||||
|
||||
assert exc_info.value.error_code == "database_path_invalid"
|
||||
assert exc_info.value.database_path == test_settings.database_path
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_db_retries_on_database_locked(test_settings: Settings) -> None:
|
||||
"""get_db retries up to 3 times when database is locked."""
|
||||
mock_connection = MagicMock()
|
||||
mock_connection.close = AsyncMock()
|
||||
|
||||
locked_err = DatabaseUnavailableError(
|
||||
test_settings.database_path, "database is locked"
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.db.open_db",
|
||||
new=AsyncMock(side_effect=[locked_err, locked_err, mock_connection]),
|
||||
) as mock_open:
|
||||
gen = get_db(settings=test_settings)
|
||||
with patch("asyncio.sleep", new=AsyncMock()) as mock_sleep:
|
||||
connection = await gen.__anext__()
|
||||
await gen.aclose()
|
||||
|
||||
assert mock_open.call_count == 3
|
||||
assert connection is mock_connection
|
||||
assert mock_sleep.call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_db_fails_after_max_retries_on_database_locked(
|
||||
test_settings: Settings,
|
||||
) -> None:
|
||||
"""After 3 retries on database locked, raises DatabaseBusyError."""
|
||||
locked_err = DatabaseUnavailableError(
|
||||
test_settings.database_path, "database is locked"
|
||||
)
|
||||
|
||||
with patch("app.db.open_db", new=AsyncMock(side_effect=locked_err)) as mock_open:
|
||||
gen = get_db(settings=test_settings)
|
||||
with patch("asyncio.sleep", new=AsyncMock()):
|
||||
with pytest.raises(DatabaseBusyError) as exc_info:
|
||||
await gen.__anext__()
|
||||
await gen.aclose()
|
||||
|
||||
assert mock_open.call_count == 3
|
||||
assert exc_info.value.error_code == "database_busy"
|
||||
assert exc_info.value.retries == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_db_raises_database_corrupted_on_malformed_db(
|
||||
test_settings: Settings,
|
||||
) -> None:
|
||||
"""sqlite3.DatabaseError('database disk image is malformed') raises DatabaseCorruptedError."""
|
||||
with patch(
|
||||
"app.db.open_db",
|
||||
new=AsyncMock(side_effect=DatabaseCorruptedError(test_settings.database_path)),
|
||||
):
|
||||
gen = get_db(settings=test_settings)
|
||||
with pytest.raises(DatabaseCorruptedError) as exc_info:
|
||||
await gen.__anext__()
|
||||
await gen.aclose()
|
||||
|
||||
assert exc_info.value.error_code == "database_corrupted"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_open_db_creates_parent_directory_if_missing(tmp_path: pytest.Path) -> None:
|
||||
"""open_db creates the parent directory when it does not exist."""
|
||||
from pathlib import Path
|
||||
|
||||
from app.db import open_db
|
||||
|
||||
db_path = str(Path(str(tmp_path)) / "subdir" / "deeper" / "bangui.db")
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.close = AsyncMock()
|
||||
mock_conn.execute = AsyncMock()
|
||||
mock_conn.commit = AsyncMock()
|
||||
|
||||
with patch("aiosqlite.connect", new=AsyncMock(return_value=mock_conn)), \
|
||||
patch("app.db._configure_connection", new=AsyncMock()):
|
||||
connection = await open_db(db_path)
|
||||
|
||||
assert connection is mock_conn
|
||||
assert Path(db_path).parent.exists()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_open_db_logs_specific_sqlite_error_code() -> None:
|
||||
"""open_db logs the SQLite error code when available."""
|
||||
from app.db import open_db
|
||||
|
||||
exc = aiosqlite.OperationalError("database is locked")
|
||||
exc.sqlite_errorcode = 5 # SQLITE_BUSY
|
||||
|
||||
with patch("aiosqlite.connect", new=AsyncMock(side_effect=exc)), \
|
||||
pytest.raises(DatabaseUnavailableError):
|
||||
await open_db("/tmp/test.db")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Error metadata tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_database_busy_error_metadata() -> None:
|
||||
"""DatabaseBusyError returns correct metadata."""
|
||||
err = DatabaseBusyError("/data/bangui.db", retries=3)
|
||||
assert err.error_code == "database_busy"
|
||||
metadata = err.get_error_metadata()
|
||||
assert metadata["database_path"] == "/data/bangui.db"
|
||||
assert metadata["retries"] == 3
|
||||
|
||||
|
||||
def test_database_permission_denied_error_metadata() -> None:
|
||||
"""DatabasePermissionDeniedError returns correct metadata."""
|
||||
err = DatabasePermissionDeniedError("/data/bangui.db")
|
||||
assert err.error_code == "database_permission_denied"
|
||||
assert err.get_error_metadata()["database_path"] == "/data/bangui.db"
|
||||
|
||||
|
||||
def test_database_path_invalid_error_metadata() -> None:
|
||||
"""DatabasePathInvalidError returns correct metadata."""
|
||||
err = DatabasePathInvalidError("/data/bangui.db")
|
||||
assert err.error_code == "database_path_invalid"
|
||||
assert err.get_error_metadata()["database_path"] == "/data/bangui.db"
|
||||
|
||||
|
||||
def test_database_corrupted_error_metadata() -> None:
|
||||
"""DatabaseCorruptedError returns correct metadata."""
|
||||
err = DatabaseCorruptedError("/data/bangui.db")
|
||||
assert err.error_code == "database_corrupted"
|
||||
assert err.get_error_metadata()["database_path"] == "/data/bangui.db"
|
||||
|
||||
|
||||
def test_database_unavailable_error_metadata() -> None:
|
||||
"""DatabaseUnavailableError returns correct metadata."""
|
||||
err = DatabaseUnavailableError("/data/bangui.db", "some error")
|
||||
assert err.error_code == "database_unavailable"
|
||||
metadata = err.get_error_metadata()
|
||||
assert metadata["database_path"] == "/data/bangui.db"
|
||||
assert metadata["error"] == "some error"
|
||||
|
||||
@@ -2,6 +2,9 @@
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
@@ -22,6 +25,7 @@ from app.main import (
|
||||
from app.middleware.correlation import CorrelationIdMiddleware
|
||||
from app.middleware.rate_limit import RateLimitMiddleware
|
||||
from app.services import setup_service
|
||||
from app.utils.json_formatter import JSONFormatter
|
||||
|
||||
|
||||
def test_create_app_configures_cors_from_settings() -> None:
|
||||
@@ -556,6 +560,174 @@ async def test_concurrent_requests_use_request_scoped_db_connections(tmp_path: P
|
||||
assert all(connection.close.await_count == 1 for connection in connections)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Logging configuration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_logging_configuration_no_duplicate_handlers(tmp_path: Path) -> None:
|
||||
"""Calling create_app() twice leaves no more than one custom StreamHandler on root."""
|
||||
fail2ban_config_dir = tmp_path / "fail2ban"
|
||||
fail2ban_config_dir.mkdir()
|
||||
|
||||
settings1 = Settings(
|
||||
database_path=str(tmp_path / "test1.db"),
|
||||
fail2ban_socket="/tmp/fake_fail2ban.sock",
|
||||
fail2ban_config_dir=str(fail2ban_config_dir),
|
||||
session_secret="test-secret-key-do-not-use-in-production",
|
||||
session_duration_minutes=60,
|
||||
timezone="UTC",
|
||||
log_level="debug",
|
||||
)
|
||||
|
||||
create_app(settings=settings1)
|
||||
|
||||
settings2 = Settings(
|
||||
database_path=str(tmp_path / "test2.db"),
|
||||
fail2ban_socket="/tmp/fake_fail2ban.sock",
|
||||
fail2ban_config_dir=str(fail2ban_config_dir),
|
||||
session_secret="test-secret-key-do-not-use-in-production-2",
|
||||
session_duration_minutes=60,
|
||||
timezone="UTC",
|
||||
log_level="debug",
|
||||
)
|
||||
|
||||
create_app(settings=settings2)
|
||||
# _configure_logging uses basicConfig which replaces handlers on the root logger.
|
||||
# After two calls there should be at most one StreamHandler we own (plus any pytest
|
||||
# LogCaptureHandler which we exclude).
|
||||
root_stream_handlers = [
|
||||
h for h in logging.getLogger().handlers
|
||||
if isinstance(h, logging.StreamHandler) and not type(h).__name__.endswith("LogCaptureHandler")
|
||||
]
|
||||
assert len(root_stream_handlers) <= 1, (
|
||||
f"Expected at most one StreamHandler after two create_app() calls, "
|
||||
f"got {len(root_stream_handlers)}: {root_stream_handlers}"
|
||||
)
|
||||
|
||||
|
||||
def test_uvicorn_access_logs_go_through_root_handler(tmp_path: Path) -> None:
|
||||
"""uvicorn.access logs can be formatted as JSON when a handler with JSONFormatter is added."""
|
||||
fail2ban_config_dir = tmp_path / "fail2ban"
|
||||
fail2ban_config_dir.mkdir()
|
||||
|
||||
settings = Settings(
|
||||
database_path=str(tmp_path / "test.db"),
|
||||
fail2ban_socket="/tmp/fake_fail2ban.sock",
|
||||
fail2ban_config_dir=str(fail2ban_config_dir),
|
||||
session_secret="test-secret-key-do-not-use-in-production",
|
||||
session_duration_minutes=60,
|
||||
timezone="UTC",
|
||||
log_level="debug",
|
||||
)
|
||||
create_app(settings=settings)
|
||||
|
||||
# uvicorn.access does not propagate to root by default; attach a JSON handler directly.
|
||||
uvicorn_access = logging.getLogger("uvicorn.access")
|
||||
output = io.StringIO()
|
||||
handler = logging.StreamHandler(stream=output)
|
||||
handler.setFormatter(JSONFormatter())
|
||||
uvicorn_access.addHandler(handler)
|
||||
|
||||
try:
|
||||
uvicorn_access.setLevel(logging.DEBUG)
|
||||
uvicorn_access.info("GET /api/v1/health 200")
|
||||
line = output.getvalue().strip()
|
||||
assert line, "Expected non-empty log output from uvicorn.access"
|
||||
parsed = json.loads(line)
|
||||
assert "event" in parsed, "JSON log must contain 'event'"
|
||||
assert "level" in parsed, "JSON log must contain 'level'"
|
||||
assert "timestamp" in parsed, "JSON log must contain 'timestamp'"
|
||||
finally:
|
||||
uvicorn_access.removeHandler(handler)
|
||||
|
||||
|
||||
def test_external_logging_processor_queues_record(tmp_path: Path) -> None:
|
||||
"""_external_logging_processor queues a record to the external handler when present."""
|
||||
from app.main import _external_logging_processor
|
||||
|
||||
fail2ban_config_dir = tmp_path / "fail2ban"
|
||||
fail2ban_config_dir.mkdir()
|
||||
|
||||
settings = Settings(
|
||||
database_path=str(tmp_path / "test.db"),
|
||||
fail2ban_socket="/tmp/fake_fail2ban.sock",
|
||||
fail2ban_config_dir=str(fail2ban_config_dir),
|
||||
session_secret="test-secret-key-do-not-use-in-production",
|
||||
session_duration_minutes=60,
|
||||
timezone="UTC",
|
||||
log_level="debug",
|
||||
)
|
||||
create_app(settings=settings)
|
||||
|
||||
from app.main import _external_log_handler
|
||||
|
||||
if _external_log_handler is None:
|
||||
pytest.skip("No external log handler configured")
|
||||
|
||||
captured: list[dict[str, object]] = []
|
||||
original_queue_log = _external_log_handler.queue_log
|
||||
|
||||
def mock_queue_log(record: dict[str, object]) -> None:
|
||||
captured.append(record)
|
||||
|
||||
_external_log_handler.queue_log = mock_queue_log
|
||||
|
||||
try:
|
||||
record = logging.makeLogRecord({"msg": "test event", "levelname": "INFO", "name": "test.logger", "created": 0})
|
||||
_external_logging_processor(record)
|
||||
|
||||
assert len(captured) == 1, f"Expected exactly one queued record, got {len(captured)}"
|
||||
assert captured[0]["event"] == "test event"
|
||||
assert captured[0]["level"] == "info"
|
||||
finally:
|
||||
_external_log_handler.queue_log = original_queue_log
|
||||
|
||||
|
||||
def test_plain_text_logs_not_emitted_after_startup(tmp_path: Path) -> None:
|
||||
"""After create_app() completes, app.db logger output is JSON, not plain text."""
|
||||
fail2ban_config_dir = tmp_path / "fail2ban"
|
||||
fail2ban_config_dir.mkdir()
|
||||
|
||||
settings = Settings(
|
||||
database_path=str(tmp_path / "test.db"),
|
||||
fail2ban_socket="/tmp/fake_fail2ban.sock",
|
||||
fail2ban_config_dir=str(fail2ban_config_dir),
|
||||
session_secret="test-secret-key-do-not-use-in-production",
|
||||
session_duration_minutes=60,
|
||||
timezone="UTC",
|
||||
log_level="debug",
|
||||
)
|
||||
create_app(settings=settings)
|
||||
|
||||
output = io.StringIO()
|
||||
handler = logging.StreamHandler(stream=output)
|
||||
handler.setFormatter(JSONFormatter())
|
||||
db_logger = logging.getLogger("app.db")
|
||||
db_logger.addHandler(handler)
|
||||
db_logger.setLevel(logging.DEBUG)
|
||||
|
||||
try:
|
||||
db_logger.info("test_db_log")
|
||||
line = output.getvalue().strip()
|
||||
assert line, "Expected non-empty log output"
|
||||
assert not line.startswith("test_db_log "), "Log must not be plain text"
|
||||
parsed = json.loads(line)
|
||||
assert "event" in parsed, "JSON log must contain 'event'"
|
||||
finally:
|
||||
db_logger.removeHandler(handler)
|
||||
|
||||
try:
|
||||
db_logger.info("test_db_log")
|
||||
line = output.getvalue().strip()
|
||||
assert line, "Expected non-empty log output"
|
||||
assert not line.startswith("test_db_log "), "Log must not be plain text"
|
||||
parsed = json.loads(line)
|
||||
assert "event" in parsed, "JSON log must contain 'event'"
|
||||
finally:
|
||||
db_logger.removeHandler(handler)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Middleware order validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -934,6 +934,29 @@ class TestBanTrend:
|
||||
parsed = datetime.fromisoformat(bucket.timestamp)
|
||||
assert parsed.tzinfo is not None # Must be timezone-aware (UTC)
|
||||
|
||||
async def test_ban_trend_since_is_within_expected_range(self, tmp_path: Path) -> None:
|
||||
"""``since`` value is within 24h + 60s slack of the current time."""
|
||||
from app.utils.constants import TIME_RANGE_SLACK_SECONDS
|
||||
|
||||
now = int(time.time())
|
||||
# Place a ban just inside the expected range: 23 hours ago.
|
||||
# With 60s slack, since ≈ now - 24h - 60s, so 23h-ago ban should be included.
|
||||
just_inside_range = now - (23 * 3600)
|
||||
path = str(tmp_path / "test_since_range.sqlite3")
|
||||
await _create_f2b_db(
|
||||
path,
|
||||
[{"jail": "sshd", "ip": "1.2.3.4", "timeofban": just_inside_range}],
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=path),
|
||||
):
|
||||
result = await ban_service.ban_trend("/fake/sock", "24h")
|
||||
|
||||
# Ban at 23h ago must appear (within 24h + 60s window).
|
||||
assert sum(b.count for b in result.buckets) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# bans_by_jail
|
||||
|
||||
@@ -134,3 +134,15 @@ class TestSinceUnix:
|
||||
# The slack should be ~60 seconds
|
||||
assert actual_slack >= TIME_RANGE_SLACK_SECONDS - 1
|
||||
assert actual_slack <= TIME_RANGE_SLACK_SECONDS + 1
|
||||
|
||||
def test_since_unix_returns_utc_epoch(self) -> None:
|
||||
"""``since_unix('24h')`` returns a value within 24h + 60s of ``time.time()``."""
|
||||
before = int(time.time())
|
||||
result = since_unix("24h")
|
||||
after = int(time.time())
|
||||
|
||||
# Allow 2 second tolerance for execution time
|
||||
expected_min = before - (24 * 3600) - TIME_RANGE_SLACK_SECONDS - 2
|
||||
expected_max = after - (24 * 3600) - TIME_RANGE_SLACK_SECONDS + 2
|
||||
|
||||
assert expected_min <= result <= expected_max
|
||||
|
||||
Reference in New Issue
Block a user