diff --git a/backend/app/routers/auth.py b/backend/app/routers/auth.py index f87852c..0aa6006 100644 --- a/backend/app/routers/auth.py +++ b/backend/app/routers/auth.py @@ -60,6 +60,7 @@ async def login( session_ctx: SessionServiceContextDep, settings: SettingsDep, rate_limiter: LoginRateLimiterDep, + session_cache: SessionCacheDep, ) -> LoginResponse: """Verify the master password and return a session token. @@ -71,6 +72,10 @@ async def login( Requests during the penalty period return ``429 Too Many Requests`` with a ``Retry-After`` header. + Cache invalidation: On successful login, any existing cached sessions for + the same user are invalidated so that stale tokens (e.g., from a stolen + device) cannot be reused beyond the cache TTL window. + Args: body: Login request validated by Pydantic. response: FastAPI response object used to set the cookie. @@ -78,6 +83,7 @@ async def login( session_ctx: Session service context containing db and repository. settings: Application settings (used for session duration and trusted proxies). rate_limiter: The login rate limiter (per IP). + session_cache: Session cache for invalidating old sessions on login. Returns: :class:`~app.models.auth.LoginResponse` containing the token. @@ -94,7 +100,7 @@ async def login( raise RateLimitError("Too many login attempts. Please try again later.", retry_after_seconds=60.0) try: - signed_token, expires_at = await auth_service.login( + signed_token, expires_at, session = await auth_service.login( session_ctx.db, password=body.password, session_duration_minutes=settings.session_duration_minutes, @@ -107,6 +113,10 @@ async def login( log.warning("login_failed", client_ip=client_ip, error=str(exc)) raise AuthenticationError(str(exc)) from exc + # Invalidate any cached sessions for the same user to prevent reuse of + # stale tokens (e.g., from a stolen device) beyond the cache TTL window. + session_cache.invalidate_by_user(session.id) + response.set_cookie( key=SESSION_COOKIE_NAME, value=signed_token, diff --git a/backend/app/services/auth_service.py b/backend/app/services/auth_service.py index 72d7781..cf81d49 100644 --- a/backend/app/services/auth_service.py +++ b/backend/app/services/auth_service.py @@ -151,7 +151,7 @@ async def login( session_duration_minutes: int, session_secret: str, session_repo: SessionRepository = default_session_repo, -) -> tuple[str, str]: +) -> tuple[str, str, Session]: """Verify *password*, create a new session, and sign the token. Args: @@ -161,7 +161,8 @@ async def login( session_secret: Secret used to sign the session token. Returns: - A tuple of the signed session token and its expiry timestamp. + A tuple of the signed session token, its expiry timestamp, + and the newly created session object. Raises: ValueError: If the password is incorrect or no password hash is stored. @@ -185,7 +186,7 @@ async def login( ) signed_token = sign_session_token(session.token, session_secret) log.info("bangui_login_success", session_id=session.id) - return signed_token, session.expires_at + return signed_token, session.expires_at, session async def validate_session( diff --git a/backend/app/utils/session_cache.py b/backend/app/utils/session_cache.py index 97d8957..ddca6dd 100644 --- a/backend/app/utils/session_cache.py +++ b/backend/app/utils/session_cache.py @@ -77,6 +77,9 @@ class SessionCache(Protocol): def invalidate(self, token: str) -> None: """Remove *token* from the cache if it exists.""" + def invalidate_by_user(self, user_id: int) -> None: + """Remove all cached sessions belonging to *user_id*.""" + def clear(self) -> None: """Remove all entries from the cache.""" @@ -86,6 +89,7 @@ class InMemorySessionCache: def __init__(self) -> None: self._entries: dict[str, tuple[Session, float]] = {} + self._user_index: dict[int, set[str]] = {} def get(self, token: str) -> Session | None: entry = self._entries.get(token) @@ -95,17 +99,36 @@ class InMemorySessionCache: session, expires_at = entry if time.monotonic() >= expires_at: self._entries.pop(token, None) + self._remove_from_user_index(token, session.id) return None return session def set(self, token: str, session: Session, ttl_seconds: float) -> None: - self._entries[token] = (session, time.monotonic() + ttl_seconds) + expires_at = time.monotonic() + ttl_seconds + self._entries[token] = (session, expires_at) + self._user_index.setdefault(session.id, set()).add(token) def invalidate(self, token: str) -> None: - self._entries.pop(token, None) + entry = self._entries.pop(token, None) + if entry is not None: + self._remove_from_user_index(token, entry[0].id) + + def invalidate_by_user(self, user_id: int) -> None: + """Remove all cached sessions for *user_id*.""" + tokens = self._user_index.pop(user_id, set()) + for token in tokens: + self._entries.pop(token, None) def clear(self) -> None: self._entries.clear() + self._user_index.clear() + + def _remove_from_user_index(self, token: str, user_id: int) -> None: + user_tokens = self._user_index.get(user_id) + if user_tokens is not None: + user_tokens.discard(token) + if not user_tokens: + self._user_index.pop(user_id, None) class NoOpSessionCache: @@ -120,5 +143,8 @@ class NoOpSessionCache: def invalidate(self, token: str) -> None: return None + def invalidate_by_user(self, user_id: int) -> None: + return None + def clear(self) -> None: return None