fix(auth): invalidate session cache on login
Stale sessions from a stolen device could be reused up to the cache TTL after a legitimate user re-logs in, because login never cleared the existing cache entry. Changes: - Add invalidate_by_user(user_id) to SessionCache protocol - InMemorySessionCache maintains a user_id -> set[token] index to support O(1) invalidation of all sessions for a given user - NoOpSessionCache stub updated for API compatibility - auth_service.login() now returns the Session object alongside signed_token and expires_at - login router calls session_cache.invalidate_by_user(session.id) immediately after successful authentication Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -77,6 +77,9 @@ class SessionCache(Protocol):
|
||||
def invalidate(self, token: str) -> None:
|
||||
"""Remove *token* from the cache if it exists."""
|
||||
|
||||
def invalidate_by_user(self, user_id: int) -> None:
|
||||
"""Remove all cached sessions belonging to *user_id*."""
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Remove all entries from the cache."""
|
||||
|
||||
@@ -86,6 +89,7 @@ class InMemorySessionCache:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._entries: dict[str, tuple[Session, float]] = {}
|
||||
self._user_index: dict[int, set[str]] = {}
|
||||
|
||||
def get(self, token: str) -> Session | None:
|
||||
entry = self._entries.get(token)
|
||||
@@ -95,17 +99,36 @@ class InMemorySessionCache:
|
||||
session, expires_at = entry
|
||||
if time.monotonic() >= expires_at:
|
||||
self._entries.pop(token, None)
|
||||
self._remove_from_user_index(token, session.id)
|
||||
return None
|
||||
return session
|
||||
|
||||
def set(self, token: str, session: Session, ttl_seconds: float) -> None:
|
||||
self._entries[token] = (session, time.monotonic() + ttl_seconds)
|
||||
expires_at = time.monotonic() + ttl_seconds
|
||||
self._entries[token] = (session, expires_at)
|
||||
self._user_index.setdefault(session.id, set()).add(token)
|
||||
|
||||
def invalidate(self, token: str) -> None:
|
||||
self._entries.pop(token, None)
|
||||
entry = self._entries.pop(token, None)
|
||||
if entry is not None:
|
||||
self._remove_from_user_index(token, entry[0].id)
|
||||
|
||||
def invalidate_by_user(self, user_id: int) -> None:
|
||||
"""Remove all cached sessions for *user_id*."""
|
||||
tokens = self._user_index.pop(user_id, set())
|
||||
for token in tokens:
|
||||
self._entries.pop(token, None)
|
||||
|
||||
def clear(self) -> None:
|
||||
self._entries.clear()
|
||||
self._user_index.clear()
|
||||
|
||||
def _remove_from_user_index(self, token: str, user_id: int) -> None:
|
||||
user_tokens = self._user_index.get(user_id)
|
||||
if user_tokens is not None:
|
||||
user_tokens.discard(token)
|
||||
if not user_tokens:
|
||||
self._user_index.pop(user_id, None)
|
||||
|
||||
|
||||
class NoOpSessionCache:
|
||||
@@ -120,5 +143,8 @@ class NoOpSessionCache:
|
||||
def invalidate(self, token: str) -> None:
|
||||
return None
|
||||
|
||||
def invalidate_by_user(self, user_id: int) -> None:
|
||||
return None
|
||||
|
||||
def clear(self) -> None:
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user