This commit is contained in:
2026-04-06 19:49:53 +02:00
parent 5107ff10d7
commit f0ee466603
6 changed files with 121 additions and 29 deletions

View File

@@ -18,6 +18,9 @@ from app.config import Settings
from app.models.auth import Session from app.models.auth import Session
from app.utils.time_utils import utc_now from app.utils.time_utils import utc_now
import aiohttp
from apscheduler.schedulers.asyncio import AsyncIOScheduler # type: ignore[import-untyped]
log: structlog.stdlib.BoundLogger = structlog.get_logger() log: structlog.stdlib.BoundLogger = structlog.get_logger()
@@ -25,6 +28,8 @@ class AppState(Protocol):
"""Partial view of the FastAPI application state used by dependencies.""" """Partial view of the FastAPI application state used by dependencies."""
settings: Settings settings: Settings
http_session: aiohttp.ClientSession
scheduler: AsyncIOScheduler
_COOKIE_NAME = "bangui_session" _COOKIE_NAME = "bangui_session"
@@ -106,6 +111,67 @@ async def get_settings(request: Request) -> Settings:
return state.settings return state.settings
async def get_http_session(request: Request) -> aiohttp.ClientSession:
"""Provide the shared HTTP client session from application state.
Args:
request: The current FastAPI request.
Returns:
A shared :class:`aiohttp.ClientSession` managed by the lifespan.
Raises:
HTTPException: If the session is unavailable.
"""
state = cast("AppState", request.app.state)
http_session = getattr(state, "http_session", None)
if http_session is None:
log.error("http_session_unavailable")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="HTTP session is not available.",
)
return http_session
async def get_scheduler(request: Request) -> AsyncIOScheduler:
"""Provide the shared scheduler from application state.
Args:
request: The current FastAPI request.
Returns:
The :class:`apscheduler.schedulers.asyncio.AsyncIOScheduler` instance.
Raises:
HTTPException: If the scheduler is unavailable.
"""
state = cast("AppState", request.app.state)
scheduler = getattr(state, "scheduler", None)
if scheduler is None:
log.error("scheduler_unavailable")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Scheduler is not available.",
)
return scheduler
async def get_fail2ban_socket(settings: Settings = Depends(get_settings)) -> str:
"""Provide the configured path to the fail2ban Unix domain socket."""
return settings.fail2ban_socket
async def get_fail2ban_config_dir(settings: Settings = Depends(get_settings)) -> str:
"""Provide the configured fail2ban configuration directory."""
return settings.fail2ban_config_dir
async def get_fail2ban_start_command(settings: Settings = Depends(get_settings)) -> str:
"""Provide the configured fail2ban start command."""
return settings.fail2ban_start_command
async def require_auth( async def require_auth(
request: Request, request: Request,
db: Annotated[aiosqlite.Connection, Depends(get_db)], db: Annotated[aiosqlite.Connection, Depends(get_db)],
@@ -171,4 +237,9 @@ async def require_auth(
# Convenience type aliases for route signatures. # Convenience type aliases for route signatures.
DbDep = Annotated[aiosqlite.Connection, Depends(get_db)] DbDep = Annotated[aiosqlite.Connection, Depends(get_db)]
SettingsDep = Annotated[Settings, Depends(get_settings)] SettingsDep = Annotated[Settings, Depends(get_settings)]
HttpSessionDep = Annotated[aiohttp.ClientSession, Depends(get_http_session)]
SchedulerDep = Annotated[AsyncIOScheduler, Depends(get_scheduler)]
Fail2BanSocketDep = Annotated[str, Depends(get_fail2ban_socket)]
Fail2BanConfigDirDep = Annotated[str, Depends(get_fail2ban_config_dir)]
Fail2BanStartCommandDep = Annotated[str, Depends(get_fail2ban_start_command)]
AuthDep = Annotated[Session, Depends(require_auth)] AuthDep = Annotated[Session, Depends(require_auth)]

View File

@@ -17,7 +17,12 @@ if TYPE_CHECKING:
from fastapi import APIRouter, HTTPException, Request, status from fastapi import APIRouter, HTTPException, Request, status
from app.dependencies import AuthDep, DbDep from app.dependencies import (
AuthDep,
DbDep,
Fail2BanSocketDep,
HttpSessionDep,
)
from app.models.ban import ActiveBanListResponse, BanRequest, UnbanAllResponse, UnbanRequest from app.models.ban import ActiveBanListResponse, BanRequest, UnbanAllResponse, UnbanRequest
from app.models.jail import JailCommandResponse from app.models.jail import JailCommandResponse
from app.services import geo_service, jail_service from app.services import geo_service, jail_service
@@ -51,6 +56,8 @@ async def get_active_bans(
request: Request, request: Request,
_auth: AuthDep, _auth: AuthDep,
db: DbDep, db: DbDep,
socket_path: Fail2BanSocketDep,
http_session: HttpSessionDep,
) -> ActiveBanListResponse: ) -> ActiveBanListResponse:
"""Return every IP that is currently banned across all fail2ban jails. """Return every IP that is currently banned across all fail2ban jails.
@@ -67,9 +74,6 @@ async def get_active_bans(
Raises: Raises:
HTTPException: 502 when fail2ban is unreachable. HTTPException: 502 when fail2ban is unreachable.
""" """
socket_path: str = request.app.state.settings.fail2ban_socket
http_session: aiohttp.ClientSession = request.app.state.http_session
try: try:
return await jail_service.get_active_bans( return await jail_service.get_active_bans(
socket_path, socket_path,
@@ -91,6 +95,7 @@ async def ban_ip(
request: Request, request: Request,
_auth: AuthDep, _auth: AuthDep,
body: BanRequest, body: BanRequest,
socket_path: Fail2BanSocketDep,
) -> JailCommandResponse: ) -> JailCommandResponse:
"""Ban an IP address in the specified fail2ban jail. """Ban an IP address in the specified fail2ban jail.
@@ -111,7 +116,6 @@ async def ban_ip(
HTTPException: 409 when fail2ban reports the ban failed. HTTPException: 409 when fail2ban reports the ban failed.
HTTPException: 502 when fail2ban is unreachable. HTTPException: 502 when fail2ban is unreachable.
""" """
socket_path: str = request.app.state.settings.fail2ban_socket
try: try:
await jail_service.ban_ip(socket_path, body.jail, body.ip) await jail_service.ban_ip(socket_path, body.jail, body.ip)
return JailCommandResponse( return JailCommandResponse(
@@ -146,6 +150,7 @@ async def unban_ip(
request: Request, request: Request,
_auth: AuthDep, _auth: AuthDep,
body: UnbanRequest, body: UnbanRequest,
socket_path: Fail2BanSocketDep,
) -> JailCommandResponse: ) -> JailCommandResponse:
"""Unban an IP address from a specific jail or all jails. """Unban an IP address from a specific jail or all jails.
@@ -168,8 +173,6 @@ async def unban_ip(
HTTPException: 409 when fail2ban reports the unban failed. HTTPException: 409 when fail2ban reports the unban failed.
HTTPException: 502 when fail2ban is unreachable. HTTPException: 502 when fail2ban is unreachable.
""" """
socket_path: str = request.app.state.settings.fail2ban_socket
# Determine target jail (None means all jails). # Determine target jail (None means all jails).
target_jail: str | None = None if (body.unban_all or body.jail is None) else body.jail target_jail: str | None = None if (body.unban_all or body.jail is None) else body.jail
@@ -207,6 +210,7 @@ async def unban_ip(
async def unban_all( async def unban_all(
request: Request, request: Request,
_auth: AuthDep, _auth: AuthDep,
socket_path: Fail2BanSocketDep,
) -> UnbanAllResponse: ) -> UnbanAllResponse:
"""Remove all active bans from every fail2ban jail in a single operation. """Remove all active bans from every fail2ban jail in a single operation.
@@ -224,7 +228,6 @@ async def unban_all(
Raises: Raises:
HTTPException: 502 when fail2ban is unreachable. HTTPException: 502 when fail2ban is unreachable.
""" """
socket_path: str = request.app.state.settings.fail2ban_socket
try: try:
count: int = await jail_service.unban_all_ips(socket_path) count: int = await jail_service.unban_all_ips(socket_path)
return UnbanAllResponse( return UnbanAllResponse(

View File

@@ -30,7 +30,14 @@ if TYPE_CHECKING:
import aiohttp import aiohttp
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from app.dependencies import AuthDep, get_db from app.dependencies import (
AuthDep,
DbDep,
Fail2BanSocketDep,
HttpSessionDep,
SchedulerDep,
get_db,
)
from app.models.blocklist import ( from app.models.blocklist import (
BlocklistListResponse, BlocklistListResponse,
BlocklistSource, BlocklistSource,
@@ -117,6 +124,8 @@ async def run_import_now(
request: Request, request: Request,
db: DbDep, db: DbDep,
_auth: AuthDep, _auth: AuthDep,
http_session: HttpSessionDep,
socket_path: Fail2BanSocketDep,
) -> ImportRunResult: ) -> ImportRunResult:
"""Download and apply all enabled blocklist sources immediately. """Download and apply all enabled blocklist sources immediately.
@@ -129,8 +138,6 @@ async def run_import_now(
:class:`~app.models.blocklist.ImportRunResult` with per-source :class:`~app.models.blocklist.ImportRunResult` with per-source
results and aggregated counters. results and aggregated counters.
""" """
http_session: aiohttp.ClientSession = request.app.state.http_session
socket_path: str = request.app.state.settings.fail2ban_socket
from app.services import jail_service from app.services import jail_service
return await blocklist_service.import_all( return await blocklist_service.import_all(
@@ -151,6 +158,7 @@ async def get_schedule(
request: Request, request: Request,
db: DbDep, db: DbDep,
_auth: AuthDep, _auth: AuthDep,
scheduler: SchedulerDep,
) -> ScheduleInfo: ) -> ScheduleInfo:
"""Return the current schedule configuration and runtime metadata. """Return the current schedule configuration and runtime metadata.
@@ -165,7 +173,6 @@ async def get_schedule(
:class:`~app.models.blocklist.ScheduleInfo` with config and run :class:`~app.models.blocklist.ScheduleInfo` with config and run
times. times.
""" """
scheduler = request.app.state.scheduler
job = scheduler.get_job(blocklist_import_task.JOB_ID) job = scheduler.get_job(blocklist_import_task.JOB_ID)
next_run_at: str | None = None next_run_at: str | None = None
if job is not None and job.next_run_time is not None: if job is not None and job.next_run_time is not None:
@@ -184,6 +191,7 @@ async def update_schedule(
request: Request, request: Request,
db: DbDep, db: DbDep,
_auth: AuthDep, _auth: AuthDep,
scheduler: SchedulerDep,
) -> ScheduleInfo: ) -> ScheduleInfo:
"""Persist a new schedule configuration and reschedule the import job. """Persist a new schedule configuration and reschedule the import job.
@@ -200,7 +208,7 @@ async def update_schedule(
# Reschedule the background job immediately. # Reschedule the background job immediately.
blocklist_import_task.reschedule(request.app) blocklist_import_task.reschedule(request.app)
job = request.app.state.scheduler.get_job(blocklist_import_task.JOB_ID) job = scheduler.get_job(blocklist_import_task.JOB_ID)
next_run_at: str | None = None next_run_at: str | None = None
if job is not None and job.next_run_time is not None: if job is not None and job.next_run_time is not None:
next_run_at = job.next_run_time.isoformat() next_run_at = job.next_run_time.isoformat()

View File

@@ -18,7 +18,13 @@ if TYPE_CHECKING:
import aiosqlite import aiosqlite
from fastapi import APIRouter, Depends, HTTPException, Path, Request, status from fastapi import APIRouter, Depends, HTTPException, Path, Request, status
from app.dependencies import AuthDep, get_db from app.dependencies import (
AuthDep,
DbDep,
Fail2BanSocketDep,
HttpSessionDep,
get_db,
)
from app.models.geo import GeoCacheStatsResponse, GeoDetail, GeoInfo, IpLookupResponse from app.models.geo import GeoCacheStatsResponse, GeoDetail, GeoInfo, IpLookupResponse
from app.services import geo_service, jail_service from app.services import geo_service, jail_service
from app.utils.fail2ban_client import Fail2BanConnectionError from app.utils.fail2ban_client import Fail2BanConnectionError
@@ -37,6 +43,8 @@ async def lookup_ip(
request: Request, request: Request,
_auth: AuthDep, _auth: AuthDep,
ip: _IpPath, ip: _IpPath,
socket_path: Fail2BanSocketDep,
http_session: HttpSessionDep,
) -> IpLookupResponse: ) -> IpLookupResponse:
"""Return current ban status, geo data, and network information for an IP. """Return current ban status, geo data, and network information for an IP.
@@ -56,9 +64,6 @@ async def lookup_ip(
HTTPException: 400 when *ip* is not a valid IP address. HTTPException: 400 when *ip* is not a valid IP address.
HTTPException: 502 when fail2ban is unreachable. HTTPException: 502 when fail2ban is unreachable.
""" """
socket_path: str = request.app.state.settings.fail2ban_socket
http_session: aiohttp.ClientSession = request.app.state.http_session
async def _enricher(addr: str) -> geo_service.GeoInfo | None: async def _enricher(addr: str) -> geo_service.GeoInfo | None:
return await geo_service.lookup(addr, http_session) return await geo_service.lookup(addr, http_session)
@@ -138,6 +143,7 @@ async def re_resolve_geo(
request: Request, request: Request,
_auth: AuthDep, _auth: AuthDep,
db: Annotated[aiosqlite.Connection, Depends(get_db)], db: Annotated[aiosqlite.Connection, Depends(get_db)],
http_session: HttpSessionDep,
) -> dict[str, int]: ) -> dict[str, int]:
"""Retry geo resolution for every IP in ``geo_cache`` with a null country. """Retry geo resolution for every IP in ``geo_cache`` with a null country.
@@ -163,7 +169,6 @@ async def re_resolve_geo(
# Clear negative cache so these IPs bypass the TTL check. # Clear negative cache so these IPs bypass the TTL check.
geo_service.clear_neg_cache() geo_service.clear_neg_cache()
http_session: aiohttp.ClientSession = request.app.state.http_session
geo_map = await geo_service.lookup_batch(unresolved, http_session, db=db) geo_map = await geo_service.lookup_batch(unresolved, http_session, db=db)
resolved_count = sum( resolved_count = sum(

View File

@@ -22,7 +22,12 @@ if TYPE_CHECKING:
from fastapi import APIRouter, HTTPException, Query, Request from fastapi import APIRouter, HTTPException, Query, Request
from app.dependencies import AuthDep, DbDep from app.dependencies import (
AuthDep,
DbDep,
Fail2BanSocketDep,
HttpSessionDep,
)
from app.models.ban import BanOrigin, TimeRange from app.models.ban import BanOrigin, TimeRange
from app.models.history import HistoryListResponse, IpDetailResponse from app.models.history import HistoryListResponse, IpDetailResponse
from app.services import geo_service, history_service from app.services import geo_service, history_service
@@ -41,6 +46,8 @@ async def get_history(
request: Request, request: Request,
_auth: AuthDep, _auth: AuthDep,
db: DbDep, db: DbDep,
socket_path: Fail2BanSocketDep,
http_session: HttpSessionDep,
range: TimeRange | None = Query( range: TimeRange | None = Query(
default=None, default=None,
description="Optional time-range filter. Omit for all-time.", description="Optional time-range filter. Omit for all-time.",
@@ -87,8 +94,6 @@ async def get_history(
:class:`~app.models.history.HistoryListResponse` with paginated items :class:`~app.models.history.HistoryListResponse` with paginated items
and the total matching count. and the total matching count.
""" """
socket_path: str = request.app.state.settings.fail2ban_socket
http_session: aiohttp.ClientSession = request.app.state.http_session
async def _enricher(addr: str) -> geo_service.GeoInfo | None: async def _enricher(addr: str) -> geo_service.GeoInfo | None:
return await geo_service.lookup(addr, http_session) return await geo_service.lookup(addr, http_session)
@@ -116,6 +121,8 @@ async def get_history_archive(
request: Request, request: Request,
_auth: AuthDep, _auth: AuthDep,
db: DbDep, db: DbDep,
socket_path: Fail2BanSocketDep,
http_session: HttpSessionDep,
range: TimeRange | None = Query( range: TimeRange | None = Query(
default=None, default=None,
description="Optional time-range filter. Omit for all-time.", description="Optional time-range filter. Omit for all-time.",
@@ -125,8 +132,6 @@ async def get_history_archive(
page: int = Query(default=1, ge=1, description="1-based page number."), page: int = Query(default=1, ge=1, description="1-based page number."),
page_size: int = Query(default=_DEFAULT_PAGE_SIZE, ge=1, le=500, description="Items per page (max 500)."), page_size: int = Query(default=_DEFAULT_PAGE_SIZE, ge=1, le=500, description="Items per page (max 500)."),
) -> HistoryListResponse: ) -> HistoryListResponse:
socket_path: str = request.app.state.settings.fail2ban_socket
http_session: aiohttp.ClientSession = request.app.state.http_session
async def _enricher(addr: str) -> geo_service.GeoInfo | None: async def _enricher(addr: str) -> geo_service.GeoInfo | None:
return await geo_service.lookup(addr, http_session) return await geo_service.lookup(addr, http_session)
@@ -153,6 +158,8 @@ async def get_ip_history(
request: Request, request: Request,
_auth: AuthDep, _auth: AuthDep,
ip: str, ip: str,
socket_path: Fail2BanSocketDep,
http_session: HttpSessionDep,
) -> IpDetailResponse: ) -> IpDetailResponse:
"""Return the complete historical record for a single IP address. """Return the complete historical record for a single IP address.
@@ -172,8 +179,6 @@ async def get_ip_history(
Raises: Raises:
HTTPException: 404 if the IP has no history in the database. HTTPException: 404 if the IP has no history in the database.
""" """
socket_path: str = request.app.state.settings.fail2ban_socket
http_session: aiohttp.ClientSession = request.app.state.http_session
async def _enricher(addr: str) -> geo_service.GeoInfo | None: async def _enricher(addr: str) -> geo_service.GeoInfo | None:
return await geo_service.lookup(addr, http_session) return await geo_service.lookup(addr, http_session)

View File

@@ -12,7 +12,7 @@ from __future__ import annotations
from fastapi import APIRouter, HTTPException, Request, status from fastapi import APIRouter, HTTPException, Request, status
from app.dependencies import AuthDep from app.dependencies import AuthDep, Fail2BanSocketDep
from app.models.server import ServerSettingsResponse, ServerSettingsUpdate from app.models.server import ServerSettingsResponse, ServerSettingsUpdate
from app.services import server_service from app.services import server_service
from app.exceptions import ServerOperationError from app.exceptions import ServerOperationError
@@ -53,6 +53,7 @@ def _bad_request(message: str) -> HTTPException:
async def get_server_settings( async def get_server_settings(
request: Request, request: Request,
_auth: AuthDep, _auth: AuthDep,
socket_path: Fail2BanSocketDep,
) -> ServerSettingsResponse: ) -> ServerSettingsResponse:
"""Return the current fail2ban server-level settings. """Return the current fail2ban server-level settings.
@@ -69,7 +70,6 @@ async def get_server_settings(
Raises: Raises:
HTTPException: 502 when fail2ban is unreachable. HTTPException: 502 when fail2ban is unreachable.
""" """
socket_path: str = request.app.state.settings.fail2ban_socket
try: try:
return await server_service.get_settings(socket_path) return await server_service.get_settings(socket_path)
except Fail2BanConnectionError as exc: except Fail2BanConnectionError as exc:
@@ -85,6 +85,7 @@ async def update_server_settings(
request: Request, request: Request,
_auth: AuthDep, _auth: AuthDep,
body: ServerSettingsUpdate, body: ServerSettingsUpdate,
socket_path: Fail2BanSocketDep,
) -> None: ) -> None:
"""Update fail2ban server-level settings. """Update fail2ban server-level settings.
@@ -100,7 +101,6 @@ async def update_server_settings(
HTTPException: 400 when a set command is rejected by fail2ban. HTTPException: 400 when a set command is rejected by fail2ban.
HTTPException: 502 when fail2ban is unreachable. HTTPException: 502 when fail2ban is unreachable.
""" """
socket_path: str = request.app.state.settings.fail2ban_socket
try: try:
await server_service.update_settings(socket_path, body) await server_service.update_settings(socket_path, body)
except ServerOperationError as exc: except ServerOperationError as exc:
@@ -117,6 +117,7 @@ async def update_server_settings(
async def flush_logs( async def flush_logs(
request: Request, request: Request,
_auth: AuthDep, _auth: AuthDep,
socket_path: Fail2BanSocketDep,
) -> dict[str, str]: ) -> dict[str, str]:
"""Flush and re-open fail2ban log files. """Flush and re-open fail2ban log files.
@@ -134,7 +135,6 @@ async def flush_logs(
HTTPException: 400 when the command is rejected. HTTPException: 400 when the command is rejected.
HTTPException: 502 when fail2ban is unreachable. HTTPException: 502 when fail2ban is unreachable.
""" """
socket_path: str = request.app.state.settings.fail2ban_socket
try: try:
result = await server_service.flush_logs(socket_path) result = await server_service.flush_logs(socket_path)
return {"message": result} return {"message": result}