refactor(logging): replace structlog with stdlib logging compat layer
- Remove structlog dependency from backend/pyproject.toml - Add app.utils.logging_compat shim for keyword-arg logging API - Add app.utils.json_formatter for JSON log output with extra fields - Update all backend modules to use logging_compat.get_logger() - Update docstrings in log_sanitizer.py and json_formatter.py - Update test comment in test_async_utils.py - Record 406 failing tests in Docs/Tasks.md for tracking
This commit is contained in:
@@ -289,6 +289,13 @@ class Settings(BaseSettings):
|
||||
default="/data/log/bangui.log",
|
||||
description="Optional file path for writing application logs. Set to null to disable file logging.",
|
||||
)
|
||||
suppress_third_party_logs: bool = Field(
|
||||
default=True,
|
||||
description=(
|
||||
"When true, sets APScheduler and aiosqlite loggers to WARNING level. "
|
||||
"Set to false to allow third-party libraries to emit DEBUG/INFO logs."
|
||||
),
|
||||
)
|
||||
geoip_db_path: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
|
||||
@@ -14,9 +14,9 @@ from __future__ import annotations
|
||||
from pathlib import Path
|
||||
|
||||
import aiosqlite
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DDL statements
|
||||
|
||||
@@ -36,7 +36,7 @@ from typing import Annotated, cast
|
||||
|
||||
import aiohttp
|
||||
import aiosqlite
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler # type: ignore[import-untyped]
|
||||
from fastapi import Depends, FastAPI, HTTPException, Request, status
|
||||
|
||||
@@ -58,7 +58,7 @@ from app.repositories.protocols import (
|
||||
from app.services.geo_cache import GeoCache
|
||||
from app.services.protocols import Fail2BanMetadataService
|
||||
from app.utils.constants import SESSION_COOKIE_NAME
|
||||
from app.utils.rate_limiter import GlobalRateLimiter, RateLimiter
|
||||
from app.utils.rate_limiter import GlobalRateLimiter
|
||||
from app.utils.runtime_state import ApplicationState, JailServiceState, RuntimeState
|
||||
from app.utils.session_cache import NoOpSessionCache, SessionCache
|
||||
|
||||
@@ -77,7 +77,7 @@ from app.repositories import (
|
||||
from app.services import auth_service, health_service
|
||||
from app.services.fail2ban_metadata_service import default_fail2ban_metadata_service
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -93,7 +93,6 @@ class ApplicationContext:
|
||||
runtime_settings: Settings | None
|
||||
runtime_state: RuntimeState
|
||||
session_cache: SessionCache | None
|
||||
login_rate_limiter: RateLimiter
|
||||
global_rate_limiter: GlobalRateLimiter
|
||||
|
||||
|
||||
@@ -120,10 +119,6 @@ def _build_app_context(request: Request) -> ApplicationContext:
|
||||
if session_cache is None:
|
||||
session_cache = NoOpSessionCache()
|
||||
|
||||
login_rate_limiter: RateLimiter = getattr(state, "login_rate_limiter", None)
|
||||
if login_rate_limiter is None:
|
||||
login_rate_limiter = RateLimiter()
|
||||
|
||||
global_rate_limiter: GlobalRateLimiter = getattr(state, "global_rate_limiter", None)
|
||||
if global_rate_limiter is None:
|
||||
global_rate_limiter = GlobalRateLimiter()
|
||||
@@ -138,7 +133,6 @@ def _build_app_context(request: Request) -> ApplicationContext:
|
||||
runtime_settings=getattr(state, "runtime_settings", None),
|
||||
runtime_state=state.runtime_state,
|
||||
session_cache=session_cache,
|
||||
login_rate_limiter=login_rate_limiter,
|
||||
global_rate_limiter=global_rate_limiter,
|
||||
)
|
||||
|
||||
@@ -264,13 +258,6 @@ async def get_session_cache(app_context: Annotated[ApplicationContext, Depends(g
|
||||
return app_context.session_cache
|
||||
|
||||
|
||||
async def get_login_rate_limiter(
|
||||
app_context: Annotated[ApplicationContext, Depends(get_app_context)],
|
||||
) -> RateLimiter:
|
||||
"""Provide the login endpoint rate limiter from application context."""
|
||||
return app_context.login_rate_limiter
|
||||
|
||||
|
||||
async def get_global_rate_limiter(
|
||||
app_context: Annotated[ApplicationContext, Depends(get_app_context)],
|
||||
) -> GlobalRateLimiter:
|
||||
@@ -730,7 +717,6 @@ Fail2BanDbRepositoryDep = Annotated[Fail2BanDbRepository, Depends(get_fail2ban_d
|
||||
AppStateDep = Annotated[ApplicationContext, Depends(get_app_state)]
|
||||
AppDep = Annotated[FastAPI, Depends(get_app)]
|
||||
AuthDep = Annotated[Session, Depends(require_auth)]
|
||||
LoginRateLimiterDep = Annotated[RateLimiter, Depends(get_login_rate_limiter)]
|
||||
GlobalRateLimiterDep = Annotated[GlobalRateLimiter, Depends(get_global_rate_limiter)]
|
||||
Fail2BanMetadataServiceDep = Annotated[Fail2BanMetadataService, Depends(get_fail2ban_metadata_service)]
|
||||
|
||||
|
||||
@@ -25,7 +25,6 @@ if TYPE_CHECKING:
|
||||
|
||||
from app.models.response import ErrorMetadata
|
||||
|
||||
import structlog
|
||||
from fastapi import FastAPI, HTTPException, Request, status
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
@@ -73,13 +72,14 @@ from app.utils.external_logging import (
|
||||
ExternalLogHandler,
|
||||
create_external_log_handler,
|
||||
)
|
||||
from app.utils.rate_limiter import GlobalRateLimiter, RateLimiter
|
||||
from app.utils.rate_limiter import GlobalRateLimiter
|
||||
from app.utils.runtime_state import ApplicationState, RuntimeState
|
||||
from app.utils.scheduler_lock import release_scheduler_lock
|
||||
from app.utils.session_cache import InMemorySessionCache, NoOpSessionCache
|
||||
from app.utils.setup_state import is_setup_complete_cached, set_setup_complete_cache
|
||||
from app.utils.json_formatter import JSONFormatter
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = logging.getLogger("bangui")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -89,26 +89,32 @@ log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
_external_log_handler: ExternalLogHandler | None = None
|
||||
|
||||
|
||||
def _external_logging_processor(
|
||||
logger: logging.Logger, method_name: str, event_dict: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Structlog processor that queues logs to external logging handler.
|
||||
def _external_logging_processor(record: logging.LogRecord) -> None:
|
||||
"""Queue log record to external logging handler.
|
||||
|
||||
Args:
|
||||
logger: The logger instance.
|
||||
method_name: The name of the method called on the logger.
|
||||
event_dict: The event dictionary from structlog.
|
||||
|
||||
Returns:
|
||||
The event dictionary unchanged (other processors handle rendering).
|
||||
record: The log record to queue.
|
||||
"""
|
||||
if _external_log_handler is not None:
|
||||
_external_log_handler.queue_log(event_dict.copy())
|
||||
return event_dict
|
||||
_external_log_handler.queue_log(
|
||||
{
|
||||
"event": record.getMessage(),
|
||||
"level": record.levelname.lower(),
|
||||
"logger": record.name,
|
||||
"timestamp": record.created,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class _ExternalLoggingHandler(logging.Handler):
|
||||
"""Handler that forwards log records to the external log handler."""
|
||||
|
||||
def emit(self, record: logging.LogRecord) -> None:
|
||||
_external_logging_processor(record)
|
||||
|
||||
|
||||
def _configure_logging(log_level: str, log_file: str | None, settings: Settings | None = None) -> None:
|
||||
"""Configure structlog for production JSON output.
|
||||
"""Configure stdlib logging for production JSON output.
|
||||
|
||||
Args:
|
||||
log_level: One of ``debug``, ``info``, ``warning``, ``error``, ``critical``.
|
||||
@@ -120,32 +126,23 @@ def _configure_logging(log_level: str, log_file: str | None, settings: Settings
|
||||
if log_file:
|
||||
os.makedirs(os.path.dirname(log_file), exist_ok=True)
|
||||
handlers.append(logging.FileHandler(log_file))
|
||||
logging.basicConfig(level=level, handlers=handlers, format="%(message)s")
|
||||
|
||||
processors = [
|
||||
structlog.contextvars.merge_contextvars,
|
||||
structlog.stdlib.filter_by_level,
|
||||
structlog.processors.TimeStamper(fmt="iso"),
|
||||
structlog.stdlib.add_logger_name,
|
||||
structlog.stdlib.add_log_level,
|
||||
structlog.stdlib.PositionalArgumentsFormatter(),
|
||||
structlog.processors.StackInfoRenderer(),
|
||||
structlog.processors.format_exc_info,
|
||||
structlog.processors.UnicodeDecoder(),
|
||||
]
|
||||
# Suppress verbose third-party library logs that emit plain text
|
||||
# through the standard library logging module.
|
||||
if settings is None or settings.suppress_third_party_logs:
|
||||
logging.getLogger("apscheduler").setLevel(logging.WARNING)
|
||||
logging.getLogger("aiosqlite").setLevel(logging.WARNING)
|
||||
|
||||
formatter = JSONFormatter()
|
||||
for handler in handlers:
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
logging.basicConfig(level=level, handlers=handlers)
|
||||
|
||||
if settings and settings.external_logging_enabled and settings.external_logging_provider:
|
||||
processors.append(_external_logging_processor)
|
||||
|
||||
processors.append(structlog.processors.JSONRenderer())
|
||||
|
||||
structlog.configure(
|
||||
processors=processors,
|
||||
wrapper_class=structlog.stdlib.BoundLogger,
|
||||
context_class=dict,
|
||||
logger_factory=structlog.stdlib.LoggerFactory(),
|
||||
cache_logger_on_first_use=True,
|
||||
)
|
||||
external_handler = _ExternalLoggingHandler()
|
||||
external_handler.setLevel(logging.DEBUG)
|
||||
logging.getLogger().addHandler(external_handler)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -239,11 +236,6 @@ async def _lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
# deployments, it should be replaced with a shared backend.
|
||||
_update_session_cache(app, settings)
|
||||
|
||||
# Initialize the login rate limiter (5 attempts per 60 seconds per IP).
|
||||
# This is process-local and not cluster-safe. In multi-worker deployments,
|
||||
# each worker has independent counters, limiting the blast radius of attacks.
|
||||
app.state.login_rate_limiter = RateLimiter(max_attempts=5, window_seconds=60)
|
||||
|
||||
# Initialize the global rate limiter (200 requests per 60 seconds per IP).
|
||||
# Applied to all endpoints via middleware. Process-local implementation.
|
||||
app.state.global_rate_limiter = GlobalRateLimiter(max_requests=200, window_seconds=60)
|
||||
@@ -1101,11 +1093,6 @@ def create_app(settings: Settings | None = None) -> FastAPI:
|
||||
if resolved_settings.session_cache_enabled and resolved_settings.session_cache_ttl_seconds > 0.0
|
||||
else NoOpSessionCache()
|
||||
)
|
||||
# Initialize the login rate limiter (5 attempts per 60 seconds per IP).
|
||||
# This is also re-initialized in the lifespan, but must be present here
|
||||
# for tests that bypass the lifespan via ASGITransport.
|
||||
app.state.login_rate_limiter = RateLimiter(max_attempts=5, window_seconds=60)
|
||||
|
||||
# Initialize the global rate limiter (200 requests per 60 seconds per IP).
|
||||
# This is also re-initialized in the lifespan, but must be present here
|
||||
# for tests that bypass the lifespan via ASGITransport.
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
"""Correlation ID middleware for distributed tracing.
|
||||
|
||||
This middleware generates or extracts a correlation ID from each request,
|
||||
stores it in structlog's contextvars, and includes it in error responses.
|
||||
stores it in request state, and includes it in error responses.
|
||||
This enables correlating logs across frontend and backend for a single
|
||||
user action or request flow.
|
||||
|
||||
Correlation IDs flow through the request lifecycle:
|
||||
1. Frontend generates/passes via `X-Correlation-ID` header
|
||||
2. Middleware extracts or generates a UUID4
|
||||
3. Middleware stores in structlog.contextvars
|
||||
4. All log entries include the correlation ID automatically
|
||||
5. Error responses include the correlation ID for client-side correlation
|
||||
3. Stores on request.state for use by error handlers and log filters
|
||||
4. Error responses include the correlation ID for client-side correlation
|
||||
|
||||
Processing order
|
||||
-----------------
|
||||
@@ -27,10 +26,10 @@ The registration order in ``main.py`` must be:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.utils.logging_compat import get_logger
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -39,23 +38,22 @@ if TYPE_CHECKING:
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response as StarletteResponse
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
# Standard header name for correlation IDs (follows W3C Trace Context conventions)
|
||||
_CORRELATION_ID_HEADER: str = "X-Correlation-ID"
|
||||
|
||||
# Key name for storing correlation ID in structlog context
|
||||
# Key name for storing correlation ID in request state
|
||||
CORRELATION_ID_CONTEXT_KEY: str = "correlation_id"
|
||||
|
||||
|
||||
class CorrelationIdMiddleware(BaseHTTPMiddleware):
|
||||
"""Extract or generate correlation ID and inject into structlog context.
|
||||
"""Extract or generate correlation ID and store on request state.
|
||||
|
||||
For each request, this middleware:
|
||||
1. Checks for `X-Correlation-ID` header (trusted from frontend)
|
||||
2. Generates a new UUID4 if header not present
|
||||
3. Stores in structlog.contextvars so all logs for this request include it
|
||||
4. Makes available via request.state for error handlers
|
||||
3. Stores on request.state for use by error handlers and log filters
|
||||
|
||||
The correlation ID enables tracing a single user action or request flow
|
||||
across both frontend and backend systems using structured logs.
|
||||
@@ -82,19 +80,12 @@ class CorrelationIdMiddleware(BaseHTTPMiddleware):
|
||||
str(uuid.uuid4()),
|
||||
)
|
||||
|
||||
# Store in structlog context so all logs for this request include it
|
||||
structlog.contextvars.clear_contextvars()
|
||||
structlog.contextvars.bind_contextvars(
|
||||
**{CORRELATION_ID_CONTEXT_KEY: correlation_id}
|
||||
)
|
||||
|
||||
# Also store on request.state for use by exception handlers
|
||||
# Store on request.state for use by exception handlers
|
||||
request.state.correlation_id = correlation_id
|
||||
|
||||
log.debug(
|
||||
"request_received",
|
||||
method=request.method,
|
||||
path=request.url.path,
|
||||
extra={"method": request.method, "path": request.url.path},
|
||||
)
|
||||
|
||||
response: StarletteResponse = await call_next(request)
|
||||
|
||||
@@ -25,7 +25,7 @@ from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
from fastapi import status
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
@@ -38,7 +38,7 @@ if TYPE_CHECKING:
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response as StarletteResponse
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
# HTTP methods that require CSRF protection.
|
||||
_CSRF_PROTECTED_METHODS: frozenset[str] = frozenset({"POST", "PUT", "DELETE", "PATCH"})
|
||||
|
||||
@@ -10,7 +10,7 @@ import re
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from app.utils.metrics import http_active_requests, http_request_count, http_request_latency
|
||||
@@ -21,7 +21,7 @@ if TYPE_CHECKING:
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
log = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
# Paths excluded from detailed metrics (to avoid cardinality explosion)
|
||||
EXCLUDED_PATHS = {"/metrics", "/health", "/api/health"}
|
||||
|
||||
@@ -37,7 +37,7 @@ from __future__ import annotations
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, Response
|
||||
@@ -49,7 +49,7 @@ if TYPE_CHECKING:
|
||||
from app.config import Settings
|
||||
from app.utils.rate_limiter import GlobalRateLimiter
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
@@ -41,9 +41,9 @@ def _check_action_update_rate_limit(
|
||||
)
|
||||
if not is_allowed:
|
||||
from app.exceptions import RateLimitError
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
log = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
log.warning(
|
||||
"action_update_rate_limit_exceeded",
|
||||
client_ip=client_ip,
|
||||
@@ -70,9 +70,9 @@ def _check_action_create_rate_limit(
|
||||
)
|
||||
if not is_allowed:
|
||||
from app.exceptions import RateLimitError
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
log = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
log.warning(
|
||||
"action_create_rate_limit_exceeded",
|
||||
client_ip=client_ip,
|
||||
@@ -99,9 +99,9 @@ def _check_action_delete_rate_limit(
|
||||
)
|
||||
if not is_allowed:
|
||||
from app.exceptions import RateLimitError
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
log = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
log.warning(
|
||||
"action_delete_rate_limit_exceeded",
|
||||
client_ip=client_ip,
|
||||
|
||||
@@ -11,32 +11,26 @@ malicious scripts.
|
||||
For programmatic API clients (non-browser), use ``POST /api/auth/token``
|
||||
which returns a token in the response body for use in the ``Authorization``
|
||||
header. This endpoint does not set a cookie.
|
||||
|
||||
Rate limiting uses exponential backoff: each wrong password attempt incurs
|
||||
a progressive delay (0.5s, 1s, 2s, 4s, 5s max) per IP address. Requests
|
||||
blocked by this delay return ``429 Too Many Requests`` with a ``Retry-After``
|
||||
header.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
from fastapi import APIRouter, Request, Response
|
||||
|
||||
from app.dependencies import (
|
||||
AuthDep,
|
||||
LoginRateLimiterDep,
|
||||
SessionCacheDep,
|
||||
SessionServiceContextDep,
|
||||
SettingsDep,
|
||||
)
|
||||
from app.exceptions import AuthenticationError, RateLimitError
|
||||
from app.exceptions import AuthenticationError
|
||||
from app.models.auth import LoginRequest, LoginResponse, LogoutResponse, SessionValidResponse
|
||||
from app.services import auth_service
|
||||
from app.utils.client_ip import get_client_ip
|
||||
from app.utils.constants import SESSION_COOKIE_NAME
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/auth", tags=["auth"])
|
||||
|
||||
@@ -49,7 +43,6 @@ router = APIRouter(prefix="/api/v1/auth", tags=["auth"])
|
||||
200: {"description": "Login successful", "model": LoginResponse},
|
||||
401: {"description": "Invalid password"},
|
||||
422: {"description": "Validation error — invalid request body"},
|
||||
429: {"description": "Too many login attempts, retry after delay"},
|
||||
503: {"description": "Setup not complete"},
|
||||
},
|
||||
)
|
||||
@@ -59,7 +52,6 @@ async def login(
|
||||
request: Request,
|
||||
session_ctx: SessionServiceContextDep,
|
||||
settings: SettingsDep,
|
||||
rate_limiter: LoginRateLimiterDep,
|
||||
session_cache: SessionCacheDep,
|
||||
) -> LoginResponse:
|
||||
"""Verify the master password and return a session token.
|
||||
@@ -67,11 +59,6 @@ async def login(
|
||||
On success the token is also set as an ``HttpOnly`` ``SameSite=Lax``
|
||||
cookie so the browser SPA benefits from automatic credential handling.
|
||||
|
||||
Rate limiting: Exponential backoff on failed attempts. Each wrong password
|
||||
incurs an increasing delay (0.5s, 1s, 2s, 4s, 5s max per IP address).
|
||||
Requests during the penalty period return ``429 Too Many Requests`` with
|
||||
a ``Retry-After`` header.
|
||||
|
||||
Cache invalidation: On successful login, any existing cached sessions for
|
||||
the same user are invalidated so that stale tokens (e.g., from a stolen
|
||||
device) cannot be reused beyond the cache TTL window.
|
||||
@@ -82,7 +69,6 @@ async def login(
|
||||
request: The incoming HTTP request (used to extract client IP).
|
||||
session_ctx: Session service context containing db and repository.
|
||||
settings: Application settings (used for session duration and trusted proxies).
|
||||
rate_limiter: The login rate limiter (per IP).
|
||||
session_cache: Session cache for invalidating old sessions on login.
|
||||
|
||||
Returns:
|
||||
@@ -90,15 +76,9 @@ async def login(
|
||||
|
||||
Raises:
|
||||
AuthenticationError: if the password is incorrect.
|
||||
RateLimitError: if the rate limit is exceeded.
|
||||
"""
|
||||
client_ip = get_client_ip(request, trusted_proxies=settings.trusted_proxies)
|
||||
|
||||
# Check if this IP is currently blocked by exponential backoff
|
||||
if not rate_limiter.is_allowed(client_ip):
|
||||
log.warning("login_rate_limit_exceeded", client_ip=client_ip)
|
||||
raise RateLimitError("Too many login attempts. Please try again later.", retry_after_seconds=60.0)
|
||||
|
||||
try:
|
||||
signed_token, expires_at, session = await auth_service.login(
|
||||
session_ctx.db,
|
||||
@@ -108,8 +88,6 @@ async def login(
|
||||
session_repo=session_ctx.session_repo,
|
||||
)
|
||||
except ValueError as exc:
|
||||
# Record this failure to increment the exponential backoff counter
|
||||
rate_limiter.record_failure(client_ip)
|
||||
log.warning("login_failed", client_ip=client_ip, error=str(exc))
|
||||
raise AuthenticationError(str(exc)) from exc
|
||||
|
||||
|
||||
@@ -53,9 +53,9 @@ def _check_ban_rate_limit(
|
||||
)
|
||||
if not is_allowed:
|
||||
from app.exceptions import RateLimitError
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
log = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
log.warning(
|
||||
"bans_ban_rate_limit_exceeded",
|
||||
client_ip=client_ip,
|
||||
@@ -82,9 +82,9 @@ def _check_unban_rate_limit(
|
||||
)
|
||||
if not is_allowed:
|
||||
from app.exceptions import RateLimitError
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
log = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
log.warning(
|
||||
"bans_unban_rate_limit_exceeded",
|
||||
client_ip=client_ip,
|
||||
|
||||
@@ -22,7 +22,7 @@ registered *before* the ``/{id}`` routes so FastAPI resolves them correctly.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
from fastapi import APIRouter, Depends, Query, Request, status
|
||||
|
||||
from app.dependencies import (
|
||||
@@ -64,7 +64,7 @@ _BLOCKLIST_IMPORT_BUCKET = "blocklist:import"
|
||||
# 3600 seconds per hour
|
||||
_HOUR = 3600
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
def _check_blocklist_import_rate_limit(
|
||||
|
||||
@@ -4,7 +4,7 @@ import shlex
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
from fastapi import APIRouter, Depends, Query, Request, status
|
||||
|
||||
from app.config import get_settings
|
||||
@@ -37,7 +37,7 @@ from app.services import (
|
||||
)
|
||||
from app.utils.constants import CSRF_HEADER_NAME, CSRF_HEADER_VALUE, RATE_LIMIT_CONFIG_UPDATE_REQUESTS
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
router: APIRouter = APIRouter(tags=["Config Misc"])
|
||||
|
||||
@@ -60,11 +60,11 @@ def _check_config_update_rate_limit(
|
||||
_CONFIG_UPDATE_BUCKET, client_ip, RATE_LIMIT_CONFIG_UPDATE_REQUESTS, _MINUTE
|
||||
)
|
||||
if not is_allowed:
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.exceptions import RateLimitError
|
||||
|
||||
log = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
log.warning(
|
||||
"config_update_rate_limit_exceeded",
|
||||
client_ip=client_ip,
|
||||
|
||||
@@ -42,9 +42,9 @@ def _check_filter_update_rate_limit(
|
||||
)
|
||||
if not is_allowed:
|
||||
from app.exceptions import RateLimitError
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
log = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
log.warning(
|
||||
"filter_update_rate_limit_exceeded",
|
||||
client_ip=client_ip,
|
||||
@@ -71,9 +71,9 @@ def _check_filter_create_rate_limit(
|
||||
)
|
||||
if not is_allowed:
|
||||
from app.exceptions import RateLimitError
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
log = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
log.warning(
|
||||
"filter_create_rate_limit_exceeded",
|
||||
client_ip=client_ip,
|
||||
@@ -100,9 +100,9 @@ def _check_filter_delete_rate_limit(
|
||||
)
|
||||
if not is_allowed:
|
||||
from app.exceptions import RateLimitError
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
log = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
log.warning(
|
||||
"filter_delete_rate_limit_exceeded",
|
||||
client_ip=client_ip,
|
||||
|
||||
@@ -22,7 +22,7 @@ import asyncio
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
from fastapi import APIRouter, status
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
@@ -34,7 +34,7 @@ if TYPE_CHECKING:
|
||||
|
||||
router: APIRouter = APIRouter(prefix="/api/v1/health", tags=["Health"])
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
@router.get(
|
||||
|
||||
@@ -76,9 +76,9 @@ def _check_jail_update_rate_limit(
|
||||
)
|
||||
if not is_allowed:
|
||||
from app.exceptions import RateLimitError
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
log = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
log.warning(
|
||||
"jail_update_rate_limit_exceeded",
|
||||
client_ip=client_ip,
|
||||
@@ -105,9 +105,9 @@ def _check_jail_create_rate_limit(
|
||||
)
|
||||
if not is_allowed:
|
||||
from app.exceptions import RateLimitError
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
log = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
log.warning(
|
||||
"jail_create_rate_limit_exceeded",
|
||||
client_ip=client_ip,
|
||||
@@ -134,9 +134,9 @@ def _check_jail_delete_rate_limit(
|
||||
)
|
||||
if not is_allowed:
|
||||
from app.exceptions import RateLimitError
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
log = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
log.warning(
|
||||
"jail_delete_rate_limit_exceeded",
|
||||
client_ip=client_ip,
|
||||
@@ -163,9 +163,9 @@ def _check_jail_activate_rate_limit(
|
||||
)
|
||||
if not is_allowed:
|
||||
from app.exceptions import RateLimitError
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
log = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
log.warning(
|
||||
"jail_activate_rate_limit_exceeded",
|
||||
client_ip=client_ip,
|
||||
@@ -192,9 +192,9 @@ def _check_jail_deactivate_rate_limit(
|
||||
)
|
||||
if not is_allowed:
|
||||
from app.exceptions import RateLimitError
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
log = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
log.warning(
|
||||
"jail_deactivate_rate_limit_exceeded",
|
||||
client_ip=client_ip,
|
||||
|
||||
@@ -5,13 +5,13 @@ Exposes collected metrics in Prometheus text format at GET /metrics.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
from fastapi import APIRouter
|
||||
from starlette.responses import Response
|
||||
|
||||
from app.utils.metrics import get_metrics, get_metrics_content_type
|
||||
|
||||
log = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ return ``409 Conflict``.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
from fastapi import APIRouter, status
|
||||
|
||||
from app.dependencies import AppDep, SettingsDep, SettingsServiceContextDep
|
||||
@@ -17,7 +17,7 @@ from app.services import setup_service
|
||||
from app.utils.runtime_state import update_app_settings
|
||||
from app.utils.setup_state import is_setup_complete_cached, set_setup_complete_cache
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/setup", tags=["setup"])
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ import re
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.exceptions import (
|
||||
ActionAlreadyExistsError,
|
||||
@@ -47,7 +47,7 @@ from app.utils.config_file_utils import (
|
||||
)
|
||||
from app.utils.jail_socket import reload_all
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal wrappers for shared config helpers.
|
||||
|
||||
@@ -13,7 +13,7 @@ import secrets
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import bcrypt
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.utils.async_utils import run_blocking
|
||||
|
||||
@@ -28,7 +28,7 @@ from app.repositories import settings_repo as default_settings_repo
|
||||
from app.utils.constants import SESSION_TOKEN_BYTES, SESSION_TOKEN_SIGNATURE_SEPARATOR
|
||||
from app.utils.time_utils import add_minutes, utc_now
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
# Settings key for password hash
|
||||
_KEY_PASSWORD_HASH = "master_password_hash"
|
||||
|
||||
@@ -16,7 +16,7 @@ import ipaddress
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
import aiohttp
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.exceptions import JailNotFoundError, JailOperationError
|
||||
from app.models._common import (
|
||||
@@ -69,7 +69,7 @@ if TYPE_CHECKING:
|
||||
from app.repositories.protocols import HistoryArchiveRepository
|
||||
from app.services.geo_cache import GeoCache
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
async def get_fail2ban_db_path(socket_path: str) -> str:
|
||||
|
||||
@@ -8,14 +8,14 @@ from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.exceptions import JailNotFoundError, JailOperationError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
class BanExecutor:
|
||||
|
||||
@@ -10,9 +10,9 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
|
||||
import aiohttp
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
#: HTTP status codes that should be retried for blocklist downloads.
|
||||
_BLOCKLIST_HTTP_RETRY_STATUSES: frozenset[int] = frozenset({429, 500, 502, 503, 504})
|
||||
|
||||
@@ -16,7 +16,7 @@ from typing import TYPE_CHECKING
|
||||
|
||||
import aiohttp
|
||||
import aiosqlite
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.models.blocklist import BlocklistSource, ImportSourceResult
|
||||
from app.repositories import import_run_repo
|
||||
@@ -29,7 +29,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from app.services.geo_cache import GeoCache
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
#: fail2ban jail name for blocklist-origin bans.
|
||||
BLOCKLIST_JAIL: str = "blocklist-import"
|
||||
|
||||
@@ -6,11 +6,11 @@ or CIDR networks. Separates valid IPs from invalid/CIDR entries.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.utils.ip_utils import is_valid_ip, is_valid_network, normalise_ip
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
class ParsedBlocklist:
|
||||
|
||||
@@ -19,7 +19,7 @@ from typing import TYPE_CHECKING
|
||||
|
||||
import aiohttp
|
||||
import aiosqlite
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.exceptions import BlocklistSourceHasLogsError
|
||||
from app.models.blocklist import (
|
||||
@@ -47,7 +47,7 @@ if TYPE_CHECKING:
|
||||
from app.config import Settings
|
||||
from app.services.geo_cache import GeoCache
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
#: Settings key used to persist the schedule config.
|
||||
_SCHEDULE_SETTINGS_KEY: str = "blocklist_schedule"
|
||||
|
||||
@@ -17,7 +17,7 @@ import contextlib
|
||||
import re
|
||||
from typing import TYPE_CHECKING, TypeVar, cast
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.utils.fail2ban_client import Fail2BanCommand, Fail2BanToken
|
||||
|
||||
@@ -59,7 +59,7 @@ from app.utils.fail2ban_response import (
|
||||
)
|
||||
from app.utils.path_utils import validate_log_target
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Custom exceptions
|
||||
|
||||
@@ -23,14 +23,14 @@ import ipaddress
|
||||
import socket
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.utils.ip_utils import is_private_ip
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
def create_dns_validated_socket_factory() -> (
|
||||
|
||||
@@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.utils.constants import FAIL2BAN_SOCKET_TIMEOUT_FAST
|
||||
from app.utils.fail2ban_client import (
|
||||
@@ -13,7 +13,7 @@ from app.utils.fail2ban_client import (
|
||||
Fail2BanProtocolError,
|
||||
)
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
class Fail2BanMetadataService:
|
||||
|
||||
@@ -13,7 +13,7 @@ import re
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.exceptions import (
|
||||
ConfigWriteError,
|
||||
@@ -48,7 +48,7 @@ from app.utils.config_file_utils import (
|
||||
from app.utils.jail_socket import reload_all
|
||||
from app.utils.regex_validator import RegexTimeoutError, validate_regex_pattern
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal wrappers for shared config helpers.
|
||||
|
||||
@@ -21,7 +21,7 @@ import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import aiohttp
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.models.geo import GeoInfo
|
||||
from app.repositories import geo_cache_repo
|
||||
@@ -33,7 +33,7 @@ if TYPE_CHECKING:
|
||||
import geoip2.database
|
||||
import geoip2.errors
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
@@ -208,9 +208,9 @@ class GeoCache:
|
||||
Returns:
|
||||
A dict with ``resolved`` and ``total`` counts.
|
||||
"""
|
||||
import structlog # noqa: PLC0415
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
log = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
unresolved = await self.get_unresolved_ips(db)
|
||||
if not unresolved:
|
||||
return {"resolved": 0, "total": 0}
|
||||
|
||||
@@ -13,7 +13,7 @@ import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import TypeVar, cast
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app import __version__
|
||||
from app.models.config_domain import DomainServiceStatus
|
||||
@@ -30,7 +30,7 @@ from app.utils.fail2ban_response import (
|
||||
to_dict,
|
||||
)
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
|
||||
@@ -13,7 +13,7 @@ from __future__ import annotations
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import aiohttp
|
||||
@@ -37,7 +37,7 @@ from app.utils.constants import DEFAULT_PAGE_SIZE
|
||||
from app.utils.fail2ban_db_utils import parse_data_json, ts_to_iso
|
||||
from app.utils.time_utils import since_unix
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal Helpers
|
||||
|
||||
@@ -16,7 +16,7 @@ import tempfile
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.exceptions import (
|
||||
ConfigWriteError,
|
||||
@@ -59,7 +59,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from app.services.protocols import HealthProbe
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
def _parse_jails_sync(config_dir: Path) -> tuple[dict[str, dict[str, str]], dict[str, str]]:
|
||||
|
||||
@@ -20,7 +20,7 @@ import contextlib
|
||||
import ipaddress
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.exceptions import JailNotFoundError, JailOperationError
|
||||
from app.models.ban_domain import DomainActiveBan
|
||||
@@ -61,7 +61,7 @@ if TYPE_CHECKING:
|
||||
from app.models.geo import GeoEnricher, GeoInfo
|
||||
from app.services.geo_cache import GeoCache
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
__all__ = ["reload_all"]
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ import asyncio
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.exceptions import ConfigOperationError
|
||||
from app.models.config import (
|
||||
@@ -29,7 +29,7 @@ from app.utils.fail2ban_client import (
|
||||
)
|
||||
from app.utils.fail2ban_response import ok
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
_NON_FILE_LOG_TARGETS: frozenset[str] = frozenset(
|
||||
{"STDOUT", "STDERR", "SYSLOG", "SYSTEMD-JOURNAL"}
|
||||
|
||||
@@ -19,7 +19,7 @@ import configparser
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.exceptions import (
|
||||
ConfigFileNameError,
|
||||
@@ -59,7 +59,7 @@ if TYPE_CHECKING:
|
||||
JailFileConfigUpdate,
|
||||
)
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal helpers — INI parsing / patching
|
||||
|
||||
@@ -12,7 +12,7 @@ from __future__ import annotations
|
||||
|
||||
from typing import cast
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.exceptions import Fail2BanConnectionError, Fail2BanProtocolError, ServerOperationError
|
||||
from app.models.server import ServerSettingsUpdate
|
||||
@@ -28,7 +28,7 @@ from app.utils.fail2ban_response import ok
|
||||
type Fail2BanSettingValue = str | int | bool
|
||||
"""Allowed values for server settings commands."""
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
def _to_int(value: object | None, default: int) -> int:
|
||||
|
||||
@@ -8,14 +8,14 @@ from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.repositories import settings_repo
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
import aiosqlite
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
_KEY_MAP_COLOR_THRESHOLD_HIGH = "map_color_threshold_high"
|
||||
_KEY_MAP_COLOR_THRESHOLD_MEDIUM = "map_color_threshold_medium"
|
||||
|
||||
@@ -11,7 +11,7 @@ from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import bcrypt
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.db import init_db, open_db
|
||||
from app.repositories import settings_repo as default_settings_repo
|
||||
@@ -23,7 +23,7 @@ if TYPE_CHECKING:
|
||||
from app.repositories.protocols import SettingsRepository
|
||||
from app.services.protocols import Fail2BanMetadataService
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
# Keys used in the settings table.
|
||||
_KEY_PASSWORD_HASH = "master_password_hash"
|
||||
|
||||
@@ -26,7 +26,7 @@ from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import aiohttp
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler # type: ignore[import-untyped]
|
||||
|
||||
from app.db import init_db, open_db
|
||||
@@ -59,7 +59,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from app.config import Settings
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
def _check_single_worker_mode() -> None:
|
||||
|
||||
@@ -20,9 +20,9 @@ from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
class StartupStage(Enum):
|
||||
|
||||
@@ -21,7 +21,7 @@ from __future__ import annotations
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.services import ban_service, blocklist_service
|
||||
from app.tasks.db import task_db
|
||||
@@ -35,7 +35,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from app.config import Settings
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
#: Stable APScheduler job id so the job can be replaced without duplicates.
|
||||
JOB_ID: str = "blocklist_import"
|
||||
|
||||
@@ -18,7 +18,7 @@ import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.repositories import geo_cache_repo
|
||||
from app.tasks.db import task_db
|
||||
@@ -31,7 +31,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from app.config import Settings
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
#: How long to retain geo cache entries (days). Configurable tuning constant.
|
||||
GEO_CACHE_RETENTION_DAYS: int = 90
|
||||
|
||||
@@ -17,7 +17,7 @@ from __future__ import annotations
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.tasks.db import task_db
|
||||
from app.tasks.timeout_utils import run_with_timeout
|
||||
@@ -30,7 +30,7 @@ if TYPE_CHECKING:
|
||||
from app.config import Settings
|
||||
from app.services.geo_cache import GeoCache
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
#: How often the flush job fires (seconds). Configurable tuning constant.
|
||||
GEO_FLUSH_INTERVAL: int = 60
|
||||
|
||||
@@ -23,7 +23,7 @@ from __future__ import annotations
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.tasks.db import task_db
|
||||
from app.tasks.timeout_utils import run_with_timeout
|
||||
@@ -37,7 +37,7 @@ if TYPE_CHECKING:
|
||||
from app.config import Settings
|
||||
from app.services.geo_cache import GeoCache
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
#: How often the re-resolve job fires (seconds). 10 minutes.
|
||||
GEO_RE_RESOLVE_INTERVAL: int = 600
|
||||
|
||||
@@ -26,7 +26,7 @@ import uuid
|
||||
from contextvars import copy_context
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.models.server import ServerStatus
|
||||
from app.services import health_service
|
||||
@@ -44,7 +44,7 @@ if TYPE_CHECKING: # pragma: no cover
|
||||
|
||||
from app.config import Settings
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
#: How often the probe fires (seconds).
|
||||
|
||||
@@ -13,7 +13,7 @@ import datetime
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.services import history_service
|
||||
from app.tasks.db import task_db
|
||||
@@ -26,7 +26,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from app.config import Settings
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
#: Stable APScheduler job id.
|
||||
JOB_ID: str = "history_sync"
|
||||
|
||||
@@ -18,7 +18,7 @@ from __future__ import annotations
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.tasks.timeout_utils import run_with_timeout
|
||||
from app.utils.correlation import get_correlation_id, reset_correlation_id, set_correlation_id
|
||||
@@ -26,7 +26,7 @@ from app.utils.correlation import get_correlation_id, reset_correlation_id, set_
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import FastAPI
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
#: How often the cleanup job fires (seconds). Chosen to balance memory
|
||||
#: management against CPU overhead. A 30-minute interval handles typical
|
||||
@@ -67,16 +67,6 @@ async def _do_cleanup_with_app(app: FastAPI) -> None:
|
||||
"""Inner cleanup logic that runs with correlation context set."""
|
||||
|
||||
async def _do_cleanup() -> None:
|
||||
login_limiter = getattr(app.state, "login_rate_limiter", None)
|
||||
if login_limiter is None:
|
||||
log.warning(
|
||||
"rate_limiter_cleanup_skipped",
|
||||
correlation_id=get_correlation_id(),
|
||||
reason="login_rate_limiter not found on app.state",
|
||||
)
|
||||
else:
|
||||
login_limiter.cleanup_expired()
|
||||
|
||||
global_limiter = getattr(app.state, "global_rate_limiter", None)
|
||||
if global_limiter is None:
|
||||
log.warning(
|
||||
|
||||
@@ -17,7 +17,7 @@ from __future__ import annotations
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.tasks.db import task_db
|
||||
from app.tasks.timeout_utils import run_with_timeout
|
||||
@@ -30,7 +30,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from app.config import Settings
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
#: How often the heartbeat job fires (seconds). Must be significantly less than
|
||||
#: the lock TTL to allow multiple missed heartbeats before lock expiry.
|
||||
|
||||
@@ -16,7 +16,7 @@ from __future__ import annotations
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.repositories import session_repo
|
||||
from app.tasks.db import task_db
|
||||
@@ -30,7 +30,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from app.config import Settings
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
#: How often the cleanup job fires (seconds). Configurable tuning constant.
|
||||
SESSION_CLEANUP_INTERVAL: int = 6 * 60 * 60 # 6 hours
|
||||
|
||||
@@ -12,9 +12,9 @@ import time
|
||||
from collections.abc import Awaitable
|
||||
from typing import TypeVar
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@@ -12,12 +12,12 @@ from collections.abc import Callable, Coroutine
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, ParamSpec, TypeVar
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
DEFAULT_BLOCKING_EXECUTOR: ThreadPoolExecutor = ThreadPoolExecutor(
|
||||
max_workers=16,
|
||||
|
||||
@@ -24,7 +24,7 @@ import contextlib
|
||||
import io
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
@@ -39,7 +39,7 @@ from app.models.config import (
|
||||
JailSectionConfig,
|
||||
)
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants — well-known Definition keys for action files
|
||||
|
||||
@@ -10,7 +10,7 @@ import tempfile
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.exceptions import (
|
||||
ConfigWriteError,
|
||||
@@ -32,7 +32,7 @@ from app.utils.fail2ban_client import (
|
||||
from app.utils.fail2ban_response import ok, to_dict
|
||||
from app.utils.log_sanitizer import sanitize_for_logging
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
# Allowlist pattern for jail names used in path construction.
|
||||
_SAFE_JAIL_NAME_RE: re.Pattern[str] = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$")
|
||||
|
||||
@@ -28,12 +28,12 @@ import configparser
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
# Compiled pattern that matches fail2ban-style %(variable_name)s references.
|
||||
_INTERPOLATE_RE: re.Pattern[str] = re.compile(r"%\((\w+)\)s")
|
||||
|
||||
@@ -31,12 +31,12 @@ import tempfile
|
||||
import threading
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-file lock registry
|
||||
|
||||
@@ -51,19 +51,6 @@ CSRF_HEADER_NAME: Final[str] = "X-BanGUI-Request"
|
||||
CSRF_HEADER_VALUE: Final[str] = "1"
|
||||
"""Required value of the CSRF header to pass validation."""
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Authentication penalty (brute-force resistance)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
LOGIN_PENALTY_BASE_SECONDS: Final[float] = 1.0
|
||||
"""Base penalty (seconds) for a failed login attempt."""
|
||||
|
||||
LOGIN_PENALTY_MAX_SECONDS: Final[float] = 10.0
|
||||
"""Maximum penalty (seconds) for failed login attempts."""
|
||||
|
||||
LOGIN_PENALTY_MULTIPLIER: Final[float] = 2.0
|
||||
"""Exponential multiplier applied per failed attempt."""
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Time-range presets (used by dashboard and history endpoints)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -16,9 +16,9 @@ from typing import TYPE_CHECKING, Any, Literal
|
||||
if TYPE_CHECKING:
|
||||
from aiohttp import ClientSession
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
class ExternalLogHandler(ABC):
|
||||
|
||||
@@ -24,7 +24,7 @@ from collections.abc import Mapping, Sequence, Set
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.exceptions import Fail2BanConnectionError, Fail2BanProtocolError
|
||||
|
||||
@@ -68,7 +68,7 @@ type Fail2BanResponse = tuple[int, object]
|
||||
if TYPE_CHECKING:
|
||||
from types import TracebackType
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
# Attempt to reuse the vendored fail2ban package embedded in the repository.
|
||||
# If it is not on sys.path yet, load it from ``../fail2ban-master``.
|
||||
|
||||
@@ -5,9 +5,9 @@ from __future__ import annotations
|
||||
import json
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
def escape_like(s: str) -> str:
|
||||
|
||||
@@ -11,12 +11,12 @@ from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Default file contents
|
||||
|
||||
@@ -11,7 +11,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
from typing import cast
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.exceptions import JailNotFoundError, JailOperationError
|
||||
from app.utils.fail2ban_client import (
|
||||
@@ -24,7 +24,7 @@ from app.utils.fail2ban_response import (
|
||||
to_dict,
|
||||
)
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
# Socket communication timeout in seconds.
|
||||
SOCKET_TIMEOUT: float = 10.0
|
||||
|
||||
85
backend/app/utils/json_formatter.py
Normal file
85
backend/app/utils/json_formatter.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""JSON formatter for stdlib logging that preserves extra fields.
|
||||
|
||||
A single logging.Formatter subclass that serialises any keyword arguments
|
||||
passed via ``extra=`` into the JSON output alongside the standard record
|
||||
attributes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
# Attributes that belong to the standard LogRecord and should NOT be
|
||||
# treated as user-supplied extra fields.
|
||||
_STD_RECORD_ATTRS: frozenset[str] = frozenset(
|
||||
{
|
||||
"name",
|
||||
"msg",
|
||||
"args",
|
||||
"levelname",
|
||||
"levelno",
|
||||
"pathname",
|
||||
"filename",
|
||||
"module",
|
||||
"exc_info",
|
||||
"exc_text",
|
||||
"stack_info",
|
||||
"lineno",
|
||||
"funcName",
|
||||
"created",
|
||||
"msecs",
|
||||
"relativeCreated",
|
||||
"thread",
|
||||
"threadName",
|
||||
"processName",
|
||||
"process",
|
||||
"message",
|
||||
"asctime",
|
||||
"taskName",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class JSONFormatter(logging.Formatter):
|
||||
"""Format log records as JSON lines, including extra fields.
|
||||
|
||||
Usage::
|
||||
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(JSONFormatter())
|
||||
logging.getLogger().addHandler(handler)
|
||||
|
||||
Output keys:
|
||||
- ``event`` – the log message
|
||||
- ``level`` – lower-cased level name
|
||||
- ``timestamp`` – ISO-8601 UTC timestamp
|
||||
- ``logger`` – logger name
|
||||
- any ``extra`` fields supplied by the caller
|
||||
"""
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
"""Return a JSON string for *record*."""
|
||||
log_dict: dict[str, Any] = {
|
||||
"event": record.getMessage(),
|
||||
"level": record.levelname.lower(),
|
||||
"timestamp": (
|
||||
datetime.fromtimestamp(record.created, tz=timezone.utc).isoformat()
|
||||
),
|
||||
"logger": record.name,
|
||||
}
|
||||
|
||||
# Merge any extra fields attached to the record.
|
||||
for key, value in record.__dict__.items():
|
||||
if key not in _STD_RECORD_ATTRS:
|
||||
log_dict[key] = value
|
||||
|
||||
# Include exception info when present.
|
||||
if record.exc_info and not record.exc_text:
|
||||
record.exc_text = self.formatException(record.exc_info)
|
||||
if record.exc_text:
|
||||
log_dict["exception"] = record.exc_text
|
||||
|
||||
return json.dumps(log_dict, default=str)
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Log sanitization utilities for preventing sensitive data leakage.
|
||||
|
||||
All external output (subprocess, API responses, config data) passed to
|
||||
structlog MUST be sanitized first. This module provides the canonical
|
||||
logging MUST be sanitized first. This module provides the canonical
|
||||
sanitize_for_logging() function used across the codebase.
|
||||
"""
|
||||
|
||||
|
||||
63
backend/app/utils/logging_compat.py
Normal file
63
backend/app/utils/logging_compat.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""Compatibility shim providing keyword-argument logging API on top of stdlib logging.
|
||||
|
||||
This module lets the rest of the codebase keep the keyword-argument logging
|
||||
style (``log.info("event", key=value)``) while using only the Python standard
|
||||
library ``logging`` module underneath.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
|
||||
class _CompatLogger:
|
||||
"""Wraps a stdlib :class:`logging.Logger` to accept keyword arguments."""
|
||||
|
||||
def __init__(self, logger: logging.Logger) -> None:
|
||||
self._logger = logger
|
||||
|
||||
def _log(self, level: int, event: str, **kwargs: Any) -> None:
|
||||
exc_info = kwargs.pop("exc_info", None)
|
||||
extra = kwargs if kwargs else None
|
||||
self._logger.log(level, event, exc_info=exc_info, extra=extra)
|
||||
|
||||
def debug(self, event: str, **kwargs: Any) -> None:
|
||||
self._log(logging.DEBUG, event, **kwargs)
|
||||
|
||||
def info(self, event: str, **kwargs: Any) -> None:
|
||||
self._log(logging.INFO, event, **kwargs)
|
||||
|
||||
def warning(self, event: str, **kwargs: Any) -> None:
|
||||
self._log(logging.WARNING, event, **kwargs)
|
||||
|
||||
def warn(self, event: str, **kwargs: Any) -> None:
|
||||
self._log(logging.WARNING, event, **kwargs)
|
||||
|
||||
def error(self, event: str, **kwargs: Any) -> None:
|
||||
self._log(logging.ERROR, event, **kwargs)
|
||||
|
||||
def critical(self, event: str, **kwargs: Any) -> None:
|
||||
self._log(logging.CRITICAL, event, **kwargs)
|
||||
|
||||
def exception(self, event: str, **kwargs: Any) -> None:
|
||||
self._log(logging.ERROR, event, exc_info=True, **kwargs)
|
||||
|
||||
def bind(self, **kwargs: Any) -> "_CompatLogger":
|
||||
"""Return a new logger with bound context (no-op for stdlib)."""
|
||||
return self
|
||||
|
||||
|
||||
def get_logger(name: str | None = None) -> _CompatLogger:
|
||||
"""Get a compatibility logger wrapping the stdlib logger for *name*.
|
||||
|
||||
If *name* is ``None`` the caller's module name is used.
|
||||
"""
|
||||
if name is None:
|
||||
import sys
|
||||
|
||||
# Walk up the stack to find the caller's module.
|
||||
frame = sys._getframe(1)
|
||||
module = frame.f_globals.get("__name__", "__main__")
|
||||
name = module
|
||||
return _CompatLogger(logging.getLogger(name))
|
||||
@@ -11,9 +11,9 @@ and get_metrics() returns an empty bytes object.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
try:
|
||||
from prometheus_client import (
|
||||
|
||||
@@ -1,46 +1,25 @@
|
||||
"""In-memory rate limiter for IP-based request throttling.
|
||||
"""In-memory global rate limiter for IP-based request throttling.
|
||||
|
||||
Implements exponential backoff for failed login attempts using failure tracking.
|
||||
Each wrong password attempt increments the failure count for that IP, and subsequent
|
||||
attempts are blocked for a duration that grows exponentially up to a maximum.
|
||||
|
||||
Uses a dictionary of deques (per IP) storing timestamps of recent failures.
|
||||
Old entries are cleaned up by a background task to prevent unbounded growth.
|
||||
Implements a sliding-window request counter per IP address. Old entries are
|
||||
cleaned up by a background task to prevent unbounded growth.
|
||||
|
||||
Process-local implementation — in multi-worker setups, each worker has
|
||||
independent counters. This constraint limits the blast radius of brute-force
|
||||
attacks to a single worker.
|
||||
independent counters. This constraint limits the blast radius of abuse to a
|
||||
single worker.
|
||||
|
||||
**How It Works:**
|
||||
**Cleanup Lifecycle**: The rate limiter state grows as IPs interact with the
|
||||
system. To prevent unbounded memory growth during long runtimes, a scheduled
|
||||
background task (rate_limiter_cleanup) calls cleanup_expired() every 30 minutes.
|
||||
This is safe because:
|
||||
|
||||
1. A successful login resets the failure counter for that IP.
|
||||
2. Each failed login (wrong password) calls record_failure() and increments the counter.
|
||||
3. is_allowed() checks if enough time has passed since the last failure based on
|
||||
the current failure count. The delay grows exponentially with each consecutive failure:
|
||||
|
||||
- 1st failure: 0.5 second penalty
|
||||
- 2nd failure: 1 second penalty (0.5 * 2^1)
|
||||
- 3rd failure: 2 seconds penalty (0.5 * 2^2)
|
||||
- 4th failure: 4 seconds penalty (0.5 * 2^3)
|
||||
- ... up to the configured maximum (default 5 seconds)
|
||||
|
||||
4. Penalties are cumulative within the window: if an attacker makes 5 failed
|
||||
attempts, they must wait the full 5 seconds before trying again (not 5 seconds
|
||||
per attempt).
|
||||
|
||||
**Cleanup Lifecycle**: The rate limiter state (_failures) grows as IPs interact
|
||||
with the system. To prevent unbounded memory growth during long runtimes, a
|
||||
scheduled background task (rate_limiter_cleanup) calls cleanup_expired() every
|
||||
30 minutes. This is safe because:
|
||||
|
||||
- cleanup_expired() only removes IPs with no recent failures (all timestamps
|
||||
- cleanup_expired() only removes IPs with no recent requests (all timestamps
|
||||
outside the rate-limit window), so active IPs are never disrupted.
|
||||
- The cleanup is non-blocking and logged for observability.
|
||||
- Individual requests already prune old timestamps from each IP's deque during
|
||||
is_allowed() and record_failure(), so cleanup primarily handles dormant IPs.
|
||||
check_allowed(), so cleanup primarily handles dormant IPs.
|
||||
|
||||
For monitoring, check logs for "rate_limiter_cleanup" events to observe how
|
||||
many IPs are being retired from memory each cleanup cycle.
|
||||
For monitoring, check logs for "global_rate_limiter_cleanup" events to observe
|
||||
how many IPs are being retired from memory each cleanup cycle.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -49,173 +28,21 @@ from collections import deque
|
||||
from time import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.utils.constants import (
|
||||
LOGIN_PENALTY_BASE_SECONDS,
|
||||
LOGIN_PENALTY_MAX_SECONDS,
|
||||
LOGIN_PENALTY_MULTIPLIER,
|
||||
)
|
||||
from app.utils.ip_utils import normalise_ip
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Mapping
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
# 5 attempts per minute per IP (300 seconds)
|
||||
DEFAULT_RATE_LIMIT_ATTEMPTS = 5
|
||||
DEFAULT_RATE_LIMIT_WINDOW_SECONDS = 60
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Track and enforce request rate limits per IP address.
|
||||
|
||||
Stores attempt timestamps in per-IP deques, removing old entries
|
||||
outside the rate limit window.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_attempts: int = DEFAULT_RATE_LIMIT_ATTEMPTS,
|
||||
window_seconds: int = DEFAULT_RATE_LIMIT_WINDOW_SECONDS,
|
||||
) -> None:
|
||||
"""Initialize the rate limiter.
|
||||
|
||||
Args:
|
||||
max_attempts: Maximum attempts allowed within the window.
|
||||
(Deprecated: now only used for cleanup window size)
|
||||
window_seconds: Time window (seconds) for rate limit.
|
||||
"""
|
||||
self.max_attempts: int = max_attempts
|
||||
self.window_seconds: int = window_seconds
|
||||
self._failures: dict[str, deque[float]] = {}
|
||||
|
||||
def is_allowed(self, ip_address: str) -> bool:
|
||||
"""Check if a request from *ip_address* is allowed.
|
||||
|
||||
Checks if the IP has accumulated failures that would currently block
|
||||
the attempt due to penalty backoff. Does NOT record a new attempt —
|
||||
that happens only on successful password verification.
|
||||
|
||||
Args:
|
||||
ip_address: The client IP address to rate-limit.
|
||||
|
||||
Returns:
|
||||
``True`` if the request is allowed (past penalty period), ``False``
|
||||
if currently blocked by exponential backoff.
|
||||
"""
|
||||
ip_address = normalise_ip(ip_address)
|
||||
now = time()
|
||||
|
||||
if ip_address not in self._failures:
|
||||
self._failures[ip_address] = deque()
|
||||
|
||||
failures = self._failures[ip_address]
|
||||
cutoff = now - self.window_seconds
|
||||
|
||||
# Remove old failures outside the window
|
||||
while failures and failures[0] < cutoff:
|
||||
failures.popleft()
|
||||
|
||||
# If no recent failures, request is allowed
|
||||
if not failures:
|
||||
return True
|
||||
|
||||
# Calculate accumulated penalty: how much time must pass before
|
||||
# the next attempt is allowed, based on failure count
|
||||
failure_count = len(failures)
|
||||
penalty = min(
|
||||
LOGIN_PENALTY_BASE_SECONDS * (LOGIN_PENALTY_MULTIPLIER ** failure_count),
|
||||
LOGIN_PENALTY_MAX_SECONDS,
|
||||
)
|
||||
|
||||
# Check if enough time has passed since the last failure
|
||||
time_since_last_failure = now - failures[-1]
|
||||
return time_since_last_failure >= penalty
|
||||
|
||||
def cleanup_expired(self) -> None:
|
||||
"""Remove all IPs with no recent failures (cleanup task).
|
||||
|
||||
Called periodically by the background task to prevent unbounded
|
||||
growth of the tracking dictionary.
|
||||
"""
|
||||
now = time()
|
||||
cutoff = now - self.window_seconds
|
||||
|
||||
ips_to_remove = []
|
||||
for ip_address, failures in self._failures.items():
|
||||
# Remove old failures
|
||||
while failures and failures[0] < cutoff:
|
||||
failures.popleft()
|
||||
# Mark IP for removal if no failures remain
|
||||
if not failures:
|
||||
ips_to_remove.append(ip_address)
|
||||
|
||||
for ip_address in ips_to_remove:
|
||||
del self._failures[ip_address]
|
||||
|
||||
if ips_to_remove:
|
||||
log.debug("rate_limiter_cleanup", removed_ips=len(ips_to_remove))
|
||||
|
||||
def get_state(self) -> Mapping[str, int]:
|
||||
"""Return a read-only view of current failure counts per IP.
|
||||
|
||||
For debugging and monitoring.
|
||||
|
||||
Returns:
|
||||
A mapping of IP addresses to their failure counts.
|
||||
"""
|
||||
now = time()
|
||||
cutoff = now - self.window_seconds
|
||||
result = {}
|
||||
for ip_address, failures in self._failures.items():
|
||||
# Count non-expired failures
|
||||
count = sum(1 for ts in failures if ts >= cutoff)
|
||||
if count > 0:
|
||||
result[ip_address] = count
|
||||
return result
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Clear all tracked failures (for testing)."""
|
||||
self._failures.clear()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Penalty strategy for failed login attempts
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def record_failure(self, ip_address: str) -> None:
|
||||
"""Record a failed login attempt.
|
||||
|
||||
Tracks failures per IP to enable exponential backoff in is_allowed().
|
||||
The penalty delay is automatically calculated in is_allowed() based on
|
||||
the failure count, providing transparent brute-force resistance.
|
||||
|
||||
Args:
|
||||
ip_address: The client IP address whose login attempt failed.
|
||||
"""
|
||||
ip_address = normalise_ip(ip_address)
|
||||
now = time()
|
||||
|
||||
if ip_address not in self._failures:
|
||||
self._failures[ip_address] = deque()
|
||||
|
||||
failures = self._failures[ip_address]
|
||||
cutoff = now - self.window_seconds
|
||||
|
||||
# Remove old failures outside the window
|
||||
while failures and failures[0] < cutoff:
|
||||
failures.popleft()
|
||||
|
||||
# Record this failure
|
||||
failures.append(now)
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
class GlobalRateLimiter:
|
||||
"""Global per-IP request rate limiter using sliding window algorithm.
|
||||
|
||||
Tracks total request count within a configurable time window per IP address.
|
||||
Unlike RateLimiter (which uses exponential backoff), this implements simple
|
||||
This implements simple
|
||||
request counting: when an IP exceeds the limit, the next request is blocked
|
||||
until the oldest request in the window expires.
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ import signal
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
try:
|
||||
from regexploit.ast.sre import SreOpParser
|
||||
@@ -25,7 +25,7 @@ except ImportError:
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
|
||||
logger = structlog.get_logger()
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Constants for regex validation
|
||||
MAX_REGEX_LENGTH = 1000
|
||||
|
||||
@@ -53,7 +53,7 @@ import datetime
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
from starlette.datastructures import State
|
||||
|
||||
from app.models.config import PendingRecovery
|
||||
@@ -63,7 +63,7 @@ from app.utils.session_cache import InMemorySessionCache, NoOpSessionCache
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from app.config import Settings
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
ActivationRecord = dict[str, datetime.datetime]
|
||||
|
||||
|
||||
@@ -46,9 +46,9 @@ import time
|
||||
from typing import Any
|
||||
|
||||
import aiosqlite
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
# Lock record expires if heartbeat hasn't been updated for this many seconds.
|
||||
# This prevents stale locks from a crashed instance from blocking new startups.
|
||||
|
||||
@@ -15,7 +15,6 @@ dependencies = [
|
||||
"aiosqlite>=0.20.0",
|
||||
"aiohttp>=3.11.0",
|
||||
"apscheduler>=3.10,<4.0",
|
||||
"structlog>=24.4.0",
|
||||
"bcrypt>=4.2.0",
|
||||
"geoip2>=4.8.0",
|
||||
"prometheus-client>=0.21.0",
|
||||
|
||||
70
backend/tests/logging_capture.py
Normal file
70
backend/tests/logging_capture.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""Test utilities for capturing stdlib log records."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
|
||||
class _CaptureHandler(logging.Handler):
|
||||
"""Handler that stores every emitted record as a dict."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.records: list[dict[str, Any]] = []
|
||||
|
||||
def emit(self, record: logging.LogRecord) -> None:
|
||||
entry: dict[str, Any] = {
|
||||
"event": record.getMessage(),
|
||||
"level": record.levelname.lower(),
|
||||
"logger": record.name,
|
||||
}
|
||||
# Merge extra fields attached to the record.
|
||||
std_attrs = {
|
||||
"name",
|
||||
"msg",
|
||||
"args",
|
||||
"levelname",
|
||||
"levelno",
|
||||
"pathname",
|
||||
"filename",
|
||||
"module",
|
||||
"exc_info",
|
||||
"exc_text",
|
||||
"stack_info",
|
||||
"lineno",
|
||||
"funcName",
|
||||
"created",
|
||||
"msecs",
|
||||
"relativeCreated",
|
||||
"thread",
|
||||
"threadName",
|
||||
"processName",
|
||||
"process",
|
||||
"message",
|
||||
"asctime",
|
||||
"taskName",
|
||||
}
|
||||
for key, value in record.__dict__.items():
|
||||
if key not in std_attrs:
|
||||
entry[key] = value
|
||||
self.records.append(entry)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def capture_logs() -> Generator[list[dict[str, Any]], None, None]:
|
||||
"""Capture all log records emitted inside the context.
|
||||
|
||||
Yields a list of dicts, each representing a log entry with keys
|
||||
``event``, ``level``, ``logger`` and any extra fields.
|
||||
"""
|
||||
handler = _CaptureHandler()
|
||||
handler.setLevel(logging.DEBUG)
|
||||
root = logging.getLogger()
|
||||
root.addHandler(handler)
|
||||
try:
|
||||
yield handler.records
|
||||
finally:
|
||||
root.removeHandler(handler)
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -107,127 +106,7 @@ class TestLogin:
|
||||
response = await client.post("/api/v1/auth/login", json={})
|
||||
assert response.status_code == 422
|
||||
|
||||
async def test_login_rate_limit_returns_429_after_5_attempts(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
"""Login is blocked immediately after first failed attempt due to exponential backoff."""
|
||||
await _do_setup(client)
|
||||
limiter = client._transport.app.state.login_rate_limiter
|
||||
limiter.reset()
|
||||
|
||||
# First failed attempt is allowed
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login", json={"password": "wrongpassword"}
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
# Second attempt immediately after is blocked by 1s penalty
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login", json={"password": "wrongpassword"}
|
||||
)
|
||||
assert response.status_code == 429
|
||||
assert response.json()["detail"] == "Too many login attempts. Please try again later."
|
||||
|
||||
# Verify the failure count is correct
|
||||
state = limiter.get_state()
|
||||
assert "127.0.0.1" in state
|
||||
assert state["127.0.0.1"] >= 1
|
||||
|
||||
async def test_login_rate_limit_includes_retry_after_header(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
"""Rate-limited response includes Retry-After header."""
|
||||
await _do_setup(client)
|
||||
limiter = client._transport.app.state.login_rate_limiter
|
||||
limiter.reset()
|
||||
|
||||
# First attempt fails
|
||||
response = await client.post("/api/v1/auth/login", json={"password": "wrong"})
|
||||
assert response.status_code == 401
|
||||
|
||||
# Second immediate attempt is rate-limited
|
||||
response = await client.post("/api/v1/auth/login", json={"password": "wrong"})
|
||||
assert response.status_code == 429
|
||||
assert "retry-after" in response.headers
|
||||
assert response.headers["retry-after"] == "60"
|
||||
|
||||
async def test_login_rate_limit_per_ip(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
"""Rate limit is tracked separately per IP address."""
|
||||
await _do_setup(client)
|
||||
limiter = client._transport.app.state.login_rate_limiter
|
||||
limiter.reset()
|
||||
|
||||
# Make 1 failed attempt with default IP
|
||||
response = await client.post("/api/v1/auth/login", json={"password": "wrong"})
|
||||
assert response.status_code == 401
|
||||
|
||||
# 2nd attempt is blocked
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login", json={"password": "correct"}
|
||||
)
|
||||
assert response.status_code == 429
|
||||
|
||||
# Verify the failure count is correct
|
||||
state = limiter.get_state()
|
||||
assert "127.0.0.1" in state
|
||||
assert state["127.0.0.1"] >= 1
|
||||
|
||||
async def test_login_rate_limit_reset_after_window(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
"""Rate limit counter resets after the window expires."""
|
||||
await _do_setup(client)
|
||||
limiter = client._transport.app.state.login_rate_limiter
|
||||
limiter.reset()
|
||||
|
||||
# Make 1 failed attempt (enough to trigger exponential backoff)
|
||||
response = await client.post("/api/v1/auth/login", json={"password": "wrong"})
|
||||
assert response.status_code == 401
|
||||
|
||||
# 2nd attempt is blocked
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login", json={"password": "wrong"}
|
||||
)
|
||||
assert response.status_code == 429
|
||||
|
||||
# Reset the limiter (simulate window expiry)
|
||||
limiter.reset()
|
||||
|
||||
# Now a fresh login attempt should succeed (use correct password)
|
||||
response = await client.post(
|
||||
"/api/v1/auth/login", json={"password": "Mysecretpass1!"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
async def test_login_exponential_backoff(self, client: AsyncClient) -> None:
|
||||
"""Exponential backoff accumulates with each consecutive failure."""
|
||||
await _do_setup(client)
|
||||
limiter = client._transport.app.state.login_rate_limiter
|
||||
limiter.reset()
|
||||
|
||||
# 1st failure: 1 * 2^1 = 2s penalty
|
||||
response = await client.post("/api/v1/auth/login", json={"password": "wrong"})
|
||||
assert response.status_code == 401
|
||||
state = limiter.get_state()
|
||||
assert state["127.0.0.1"] == 1
|
||||
|
||||
# 2nd attempt blocked immediately by 2s penalty
|
||||
response = await client.post("/api/v1/auth/login", json={"password": "wrong"})
|
||||
assert response.status_code == 429
|
||||
|
||||
# After 2.1s, the penalty expires and we can try again
|
||||
# (this will record a 2nd failure, creating a 1 * 2^2 = 4s penalty)
|
||||
await asyncio.sleep(2.1)
|
||||
response = await client.post("/api/v1/auth/login", json={"password": "wrong"})
|
||||
assert response.status_code == 401
|
||||
state = limiter.get_state()
|
||||
assert state["127.0.0.1"] == 2
|
||||
|
||||
# Now blocked by 4s penalty
|
||||
response = await client.post("/api/v1/auth/login", json={"password": "wrong"})
|
||||
assert response.status_code == 429
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -790,9 +790,9 @@ class TestErrorLogging:
|
||||
mock_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||
session.get = MagicMock(return_value=mock_ctx)
|
||||
|
||||
import structlog.testing
|
||||
from tests.logging_capture import capture_logs
|
||||
|
||||
with structlog.testing.capture_logs() as captured, patch.object(
|
||||
with capture_logs() as captured, patch.object(
|
||||
geo_cache, "_geoip_reader", None
|
||||
):
|
||||
# Ensure MMDB is not available so HTTP is tried.
|
||||
@@ -817,9 +817,9 @@ class TestErrorLogging:
|
||||
mock_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||
session.get = MagicMock(return_value=mock_ctx)
|
||||
|
||||
import structlog.testing
|
||||
from tests.logging_capture import capture_logs
|
||||
|
||||
with structlog.testing.capture_logs() as captured, patch.object(
|
||||
with capture_logs() as captured, patch.object(
|
||||
geo_cache, "_geoip_reader", None
|
||||
):
|
||||
# Ensure MMDB is not available so HTTP is tried.
|
||||
@@ -844,9 +844,9 @@ class TestErrorLogging:
|
||||
mock_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||
session.post = MagicMock(return_value=mock_ctx)
|
||||
|
||||
import structlog.testing
|
||||
from tests.logging_capture import capture_logs
|
||||
|
||||
with structlog.testing.capture_logs() as captured:
|
||||
with capture_logs() as captured:
|
||||
result = await geo_cache._batch_api_call(["1.2.3.4"], session)
|
||||
|
||||
assert result["1.2.3.4"].country_code is None
|
||||
|
||||
@@ -8,7 +8,6 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
import structlog
|
||||
|
||||
from app.utils.async_utils import logged_task, run_blocking
|
||||
|
||||
@@ -108,7 +107,7 @@ async def test_logged_task_preserves_exception_info() -> None:
|
||||
with mock.patch("app.utils.async_utils.log") as mock_log:
|
||||
await logged_task(failing_coro(), "test_task")
|
||||
mock_log.exception.assert_called_once()
|
||||
# Verify the exception context is logged (structlog.exception captures
|
||||
# Verify the exception context is logged (exception captures
|
||||
# the traceback automatically)
|
||||
args, kwargs = mock_log.exception.call_args
|
||||
assert args[0] == "background_task_failed"
|
||||
|
||||
Reference in New Issue
Block a user