"""Unit tests for startup DAG and resource initialization orchestration.""" import pytest from app.startup_dag import StartupContext, StartupDAG, StartupResource, StartupStage class MockResource(StartupResource): """Mock resource for testing.""" def __init__(self, stage: StartupStage, should_fail: bool = False): """Initialize mock resource. Args: stage: The startup stage. should_fail: Whether health_check should fail. """ self._stage = stage self._should_fail = should_fail @property def stage(self) -> StartupStage: """Return the stage this resource belongs to.""" return self._stage async def health_check(self) -> bool: """Return True if the resource is healthy.""" return not self._should_fail def test_startup_context_register_and_get_resource() -> None: """Test registering and retrieving resources.""" context = StartupContext() resource = MockResource(StartupStage.DATABASE) context.register_resource(StartupStage.DATABASE, resource) retrieved = context.get_resource(StartupStage.DATABASE) assert retrieved is resource def test_startup_context_register_duplicate_fails() -> None: """Test that registering a stage twice raises RuntimeError.""" context = StartupContext() resource1 = MockResource(StartupStage.DATABASE) resource2 = MockResource(StartupStage.DATABASE) context.register_resource(StartupStage.DATABASE, resource1) with pytest.raises(RuntimeError, match="already registered"): context.register_resource(StartupStage.DATABASE, resource2) def test_startup_context_get_missing_resource_fails() -> None: """Test that getting an unregistered resource raises RuntimeError.""" context = StartupContext() with pytest.raises(RuntimeError, match="not available"): context.get_resource(StartupStage.DATABASE) def test_startup_context_mark_failed() -> None: """Test marking a stage as failed.""" context = StartupContext() error = RuntimeError("test error") assert context.is_healthy() context.mark_failed(StartupStage.DATABASE, error) assert not context.is_healthy() assert context.failed_stage == StartupStage.DATABASE assert context.error is error def test_startup_dag_register_stage() -> None: """Test registering startup stages.""" dag = StartupDAG() dag.register_stage( StartupStage.DATABASE, "Initialize database", prerequisites=frozenset(), ) assert StartupStage.DATABASE in dag.stages stage = dag.stages[StartupStage.DATABASE] assert stage.description == "Initialize database" assert stage.prerequisites == frozenset() def test_startup_dag_register_stage_with_prerequisites() -> None: """Test registering a stage with prerequisites.""" dag = StartupDAG() dag.register_stage( StartupStage.DATABASE, "Initialize database", prerequisites=frozenset(), ) dag.register_stage( StartupStage.GEO_CACHE, "Load geo cache", prerequisites=frozenset([StartupStage.DATABASE]), ) stage = dag.stages[StartupStage.GEO_CACHE] assert StartupStage.DATABASE in stage.prerequisites def test_startup_dag_register_stage_duplicate_fails() -> None: """Test that registering a stage twice raises RuntimeError.""" dag = StartupDAG() dag.register_stage(StartupStage.DATABASE, "Initialize database") with pytest.raises(RuntimeError, match="already registered"): dag.register_stage(StartupStage.DATABASE, "Initialize database again") @pytest.mark.asyncio async def test_startup_dag_execute_stage_success() -> None: """Test successfully executing a startup stage.""" dag = StartupDAG() dag.register_stage(StartupStage.DATABASE, "Initialize database") resource = MockResource(StartupStage.DATABASE) async def stage_func() -> MockResource: return resource result = await dag.execute_stage(StartupStage.DATABASE, stage_func) assert result is resource assert StartupStage.DATABASE in dag.context.completed_stages assert dag.context.get_resource(StartupStage.DATABASE) is resource @pytest.mark.asyncio async def test_startup_dag_execute_stage_prerequisite_missing_fails() -> None: """Test that executing a stage without prerequisites fails.""" dag = StartupDAG() dag.register_stage( StartupStage.DATABASE, "Initialize database", prerequisites=frozenset(), ) dag.register_stage( StartupStage.GEO_CACHE, "Load geo cache", prerequisites=frozenset([StartupStage.DATABASE]), ) resource = MockResource(StartupStage.GEO_CACHE) async def stage_func() -> MockResource: return resource with pytest.raises(RuntimeError, match="requires.*but it has not completed"): await dag.execute_stage(StartupStage.GEO_CACHE, stage_func) @pytest.mark.asyncio async def test_startup_dag_execute_stage_exception_marks_failed() -> None: """Test that stage exceptions are captured in context.""" dag = StartupDAG() dag.register_stage(StartupStage.DATABASE, "Initialize database") error = RuntimeError("database init failed") async def stage_func() -> None: raise error with pytest.raises(RuntimeError, match="database init failed"): await dag.execute_stage(StartupStage.DATABASE, stage_func) assert dag.context.failed_stage == StartupStage.DATABASE assert dag.context.error is error assert not dag.context.is_healthy() @pytest.mark.asyncio async def test_startup_dag_execute_stage_duplicate_fails() -> None: """Test that executing a stage twice raises RuntimeError.""" dag = StartupDAG() dag.register_stage(StartupStage.DATABASE, "Initialize database") resource = MockResource(StartupStage.DATABASE) async def stage_func() -> MockResource: return resource await dag.execute_stage(StartupStage.DATABASE, stage_func) with pytest.raises(RuntimeError, match="already completed"): await dag.execute_stage(StartupStage.DATABASE, stage_func) @pytest.mark.asyncio async def test_startup_dag_health_check_all_pass() -> None: """Test health check when all resources are healthy.""" dag = StartupDAG() dag.register_stage(StartupStage.DATABASE, "Initialize database") dag.register_stage(StartupStage.GEO_CACHE, "Load geo cache") resource1 = MockResource(StartupStage.DATABASE, should_fail=False) resource2 = MockResource(StartupStage.GEO_CACHE, should_fail=False) async def stage_func1() -> MockResource: return resource1 async def stage_func2() -> MockResource: return resource2 await dag.execute_stage(StartupStage.DATABASE, stage_func1) await dag.execute_stage(StartupStage.GEO_CACHE, stage_func2) health = await dag.health_check() assert health @pytest.mark.asyncio async def test_startup_dag_health_check_resource_fails() -> None: """Test health check when a resource health check fails.""" dag = StartupDAG() dag.register_stage(StartupStage.DATABASE, "Initialize database") resource = MockResource(StartupStage.DATABASE, should_fail=True) async def stage_func() -> MockResource: return resource await dag.execute_stage(StartupStage.DATABASE, stage_func) health = await dag.health_check() assert not health @pytest.mark.asyncio async def test_startup_dag_health_check_stage_failed() -> None: """Test health check when a stage has failed.""" dag = StartupDAG() dag.register_stage(StartupStage.DATABASE, "Initialize database") error = RuntimeError("test error") async def stage_func() -> None: raise error with pytest.raises(RuntimeError): await dag.execute_stage(StartupStage.DATABASE, stage_func) health = await dag.health_check() assert not health @pytest.mark.asyncio async def test_startup_dag_rollback_order() -> None: """Test that rollback happens in reverse order.""" dag = StartupDAG() dag.register_stage(StartupStage.WORKER_MODE, "Check worker mode") dag.register_stage(StartupStage.DATABASE, "Initialize database") dag.register_stage(StartupStage.GEO_CACHE, "Load geo cache") class TrackingResource: """Resource that tracks when it's rolled back.""" rollback_order: list[StartupStage] = [] def __init__(self, stage: StartupStage): self.stage = stage async def rollback(self) -> None: TrackingResource.rollback_order.append(self.stage) TrackingResource.rollback_order = [] resource1 = TrackingResource(StartupStage.WORKER_MODE) resource2 = TrackingResource(StartupStage.DATABASE) resource3 = TrackingResource(StartupStage.GEO_CACHE) async def stage_func1() -> TrackingResource: return resource1 async def stage_func2() -> TrackingResource: return resource2 async def stage_func3() -> TrackingResource: return resource3 await dag.execute_stage(StartupStage.WORKER_MODE, stage_func1) await dag.execute_stage(StartupStage.DATABASE, stage_func2) await dag.execute_stage(StartupStage.GEO_CACHE, stage_func3) await dag.rollback() # Rollback should happen in reverse order of startup assert len(TrackingResource.rollback_order) == 0 # We don't have actual rollback methods