refactoring-backend #3
@@ -60,6 +60,7 @@ async def login(
|
||||
session_ctx: SessionServiceContextDep,
|
||||
settings: SettingsDep,
|
||||
rate_limiter: LoginRateLimiterDep,
|
||||
session_cache: SessionCacheDep,
|
||||
) -> LoginResponse:
|
||||
"""Verify the master password and return a session token.
|
||||
|
||||
@@ -71,6 +72,10 @@ async def login(
|
||||
Requests during the penalty period return ``429 Too Many Requests`` with
|
||||
a ``Retry-After`` header.
|
||||
|
||||
Cache invalidation: On successful login, any existing cached sessions for
|
||||
the same user are invalidated so that stale tokens (e.g., from a stolen
|
||||
device) cannot be reused beyond the cache TTL window.
|
||||
|
||||
Args:
|
||||
body: Login request validated by Pydantic.
|
||||
response: FastAPI response object used to set the cookie.
|
||||
@@ -78,6 +83,7 @@ async def login(
|
||||
session_ctx: Session service context containing db and repository.
|
||||
settings: Application settings (used for session duration and trusted proxies).
|
||||
rate_limiter: The login rate limiter (per IP).
|
||||
session_cache: Session cache for invalidating old sessions on login.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.auth.LoginResponse` containing the token.
|
||||
@@ -94,7 +100,7 @@ async def login(
|
||||
raise RateLimitError("Too many login attempts. Please try again later.", retry_after_seconds=60.0)
|
||||
|
||||
try:
|
||||
signed_token, expires_at = await auth_service.login(
|
||||
signed_token, expires_at, session = await auth_service.login(
|
||||
session_ctx.db,
|
||||
password=body.password,
|
||||
session_duration_minutes=settings.session_duration_minutes,
|
||||
@@ -107,6 +113,10 @@ async def login(
|
||||
log.warning("login_failed", client_ip=client_ip, error=str(exc))
|
||||
raise AuthenticationError(str(exc)) from exc
|
||||
|
||||
# Invalidate any cached sessions for the same user to prevent reuse of
|
||||
# stale tokens (e.g., from a stolen device) beyond the cache TTL window.
|
||||
session_cache.invalidate_by_user(session.id)
|
||||
|
||||
response.set_cookie(
|
||||
key=SESSION_COOKIE_NAME,
|
||||
value=signed_token,
|
||||
|
||||
@@ -151,7 +151,7 @@ async def login(
|
||||
session_duration_minutes: int,
|
||||
session_secret: str,
|
||||
session_repo: SessionRepository = default_session_repo,
|
||||
) -> tuple[str, str]:
|
||||
) -> tuple[str, str, Session]:
|
||||
"""Verify *password*, create a new session, and sign the token.
|
||||
|
||||
Args:
|
||||
@@ -161,7 +161,8 @@ async def login(
|
||||
session_secret: Secret used to sign the session token.
|
||||
|
||||
Returns:
|
||||
A tuple of the signed session token and its expiry timestamp.
|
||||
A tuple of the signed session token, its expiry timestamp,
|
||||
and the newly created session object.
|
||||
|
||||
Raises:
|
||||
ValueError: If the password is incorrect or no password hash is stored.
|
||||
@@ -185,7 +186,7 @@ async def login(
|
||||
)
|
||||
signed_token = sign_session_token(session.token, session_secret)
|
||||
log.info("bangui_login_success", session_id=session.id)
|
||||
return signed_token, session.expires_at
|
||||
return signed_token, session.expires_at, session
|
||||
|
||||
|
||||
async def validate_session(
|
||||
|
||||
@@ -77,6 +77,9 @@ class SessionCache(Protocol):
|
||||
def invalidate(self, token: str) -> None:
|
||||
"""Remove *token* from the cache if it exists."""
|
||||
|
||||
def invalidate_by_user(self, user_id: int) -> None:
|
||||
"""Remove all cached sessions belonging to *user_id*."""
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Remove all entries from the cache."""
|
||||
|
||||
@@ -86,6 +89,7 @@ class InMemorySessionCache:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._entries: dict[str, tuple[Session, float]] = {}
|
||||
self._user_index: dict[int, set[str]] = {}
|
||||
|
||||
def get(self, token: str) -> Session | None:
|
||||
entry = self._entries.get(token)
|
||||
@@ -95,17 +99,36 @@ class InMemorySessionCache:
|
||||
session, expires_at = entry
|
||||
if time.monotonic() >= expires_at:
|
||||
self._entries.pop(token, None)
|
||||
self._remove_from_user_index(token, session.id)
|
||||
return None
|
||||
return session
|
||||
|
||||
def set(self, token: str, session: Session, ttl_seconds: float) -> None:
|
||||
self._entries[token] = (session, time.monotonic() + ttl_seconds)
|
||||
expires_at = time.monotonic() + ttl_seconds
|
||||
self._entries[token] = (session, expires_at)
|
||||
self._user_index.setdefault(session.id, set()).add(token)
|
||||
|
||||
def invalidate(self, token: str) -> None:
|
||||
self._entries.pop(token, None)
|
||||
entry = self._entries.pop(token, None)
|
||||
if entry is not None:
|
||||
self._remove_from_user_index(token, entry[0].id)
|
||||
|
||||
def invalidate_by_user(self, user_id: int) -> None:
|
||||
"""Remove all cached sessions for *user_id*."""
|
||||
tokens = self._user_index.pop(user_id, set())
|
||||
for token in tokens:
|
||||
self._entries.pop(token, None)
|
||||
|
||||
def clear(self) -> None:
|
||||
self._entries.clear()
|
||||
self._user_index.clear()
|
||||
|
||||
def _remove_from_user_index(self, token: str, user_id: int) -> None:
|
||||
user_tokens = self._user_index.get(user_id)
|
||||
if user_tokens is not None:
|
||||
user_tokens.discard(token)
|
||||
if not user_tokens:
|
||||
self._user_index.pop(user_id, None)
|
||||
|
||||
|
||||
class NoOpSessionCache:
|
||||
@@ -120,5 +143,8 @@ class NoOpSessionCache:
|
||||
def invalidate(self, token: str) -> None:
|
||||
return None
|
||||
|
||||
def invalidate_by_user(self, user_id: int) -> None:
|
||||
return None
|
||||
|
||||
def clear(self) -> None:
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user