fixed tests
This commit is contained in:
@@ -61,17 +61,20 @@ def normalise_ip(address: str) -> str:
|
||||
IPv4-mapped IPv6 addresses (e.g. ``::ffff:192.168.1.1``) are converted
|
||||
to their IPv4 equivalent (``192.168.1.1``).
|
||||
Plain IPv4 addresses are returned unchanged.
|
||||
Non-IP strings (e.g. ``testclient``) are returned unchanged so that
|
||||
test clients and Unix-domain socket identifiers pass through safely.
|
||||
|
||||
Args:
|
||||
address: A valid IP address string.
|
||||
address: An IP address string or other identifier.
|
||||
|
||||
Returns:
|
||||
Normalised IP address string.
|
||||
|
||||
Raises:
|
||||
ValueError: If *address* is not a valid IP address.
|
||||
Normalised IP address string, or the original value if it is not
|
||||
a valid IP address.
|
||||
"""
|
||||
ip = ipaddress.ip_address(address)
|
||||
try:
|
||||
ip = ipaddress.ip_address(address)
|
||||
except ValueError:
|
||||
return address
|
||||
if isinstance(ip, ipaddress.IPv6Address) and ip.ipv4_mapped:
|
||||
return str(ip.ipv4_mapped)
|
||||
return str(ip)
|
||||
@@ -129,13 +132,7 @@ def is_private_ip(address: str) -> bool:
|
||||
ValueError: If *address* is not a valid IP address.
|
||||
"""
|
||||
ip = ipaddress.ip_address(address)
|
||||
return (
|
||||
ip.is_private
|
||||
or ip.is_loopback
|
||||
or ip.is_link_local
|
||||
or ip.is_multicast
|
||||
or ip.is_reserved
|
||||
)
|
||||
return ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_multicast or ip.is_reserved
|
||||
|
||||
|
||||
async def validate_blocklist_url(url: str) -> None:
|
||||
@@ -165,9 +162,7 @@ async def validate_blocklist_url(url: str) -> None:
|
||||
raise ValueError(f"Invalid URL format: {exc}") from exc
|
||||
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
raise ValueError(
|
||||
f"Invalid scheme '{parsed.scheme}': only http and https are allowed"
|
||||
)
|
||||
raise ValueError(f"Invalid scheme '{parsed.scheme}': only http and https are allowed")
|
||||
|
||||
if not parsed.hostname:
|
||||
raise ValueError("URL has no hostname")
|
||||
@@ -201,14 +196,9 @@ async def validate_blocklist_url(url: str) -> None:
|
||||
# connection time, and host mode is never used in production.
|
||||
if is_private_ip(ip_str):
|
||||
import os
|
||||
if (
|
||||
os.getenv("BANGUI_LOG_LEVEL") == "debug"
|
||||
and ipaddress.ip_address(ip_str).is_loopback
|
||||
):
|
||||
|
||||
if os.getenv("BANGUI_LOG_LEVEL") == "debug" and ipaddress.ip_address(ip_str).is_loopback:
|
||||
continue
|
||||
raise ValueError(
|
||||
f"Hostname '{hostname}' resolves to private/reserved IP: {ip_str}"
|
||||
)
|
||||
raise ValueError(f"Hostname '{hostname}' resolves to private/reserved IP: {ip_str}")
|
||||
except ipaddress.AddressValueError as exc:
|
||||
raise ValueError(f"Invalid IP address: {ip_str}") from exc
|
||||
|
||||
|
||||
@@ -26,6 +26,19 @@ class _CompatLogger:
|
||||
if v is not None:
|
||||
stdlib_kwargs[k] = v
|
||||
if kwargs:
|
||||
# Several keys are reserved in LogRecord; rename them to avoid KeyError.
|
||||
reserved_renames = {
|
||||
"message": "log_message",
|
||||
"name": "log_name",
|
||||
"filename": "log_filename",
|
||||
"funcName": "log_funcName",
|
||||
"lineno": "log_lineno",
|
||||
"module": "log_module",
|
||||
"pathname": "log_pathname",
|
||||
}
|
||||
for old_key, new_key in reserved_renames.items():
|
||||
if old_key in kwargs:
|
||||
kwargs[new_key] = kwargs.pop(old_key)
|
||||
stdlib_kwargs["extra"] = kwargs
|
||||
self._logger.log(level, event, **stdlib_kwargs)
|
||||
|
||||
@@ -50,7 +63,7 @@ class _CompatLogger:
|
||||
def exception(self, event: str, **kwargs: Any) -> None:
|
||||
self._log(logging.ERROR, event, exc_info=True, **kwargs)
|
||||
|
||||
def bind(self, **kwargs: Any) -> "_CompatLogger":
|
||||
def bind(self, **kwargs: Any) -> _CompatLogger:
|
||||
"""Return a new logger with bound context (no-op for stdlib)."""
|
||||
return self
|
||||
|
||||
|
||||
@@ -46,6 +46,7 @@ import time
|
||||
from typing import Any
|
||||
|
||||
import aiosqlite
|
||||
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
log = get_logger(__name__)
|
||||
@@ -133,12 +134,10 @@ async def acquire_scheduler_lock(db: aiosqlite.Connection) -> bool:
|
||||
await db.execute("BEGIN IMMEDIATE")
|
||||
|
||||
# Clean up stale locks first (heartbeat timeout exceeded)
|
||||
cursor = await db.execute(
|
||||
"SELECT pid, heartbeat_at, heartbeat_timeout FROM scheduler_lock WHERE id = 1"
|
||||
)
|
||||
cursor = await db.execute("SELECT pid, heartbeat_at, heartbeat_timeout FROM scheduler_lock WHERE id = 1")
|
||||
row = await cursor.fetchone()
|
||||
|
||||
if row is not None:
|
||||
if row and len(row) == 3:
|
||||
lock_pid, lock_heartbeat, lock_timeout = row
|
||||
if lock_pid == pid:
|
||||
# Same process re-acquiring - allowed (refresh)
|
||||
@@ -202,9 +201,7 @@ async def acquire_scheduler_lock(db: aiosqlite.Connection) -> bool:
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to acquire scheduler lock due to database error: {e}"
|
||||
) from e
|
||||
raise RuntimeError(f"Failed to acquire scheduler lock due to database error: {e}") from e
|
||||
|
||||
|
||||
async def release_scheduler_lock(db: aiosqlite.Connection) -> None:
|
||||
@@ -372,9 +369,7 @@ async def get_lock_health(db: aiosqlite.Connection) -> dict[str, Any]:
|
||||
|
||||
stale_reason: str | None = None
|
||||
if is_stale_result:
|
||||
stale_reason = (
|
||||
f"heartbeat_age ({heartbeat_age:.1f}s) > timeout ({info['heartbeat_timeout']:.1f}s)"
|
||||
)
|
||||
stale_reason = f"heartbeat_age ({heartbeat_age:.1f}s) > timeout ({info['heartbeat_timeout']:.1f}s)"
|
||||
|
||||
return {
|
||||
"has_lock": True,
|
||||
|
||||
Reference in New Issue
Block a user