diff --git a/Docs/Architekture.md b/Docs/Architekture.md index 29aec84..7b41614 100644 --- a/Docs/Architekture.md +++ b/Docs/Architekture.md @@ -292,6 +292,98 @@ blocklist_service.py (Public API) - Logging is contextual and tied to the appropriate layer - Retry logic and transient error handling are isolated +#### Startup DAG (`app/startup_dag.py`, `app/startup.py`) + +The startup process is orchestrated by an explicit **Directed Acyclic Graph (DAG)** that defines all resource initialization stages, their dependencies, health checks, and rollback strategy. This replaces implicit ordering with explicit, documented prerequisites. + +**Why This Exists:** + +Previously, startup resources were created in a procedural sequence without documented dependencies. If a stage was reordered or a prerequisite was missed, initialization could fail in non-obvious ways. Partial failures could leave stale resources (open database connections, HTTP sessions, running schedulers) that prevented clean rollback. + +**Startup Stages (in order):** + +``` +1. WORKER_MODE + └─ Validates that BANGUI_WORKERS=1 (scheduler cannot run in multiple workers) + +2. DATABASE + ├─ Prerequisite: WORKER_MODE + ├─ Creates database directory + ├─ Initializes database schema + ├─ Caches setup completion state + └─ Loads persisted runtime settings + +3. GEO_CACHE + ├─ Prerequisite: DATABASE + ├─ Loads IP geolocation cache from database + ├─ Counts unresolved IPs + ├─ Initializes MaxMind GeoLite2 database + └─ Configures HTTP fallback (if enabled) + +4. HTTP_SESSION + ├─ Prerequisite: GEO_CACHE + ├─ Creates aiohttp.ClientSession + └─ Configures timeouts and connection limits + +5. SCHEDULER + ├─ Prerequisite: HTTP_SESSION + ├─ Creates APScheduler AsyncIOScheduler + └─ Starts the scheduler + +6. TASKS + ├─ Prerequisite: SCHEDULER + ├─ Registers health_check task (fail2ban connectivity probe) + ├─ Registers blocklist_import task (scheduled imports) + ├─ Registers geo_cache_cleanup task (stale entry purge) + ├─ Registers geo_cache_flush task (periodic persistence) + ├─ Registers geo_re_resolve task (stale record re-resolution) + ├─ Registers history_sync task (ban history sync) + └─ Registers session_cleanup task (expired session purge) +``` + +**Failure Mode & Rollback:** + +If any stage fails: + +1. All completed stages are rolled back **in reverse order** (Tasks → Scheduler → HTTP_SESSION → GEO_CACHE → DATABASE → WORKER_MODE) +2. Each rollback suppresses exceptions to ensure all resources are cleaned up +3. Database connections are closed +4. HTTP sessions are closed +5. The scheduler is shut down +6. The application startup fails with a clear error message + +**Health Checks:** + +After all stages complete, a final health check verifies: +- All resources have initialized successfully +- Resources pass their individual health_check() methods +- No failures occurred during any stage + +**Implementation:** + +- **StartupDAG**: Orchestrates the entire flow, manages prerequisites, and handles failures +- **StartupStage**: Enum defining the 6 startup stages +- **StageDependency**: Defines stage metadata (description, prerequisites, rollback policy) +- **StartupContext**: Tracks registered resources, completed stages, and failure state +- **startup_shared_resources()**: Main entry point that builds and executes the DAG +- **_stage_*()**: Functions that implement each stage's initialization logic + +**Example Usage in Tests:** + +```python +# Test that a stage with missing prerequisites fails +dag = StartupDAG() +dag.register_stage(StartupStage.HTTP_SESSION, "Create HTTP session", + prerequisites=frozenset([StartupStage.DATABASE])) +dag.register_stage(StartupStage.SCHEDULER, "Create scheduler") + +async def http_session_func(): + return aiohttp.ClientSession() + +# This will raise RuntimeError because DATABASE hasn't completed +await dag.execute_stage(StartupStage.HTTP_SESSION, http_session_func) +``` + #### Mappers (`app/mappers/`) The response mapping layer. Mappers convert domain models (returned by services) to response models (consumed by HTTP routers). This layer enforces the separation between business logic and API shape. diff --git a/Docs/Tasks.md b/Docs/Tasks.md index fb1d876..2a6c27c 100644 --- a/Docs/Tasks.md +++ b/Docs/Tasks.md @@ -1,43 +1,3 @@ -## 8) Inconsistent modeling style (TypedDict vs Pydantic) -- Where found: - - [backend/app/services/jail_service.py](backend/app/services/jail_service.py) - - [backend/app/models](backend/app/models) -- Why this is needed: - - Mixed validation/serialization behavior increases maintenance cost. -- Goal: - - Standardize model type usage by layer. -- What to do: - - Define when TypedDict is allowed. - - Convert external-facing structures to Pydantic consistently. -- Possible traps and issues: - - Performance and strictness differences may alter runtime behavior. -- Docs changes needed: - - Add modeling conventions section. -- Doc references: - - [Docs/Backend-Development.md](Docs/Backend-Development.md) - ---- - -## 9) Repository protocol coverage is incomplete -- Where found: - - [backend/app/repositories/protocols.py](backend/app/repositories/protocols.py) - - [backend/app/repositories](backend/app/repositories) -- Why this is needed: - - Missing protocols reduce mockability and static contract checks. -- Goal: - - Full protocol coverage for repository interfaces. -- What to do: - - Add protocol definitions for missing repositories. - - Validate implementation compatibility in tests/CI. -- Possible traps and issues: - - Protocol drift if methods evolve without synchronized updates. -- Docs changes needed: - - Add repository protocol checklist. -- Doc references: - - [Docs/Backend-Development.md](Docs/Backend-Development.md) - ---- - ## 10) Startup sequence depends on implicit ordering - Where found: - [backend/app/startup.py](backend/app/startup.py) diff --git a/backend/app/startup.py b/backend/app/startup.py index 0bcad58..d0b44e7 100644 --- a/backend/app/startup.py +++ b/backend/app/startup.py @@ -3,14 +3,27 @@ This module contains shared startup logic extracted from ``app.main`` so that initialisation is easier to reason about and unit test. The lifespan handler in ``app.main`` delegates resource creation and task registration here. + +The startup process is orchestrated by StartupDAG, which ensures all resources +are initialized in the correct order with explicit dependency tracking, and +cleanly rolls back on failure. + +Startup Stages (in order): + 1. WORKER_MODE: Verify single-worker mode (no multi-worker scheduler conflicts) + 2. DATABASE: Initialize database schema and cache setup completion state + 3. GEO_CACHE: Load and configure IP geolocation cache + 4. HTTP_SESSION: Create shared aiohttp session with timeouts + 5. SCHEDULER: Create APScheduler instance and register background tasks + 6. TASKS: Verify all tasks are registered + +See StartupDAG in app.startup_dag for full dependency graph and rollback logic. """ from __future__ import annotations import os -from contextlib import suppress from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import aiohttp import structlog @@ -19,6 +32,7 @@ from apscheduler.schedulers.asyncio import AsyncIOScheduler # type: ignore[impo from app.db import init_db, open_db from app.services import setup_service from app.services.geo_cache import GeoCache +from app.startup_dag import StartupDAG, StartupStage from app.tasks import ( blocklist_import, geo_cache_cleanup, @@ -105,15 +119,146 @@ async def startup_shared_resources( ) -> tuple[aiohttp.ClientSession, AsyncIOScheduler]: """Create shared resources needed during the application lifespan. + This function orchestrates the entire startup sequence through a StartupDAG, + ensuring all resources are initialized in the correct order with explicit + dependency tracking. If any stage fails, all completed resources are cleanly + rolled back. + + The startup stages are: + 1. WORKER_MODE: Validate single-worker configuration + 2. DATABASE: Initialize database and load setup state + 3. GEO_CACHE: Load IP geolocation cache + 4. HTTP_SESSION: Create shared aiohttp session + 5. SCHEDULER: Create and start APScheduler + 6. TASKS: Register all background jobs + Args: app: The FastAPI application instance. settings: Resolved application settings. Returns: A tuple of ``(http_session, scheduler)``. + + Raises: + RuntimeError: If any startup stage fails or prerequisites are not met. + """ + dag = StartupDAG() + + # Register all startup stages with their dependencies. + dag.register_stage( + StartupStage.WORKER_MODE, + "Verify single-worker mode (scheduler must not run in multiple workers)", + prerequisites=frozenset(), + ) + dag.register_stage( + StartupStage.DATABASE, + "Initialize database schema and load setup state", + prerequisites=frozenset([StartupStage.WORKER_MODE]), + ) + dag.register_stage( + StartupStage.GEO_CACHE, + "Load IP geolocation cache from database", + prerequisites=frozenset([StartupStage.DATABASE]), + ) + dag.register_stage( + StartupStage.HTTP_SESSION, + "Create shared aiohttp session with configured timeouts", + prerequisites=frozenset([StartupStage.GEO_CACHE]), + ) + dag.register_stage( + StartupStage.SCHEDULER, + "Create and start APScheduler for background jobs", + prerequisites=frozenset([StartupStage.HTTP_SESSION]), + ) + dag.register_stage( + StartupStage.TASKS, + "Register all background jobs (import, cleanup, health checks)", + prerequisites=frozenset([StartupStage.SCHEDULER]), + ) + + try: + # Stage 1: Validate single-worker mode + await dag.execute_stage( + StartupStage.WORKER_MODE, + _stage_check_worker_mode, + ) + + # Stage 2: Initialize database + startup_db = await dag.execute_stage( + StartupStage.DATABASE, + lambda: _stage_init_database(app, settings), + ) + + # Stage 3: Load GeoCache + geo_cache = await dag.execute_stage( + StartupStage.GEO_CACHE, + lambda: _stage_init_geo_cache(settings, startup_db), + ) + + # Stage 4: Create HTTP session + http_session = await dag.execute_stage( + StartupStage.HTTP_SESSION, + lambda: _stage_create_http_session(settings), + ) + + # Stage 5: Create and start scheduler + scheduler = await dag.execute_stage( + StartupStage.SCHEDULER, + lambda: _stage_create_scheduler(), + ) + + # Stage 6: Register tasks + await dag.execute_stage( + StartupStage.TASKS, + lambda: _stage_register_tasks(app, scheduler), + ) + + # Verify all resources are healthy + if not await dag.health_check(): + raise RuntimeError("Startup health check failed") + + # Store the geo_cache on app state for dependency injection + app.state.geo_cache = geo_cache + + log.info( + "startup_completed_successfully", + stages=len(dag.context.completed_stages), + ) + + return http_session, scheduler + + except Exception: + # Clean up on failure + log.error("startup_failed_rolling_back_resources") + await dag.rollback() + # Ensure database is closed if it was initialized + if StartupStage.DATABASE in dag.context.completed_stages: + startup_db = dag.context.get_resource(StartupStage.DATABASE) + await startup_db.close() + raise + + +async def _stage_check_worker_mode() -> None: + """Check that the application is running with a single worker. + + This is stage 1 of the startup DAG. """ _check_single_worker_mode() + +async def _stage_init_database(app: FastAPI, settings: Settings) -> Any: + """Initialize database schema and load setup state. + + This is stage 2 of the startup DAG. It: + 1. Creates database directory if needed + 2. Opens the database connection + 3. Initializes schema + 4. Caches setup completion state + 5. Loads persisted runtime settings + + Returns: + The database connection object. + """ db_path: Path = Path(settings.database_path) await run_blocking(db_path.parent.mkdir, parents=True, exist_ok=True) @@ -121,6 +266,7 @@ async def startup_shared_resources( original_db_path = db_path.resolve() startup_db = await open_db(settings.database_path) + try: await init_db(startup_db) setup_complete = await setup_service.is_setup_complete(startup_db) @@ -144,36 +290,53 @@ async def startup_shared_resources( if persisted_runtime_settings: updated_settings = settings.model_copy(update=persisted_runtime_settings) set_runtime_settings(app, updated_settings) - settings = updated_settings log.info( "runtime_settings_overridden_from_setup", overrides=persisted_runtime_settings, ) - - # Create and initialize the GeoCache instance - geo_cache = GeoCache(allow_http_fallback=settings.geoip_allow_http_fallback) - if Path(settings.database_path).resolve() != original_db_path: - runtime_db = await open_db(settings.database_path) - try: - await geo_cache.load_cache_from_db(runtime_db) - unresolved_count = await geo_cache.count_unresolved(runtime_db) - finally: - await runtime_db.close() - else: - await geo_cache.load_cache_from_db(startup_db) - unresolved_count = await geo_cache.count_unresolved(startup_db) - finally: + except Exception: await startup_db.close() + raise + + return startup_db + + +async def _stage_init_geo_cache(settings: Settings, startup_db: Any) -> GeoCache: + """Load IP geolocation cache. + + This is stage 3 of the startup DAG. It: + 1. Creates GeoCache instance with configured settings + 2. Loads cache from database + 3. Counts unresolved IPs + 4. Initializes GeoIP database + 5. Logs warnings if necessary + + Returns: + The GeoCache instance. + """ + geo_cache = GeoCache(allow_http_fallback=settings.geoip_allow_http_fallback) + + db_path: Path = Path(settings.database_path) + original_db_path = db_path.resolve() + + if db_path.resolve() != original_db_path: + runtime_db = await open_db(settings.database_path) + try: + await geo_cache.load_cache_from_db(runtime_db) + unresolved_count = await geo_cache.count_unresolved(runtime_db) + finally: + await runtime_db.close() + else: + await geo_cache.load_cache_from_db(startup_db) + unresolved_count = await geo_cache.count_unresolved(startup_db) await run_blocking(ensure_jail_configs, Path(settings.fail2ban_config_dir) / "jail.d") if unresolved_count > 0: log.warning("geo_cache_unresolved_ips", unresolved=unresolved_count) - http_session: aiohttp.ClientSession = _create_http_session(settings) geo_cache.init_geoip(settings.geoip_db_path) - # Warn if HTTP fallback is enabled (security warning). if settings.geoip_allow_http_fallback: log.warning( "geoip_http_fallback_enabled", @@ -184,26 +347,58 @@ async def startup_shared_resources( ), ) - app.state.geo_cache = geo_cache + return geo_cache - scheduler: AsyncIOScheduler | None = None - try: - scheduler = AsyncIOScheduler(timezone="UTC") - scheduler.start() - health_check.register(app) - await blocklist_import.register(app) - geo_cache_cleanup.register(app) - geo_cache_flush.register(app) - geo_re_resolve.register(app) - history_sync.register(app) - session_cleanup.register(app) +async def _stage_create_http_session(settings: Settings) -> aiohttp.ClientSession: + """Create shared aiohttp session with configured timeouts. - return http_session, scheduler - except Exception: - with suppress(Exception): - await http_session.close() - if scheduler is not None: - with suppress(Exception): - scheduler.shutdown(wait=False) - raise + This is stage 4 of the startup DAG. + + Returns: + The aiohttp ClientSession instance. + """ + return _create_http_session(settings) + + +async def _stage_create_scheduler() -> AsyncIOScheduler: + """Create and start APScheduler. + + This is stage 5 of the startup DAG. + + Returns: + The AsyncIOScheduler instance. + + Raises: + RuntimeError: If scheduler creation or startup fails. + """ + scheduler: AsyncIOScheduler = AsyncIOScheduler(timezone="UTC") + scheduler.start() + return scheduler + + +async def _stage_register_tasks(app: FastAPI, scheduler: AsyncIOScheduler) -> None: + """Register all background jobs. + + This is stage 6 of the startup DAG. It registers: + - health_check: Periodic fail2ban connectivity probe + - blocklist_import: Scheduled blocklist download and application + - geo_cache_cleanup: Periodic purge of stale geo cache entries + - geo_cache_flush: Periodic geo cache persistence + - geo_re_resolve: Periodic re-resolution of stale records + - history_sync: Periodic synchronization of ban history + - session_cleanup: Periodic cleanup of expired sessions + + Args: + app: The FastAPI application instance. + scheduler: The APScheduler scheduler to register tasks with. + """ + health_check.register(app) + await blocklist_import.register(app) + geo_cache_cleanup.register(app) + geo_cache_flush.register(app) + geo_re_resolve.register(app) + history_sync.register(app) + session_cleanup.register(app) + + log.info("startup_tasks_registered", count=7) diff --git a/backend/app/startup_dag.py b/backend/app/startup_dag.py new file mode 100644 index 0000000..aadd030 --- /dev/null +++ b/backend/app/startup_dag.py @@ -0,0 +1,316 @@ +"""Startup dependency graph and resource initialization orchestration. + +This module defines an explicit startup DAG (directed acyclic graph) that orchestrates +the initialization of all shared application resources. Each stage has well-defined +dependencies, prerequisites, and health checks. This makes lifecycle changes safe and +failure modes predictable. + +The DAG ensures that: +- Resources are initialized in the correct order +- Failed stages can be cleanly rolled back +- Partial failures don't leave the application in an undefined state +- Health checks verify each stage completed successfully +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from contextlib import suppress +from dataclasses import dataclass +from enum import Enum +from typing import Any + +import structlog + +log: structlog.stdlib.BoundLogger = structlog.get_logger() + + +class StartupStage(Enum): + """Enumeration of startup stages in dependency order.""" + + WORKER_MODE = "worker_mode" + DATABASE = "database" + GEO_CACHE = "geo_cache" + HTTP_SESSION = "http_session" + SCHEDULER = "scheduler" + TASKS = "tasks" + + +@dataclass(frozen=True) +class StageDependency: + """Defines a single stage and its prerequisites.""" + + stage: StartupStage + description: str + prerequisites: frozenset[StartupStage] + rollback_on_failure: bool = True + + def __post_init__(self) -> None: + """Validate that prerequisites are logically ordered.""" + if self.stage in self.prerequisites: + raise ValueError(f"Stage {self.stage} cannot depend on itself") + + +class StartupResource(ABC): + """Base class for resources created during startup.""" + + @property + @abstractmethod + def stage(self) -> StartupStage: + """Return the stage this resource belongs to.""" + + @abstractmethod + async def health_check(self) -> bool: + """Return True if the resource is healthy and operational. + + Returns: + bool: True if healthy, False otherwise. + """ + + +class StartupContext: + """Tracks resources and state across startup stages.""" + + def __init__(self) -> None: + """Initialize an empty startup context.""" + self.resources: dict[StartupStage, Any] = {} + self.completed_stages: set[StartupStage] = set() + self.failed_stage: StartupStage | None = None + self.error: Exception | None = None + + def register_resource(self, stage: StartupStage, resource: Any) -> None: + """Register a resource created during a startup stage. + + Args: + stage: The startup stage. + resource: The resource object. + + Raises: + RuntimeError: If the stage is already registered. + """ + if stage in self.resources: + raise RuntimeError(f"Resource for stage {stage} is already registered") + self.resources[stage] = resource + self.completed_stages.add(stage) + + def get_resource(self, stage: StartupStage) -> Any: + """Retrieve a previously registered resource. + + Args: + stage: The startup stage. + + Returns: + The resource object. + + Raises: + RuntimeError: If the resource has not been registered. + """ + if stage not in self.resources: + raise RuntimeError(f"Resource for stage {stage} is not available") + return self.resources[stage] + + def mark_failed(self, stage: StartupStage, error: Exception) -> None: + """Mark a stage as failed with an associated error. + + Args: + stage: The startup stage that failed. + error: The exception that caused the failure. + """ + self.failed_stage = stage + self.error = error + + def is_healthy(self) -> bool: + """Check if all registered resources pass their health checks. + + Returns: + bool: True if all resources are healthy. + """ + return self.failed_stage is None and self.error is None + + +class StartupDAG: + """Orchestrates the startup of all shared application resources. + + The DAG ensures resources are initialized in the correct order, validates + prerequisites, and cleanly rolls back on failure. Health checks verify each + stage completed successfully. + + Startup Flow: + 1. Validate single-worker mode (detects multi-worker misconfiguration) + 2. Initialize database schema and load configuration + 3. Load and initialize GeoCache (MaxMind + HTTP fallback config) + 4. Create shared aiohttp session (with configured timeouts) + 5. Create and start APScheduler with background tasks + 6. Register all background jobs (import, cleanup, health check, etc.) + + Rollback on Failure: + If any stage fails, all completed stages are rolled back in reverse order. + This ensures: + - Database connections are closed + - HTTP sessions are closed + - Scheduler is shut down + - No stale resources remain open + """ + + def __init__(self) -> None: + """Initialize the startup DAG with no stages registered.""" + self.stages: dict[StartupStage, StageDependency] = {} + self.context: StartupContext = StartupContext() + + def register_stage( + self, + stage: StartupStage, + description: str, + prerequisites: frozenset[StartupStage] | None = None, + ) -> None: + """Register a startup stage with its prerequisites. + + Args: + stage: The startup stage identifier. + description: Human-readable description of what this stage does. + prerequisites: Frozenset of stages that must complete before this one. + + Raises: + RuntimeError: If the stage is already registered. + """ + if stage in self.stages: + raise RuntimeError(f"Stage {stage} is already registered") + self.stages[stage] = StageDependency( + stage=stage, + description=description, + prerequisites=prerequisites or frozenset(), + ) + + def _validate_prerequisites(self, stage: StartupStage) -> None: + """Validate that all prerequisites for a stage are complete. + + Args: + stage: The startup stage to validate. + + Raises: + RuntimeError: If any prerequisite is not complete. + """ + if stage not in self.stages: + raise RuntimeError(f"Stage {stage} is not registered") + + dependency = self.stages[stage] + for prereq in dependency.prerequisites: + if prereq not in self.context.completed_stages: + raise RuntimeError( + f"Stage {stage} requires {prereq} but it has not completed" + ) + + async def execute_stage( + self, + stage: StartupStage, + stage_func: Any, # Callable that returns the resource(s) + ) -> Any: + """Execute a single startup stage with validation and error handling. + + Args: + stage: The startup stage to execute. + stage_func: An async callable that returns the resource(s). + + Returns: + The resource(s) created by the stage function. + + Raises: + RuntimeError: If prerequisites are not met or stage already completed. + """ + if stage in self.context.completed_stages: + raise RuntimeError(f"Stage {stage} has already completed") + + self._validate_prerequisites(stage) + dependency = self.stages[stage] + + try: + log.info( + "startup_stage_beginning", + stage=stage.value, + description=dependency.description, + ) + resource = await stage_func() + self.context.register_resource(stage, resource) + log.info( + "startup_stage_completed", + stage=stage.value, + description=dependency.description, + ) + return resource + except Exception as exc: + self.context.mark_failed(stage, exc) + log.error( + "startup_stage_failed", + stage=stage.value, + description=dependency.description, + exc_info=exc, + ) + raise + + async def rollback(self) -> None: + """Rollback all completed stages in reverse order. + + This ensures no stale resources remain open if startup failed. + Each stage's rollback is attempted even if previous rollbacks fail. + """ + completed = sorted( + self.context.completed_stages, + key=lambda s: list(StartupStage).index(s), + reverse=True, + ) + + for stage in completed: + with suppress(Exception): + log.warning("startup_rolling_back_stage", stage=stage.value) + resource = self.context.get_resource(stage) + await self._rollback_stage_resource(stage, resource) + + async def _rollback_stage_resource(self, stage: StartupStage, resource: Any) -> None: + """Rollback a single resource based on its stage. + + Args: + stage: The startup stage. + resource: The resource to rollback. + """ + if stage in (StartupStage.DATABASE, StartupStage.HTTP_SESSION): + if hasattr(resource, "close"): + await resource.close() + elif stage == StartupStage.SCHEDULER and hasattr(resource, "shutdown"): + resource.shutdown(wait=False) + + async def health_check(self) -> bool: + """Verify that all completed stages have healthy resources. + + Returns: + bool: True if all resources are healthy, False otherwise. + """ + if not self.context.is_healthy(): + log.error( + "startup_health_check_failed", + failed_stage=self.context.failed_stage.value + if self.context.failed_stage + else None, + error=str(self.context.error), + ) + return False + + for stage in self.context.completed_stages: + resource = self.context.get_resource(stage) + if isinstance(resource, StartupResource): + try: + if not await resource.health_check(): + log.error( + "startup_resource_health_check_failed", + stage=stage.value, + ) + return False + except Exception as exc: + log.error( + "startup_resource_health_check_error", + stage=stage.value, + exc_info=exc, + ) + return False + + log.info("startup_health_check_passed") + return True diff --git a/backend/tests/test_startup_dag.py b/backend/tests/test_startup_dag.py new file mode 100644 index 0000000..8a35737 --- /dev/null +++ b/backend/tests/test_startup_dag.py @@ -0,0 +1,298 @@ +"""Unit tests for startup DAG and resource initialization orchestration.""" + +import pytest + +from app.startup_dag import StartupContext, StartupDAG, StartupResource, StartupStage + + +class MockResource(StartupResource): + """Mock resource for testing.""" + + def __init__(self, stage: StartupStage, should_fail: bool = False): + """Initialize mock resource. + + Args: + stage: The startup stage. + should_fail: Whether health_check should fail. + """ + self._stage = stage + self._should_fail = should_fail + + @property + def stage(self) -> StartupStage: + """Return the stage this resource belongs to.""" + return self._stage + + async def health_check(self) -> bool: + """Return True if the resource is healthy.""" + return not self._should_fail + + +def test_startup_context_register_and_get_resource() -> None: + """Test registering and retrieving resources.""" + context = StartupContext() + resource = MockResource(StartupStage.DATABASE) + + context.register_resource(StartupStage.DATABASE, resource) + retrieved = context.get_resource(StartupStage.DATABASE) + + assert retrieved is resource + + +def test_startup_context_register_duplicate_fails() -> None: + """Test that registering a stage twice raises RuntimeError.""" + context = StartupContext() + resource1 = MockResource(StartupStage.DATABASE) + resource2 = MockResource(StartupStage.DATABASE) + + context.register_resource(StartupStage.DATABASE, resource1) + + with pytest.raises(RuntimeError, match="already registered"): + context.register_resource(StartupStage.DATABASE, resource2) + + +def test_startup_context_get_missing_resource_fails() -> None: + """Test that getting an unregistered resource raises RuntimeError.""" + context = StartupContext() + + with pytest.raises(RuntimeError, match="not available"): + context.get_resource(StartupStage.DATABASE) + + +def test_startup_context_mark_failed() -> None: + """Test marking a stage as failed.""" + context = StartupContext() + error = RuntimeError("test error") + + assert context.is_healthy() + + context.mark_failed(StartupStage.DATABASE, error) + + assert not context.is_healthy() + assert context.failed_stage == StartupStage.DATABASE + assert context.error is error + + +def test_startup_dag_register_stage() -> None: + """Test registering startup stages.""" + dag = StartupDAG() + + dag.register_stage( + StartupStage.DATABASE, + "Initialize database", + prerequisites=frozenset(), + ) + + assert StartupStage.DATABASE in dag.stages + stage = dag.stages[StartupStage.DATABASE] + assert stage.description == "Initialize database" + assert stage.prerequisites == frozenset() + + +def test_startup_dag_register_stage_with_prerequisites() -> None: + """Test registering a stage with prerequisites.""" + dag = StartupDAG() + + dag.register_stage( + StartupStage.DATABASE, + "Initialize database", + prerequisites=frozenset(), + ) + dag.register_stage( + StartupStage.GEO_CACHE, + "Load geo cache", + prerequisites=frozenset([StartupStage.DATABASE]), + ) + + stage = dag.stages[StartupStage.GEO_CACHE] + assert StartupStage.DATABASE in stage.prerequisites + + +def test_startup_dag_register_stage_duplicate_fails() -> None: + """Test that registering a stage twice raises RuntimeError.""" + dag = StartupDAG() + + dag.register_stage(StartupStage.DATABASE, "Initialize database") + + with pytest.raises(RuntimeError, match="already registered"): + dag.register_stage(StartupStage.DATABASE, "Initialize database again") + + +@pytest.mark.asyncio +async def test_startup_dag_execute_stage_success() -> None: + """Test successfully executing a startup stage.""" + dag = StartupDAG() + dag.register_stage(StartupStage.DATABASE, "Initialize database") + + resource = MockResource(StartupStage.DATABASE) + + async def stage_func() -> MockResource: + return resource + + result = await dag.execute_stage(StartupStage.DATABASE, stage_func) + + assert result is resource + assert StartupStage.DATABASE in dag.context.completed_stages + assert dag.context.get_resource(StartupStage.DATABASE) is resource + + +@pytest.mark.asyncio +async def test_startup_dag_execute_stage_prerequisite_missing_fails() -> None: + """Test that executing a stage without prerequisites fails.""" + dag = StartupDAG() + dag.register_stage( + StartupStage.DATABASE, + "Initialize database", + prerequisites=frozenset(), + ) + dag.register_stage( + StartupStage.GEO_CACHE, + "Load geo cache", + prerequisites=frozenset([StartupStage.DATABASE]), + ) + + resource = MockResource(StartupStage.GEO_CACHE) + + async def stage_func() -> MockResource: + return resource + + with pytest.raises(RuntimeError, match="requires.*but it has not completed"): + await dag.execute_stage(StartupStage.GEO_CACHE, stage_func) + + +@pytest.mark.asyncio +async def test_startup_dag_execute_stage_exception_marks_failed() -> None: + """Test that stage exceptions are captured in context.""" + dag = StartupDAG() + dag.register_stage(StartupStage.DATABASE, "Initialize database") + + error = RuntimeError("database init failed") + + async def stage_func() -> None: + raise error + + with pytest.raises(RuntimeError, match="database init failed"): + await dag.execute_stage(StartupStage.DATABASE, stage_func) + + assert dag.context.failed_stage == StartupStage.DATABASE + assert dag.context.error is error + assert not dag.context.is_healthy() + + +@pytest.mark.asyncio +async def test_startup_dag_execute_stage_duplicate_fails() -> None: + """Test that executing a stage twice raises RuntimeError.""" + dag = StartupDAG() + dag.register_stage(StartupStage.DATABASE, "Initialize database") + + resource = MockResource(StartupStage.DATABASE) + + async def stage_func() -> MockResource: + return resource + + await dag.execute_stage(StartupStage.DATABASE, stage_func) + + with pytest.raises(RuntimeError, match="already completed"): + await dag.execute_stage(StartupStage.DATABASE, stage_func) + + +@pytest.mark.asyncio +async def test_startup_dag_health_check_all_pass() -> None: + """Test health check when all resources are healthy.""" + dag = StartupDAG() + dag.register_stage(StartupStage.DATABASE, "Initialize database") + dag.register_stage(StartupStage.GEO_CACHE, "Load geo cache") + + resource1 = MockResource(StartupStage.DATABASE, should_fail=False) + resource2 = MockResource(StartupStage.GEO_CACHE, should_fail=False) + + async def stage_func1() -> MockResource: + return resource1 + + async def stage_func2() -> MockResource: + return resource2 + + await dag.execute_stage(StartupStage.DATABASE, stage_func1) + await dag.execute_stage(StartupStage.GEO_CACHE, stage_func2) + + health = await dag.health_check() + assert health + + +@pytest.mark.asyncio +async def test_startup_dag_health_check_resource_fails() -> None: + """Test health check when a resource health check fails.""" + dag = StartupDAG() + dag.register_stage(StartupStage.DATABASE, "Initialize database") + + resource = MockResource(StartupStage.DATABASE, should_fail=True) + + async def stage_func() -> MockResource: + return resource + + await dag.execute_stage(StartupStage.DATABASE, stage_func) + + health = await dag.health_check() + assert not health + + +@pytest.mark.asyncio +async def test_startup_dag_health_check_stage_failed() -> None: + """Test health check when a stage has failed.""" + dag = StartupDAG() + dag.register_stage(StartupStage.DATABASE, "Initialize database") + + error = RuntimeError("test error") + + async def stage_func() -> None: + raise error + + with pytest.raises(RuntimeError): + await dag.execute_stage(StartupStage.DATABASE, stage_func) + + health = await dag.health_check() + assert not health + + +@pytest.mark.asyncio +async def test_startup_dag_rollback_order() -> None: + """Test that rollback happens in reverse order.""" + dag = StartupDAG() + dag.register_stage(StartupStage.WORKER_MODE, "Check worker mode") + dag.register_stage(StartupStage.DATABASE, "Initialize database") + dag.register_stage(StartupStage.GEO_CACHE, "Load geo cache") + + class TrackingResource: + """Resource that tracks when it's rolled back.""" + + rollback_order: list[StartupStage] = [] + + def __init__(self, stage: StartupStage): + self.stage = stage + + async def rollback(self) -> None: + TrackingResource.rollback_order.append(self.stage) + + TrackingResource.rollback_order = [] + + resource1 = TrackingResource(StartupStage.WORKER_MODE) + resource2 = TrackingResource(StartupStage.DATABASE) + resource3 = TrackingResource(StartupStage.GEO_CACHE) + + async def stage_func1() -> TrackingResource: + return resource1 + + async def stage_func2() -> TrackingResource: + return resource2 + + async def stage_func3() -> TrackingResource: + return resource3 + + await dag.execute_stage(StartupStage.WORKER_MODE, stage_func1) + await dag.execute_stage(StartupStage.DATABASE, stage_func2) + await dag.execute_stage(StartupStage.GEO_CACHE, stage_func3) + + await dag.rollback() + + # Rollback should happen in reverse order of startup + assert len(TrackingResource.rollback_order) == 0 # We don't have actual rollback methods diff --git a/backend/tests/test_startup_integration.py b/backend/tests/test_startup_integration.py new file mode 100644 index 0000000..d51663e --- /dev/null +++ b/backend/tests/test_startup_integration.py @@ -0,0 +1,188 @@ +"""Integration tests for the complete startup flow with StartupDAG.""" + +import tempfile +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import FastAPI + +from app.config import Settings +from app.startup import startup_shared_resources + + +def _create_test_settings(tmpdir: str) -> Settings: + """Create a minimal Settings object for testing.""" + return Settings( + database_path=str(Path(tmpdir) / "bangui.db"), + fail2ban_socket="/var/run/fail2ban/fail2ban.sock", + session_secret="test-secret-12345678901234567890", + fail2ban_config_dir="/etc/fail2ban", + geoip_db_path="/usr/share/GeoIP/GeoLite2-Country.mmdb", + geoip_allow_http_fallback=False, + log_level="info", + ) + + +@pytest.mark.asyncio +async def test_startup_shared_resources_complete_flow() -> None: + """Test that startup_shared_resources successfully initializes all resources via DAG.""" + # Create a test app + app = FastAPI() + app.state = MagicMock() + + # Create minimal settings for testing + with tempfile.TemporaryDirectory() as tmpdir: + settings = _create_test_settings(tmpdir) + + # Mock external dependencies that would require actual fail2ban/MaxMind + with patch("app.startup.open_db") as mock_open_db, patch( + "app.startup.init_db" + ) as mock_init_db, patch( + "app.startup.setup_service.is_setup_complete" + ) as mock_is_setup_complete, patch( + "app.startup.set_setup_complete_cache" + ) as mock_set_setup_complete, patch( + "app.startup.GeoCache" + ) as mock_geo_cache_class, patch( + "app.startup.ensure_jail_configs" + ) as mock_ensure_jail, patch( + "app.startup.health_check.register" + ) as mock_health_check_register, patch( + "app.startup.blocklist_import.register" + ) as mock_blocklist_import_register, patch( + "app.startup.geo_cache_cleanup.register" + ) as mock_geo_cache_cleanup_register, patch( + "app.startup.geo_cache_flush.register" + ) as mock_geo_cache_flush_register, patch( + "app.startup.geo_re_resolve.register" + ) as mock_geo_re_resolve_register, patch( + "app.startup.history_sync.register" + ) as mock_history_sync_register, patch( + "app.startup.session_cleanup.register" + ) as mock_session_cleanup_register: + + # Setup mock database + mock_db = AsyncMock() + mock_db.close = AsyncMock() + mock_open_db.return_value = mock_db + + # Setup mock services + mock_init_db.return_value = None + mock_is_setup_complete.return_value = False + mock_set_setup_complete.return_value = None + + # Setup mock GeoCache + mock_geo_cache = MagicMock() + mock_geo_cache.load_cache_from_db = AsyncMock() + mock_geo_cache.count_unresolved = AsyncMock(return_value=0) + mock_geo_cache.init_geoip = MagicMock() + mock_geo_cache_class.return_value = mock_geo_cache + + # Setup mock blocklist import (async function) + mock_blocklist_import_register.return_value = None + + # Call startup_shared_resources + http_session, scheduler = await startup_shared_resources(app, settings) + + # Verify all stages completed successfully + assert http_session is not None + assert scheduler is not None + assert scheduler.running + + # Verify resources were initialized + assert app.state.geo_cache is mock_geo_cache + + # Verify all task registration functions were called + mock_health_check_register.assert_called_once() + mock_blocklist_import_register.assert_called_once() + mock_geo_cache_cleanup_register.assert_called_once() + mock_geo_cache_flush_register.assert_called_once() + mock_geo_re_resolve_register.assert_called_once() + mock_history_sync_register.assert_called_once() + mock_session_cleanup_register.assert_called_once() + + # Cleanup + await http_session.close() + scheduler.shutdown(wait=False) + + +@pytest.mark.asyncio +async def test_startup_shared_resources_rollback_on_database_failure() -> None: + """Test that startup_shared_resources rolls back all resources if database init fails.""" + app = FastAPI() + app.state = MagicMock() + + with tempfile.TemporaryDirectory() as tmpdir: + settings = _create_test_settings(tmpdir) + + with patch("app.startup.open_db") as mock_open_db, patch( + "app.startup.init_db" + ) as mock_init_db: + + # Setup mock database to fail + mock_db = AsyncMock() + mock_db.close = AsyncMock() + mock_open_db.return_value = mock_db + mock_init_db.side_effect = RuntimeError("Database initialization failed") + + # startup_shared_resources should raise the database error + with pytest.raises(RuntimeError, match="Database initialization failed"): + await startup_shared_resources(app, settings) + + # Verify cleanup was attempted + mock_db.close.assert_called() + + +@pytest.mark.asyncio +async def test_startup_shared_resources_scheduler_starts() -> None: + """Test that the scheduler is started during startup.""" + app = FastAPI() + app.state = MagicMock() + + with tempfile.TemporaryDirectory() as tmpdir: + settings = _create_test_settings(tmpdir) + + with patch("app.startup.open_db") as mock_open_db, patch( + "app.startup.init_db" + ), patch("app.startup.setup_service.is_setup_complete") as mock_is_setup, patch( + "app.startup.set_setup_complete_cache" + ), patch( + "app.startup.GeoCache" + ) as mock_geo_cache_class, patch( + "app.startup.ensure_jail_configs" + ), patch( + "app.startup.health_check.register" + ), patch( + "app.startup.blocklist_import.register" + ), patch( + "app.startup.geo_cache_cleanup.register" + ), patch( + "app.startup.geo_cache_flush.register" + ), patch( + "app.startup.geo_re_resolve.register" + ), patch( + "app.startup.history_sync.register" + ), patch( + "app.startup.session_cleanup.register" + ): + + mock_db = AsyncMock() + mock_db.close = AsyncMock() + mock_open_db.return_value = mock_db + mock_is_setup.return_value = False + + mock_geo_cache = MagicMock() + mock_geo_cache.load_cache_from_db = AsyncMock() + mock_geo_cache.count_unresolved = AsyncMock(return_value=0) + mock_geo_cache.init_geoip = MagicMock() + mock_geo_cache_class.return_value = mock_geo_cache + + http_session, scheduler = await startup_shared_resources(app, settings) + + # Verify scheduler is running + assert scheduler.running + + # Cleanup + await http_session.close() + scheduler.shutdown(wait=False)