refactoring-backend #3
@@ -292,6 +292,98 @@ blocklist_service.py (Public API)
|
||||
- Logging is contextual and tied to the appropriate layer
|
||||
- Retry logic and transient error handling are isolated
|
||||
|
||||
#### Startup DAG (`app/startup_dag.py`, `app/startup.py`)
|
||||
|
||||
The startup process is orchestrated by an explicit **Directed Acyclic Graph (DAG)** that defines all resource initialization stages, their dependencies, health checks, and rollback strategy. This replaces implicit ordering with explicit, documented prerequisites.
|
||||
|
||||
**Why This Exists:**
|
||||
|
||||
Previously, startup resources were created in a procedural sequence without documented dependencies. If a stage was reordered or a prerequisite was missed, initialization could fail in non-obvious ways. Partial failures could leave stale resources (open database connections, HTTP sessions, running schedulers) that prevented clean rollback.
|
||||
|
||||
**Startup Stages (in order):**
|
||||
|
||||
```
|
||||
1. WORKER_MODE
|
||||
└─ Validates that BANGUI_WORKERS=1 (scheduler cannot run in multiple workers)
|
||||
|
||||
2. DATABASE
|
||||
├─ Prerequisite: WORKER_MODE
|
||||
├─ Creates database directory
|
||||
├─ Initializes database schema
|
||||
├─ Caches setup completion state
|
||||
└─ Loads persisted runtime settings
|
||||
|
||||
3. GEO_CACHE
|
||||
├─ Prerequisite: DATABASE
|
||||
├─ Loads IP geolocation cache from database
|
||||
├─ Counts unresolved IPs
|
||||
├─ Initializes MaxMind GeoLite2 database
|
||||
└─ Configures HTTP fallback (if enabled)
|
||||
|
||||
4. HTTP_SESSION
|
||||
├─ Prerequisite: GEO_CACHE
|
||||
├─ Creates aiohttp.ClientSession
|
||||
└─ Configures timeouts and connection limits
|
||||
|
||||
5. SCHEDULER
|
||||
├─ Prerequisite: HTTP_SESSION
|
||||
├─ Creates APScheduler AsyncIOScheduler
|
||||
└─ Starts the scheduler
|
||||
|
||||
6. TASKS
|
||||
├─ Prerequisite: SCHEDULER
|
||||
├─ Registers health_check task (fail2ban connectivity probe)
|
||||
├─ Registers blocklist_import task (scheduled imports)
|
||||
├─ Registers geo_cache_cleanup task (stale entry purge)
|
||||
├─ Registers geo_cache_flush task (periodic persistence)
|
||||
├─ Registers geo_re_resolve task (stale record re-resolution)
|
||||
├─ Registers history_sync task (ban history sync)
|
||||
└─ Registers session_cleanup task (expired session purge)
|
||||
```
|
||||
|
||||
**Failure Mode & Rollback:**
|
||||
|
||||
If any stage fails:
|
||||
|
||||
1. All completed stages are rolled back **in reverse order** (Tasks → Scheduler → HTTP_SESSION → GEO_CACHE → DATABASE → WORKER_MODE)
|
||||
2. Each rollback suppresses exceptions to ensure all resources are cleaned up
|
||||
3. Database connections are closed
|
||||
4. HTTP sessions are closed
|
||||
5. The scheduler is shut down
|
||||
6. The application startup fails with a clear error message
|
||||
|
||||
**Health Checks:**
|
||||
|
||||
After all stages complete, a final health check verifies:
|
||||
- All resources have initialized successfully
|
||||
- Resources pass their individual health_check() methods
|
||||
- No failures occurred during any stage
|
||||
|
||||
**Implementation:**
|
||||
|
||||
- **StartupDAG**: Orchestrates the entire flow, manages prerequisites, and handles failures
|
||||
- **StartupStage**: Enum defining the 6 startup stages
|
||||
- **StageDependency**: Defines stage metadata (description, prerequisites, rollback policy)
|
||||
- **StartupContext**: Tracks registered resources, completed stages, and failure state
|
||||
- **startup_shared_resources()**: Main entry point that builds and executes the DAG
|
||||
- **_stage_*()**: Functions that implement each stage's initialization logic
|
||||
|
||||
**Example Usage in Tests:**
|
||||
|
||||
```python
|
||||
# Test that a stage with missing prerequisites fails
|
||||
dag = StartupDAG()
|
||||
dag.register_stage(StartupStage.HTTP_SESSION, "Create HTTP session",
|
||||
prerequisites=frozenset([StartupStage.DATABASE]))
|
||||
dag.register_stage(StartupStage.SCHEDULER, "Create scheduler")
|
||||
|
||||
async def http_session_func():
|
||||
return aiohttp.ClientSession()
|
||||
|
||||
# This will raise RuntimeError because DATABASE hasn't completed
|
||||
await dag.execute_stage(StartupStage.HTTP_SESSION, http_session_func)
|
||||
```
|
||||
|
||||
#### Mappers (`app/mappers/`)
|
||||
|
||||
The response mapping layer. Mappers convert domain models (returned by services) to response models (consumed by HTTP routers). This layer enforces the separation between business logic and API shape.
|
||||
|
||||
@@ -1,43 +1,3 @@
|
||||
## 8) Inconsistent modeling style (TypedDict vs Pydantic)
|
||||
- Where found:
|
||||
- [backend/app/services/jail_service.py](backend/app/services/jail_service.py)
|
||||
- [backend/app/models](backend/app/models)
|
||||
- Why this is needed:
|
||||
- Mixed validation/serialization behavior increases maintenance cost.
|
||||
- Goal:
|
||||
- Standardize model type usage by layer.
|
||||
- What to do:
|
||||
- Define when TypedDict is allowed.
|
||||
- Convert external-facing structures to Pydantic consistently.
|
||||
- Possible traps and issues:
|
||||
- Performance and strictness differences may alter runtime behavior.
|
||||
- Docs changes needed:
|
||||
- Add modeling conventions section.
|
||||
- Doc references:
|
||||
- [Docs/Backend-Development.md](Docs/Backend-Development.md)
|
||||
|
||||
---
|
||||
|
||||
## 9) Repository protocol coverage is incomplete
|
||||
- Where found:
|
||||
- [backend/app/repositories/protocols.py](backend/app/repositories/protocols.py)
|
||||
- [backend/app/repositories](backend/app/repositories)
|
||||
- Why this is needed:
|
||||
- Missing protocols reduce mockability and static contract checks.
|
||||
- Goal:
|
||||
- Full protocol coverage for repository interfaces.
|
||||
- What to do:
|
||||
- Add protocol definitions for missing repositories.
|
||||
- Validate implementation compatibility in tests/CI.
|
||||
- Possible traps and issues:
|
||||
- Protocol drift if methods evolve without synchronized updates.
|
||||
- Docs changes needed:
|
||||
- Add repository protocol checklist.
|
||||
- Doc references:
|
||||
- [Docs/Backend-Development.md](Docs/Backend-Development.md)
|
||||
|
||||
---
|
||||
|
||||
## 10) Startup sequence depends on implicit ordering
|
||||
- Where found:
|
||||
- [backend/app/startup.py](backend/app/startup.py)
|
||||
|
||||
@@ -3,14 +3,27 @@
|
||||
This module contains shared startup logic extracted from ``app.main`` so that
|
||||
initialisation is easier to reason about and unit test. The lifespan handler
|
||||
in ``app.main`` delegates resource creation and task registration here.
|
||||
|
||||
The startup process is orchestrated by StartupDAG, which ensures all resources
|
||||
are initialized in the correct order with explicit dependency tracking, and
|
||||
cleanly rolls back on failure.
|
||||
|
||||
Startup Stages (in order):
|
||||
1. WORKER_MODE: Verify single-worker mode (no multi-worker scheduler conflicts)
|
||||
2. DATABASE: Initialize database schema and cache setup completion state
|
||||
3. GEO_CACHE: Load and configure IP geolocation cache
|
||||
4. HTTP_SESSION: Create shared aiohttp session with timeouts
|
||||
5. SCHEDULER: Create APScheduler instance and register background tasks
|
||||
6. TASKS: Verify all tasks are registered
|
||||
|
||||
See StartupDAG in app.startup_dag for full dependency graph and rollback logic.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from contextlib import suppress
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import aiohttp
|
||||
import structlog
|
||||
@@ -19,6 +32,7 @@ from apscheduler.schedulers.asyncio import AsyncIOScheduler # type: ignore[impo
|
||||
from app.db import init_db, open_db
|
||||
from app.services import setup_service
|
||||
from app.services.geo_cache import GeoCache
|
||||
from app.startup_dag import StartupDAG, StartupStage
|
||||
from app.tasks import (
|
||||
blocklist_import,
|
||||
geo_cache_cleanup,
|
||||
@@ -105,15 +119,146 @@ async def startup_shared_resources(
|
||||
) -> tuple[aiohttp.ClientSession, AsyncIOScheduler]:
|
||||
"""Create shared resources needed during the application lifespan.
|
||||
|
||||
This function orchestrates the entire startup sequence through a StartupDAG,
|
||||
ensuring all resources are initialized in the correct order with explicit
|
||||
dependency tracking. If any stage fails, all completed resources are cleanly
|
||||
rolled back.
|
||||
|
||||
The startup stages are:
|
||||
1. WORKER_MODE: Validate single-worker configuration
|
||||
2. DATABASE: Initialize database and load setup state
|
||||
3. GEO_CACHE: Load IP geolocation cache
|
||||
4. HTTP_SESSION: Create shared aiohttp session
|
||||
5. SCHEDULER: Create and start APScheduler
|
||||
6. TASKS: Register all background jobs
|
||||
|
||||
Args:
|
||||
app: The FastAPI application instance.
|
||||
settings: Resolved application settings.
|
||||
|
||||
Returns:
|
||||
A tuple of ``(http_session, scheduler)``.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If any startup stage fails or prerequisites are not met.
|
||||
"""
|
||||
dag = StartupDAG()
|
||||
|
||||
# Register all startup stages with their dependencies.
|
||||
dag.register_stage(
|
||||
StartupStage.WORKER_MODE,
|
||||
"Verify single-worker mode (scheduler must not run in multiple workers)",
|
||||
prerequisites=frozenset(),
|
||||
)
|
||||
dag.register_stage(
|
||||
StartupStage.DATABASE,
|
||||
"Initialize database schema and load setup state",
|
||||
prerequisites=frozenset([StartupStage.WORKER_MODE]),
|
||||
)
|
||||
dag.register_stage(
|
||||
StartupStage.GEO_CACHE,
|
||||
"Load IP geolocation cache from database",
|
||||
prerequisites=frozenset([StartupStage.DATABASE]),
|
||||
)
|
||||
dag.register_stage(
|
||||
StartupStage.HTTP_SESSION,
|
||||
"Create shared aiohttp session with configured timeouts",
|
||||
prerequisites=frozenset([StartupStage.GEO_CACHE]),
|
||||
)
|
||||
dag.register_stage(
|
||||
StartupStage.SCHEDULER,
|
||||
"Create and start APScheduler for background jobs",
|
||||
prerequisites=frozenset([StartupStage.HTTP_SESSION]),
|
||||
)
|
||||
dag.register_stage(
|
||||
StartupStage.TASKS,
|
||||
"Register all background jobs (import, cleanup, health checks)",
|
||||
prerequisites=frozenset([StartupStage.SCHEDULER]),
|
||||
)
|
||||
|
||||
try:
|
||||
# Stage 1: Validate single-worker mode
|
||||
await dag.execute_stage(
|
||||
StartupStage.WORKER_MODE,
|
||||
_stage_check_worker_mode,
|
||||
)
|
||||
|
||||
# Stage 2: Initialize database
|
||||
startup_db = await dag.execute_stage(
|
||||
StartupStage.DATABASE,
|
||||
lambda: _stage_init_database(app, settings),
|
||||
)
|
||||
|
||||
# Stage 3: Load GeoCache
|
||||
geo_cache = await dag.execute_stage(
|
||||
StartupStage.GEO_CACHE,
|
||||
lambda: _stage_init_geo_cache(settings, startup_db),
|
||||
)
|
||||
|
||||
# Stage 4: Create HTTP session
|
||||
http_session = await dag.execute_stage(
|
||||
StartupStage.HTTP_SESSION,
|
||||
lambda: _stage_create_http_session(settings),
|
||||
)
|
||||
|
||||
# Stage 5: Create and start scheduler
|
||||
scheduler = await dag.execute_stage(
|
||||
StartupStage.SCHEDULER,
|
||||
lambda: _stage_create_scheduler(),
|
||||
)
|
||||
|
||||
# Stage 6: Register tasks
|
||||
await dag.execute_stage(
|
||||
StartupStage.TASKS,
|
||||
lambda: _stage_register_tasks(app, scheduler),
|
||||
)
|
||||
|
||||
# Verify all resources are healthy
|
||||
if not await dag.health_check():
|
||||
raise RuntimeError("Startup health check failed")
|
||||
|
||||
# Store the geo_cache on app state for dependency injection
|
||||
app.state.geo_cache = geo_cache
|
||||
|
||||
log.info(
|
||||
"startup_completed_successfully",
|
||||
stages=len(dag.context.completed_stages),
|
||||
)
|
||||
|
||||
return http_session, scheduler
|
||||
|
||||
except Exception:
|
||||
# Clean up on failure
|
||||
log.error("startup_failed_rolling_back_resources")
|
||||
await dag.rollback()
|
||||
# Ensure database is closed if it was initialized
|
||||
if StartupStage.DATABASE in dag.context.completed_stages:
|
||||
startup_db = dag.context.get_resource(StartupStage.DATABASE)
|
||||
await startup_db.close()
|
||||
raise
|
||||
|
||||
|
||||
async def _stage_check_worker_mode() -> None:
|
||||
"""Check that the application is running with a single worker.
|
||||
|
||||
This is stage 1 of the startup DAG.
|
||||
"""
|
||||
_check_single_worker_mode()
|
||||
|
||||
|
||||
async def _stage_init_database(app: FastAPI, settings: Settings) -> Any:
|
||||
"""Initialize database schema and load setup state.
|
||||
|
||||
This is stage 2 of the startup DAG. It:
|
||||
1. Creates database directory if needed
|
||||
2. Opens the database connection
|
||||
3. Initializes schema
|
||||
4. Caches setup completion state
|
||||
5. Loads persisted runtime settings
|
||||
|
||||
Returns:
|
||||
The database connection object.
|
||||
"""
|
||||
db_path: Path = Path(settings.database_path)
|
||||
await run_blocking(db_path.parent.mkdir, parents=True, exist_ok=True)
|
||||
|
||||
@@ -121,6 +266,7 @@ async def startup_shared_resources(
|
||||
|
||||
original_db_path = db_path.resolve()
|
||||
startup_db = await open_db(settings.database_path)
|
||||
|
||||
try:
|
||||
await init_db(startup_db)
|
||||
setup_complete = await setup_service.is_setup_complete(startup_db)
|
||||
@@ -144,36 +290,53 @@ async def startup_shared_resources(
|
||||
if persisted_runtime_settings:
|
||||
updated_settings = settings.model_copy(update=persisted_runtime_settings)
|
||||
set_runtime_settings(app, updated_settings)
|
||||
settings = updated_settings
|
||||
log.info(
|
||||
"runtime_settings_overridden_from_setup",
|
||||
overrides=persisted_runtime_settings,
|
||||
)
|
||||
|
||||
# Create and initialize the GeoCache instance
|
||||
geo_cache = GeoCache(allow_http_fallback=settings.geoip_allow_http_fallback)
|
||||
if Path(settings.database_path).resolve() != original_db_path:
|
||||
runtime_db = await open_db(settings.database_path)
|
||||
try:
|
||||
await geo_cache.load_cache_from_db(runtime_db)
|
||||
unresolved_count = await geo_cache.count_unresolved(runtime_db)
|
||||
finally:
|
||||
await runtime_db.close()
|
||||
else:
|
||||
await geo_cache.load_cache_from_db(startup_db)
|
||||
unresolved_count = await geo_cache.count_unresolved(startup_db)
|
||||
finally:
|
||||
except Exception:
|
||||
await startup_db.close()
|
||||
raise
|
||||
|
||||
return startup_db
|
||||
|
||||
|
||||
async def _stage_init_geo_cache(settings: Settings, startup_db: Any) -> GeoCache:
|
||||
"""Load IP geolocation cache.
|
||||
|
||||
This is stage 3 of the startup DAG. It:
|
||||
1. Creates GeoCache instance with configured settings
|
||||
2. Loads cache from database
|
||||
3. Counts unresolved IPs
|
||||
4. Initializes GeoIP database
|
||||
5. Logs warnings if necessary
|
||||
|
||||
Returns:
|
||||
The GeoCache instance.
|
||||
"""
|
||||
geo_cache = GeoCache(allow_http_fallback=settings.geoip_allow_http_fallback)
|
||||
|
||||
db_path: Path = Path(settings.database_path)
|
||||
original_db_path = db_path.resolve()
|
||||
|
||||
if db_path.resolve() != original_db_path:
|
||||
runtime_db = await open_db(settings.database_path)
|
||||
try:
|
||||
await geo_cache.load_cache_from_db(runtime_db)
|
||||
unresolved_count = await geo_cache.count_unresolved(runtime_db)
|
||||
finally:
|
||||
await runtime_db.close()
|
||||
else:
|
||||
await geo_cache.load_cache_from_db(startup_db)
|
||||
unresolved_count = await geo_cache.count_unresolved(startup_db)
|
||||
|
||||
await run_blocking(ensure_jail_configs, Path(settings.fail2ban_config_dir) / "jail.d")
|
||||
|
||||
if unresolved_count > 0:
|
||||
log.warning("geo_cache_unresolved_ips", unresolved=unresolved_count)
|
||||
|
||||
http_session: aiohttp.ClientSession = _create_http_session(settings)
|
||||
geo_cache.init_geoip(settings.geoip_db_path)
|
||||
|
||||
# Warn if HTTP fallback is enabled (security warning).
|
||||
if settings.geoip_allow_http_fallback:
|
||||
log.warning(
|
||||
"geoip_http_fallback_enabled",
|
||||
@@ -184,26 +347,58 @@ async def startup_shared_resources(
|
||||
),
|
||||
)
|
||||
|
||||
app.state.geo_cache = geo_cache
|
||||
return geo_cache
|
||||
|
||||
scheduler: AsyncIOScheduler | None = None
|
||||
try:
|
||||
scheduler = AsyncIOScheduler(timezone="UTC")
|
||||
scheduler.start()
|
||||
|
||||
health_check.register(app)
|
||||
await blocklist_import.register(app)
|
||||
geo_cache_cleanup.register(app)
|
||||
geo_cache_flush.register(app)
|
||||
geo_re_resolve.register(app)
|
||||
history_sync.register(app)
|
||||
session_cleanup.register(app)
|
||||
async def _stage_create_http_session(settings: Settings) -> aiohttp.ClientSession:
|
||||
"""Create shared aiohttp session with configured timeouts.
|
||||
|
||||
return http_session, scheduler
|
||||
except Exception:
|
||||
with suppress(Exception):
|
||||
await http_session.close()
|
||||
if scheduler is not None:
|
||||
with suppress(Exception):
|
||||
scheduler.shutdown(wait=False)
|
||||
raise
|
||||
This is stage 4 of the startup DAG.
|
||||
|
||||
Returns:
|
||||
The aiohttp ClientSession instance.
|
||||
"""
|
||||
return _create_http_session(settings)
|
||||
|
||||
|
||||
async def _stage_create_scheduler() -> AsyncIOScheduler:
|
||||
"""Create and start APScheduler.
|
||||
|
||||
This is stage 5 of the startup DAG.
|
||||
|
||||
Returns:
|
||||
The AsyncIOScheduler instance.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If scheduler creation or startup fails.
|
||||
"""
|
||||
scheduler: AsyncIOScheduler = AsyncIOScheduler(timezone="UTC")
|
||||
scheduler.start()
|
||||
return scheduler
|
||||
|
||||
|
||||
async def _stage_register_tasks(app: FastAPI, scheduler: AsyncIOScheduler) -> None:
|
||||
"""Register all background jobs.
|
||||
|
||||
This is stage 6 of the startup DAG. It registers:
|
||||
- health_check: Periodic fail2ban connectivity probe
|
||||
- blocklist_import: Scheduled blocklist download and application
|
||||
- geo_cache_cleanup: Periodic purge of stale geo cache entries
|
||||
- geo_cache_flush: Periodic geo cache persistence
|
||||
- geo_re_resolve: Periodic re-resolution of stale records
|
||||
- history_sync: Periodic synchronization of ban history
|
||||
- session_cleanup: Periodic cleanup of expired sessions
|
||||
|
||||
Args:
|
||||
app: The FastAPI application instance.
|
||||
scheduler: The APScheduler scheduler to register tasks with.
|
||||
"""
|
||||
health_check.register(app)
|
||||
await blocklist_import.register(app)
|
||||
geo_cache_cleanup.register(app)
|
||||
geo_cache_flush.register(app)
|
||||
geo_re_resolve.register(app)
|
||||
history_sync.register(app)
|
||||
session_cleanup.register(app)
|
||||
|
||||
log.info("startup_tasks_registered", count=7)
|
||||
|
||||
316
backend/app/startup_dag.py
Normal file
316
backend/app/startup_dag.py
Normal file
@@ -0,0 +1,316 @@
|
||||
"""Startup dependency graph and resource initialization orchestration.
|
||||
|
||||
This module defines an explicit startup DAG (directed acyclic graph) that orchestrates
|
||||
the initialization of all shared application resources. Each stage has well-defined
|
||||
dependencies, prerequisites, and health checks. This makes lifecycle changes safe and
|
||||
failure modes predictable.
|
||||
|
||||
The DAG ensures that:
|
||||
- Resources are initialized in the correct order
|
||||
- Failed stages can be cleanly rolled back
|
||||
- Partial failures don't leave the application in an undefined state
|
||||
- Health checks verify each stage completed successfully
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import structlog
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
|
||||
class StartupStage(Enum):
|
||||
"""Enumeration of startup stages in dependency order."""
|
||||
|
||||
WORKER_MODE = "worker_mode"
|
||||
DATABASE = "database"
|
||||
GEO_CACHE = "geo_cache"
|
||||
HTTP_SESSION = "http_session"
|
||||
SCHEDULER = "scheduler"
|
||||
TASKS = "tasks"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class StageDependency:
|
||||
"""Defines a single stage and its prerequisites."""
|
||||
|
||||
stage: StartupStage
|
||||
description: str
|
||||
prerequisites: frozenset[StartupStage]
|
||||
rollback_on_failure: bool = True
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate that prerequisites are logically ordered."""
|
||||
if self.stage in self.prerequisites:
|
||||
raise ValueError(f"Stage {self.stage} cannot depend on itself")
|
||||
|
||||
|
||||
class StartupResource(ABC):
|
||||
"""Base class for resources created during startup."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def stage(self) -> StartupStage:
|
||||
"""Return the stage this resource belongs to."""
|
||||
|
||||
@abstractmethod
|
||||
async def health_check(self) -> bool:
|
||||
"""Return True if the resource is healthy and operational.
|
||||
|
||||
Returns:
|
||||
bool: True if healthy, False otherwise.
|
||||
"""
|
||||
|
||||
|
||||
class StartupContext:
|
||||
"""Tracks resources and state across startup stages."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize an empty startup context."""
|
||||
self.resources: dict[StartupStage, Any] = {}
|
||||
self.completed_stages: set[StartupStage] = set()
|
||||
self.failed_stage: StartupStage | None = None
|
||||
self.error: Exception | None = None
|
||||
|
||||
def register_resource(self, stage: StartupStage, resource: Any) -> None:
|
||||
"""Register a resource created during a startup stage.
|
||||
|
||||
Args:
|
||||
stage: The startup stage.
|
||||
resource: The resource object.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the stage is already registered.
|
||||
"""
|
||||
if stage in self.resources:
|
||||
raise RuntimeError(f"Resource for stage {stage} is already registered")
|
||||
self.resources[stage] = resource
|
||||
self.completed_stages.add(stage)
|
||||
|
||||
def get_resource(self, stage: StartupStage) -> Any:
|
||||
"""Retrieve a previously registered resource.
|
||||
|
||||
Args:
|
||||
stage: The startup stage.
|
||||
|
||||
Returns:
|
||||
The resource object.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the resource has not been registered.
|
||||
"""
|
||||
if stage not in self.resources:
|
||||
raise RuntimeError(f"Resource for stage {stage} is not available")
|
||||
return self.resources[stage]
|
||||
|
||||
def mark_failed(self, stage: StartupStage, error: Exception) -> None:
|
||||
"""Mark a stage as failed with an associated error.
|
||||
|
||||
Args:
|
||||
stage: The startup stage that failed.
|
||||
error: The exception that caused the failure.
|
||||
"""
|
||||
self.failed_stage = stage
|
||||
self.error = error
|
||||
|
||||
def is_healthy(self) -> bool:
|
||||
"""Check if all registered resources pass their health checks.
|
||||
|
||||
Returns:
|
||||
bool: True if all resources are healthy.
|
||||
"""
|
||||
return self.failed_stage is None and self.error is None
|
||||
|
||||
|
||||
class StartupDAG:
|
||||
"""Orchestrates the startup of all shared application resources.
|
||||
|
||||
The DAG ensures resources are initialized in the correct order, validates
|
||||
prerequisites, and cleanly rolls back on failure. Health checks verify each
|
||||
stage completed successfully.
|
||||
|
||||
Startup Flow:
|
||||
1. Validate single-worker mode (detects multi-worker misconfiguration)
|
||||
2. Initialize database schema and load configuration
|
||||
3. Load and initialize GeoCache (MaxMind + HTTP fallback config)
|
||||
4. Create shared aiohttp session (with configured timeouts)
|
||||
5. Create and start APScheduler with background tasks
|
||||
6. Register all background jobs (import, cleanup, health check, etc.)
|
||||
|
||||
Rollback on Failure:
|
||||
If any stage fails, all completed stages are rolled back in reverse order.
|
||||
This ensures:
|
||||
- Database connections are closed
|
||||
- HTTP sessions are closed
|
||||
- Scheduler is shut down
|
||||
- No stale resources remain open
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the startup DAG with no stages registered."""
|
||||
self.stages: dict[StartupStage, StageDependency] = {}
|
||||
self.context: StartupContext = StartupContext()
|
||||
|
||||
def register_stage(
|
||||
self,
|
||||
stage: StartupStage,
|
||||
description: str,
|
||||
prerequisites: frozenset[StartupStage] | None = None,
|
||||
) -> None:
|
||||
"""Register a startup stage with its prerequisites.
|
||||
|
||||
Args:
|
||||
stage: The startup stage identifier.
|
||||
description: Human-readable description of what this stage does.
|
||||
prerequisites: Frozenset of stages that must complete before this one.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the stage is already registered.
|
||||
"""
|
||||
if stage in self.stages:
|
||||
raise RuntimeError(f"Stage {stage} is already registered")
|
||||
self.stages[stage] = StageDependency(
|
||||
stage=stage,
|
||||
description=description,
|
||||
prerequisites=prerequisites or frozenset(),
|
||||
)
|
||||
|
||||
def _validate_prerequisites(self, stage: StartupStage) -> None:
|
||||
"""Validate that all prerequisites for a stage are complete.
|
||||
|
||||
Args:
|
||||
stage: The startup stage to validate.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If any prerequisite is not complete.
|
||||
"""
|
||||
if stage not in self.stages:
|
||||
raise RuntimeError(f"Stage {stage} is not registered")
|
||||
|
||||
dependency = self.stages[stage]
|
||||
for prereq in dependency.prerequisites:
|
||||
if prereq not in self.context.completed_stages:
|
||||
raise RuntimeError(
|
||||
f"Stage {stage} requires {prereq} but it has not completed"
|
||||
)
|
||||
|
||||
async def execute_stage(
|
||||
self,
|
||||
stage: StartupStage,
|
||||
stage_func: Any, # Callable that returns the resource(s)
|
||||
) -> Any:
|
||||
"""Execute a single startup stage with validation and error handling.
|
||||
|
||||
Args:
|
||||
stage: The startup stage to execute.
|
||||
stage_func: An async callable that returns the resource(s).
|
||||
|
||||
Returns:
|
||||
The resource(s) created by the stage function.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If prerequisites are not met or stage already completed.
|
||||
"""
|
||||
if stage in self.context.completed_stages:
|
||||
raise RuntimeError(f"Stage {stage} has already completed")
|
||||
|
||||
self._validate_prerequisites(stage)
|
||||
dependency = self.stages[stage]
|
||||
|
||||
try:
|
||||
log.info(
|
||||
"startup_stage_beginning",
|
||||
stage=stage.value,
|
||||
description=dependency.description,
|
||||
)
|
||||
resource = await stage_func()
|
||||
self.context.register_resource(stage, resource)
|
||||
log.info(
|
||||
"startup_stage_completed",
|
||||
stage=stage.value,
|
||||
description=dependency.description,
|
||||
)
|
||||
return resource
|
||||
except Exception as exc:
|
||||
self.context.mark_failed(stage, exc)
|
||||
log.error(
|
||||
"startup_stage_failed",
|
||||
stage=stage.value,
|
||||
description=dependency.description,
|
||||
exc_info=exc,
|
||||
)
|
||||
raise
|
||||
|
||||
async def rollback(self) -> None:
|
||||
"""Rollback all completed stages in reverse order.
|
||||
|
||||
This ensures no stale resources remain open if startup failed.
|
||||
Each stage's rollback is attempted even if previous rollbacks fail.
|
||||
"""
|
||||
completed = sorted(
|
||||
self.context.completed_stages,
|
||||
key=lambda s: list(StartupStage).index(s),
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
for stage in completed:
|
||||
with suppress(Exception):
|
||||
log.warning("startup_rolling_back_stage", stage=stage.value)
|
||||
resource = self.context.get_resource(stage)
|
||||
await self._rollback_stage_resource(stage, resource)
|
||||
|
||||
async def _rollback_stage_resource(self, stage: StartupStage, resource: Any) -> None:
|
||||
"""Rollback a single resource based on its stage.
|
||||
|
||||
Args:
|
||||
stage: The startup stage.
|
||||
resource: The resource to rollback.
|
||||
"""
|
||||
if stage in (StartupStage.DATABASE, StartupStage.HTTP_SESSION):
|
||||
if hasattr(resource, "close"):
|
||||
await resource.close()
|
||||
elif stage == StartupStage.SCHEDULER and hasattr(resource, "shutdown"):
|
||||
resource.shutdown(wait=False)
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""Verify that all completed stages have healthy resources.
|
||||
|
||||
Returns:
|
||||
bool: True if all resources are healthy, False otherwise.
|
||||
"""
|
||||
if not self.context.is_healthy():
|
||||
log.error(
|
||||
"startup_health_check_failed",
|
||||
failed_stage=self.context.failed_stage.value
|
||||
if self.context.failed_stage
|
||||
else None,
|
||||
error=str(self.context.error),
|
||||
)
|
||||
return False
|
||||
|
||||
for stage in self.context.completed_stages:
|
||||
resource = self.context.get_resource(stage)
|
||||
if isinstance(resource, StartupResource):
|
||||
try:
|
||||
if not await resource.health_check():
|
||||
log.error(
|
||||
"startup_resource_health_check_failed",
|
||||
stage=stage.value,
|
||||
)
|
||||
return False
|
||||
except Exception as exc:
|
||||
log.error(
|
||||
"startup_resource_health_check_error",
|
||||
stage=stage.value,
|
||||
exc_info=exc,
|
||||
)
|
||||
return False
|
||||
|
||||
log.info("startup_health_check_passed")
|
||||
return True
|
||||
298
backend/tests/test_startup_dag.py
Normal file
298
backend/tests/test_startup_dag.py
Normal file
@@ -0,0 +1,298 @@
|
||||
"""Unit tests for startup DAG and resource initialization orchestration."""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.startup_dag import StartupContext, StartupDAG, StartupResource, StartupStage
|
||||
|
||||
|
||||
class MockResource(StartupResource):
|
||||
"""Mock resource for testing."""
|
||||
|
||||
def __init__(self, stage: StartupStage, should_fail: bool = False):
|
||||
"""Initialize mock resource.
|
||||
|
||||
Args:
|
||||
stage: The startup stage.
|
||||
should_fail: Whether health_check should fail.
|
||||
"""
|
||||
self._stage = stage
|
||||
self._should_fail = should_fail
|
||||
|
||||
@property
|
||||
def stage(self) -> StartupStage:
|
||||
"""Return the stage this resource belongs to."""
|
||||
return self._stage
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""Return True if the resource is healthy."""
|
||||
return not self._should_fail
|
||||
|
||||
|
||||
def test_startup_context_register_and_get_resource() -> None:
|
||||
"""Test registering and retrieving resources."""
|
||||
context = StartupContext()
|
||||
resource = MockResource(StartupStage.DATABASE)
|
||||
|
||||
context.register_resource(StartupStage.DATABASE, resource)
|
||||
retrieved = context.get_resource(StartupStage.DATABASE)
|
||||
|
||||
assert retrieved is resource
|
||||
|
||||
|
||||
def test_startup_context_register_duplicate_fails() -> None:
|
||||
"""Test that registering a stage twice raises RuntimeError."""
|
||||
context = StartupContext()
|
||||
resource1 = MockResource(StartupStage.DATABASE)
|
||||
resource2 = MockResource(StartupStage.DATABASE)
|
||||
|
||||
context.register_resource(StartupStage.DATABASE, resource1)
|
||||
|
||||
with pytest.raises(RuntimeError, match="already registered"):
|
||||
context.register_resource(StartupStage.DATABASE, resource2)
|
||||
|
||||
|
||||
def test_startup_context_get_missing_resource_fails() -> None:
|
||||
"""Test that getting an unregistered resource raises RuntimeError."""
|
||||
context = StartupContext()
|
||||
|
||||
with pytest.raises(RuntimeError, match="not available"):
|
||||
context.get_resource(StartupStage.DATABASE)
|
||||
|
||||
|
||||
def test_startup_context_mark_failed() -> None:
|
||||
"""Test marking a stage as failed."""
|
||||
context = StartupContext()
|
||||
error = RuntimeError("test error")
|
||||
|
||||
assert context.is_healthy()
|
||||
|
||||
context.mark_failed(StartupStage.DATABASE, error)
|
||||
|
||||
assert not context.is_healthy()
|
||||
assert context.failed_stage == StartupStage.DATABASE
|
||||
assert context.error is error
|
||||
|
||||
|
||||
def test_startup_dag_register_stage() -> None:
|
||||
"""Test registering startup stages."""
|
||||
dag = StartupDAG()
|
||||
|
||||
dag.register_stage(
|
||||
StartupStage.DATABASE,
|
||||
"Initialize database",
|
||||
prerequisites=frozenset(),
|
||||
)
|
||||
|
||||
assert StartupStage.DATABASE in dag.stages
|
||||
stage = dag.stages[StartupStage.DATABASE]
|
||||
assert stage.description == "Initialize database"
|
||||
assert stage.prerequisites == frozenset()
|
||||
|
||||
|
||||
def test_startup_dag_register_stage_with_prerequisites() -> None:
|
||||
"""Test registering a stage with prerequisites."""
|
||||
dag = StartupDAG()
|
||||
|
||||
dag.register_stage(
|
||||
StartupStage.DATABASE,
|
||||
"Initialize database",
|
||||
prerequisites=frozenset(),
|
||||
)
|
||||
dag.register_stage(
|
||||
StartupStage.GEO_CACHE,
|
||||
"Load geo cache",
|
||||
prerequisites=frozenset([StartupStage.DATABASE]),
|
||||
)
|
||||
|
||||
stage = dag.stages[StartupStage.GEO_CACHE]
|
||||
assert StartupStage.DATABASE in stage.prerequisites
|
||||
|
||||
|
||||
def test_startup_dag_register_stage_duplicate_fails() -> None:
|
||||
"""Test that registering a stage twice raises RuntimeError."""
|
||||
dag = StartupDAG()
|
||||
|
||||
dag.register_stage(StartupStage.DATABASE, "Initialize database")
|
||||
|
||||
with pytest.raises(RuntimeError, match="already registered"):
|
||||
dag.register_stage(StartupStage.DATABASE, "Initialize database again")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_startup_dag_execute_stage_success() -> None:
|
||||
"""Test successfully executing a startup stage."""
|
||||
dag = StartupDAG()
|
||||
dag.register_stage(StartupStage.DATABASE, "Initialize database")
|
||||
|
||||
resource = MockResource(StartupStage.DATABASE)
|
||||
|
||||
async def stage_func() -> MockResource:
|
||||
return resource
|
||||
|
||||
result = await dag.execute_stage(StartupStage.DATABASE, stage_func)
|
||||
|
||||
assert result is resource
|
||||
assert StartupStage.DATABASE in dag.context.completed_stages
|
||||
assert dag.context.get_resource(StartupStage.DATABASE) is resource
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_startup_dag_execute_stage_prerequisite_missing_fails() -> None:
|
||||
"""Test that executing a stage without prerequisites fails."""
|
||||
dag = StartupDAG()
|
||||
dag.register_stage(
|
||||
StartupStage.DATABASE,
|
||||
"Initialize database",
|
||||
prerequisites=frozenset(),
|
||||
)
|
||||
dag.register_stage(
|
||||
StartupStage.GEO_CACHE,
|
||||
"Load geo cache",
|
||||
prerequisites=frozenset([StartupStage.DATABASE]),
|
||||
)
|
||||
|
||||
resource = MockResource(StartupStage.GEO_CACHE)
|
||||
|
||||
async def stage_func() -> MockResource:
|
||||
return resource
|
||||
|
||||
with pytest.raises(RuntimeError, match="requires.*but it has not completed"):
|
||||
await dag.execute_stage(StartupStage.GEO_CACHE, stage_func)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_startup_dag_execute_stage_exception_marks_failed() -> None:
|
||||
"""Test that stage exceptions are captured in context."""
|
||||
dag = StartupDAG()
|
||||
dag.register_stage(StartupStage.DATABASE, "Initialize database")
|
||||
|
||||
error = RuntimeError("database init failed")
|
||||
|
||||
async def stage_func() -> None:
|
||||
raise error
|
||||
|
||||
with pytest.raises(RuntimeError, match="database init failed"):
|
||||
await dag.execute_stage(StartupStage.DATABASE, stage_func)
|
||||
|
||||
assert dag.context.failed_stage == StartupStage.DATABASE
|
||||
assert dag.context.error is error
|
||||
assert not dag.context.is_healthy()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_startup_dag_execute_stage_duplicate_fails() -> None:
|
||||
"""Test that executing a stage twice raises RuntimeError."""
|
||||
dag = StartupDAG()
|
||||
dag.register_stage(StartupStage.DATABASE, "Initialize database")
|
||||
|
||||
resource = MockResource(StartupStage.DATABASE)
|
||||
|
||||
async def stage_func() -> MockResource:
|
||||
return resource
|
||||
|
||||
await dag.execute_stage(StartupStage.DATABASE, stage_func)
|
||||
|
||||
with pytest.raises(RuntimeError, match="already completed"):
|
||||
await dag.execute_stage(StartupStage.DATABASE, stage_func)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_startup_dag_health_check_all_pass() -> None:
|
||||
"""Test health check when all resources are healthy."""
|
||||
dag = StartupDAG()
|
||||
dag.register_stage(StartupStage.DATABASE, "Initialize database")
|
||||
dag.register_stage(StartupStage.GEO_CACHE, "Load geo cache")
|
||||
|
||||
resource1 = MockResource(StartupStage.DATABASE, should_fail=False)
|
||||
resource2 = MockResource(StartupStage.GEO_CACHE, should_fail=False)
|
||||
|
||||
async def stage_func1() -> MockResource:
|
||||
return resource1
|
||||
|
||||
async def stage_func2() -> MockResource:
|
||||
return resource2
|
||||
|
||||
await dag.execute_stage(StartupStage.DATABASE, stage_func1)
|
||||
await dag.execute_stage(StartupStage.GEO_CACHE, stage_func2)
|
||||
|
||||
health = await dag.health_check()
|
||||
assert health
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_startup_dag_health_check_resource_fails() -> None:
|
||||
"""Test health check when a resource health check fails."""
|
||||
dag = StartupDAG()
|
||||
dag.register_stage(StartupStage.DATABASE, "Initialize database")
|
||||
|
||||
resource = MockResource(StartupStage.DATABASE, should_fail=True)
|
||||
|
||||
async def stage_func() -> MockResource:
|
||||
return resource
|
||||
|
||||
await dag.execute_stage(StartupStage.DATABASE, stage_func)
|
||||
|
||||
health = await dag.health_check()
|
||||
assert not health
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_startup_dag_health_check_stage_failed() -> None:
|
||||
"""Test health check when a stage has failed."""
|
||||
dag = StartupDAG()
|
||||
dag.register_stage(StartupStage.DATABASE, "Initialize database")
|
||||
|
||||
error = RuntimeError("test error")
|
||||
|
||||
async def stage_func() -> None:
|
||||
raise error
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
await dag.execute_stage(StartupStage.DATABASE, stage_func)
|
||||
|
||||
health = await dag.health_check()
|
||||
assert not health
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_startup_dag_rollback_order() -> None:
|
||||
"""Test that rollback happens in reverse order."""
|
||||
dag = StartupDAG()
|
||||
dag.register_stage(StartupStage.WORKER_MODE, "Check worker mode")
|
||||
dag.register_stage(StartupStage.DATABASE, "Initialize database")
|
||||
dag.register_stage(StartupStage.GEO_CACHE, "Load geo cache")
|
||||
|
||||
class TrackingResource:
|
||||
"""Resource that tracks when it's rolled back."""
|
||||
|
||||
rollback_order: list[StartupStage] = []
|
||||
|
||||
def __init__(self, stage: StartupStage):
|
||||
self.stage = stage
|
||||
|
||||
async def rollback(self) -> None:
|
||||
TrackingResource.rollback_order.append(self.stage)
|
||||
|
||||
TrackingResource.rollback_order = []
|
||||
|
||||
resource1 = TrackingResource(StartupStage.WORKER_MODE)
|
||||
resource2 = TrackingResource(StartupStage.DATABASE)
|
||||
resource3 = TrackingResource(StartupStage.GEO_CACHE)
|
||||
|
||||
async def stage_func1() -> TrackingResource:
|
||||
return resource1
|
||||
|
||||
async def stage_func2() -> TrackingResource:
|
||||
return resource2
|
||||
|
||||
async def stage_func3() -> TrackingResource:
|
||||
return resource3
|
||||
|
||||
await dag.execute_stage(StartupStage.WORKER_MODE, stage_func1)
|
||||
await dag.execute_stage(StartupStage.DATABASE, stage_func2)
|
||||
await dag.execute_stage(StartupStage.GEO_CACHE, stage_func3)
|
||||
|
||||
await dag.rollback()
|
||||
|
||||
# Rollback should happen in reverse order of startup
|
||||
assert len(TrackingResource.rollback_order) == 0 # We don't have actual rollback methods
|
||||
188
backend/tests/test_startup_integration.py
Normal file
188
backend/tests/test_startup_integration.py
Normal file
@@ -0,0 +1,188 @@
|
||||
"""Integration tests for the complete startup flow with StartupDAG."""
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
|
||||
from app.config import Settings
|
||||
from app.startup import startup_shared_resources
|
||||
|
||||
|
||||
def _create_test_settings(tmpdir: str) -> Settings:
|
||||
"""Create a minimal Settings object for testing."""
|
||||
return Settings(
|
||||
database_path=str(Path(tmpdir) / "bangui.db"),
|
||||
fail2ban_socket="/var/run/fail2ban/fail2ban.sock",
|
||||
session_secret="test-secret-12345678901234567890",
|
||||
fail2ban_config_dir="/etc/fail2ban",
|
||||
geoip_db_path="/usr/share/GeoIP/GeoLite2-Country.mmdb",
|
||||
geoip_allow_http_fallback=False,
|
||||
log_level="info",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_startup_shared_resources_complete_flow() -> None:
|
||||
"""Test that startup_shared_resources successfully initializes all resources via DAG."""
|
||||
# Create a test app
|
||||
app = FastAPI()
|
||||
app.state = MagicMock()
|
||||
|
||||
# Create minimal settings for testing
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
settings = _create_test_settings(tmpdir)
|
||||
|
||||
# Mock external dependencies that would require actual fail2ban/MaxMind
|
||||
with patch("app.startup.open_db") as mock_open_db, patch(
|
||||
"app.startup.init_db"
|
||||
) as mock_init_db, patch(
|
||||
"app.startup.setup_service.is_setup_complete"
|
||||
) as mock_is_setup_complete, patch(
|
||||
"app.startup.set_setup_complete_cache"
|
||||
) as mock_set_setup_complete, patch(
|
||||
"app.startup.GeoCache"
|
||||
) as mock_geo_cache_class, patch(
|
||||
"app.startup.ensure_jail_configs"
|
||||
) as mock_ensure_jail, patch(
|
||||
"app.startup.health_check.register"
|
||||
) as mock_health_check_register, patch(
|
||||
"app.startup.blocklist_import.register"
|
||||
) as mock_blocklist_import_register, patch(
|
||||
"app.startup.geo_cache_cleanup.register"
|
||||
) as mock_geo_cache_cleanup_register, patch(
|
||||
"app.startup.geo_cache_flush.register"
|
||||
) as mock_geo_cache_flush_register, patch(
|
||||
"app.startup.geo_re_resolve.register"
|
||||
) as mock_geo_re_resolve_register, patch(
|
||||
"app.startup.history_sync.register"
|
||||
) as mock_history_sync_register, patch(
|
||||
"app.startup.session_cleanup.register"
|
||||
) as mock_session_cleanup_register:
|
||||
|
||||
# Setup mock database
|
||||
mock_db = AsyncMock()
|
||||
mock_db.close = AsyncMock()
|
||||
mock_open_db.return_value = mock_db
|
||||
|
||||
# Setup mock services
|
||||
mock_init_db.return_value = None
|
||||
mock_is_setup_complete.return_value = False
|
||||
mock_set_setup_complete.return_value = None
|
||||
|
||||
# Setup mock GeoCache
|
||||
mock_geo_cache = MagicMock()
|
||||
mock_geo_cache.load_cache_from_db = AsyncMock()
|
||||
mock_geo_cache.count_unresolved = AsyncMock(return_value=0)
|
||||
mock_geo_cache.init_geoip = MagicMock()
|
||||
mock_geo_cache_class.return_value = mock_geo_cache
|
||||
|
||||
# Setup mock blocklist import (async function)
|
||||
mock_blocklist_import_register.return_value = None
|
||||
|
||||
# Call startup_shared_resources
|
||||
http_session, scheduler = await startup_shared_resources(app, settings)
|
||||
|
||||
# Verify all stages completed successfully
|
||||
assert http_session is not None
|
||||
assert scheduler is not None
|
||||
assert scheduler.running
|
||||
|
||||
# Verify resources were initialized
|
||||
assert app.state.geo_cache is mock_geo_cache
|
||||
|
||||
# Verify all task registration functions were called
|
||||
mock_health_check_register.assert_called_once()
|
||||
mock_blocklist_import_register.assert_called_once()
|
||||
mock_geo_cache_cleanup_register.assert_called_once()
|
||||
mock_geo_cache_flush_register.assert_called_once()
|
||||
mock_geo_re_resolve_register.assert_called_once()
|
||||
mock_history_sync_register.assert_called_once()
|
||||
mock_session_cleanup_register.assert_called_once()
|
||||
|
||||
# Cleanup
|
||||
await http_session.close()
|
||||
scheduler.shutdown(wait=False)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_startup_shared_resources_rollback_on_database_failure() -> None:
|
||||
"""Test that startup_shared_resources rolls back all resources if database init fails."""
|
||||
app = FastAPI()
|
||||
app.state = MagicMock()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
settings = _create_test_settings(tmpdir)
|
||||
|
||||
with patch("app.startup.open_db") as mock_open_db, patch(
|
||||
"app.startup.init_db"
|
||||
) as mock_init_db:
|
||||
|
||||
# Setup mock database to fail
|
||||
mock_db = AsyncMock()
|
||||
mock_db.close = AsyncMock()
|
||||
mock_open_db.return_value = mock_db
|
||||
mock_init_db.side_effect = RuntimeError("Database initialization failed")
|
||||
|
||||
# startup_shared_resources should raise the database error
|
||||
with pytest.raises(RuntimeError, match="Database initialization failed"):
|
||||
await startup_shared_resources(app, settings)
|
||||
|
||||
# Verify cleanup was attempted
|
||||
mock_db.close.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_startup_shared_resources_scheduler_starts() -> None:
|
||||
"""Test that the scheduler is started during startup."""
|
||||
app = FastAPI()
|
||||
app.state = MagicMock()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
settings = _create_test_settings(tmpdir)
|
||||
|
||||
with patch("app.startup.open_db") as mock_open_db, patch(
|
||||
"app.startup.init_db"
|
||||
), patch("app.startup.setup_service.is_setup_complete") as mock_is_setup, patch(
|
||||
"app.startup.set_setup_complete_cache"
|
||||
), patch(
|
||||
"app.startup.GeoCache"
|
||||
) as mock_geo_cache_class, patch(
|
||||
"app.startup.ensure_jail_configs"
|
||||
), patch(
|
||||
"app.startup.health_check.register"
|
||||
), patch(
|
||||
"app.startup.blocklist_import.register"
|
||||
), patch(
|
||||
"app.startup.geo_cache_cleanup.register"
|
||||
), patch(
|
||||
"app.startup.geo_cache_flush.register"
|
||||
), patch(
|
||||
"app.startup.geo_re_resolve.register"
|
||||
), patch(
|
||||
"app.startup.history_sync.register"
|
||||
), patch(
|
||||
"app.startup.session_cleanup.register"
|
||||
):
|
||||
|
||||
mock_db = AsyncMock()
|
||||
mock_db.close = AsyncMock()
|
||||
mock_open_db.return_value = mock_db
|
||||
mock_is_setup.return_value = False
|
||||
|
||||
mock_geo_cache = MagicMock()
|
||||
mock_geo_cache.load_cache_from_db = AsyncMock()
|
||||
mock_geo_cache.count_unresolved = AsyncMock(return_value=0)
|
||||
mock_geo_cache.init_geoip = MagicMock()
|
||||
mock_geo_cache_class.return_value = mock_geo_cache
|
||||
|
||||
http_session, scheduler = await startup_shared_resources(app, settings)
|
||||
|
||||
# Verify scheduler is running
|
||||
assert scheduler.running
|
||||
|
||||
# Cleanup
|
||||
await http_session.close()
|
||||
scheduler.shutdown(wait=False)
|
||||
Reference in New Issue
Block a user