fix(auth): invalidate session cache on login

Stale sessions from a stolen device could be reused up to the cache
TTL after a legitimate user re-logs in, because login never cleared
the existing cache entry.

Changes:
- Add invalidate_by_user(user_id) to SessionCache protocol
- InMemorySessionCache maintains a user_id -> set[token] index to
  support O(1) invalidation of all sessions for a given user
- NoOpSessionCache stub updated for API compatibility
- auth_service.login() now returns the Session object alongside
  signed_token and expires_at
- login router calls session_cache.invalidate_by_user(session.id)
  immediately after successful authentication

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
2026-05-03 20:51:51 +02:00
parent ae9313568e
commit c3cd1574dc
3 changed files with 43 additions and 6 deletions

View File

@@ -77,6 +77,9 @@ class SessionCache(Protocol):
def invalidate(self, token: str) -> None:
"""Remove *token* from the cache if it exists."""
def invalidate_by_user(self, user_id: int) -> None:
"""Remove all cached sessions belonging to *user_id*."""
def clear(self) -> None:
"""Remove all entries from the cache."""
@@ -86,6 +89,7 @@ class InMemorySessionCache:
def __init__(self) -> None:
self._entries: dict[str, tuple[Session, float]] = {}
self._user_index: dict[int, set[str]] = {}
def get(self, token: str) -> Session | None:
entry = self._entries.get(token)
@@ -95,17 +99,36 @@ class InMemorySessionCache:
session, expires_at = entry
if time.monotonic() >= expires_at:
self._entries.pop(token, None)
self._remove_from_user_index(token, session.id)
return None
return session
def set(self, token: str, session: Session, ttl_seconds: float) -> None:
self._entries[token] = (session, time.monotonic() + ttl_seconds)
expires_at = time.monotonic() + ttl_seconds
self._entries[token] = (session, expires_at)
self._user_index.setdefault(session.id, set()).add(token)
def invalidate(self, token: str) -> None:
self._entries.pop(token, None)
entry = self._entries.pop(token, None)
if entry is not None:
self._remove_from_user_index(token, entry[0].id)
def invalidate_by_user(self, user_id: int) -> None:
"""Remove all cached sessions for *user_id*."""
tokens = self._user_index.pop(user_id, set())
for token in tokens:
self._entries.pop(token, None)
def clear(self) -> None:
self._entries.clear()
self._user_index.clear()
def _remove_from_user_index(self, token: str, user_id: int) -> None:
user_tokens = self._user_index.get(user_id)
if user_tokens is not None:
user_tokens.discard(token)
if not user_tokens:
self._user_index.pop(user_id, None)
class NoOpSessionCache:
@@ -120,5 +143,8 @@ class NoOpSessionCache:
def invalidate(self, token: str) -> None:
return None
def invalidate_by_user(self, user_id: int) -> None:
return None
def clear(self) -> None:
return None