diff --git a/backend/app/dependencies.py b/backend/app/dependencies.py index 0f123f4..df1e6ba 100644 --- a/backend/app/dependencies.py +++ b/backend/app/dependencies.py @@ -18,6 +18,9 @@ from app.config import Settings from app.models.auth import Session from app.utils.time_utils import utc_now +import aiohttp +from apscheduler.schedulers.asyncio import AsyncIOScheduler # type: ignore[import-untyped] + log: structlog.stdlib.BoundLogger = structlog.get_logger() @@ -25,6 +28,8 @@ class AppState(Protocol): """Partial view of the FastAPI application state used by dependencies.""" settings: Settings + http_session: aiohttp.ClientSession + scheduler: AsyncIOScheduler _COOKIE_NAME = "bangui_session" @@ -106,6 +111,67 @@ async def get_settings(request: Request) -> Settings: return state.settings +async def get_http_session(request: Request) -> aiohttp.ClientSession: + """Provide the shared HTTP client session from application state. + + Args: + request: The current FastAPI request. + + Returns: + A shared :class:`aiohttp.ClientSession` managed by the lifespan. + + Raises: + HTTPException: If the session is unavailable. + """ + state = cast("AppState", request.app.state) + http_session = getattr(state, "http_session", None) + if http_session is None: + log.error("http_session_unavailable") + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="HTTP session is not available.", + ) + return http_session + + +async def get_scheduler(request: Request) -> AsyncIOScheduler: + """Provide the shared scheduler from application state. + + Args: + request: The current FastAPI request. + + Returns: + The :class:`apscheduler.schedulers.asyncio.AsyncIOScheduler` instance. + + Raises: + HTTPException: If the scheduler is unavailable. + """ + state = cast("AppState", request.app.state) + scheduler = getattr(state, "scheduler", None) + if scheduler is None: + log.error("scheduler_unavailable") + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Scheduler is not available.", + ) + return scheduler + + +async def get_fail2ban_socket(settings: Settings = Depends(get_settings)) -> str: + """Provide the configured path to the fail2ban Unix domain socket.""" + return settings.fail2ban_socket + + +async def get_fail2ban_config_dir(settings: Settings = Depends(get_settings)) -> str: + """Provide the configured fail2ban configuration directory.""" + return settings.fail2ban_config_dir + + +async def get_fail2ban_start_command(settings: Settings = Depends(get_settings)) -> str: + """Provide the configured fail2ban start command.""" + return settings.fail2ban_start_command + + async def require_auth( request: Request, db: Annotated[aiosqlite.Connection, Depends(get_db)], @@ -171,4 +237,9 @@ async def require_auth( # Convenience type aliases for route signatures. DbDep = Annotated[aiosqlite.Connection, Depends(get_db)] SettingsDep = Annotated[Settings, Depends(get_settings)] +HttpSessionDep = Annotated[aiohttp.ClientSession, Depends(get_http_session)] +SchedulerDep = Annotated[AsyncIOScheduler, Depends(get_scheduler)] +Fail2BanSocketDep = Annotated[str, Depends(get_fail2ban_socket)] +Fail2BanConfigDirDep = Annotated[str, Depends(get_fail2ban_config_dir)] +Fail2BanStartCommandDep = Annotated[str, Depends(get_fail2ban_start_command)] AuthDep = Annotated[Session, Depends(require_auth)] diff --git a/backend/app/routers/bans.py b/backend/app/routers/bans.py index 4d706ec..3848e16 100644 --- a/backend/app/routers/bans.py +++ b/backend/app/routers/bans.py @@ -17,7 +17,12 @@ if TYPE_CHECKING: from fastapi import APIRouter, HTTPException, Request, status -from app.dependencies import AuthDep, DbDep +from app.dependencies import ( + AuthDep, + DbDep, + Fail2BanSocketDep, + HttpSessionDep, +) from app.models.ban import ActiveBanListResponse, BanRequest, UnbanAllResponse, UnbanRequest from app.models.jail import JailCommandResponse from app.services import geo_service, jail_service @@ -51,6 +56,8 @@ async def get_active_bans( request: Request, _auth: AuthDep, db: DbDep, + socket_path: Fail2BanSocketDep, + http_session: HttpSessionDep, ) -> ActiveBanListResponse: """Return every IP that is currently banned across all fail2ban jails. @@ -67,9 +74,6 @@ async def get_active_bans( Raises: HTTPException: 502 when fail2ban is unreachable. """ - socket_path: str = request.app.state.settings.fail2ban_socket - http_session: aiohttp.ClientSession = request.app.state.http_session - try: return await jail_service.get_active_bans( socket_path, @@ -91,6 +95,7 @@ async def ban_ip( request: Request, _auth: AuthDep, body: BanRequest, + socket_path: Fail2BanSocketDep, ) -> JailCommandResponse: """Ban an IP address in the specified fail2ban jail. @@ -111,7 +116,6 @@ async def ban_ip( HTTPException: 409 when fail2ban reports the ban failed. HTTPException: 502 when fail2ban is unreachable. """ - socket_path: str = request.app.state.settings.fail2ban_socket try: await jail_service.ban_ip(socket_path, body.jail, body.ip) return JailCommandResponse( @@ -146,6 +150,7 @@ async def unban_ip( request: Request, _auth: AuthDep, body: UnbanRequest, + socket_path: Fail2BanSocketDep, ) -> JailCommandResponse: """Unban an IP address from a specific jail or all jails. @@ -168,8 +173,6 @@ async def unban_ip( HTTPException: 409 when fail2ban reports the unban failed. HTTPException: 502 when fail2ban is unreachable. """ - socket_path: str = request.app.state.settings.fail2ban_socket - # Determine target jail (None means all jails). target_jail: str | None = None if (body.unban_all or body.jail is None) else body.jail @@ -207,6 +210,7 @@ async def unban_ip( async def unban_all( request: Request, _auth: AuthDep, + socket_path: Fail2BanSocketDep, ) -> UnbanAllResponse: """Remove all active bans from every fail2ban jail in a single operation. @@ -224,7 +228,6 @@ async def unban_all( Raises: HTTPException: 502 when fail2ban is unreachable. """ - socket_path: str = request.app.state.settings.fail2ban_socket try: count: int = await jail_service.unban_all_ips(socket_path) return UnbanAllResponse( diff --git a/backend/app/routers/blocklist.py b/backend/app/routers/blocklist.py index 055c134..a1a9590 100644 --- a/backend/app/routers/blocklist.py +++ b/backend/app/routers/blocklist.py @@ -30,7 +30,14 @@ if TYPE_CHECKING: import aiohttp from fastapi import APIRouter, Depends, HTTPException, Query, Request, status -from app.dependencies import AuthDep, get_db +from app.dependencies import ( + AuthDep, + DbDep, + Fail2BanSocketDep, + HttpSessionDep, + SchedulerDep, + get_db, +) from app.models.blocklist import ( BlocklistListResponse, BlocklistSource, @@ -117,6 +124,8 @@ async def run_import_now( request: Request, db: DbDep, _auth: AuthDep, + http_session: HttpSessionDep, + socket_path: Fail2BanSocketDep, ) -> ImportRunResult: """Download and apply all enabled blocklist sources immediately. @@ -129,8 +138,6 @@ async def run_import_now( :class:`~app.models.blocklist.ImportRunResult` with per-source results and aggregated counters. """ - http_session: aiohttp.ClientSession = request.app.state.http_session - socket_path: str = request.app.state.settings.fail2ban_socket from app.services import jail_service return await blocklist_service.import_all( @@ -151,6 +158,7 @@ async def get_schedule( request: Request, db: DbDep, _auth: AuthDep, + scheduler: SchedulerDep, ) -> ScheduleInfo: """Return the current schedule configuration and runtime metadata. @@ -165,7 +173,6 @@ async def get_schedule( :class:`~app.models.blocklist.ScheduleInfo` with config and run times. """ - scheduler = request.app.state.scheduler job = scheduler.get_job(blocklist_import_task.JOB_ID) next_run_at: str | None = None if job is not None and job.next_run_time is not None: @@ -184,6 +191,7 @@ async def update_schedule( request: Request, db: DbDep, _auth: AuthDep, + scheduler: SchedulerDep, ) -> ScheduleInfo: """Persist a new schedule configuration and reschedule the import job. @@ -200,7 +208,7 @@ async def update_schedule( # Reschedule the background job immediately. blocklist_import_task.reschedule(request.app) - job = request.app.state.scheduler.get_job(blocklist_import_task.JOB_ID) + job = scheduler.get_job(blocklist_import_task.JOB_ID) next_run_at: str | None = None if job is not None and job.next_run_time is not None: next_run_at = job.next_run_time.isoformat() diff --git a/backend/app/routers/geo.py b/backend/app/routers/geo.py index b2b54cb..3f9e4f1 100644 --- a/backend/app/routers/geo.py +++ b/backend/app/routers/geo.py @@ -18,7 +18,13 @@ if TYPE_CHECKING: import aiosqlite from fastapi import APIRouter, Depends, HTTPException, Path, Request, status -from app.dependencies import AuthDep, get_db +from app.dependencies import ( + AuthDep, + DbDep, + Fail2BanSocketDep, + HttpSessionDep, + get_db, +) from app.models.geo import GeoCacheStatsResponse, GeoDetail, GeoInfo, IpLookupResponse from app.services import geo_service, jail_service from app.utils.fail2ban_client import Fail2BanConnectionError @@ -37,6 +43,8 @@ async def lookup_ip( request: Request, _auth: AuthDep, ip: _IpPath, + socket_path: Fail2BanSocketDep, + http_session: HttpSessionDep, ) -> IpLookupResponse: """Return current ban status, geo data, and network information for an IP. @@ -56,9 +64,6 @@ async def lookup_ip( HTTPException: 400 when *ip* is not a valid IP address. HTTPException: 502 when fail2ban is unreachable. """ - socket_path: str = request.app.state.settings.fail2ban_socket - http_session: aiohttp.ClientSession = request.app.state.http_session - async def _enricher(addr: str) -> geo_service.GeoInfo | None: return await geo_service.lookup(addr, http_session) @@ -138,6 +143,7 @@ async def re_resolve_geo( request: Request, _auth: AuthDep, db: Annotated[aiosqlite.Connection, Depends(get_db)], + http_session: HttpSessionDep, ) -> dict[str, int]: """Retry geo resolution for every IP in ``geo_cache`` with a null country. @@ -163,7 +169,6 @@ async def re_resolve_geo( # Clear negative cache so these IPs bypass the TTL check. geo_service.clear_neg_cache() - http_session: aiohttp.ClientSession = request.app.state.http_session geo_map = await geo_service.lookup_batch(unresolved, http_session, db=db) resolved_count = sum( diff --git a/backend/app/routers/history.py b/backend/app/routers/history.py index 90cf5f2..c819ba8 100644 --- a/backend/app/routers/history.py +++ b/backend/app/routers/history.py @@ -22,7 +22,12 @@ if TYPE_CHECKING: from fastapi import APIRouter, HTTPException, Query, Request -from app.dependencies import AuthDep, DbDep +from app.dependencies import ( + AuthDep, + DbDep, + Fail2BanSocketDep, + HttpSessionDep, +) from app.models.ban import BanOrigin, TimeRange from app.models.history import HistoryListResponse, IpDetailResponse from app.services import geo_service, history_service @@ -41,6 +46,8 @@ async def get_history( request: Request, _auth: AuthDep, db: DbDep, + socket_path: Fail2BanSocketDep, + http_session: HttpSessionDep, range: TimeRange | None = Query( default=None, description="Optional time-range filter. Omit for all-time.", @@ -87,8 +94,6 @@ async def get_history( :class:`~app.models.history.HistoryListResponse` with paginated items and the total matching count. """ - socket_path: str = request.app.state.settings.fail2ban_socket - http_session: aiohttp.ClientSession = request.app.state.http_session async def _enricher(addr: str) -> geo_service.GeoInfo | None: return await geo_service.lookup(addr, http_session) @@ -116,6 +121,8 @@ async def get_history_archive( request: Request, _auth: AuthDep, db: DbDep, + socket_path: Fail2BanSocketDep, + http_session: HttpSessionDep, range: TimeRange | None = Query( default=None, description="Optional time-range filter. Omit for all-time.", @@ -125,8 +132,6 @@ async def get_history_archive( page: int = Query(default=1, ge=1, description="1-based page number."), page_size: int = Query(default=_DEFAULT_PAGE_SIZE, ge=1, le=500, description="Items per page (max 500)."), ) -> HistoryListResponse: - socket_path: str = request.app.state.settings.fail2ban_socket - http_session: aiohttp.ClientSession = request.app.state.http_session async def _enricher(addr: str) -> geo_service.GeoInfo | None: return await geo_service.lookup(addr, http_session) @@ -153,6 +158,8 @@ async def get_ip_history( request: Request, _auth: AuthDep, ip: str, + socket_path: Fail2BanSocketDep, + http_session: HttpSessionDep, ) -> IpDetailResponse: """Return the complete historical record for a single IP address. @@ -172,8 +179,6 @@ async def get_ip_history( Raises: HTTPException: 404 if the IP has no history in the database. """ - socket_path: str = request.app.state.settings.fail2ban_socket - http_session: aiohttp.ClientSession = request.app.state.http_session async def _enricher(addr: str) -> geo_service.GeoInfo | None: return await geo_service.lookup(addr, http_session) diff --git a/backend/app/routers/server.py b/backend/app/routers/server.py index 66c8df6..866d54c 100644 --- a/backend/app/routers/server.py +++ b/backend/app/routers/server.py @@ -12,7 +12,7 @@ from __future__ import annotations from fastapi import APIRouter, HTTPException, Request, status -from app.dependencies import AuthDep +from app.dependencies import AuthDep, Fail2BanSocketDep from app.models.server import ServerSettingsResponse, ServerSettingsUpdate from app.services import server_service from app.exceptions import ServerOperationError @@ -53,6 +53,7 @@ def _bad_request(message: str) -> HTTPException: async def get_server_settings( request: Request, _auth: AuthDep, + socket_path: Fail2BanSocketDep, ) -> ServerSettingsResponse: """Return the current fail2ban server-level settings. @@ -69,7 +70,6 @@ async def get_server_settings( Raises: HTTPException: 502 when fail2ban is unreachable. """ - socket_path: str = request.app.state.settings.fail2ban_socket try: return await server_service.get_settings(socket_path) except Fail2BanConnectionError as exc: @@ -85,6 +85,7 @@ async def update_server_settings( request: Request, _auth: AuthDep, body: ServerSettingsUpdate, + socket_path: Fail2BanSocketDep, ) -> None: """Update fail2ban server-level settings. @@ -100,7 +101,6 @@ async def update_server_settings( HTTPException: 400 when a set command is rejected by fail2ban. HTTPException: 502 when fail2ban is unreachable. """ - socket_path: str = request.app.state.settings.fail2ban_socket try: await server_service.update_settings(socket_path, body) except ServerOperationError as exc: @@ -117,6 +117,7 @@ async def update_server_settings( async def flush_logs( request: Request, _auth: AuthDep, + socket_path: Fail2BanSocketDep, ) -> dict[str, str]: """Flush and re-open fail2ban log files. @@ -134,7 +135,6 @@ async def flush_logs( HTTPException: 400 when the command is rejected. HTTPException: 502 when fail2ban is unreachable. """ - socket_path: str = request.app.state.settings.fail2ban_socket try: result = await server_service.flush_logs(socket_path) return {"message": result}