880 lines
26 KiB
Python
880 lines
26 KiB
Python
"""Database service layer for CRUD operations.
|
|
|
|
This module provides a comprehensive service layer for database operations,
|
|
implementing the Repository pattern for clean separation of concerns.
|
|
|
|
Services:
|
|
- AnimeSeriesService: CRUD operations for anime series
|
|
- EpisodeService: CRUD operations for episodes
|
|
- DownloadQueueService: CRUD operations for download queue
|
|
- UserSessionService: CRUD operations for user sessions
|
|
|
|
All services support both async and sync operations for flexibility.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import Dict, List, Optional
|
|
|
|
from sqlalchemy import delete, select, update
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.orm import Session, selectinload
|
|
|
|
from src.server.database.models import (
|
|
AnimeSeries,
|
|
DownloadPriority,
|
|
DownloadQueueItem,
|
|
DownloadStatus,
|
|
Episode,
|
|
UserSession,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# ============================================================================
|
|
# Anime Series Service
|
|
# ============================================================================
|
|
|
|
|
|
class AnimeSeriesService:
|
|
"""Service for anime series CRUD operations.
|
|
|
|
Provides methods for creating, reading, updating, and deleting anime series
|
|
with support for both async and sync database sessions.
|
|
"""
|
|
|
|
@staticmethod
|
|
async def create(
|
|
db: AsyncSession,
|
|
key: str,
|
|
name: str,
|
|
site: str,
|
|
folder: str,
|
|
description: Optional[str] = None,
|
|
status: Optional[str] = None,
|
|
total_episodes: Optional[int] = None,
|
|
cover_url: Optional[str] = None,
|
|
episode_dict: Optional[Dict] = None,
|
|
) -> AnimeSeries:
|
|
"""Create a new anime series.
|
|
|
|
Args:
|
|
db: Database session
|
|
key: Unique provider key
|
|
name: Series name
|
|
site: Provider site URL
|
|
folder: Local filesystem path
|
|
description: Optional series description
|
|
status: Optional series status
|
|
total_episodes: Optional total episode count
|
|
cover_url: Optional cover image URL
|
|
episode_dict: Optional episode dictionary
|
|
|
|
Returns:
|
|
Created AnimeSeries instance
|
|
|
|
Raises:
|
|
IntegrityError: If series with key already exists
|
|
"""
|
|
series = AnimeSeries(
|
|
key=key,
|
|
name=name,
|
|
site=site,
|
|
folder=folder,
|
|
description=description,
|
|
status=status,
|
|
total_episodes=total_episodes,
|
|
cover_url=cover_url,
|
|
episode_dict=episode_dict,
|
|
)
|
|
db.add(series)
|
|
await db.flush()
|
|
await db.refresh(series)
|
|
logger.info(f"Created anime series: {series.name} (key={series.key})")
|
|
return series
|
|
|
|
@staticmethod
|
|
async def get_by_id(db: AsyncSession, series_id: int) -> Optional[AnimeSeries]:
|
|
"""Get anime series by ID.
|
|
|
|
Args:
|
|
db: Database session
|
|
series_id: Series primary key
|
|
|
|
Returns:
|
|
AnimeSeries instance or None if not found
|
|
"""
|
|
result = await db.execute(
|
|
select(AnimeSeries).where(AnimeSeries.id == series_id)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
@staticmethod
|
|
async def get_by_key(db: AsyncSession, key: str) -> Optional[AnimeSeries]:
|
|
"""Get anime series by provider key.
|
|
|
|
Args:
|
|
db: Database session
|
|
key: Unique provider key
|
|
|
|
Returns:
|
|
AnimeSeries instance or None if not found
|
|
"""
|
|
result = await db.execute(
|
|
select(AnimeSeries).where(AnimeSeries.key == key)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
@staticmethod
|
|
async def get_all(
|
|
db: AsyncSession,
|
|
limit: Optional[int] = None,
|
|
offset: int = 0,
|
|
with_episodes: bool = False,
|
|
) -> List[AnimeSeries]:
|
|
"""Get all anime series.
|
|
|
|
Args:
|
|
db: Database session
|
|
limit: Optional limit for results
|
|
offset: Offset for pagination
|
|
with_episodes: Whether to eagerly load episodes
|
|
|
|
Returns:
|
|
List of AnimeSeries instances
|
|
"""
|
|
query = select(AnimeSeries)
|
|
|
|
if with_episodes:
|
|
query = query.options(selectinload(AnimeSeries.episodes))
|
|
|
|
query = query.offset(offset)
|
|
if limit:
|
|
query = query.limit(limit)
|
|
|
|
result = await db.execute(query)
|
|
return list(result.scalars().all())
|
|
|
|
@staticmethod
|
|
async def update(
|
|
db: AsyncSession,
|
|
series_id: int,
|
|
**kwargs,
|
|
) -> Optional[AnimeSeries]:
|
|
"""Update anime series.
|
|
|
|
Args:
|
|
db: Database session
|
|
series_id: Series primary key
|
|
**kwargs: Fields to update
|
|
|
|
Returns:
|
|
Updated AnimeSeries instance or None if not found
|
|
"""
|
|
series = await AnimeSeriesService.get_by_id(db, series_id)
|
|
if not series:
|
|
return None
|
|
|
|
for key, value in kwargs.items():
|
|
if hasattr(series, key):
|
|
setattr(series, key, value)
|
|
|
|
await db.flush()
|
|
await db.refresh(series)
|
|
logger.info(f"Updated anime series: {series.name} (id={series_id})")
|
|
return series
|
|
|
|
@staticmethod
|
|
async def delete(db: AsyncSession, series_id: int) -> bool:
|
|
"""Delete anime series.
|
|
|
|
Cascades to delete all episodes and download items.
|
|
|
|
Args:
|
|
db: Database session
|
|
series_id: Series primary key
|
|
|
|
Returns:
|
|
True if deleted, False if not found
|
|
"""
|
|
result = await db.execute(
|
|
delete(AnimeSeries).where(AnimeSeries.id == series_id)
|
|
)
|
|
deleted = result.rowcount > 0
|
|
if deleted:
|
|
logger.info(f"Deleted anime series with id={series_id}")
|
|
return deleted
|
|
|
|
@staticmethod
|
|
async def search(
|
|
db: AsyncSession,
|
|
query: str,
|
|
limit: int = 50,
|
|
) -> List[AnimeSeries]:
|
|
"""Search anime series by name.
|
|
|
|
Args:
|
|
db: Database session
|
|
query: Search query
|
|
limit: Maximum results
|
|
|
|
Returns:
|
|
List of matching AnimeSeries instances
|
|
"""
|
|
result = await db.execute(
|
|
select(AnimeSeries)
|
|
.where(AnimeSeries.name.ilike(f"%{query}%"))
|
|
.limit(limit)
|
|
)
|
|
return list(result.scalars().all())
|
|
|
|
|
|
# ============================================================================
|
|
# Episode Service
|
|
# ============================================================================
|
|
|
|
|
|
class EpisodeService:
|
|
"""Service for episode CRUD operations.
|
|
|
|
Provides methods for managing episodes within anime series.
|
|
"""
|
|
|
|
@staticmethod
|
|
async def create(
|
|
db: AsyncSession,
|
|
series_id: int,
|
|
season: int,
|
|
episode_number: int,
|
|
title: Optional[str] = None,
|
|
file_path: Optional[str] = None,
|
|
file_size: Optional[int] = None,
|
|
is_downloaded: bool = False,
|
|
) -> Episode:
|
|
"""Create a new episode.
|
|
|
|
Args:
|
|
db: Database session
|
|
series_id: Foreign key to AnimeSeries
|
|
season: Season number
|
|
episode_number: Episode number within season
|
|
title: Optional episode title
|
|
file_path: Optional local file path
|
|
file_size: Optional file size in bytes
|
|
is_downloaded: Whether episode is downloaded
|
|
|
|
Returns:
|
|
Created Episode instance
|
|
"""
|
|
episode = Episode(
|
|
series_id=series_id,
|
|
season=season,
|
|
episode_number=episode_number,
|
|
title=title,
|
|
file_path=file_path,
|
|
file_size=file_size,
|
|
is_downloaded=is_downloaded,
|
|
download_date=datetime.now(timezone.utc) if is_downloaded else None,
|
|
)
|
|
db.add(episode)
|
|
await db.flush()
|
|
await db.refresh(episode)
|
|
logger.debug(
|
|
f"Created episode: S{season:02d}E{episode_number:02d} "
|
|
f"for series_id={series_id}"
|
|
)
|
|
return episode
|
|
|
|
@staticmethod
|
|
async def get_by_id(db: AsyncSession, episode_id: int) -> Optional[Episode]:
|
|
"""Get episode by ID.
|
|
|
|
Args:
|
|
db: Database session
|
|
episode_id: Episode primary key
|
|
|
|
Returns:
|
|
Episode instance or None if not found
|
|
"""
|
|
result = await db.execute(
|
|
select(Episode).where(Episode.id == episode_id)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
@staticmethod
|
|
async def get_by_series(
|
|
db: AsyncSession,
|
|
series_id: int,
|
|
season: Optional[int] = None,
|
|
) -> List[Episode]:
|
|
"""Get episodes for a series.
|
|
|
|
Args:
|
|
db: Database session
|
|
series_id: Foreign key to AnimeSeries
|
|
season: Optional season filter
|
|
|
|
Returns:
|
|
List of Episode instances
|
|
"""
|
|
query = select(Episode).where(Episode.series_id == series_id)
|
|
|
|
if season is not None:
|
|
query = query.where(Episode.season == season)
|
|
|
|
query = query.order_by(Episode.season, Episode.episode_number)
|
|
result = await db.execute(query)
|
|
return list(result.scalars().all())
|
|
|
|
@staticmethod
|
|
async def get_by_episode(
|
|
db: AsyncSession,
|
|
series_id: int,
|
|
season: int,
|
|
episode_number: int,
|
|
) -> Optional[Episode]:
|
|
"""Get specific episode.
|
|
|
|
Args:
|
|
db: Database session
|
|
series_id: Foreign key to AnimeSeries
|
|
season: Season number
|
|
episode_number: Episode number
|
|
|
|
Returns:
|
|
Episode instance or None if not found
|
|
"""
|
|
result = await db.execute(
|
|
select(Episode).where(
|
|
Episode.series_id == series_id,
|
|
Episode.season == season,
|
|
Episode.episode_number == episode_number,
|
|
)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
@staticmethod
|
|
async def mark_downloaded(
|
|
db: AsyncSession,
|
|
episode_id: int,
|
|
file_path: str,
|
|
file_size: int,
|
|
) -> Optional[Episode]:
|
|
"""Mark episode as downloaded.
|
|
|
|
Args:
|
|
db: Database session
|
|
episode_id: Episode primary key
|
|
file_path: Local file path
|
|
file_size: File size in bytes
|
|
|
|
Returns:
|
|
Updated Episode instance or None if not found
|
|
"""
|
|
episode = await EpisodeService.get_by_id(db, episode_id)
|
|
if not episode:
|
|
return None
|
|
|
|
episode.is_downloaded = True
|
|
episode.file_path = file_path
|
|
episode.file_size = file_size
|
|
episode.download_date = datetime.now(timezone.utc)
|
|
|
|
await db.flush()
|
|
await db.refresh(episode)
|
|
logger.info(
|
|
f"Marked episode as downloaded: "
|
|
f"S{episode.season:02d}E{episode.episode_number:02d}"
|
|
)
|
|
return episode
|
|
|
|
@staticmethod
|
|
async def delete(db: AsyncSession, episode_id: int) -> bool:
|
|
"""Delete episode.
|
|
|
|
Args:
|
|
db: Database session
|
|
episode_id: Episode primary key
|
|
|
|
Returns:
|
|
True if deleted, False if not found
|
|
"""
|
|
result = await db.execute(
|
|
delete(Episode).where(Episode.id == episode_id)
|
|
)
|
|
return result.rowcount > 0
|
|
|
|
|
|
# ============================================================================
|
|
# Download Queue Service
|
|
# ============================================================================
|
|
|
|
|
|
class DownloadQueueService:
|
|
"""Service for download queue CRUD operations.
|
|
|
|
Provides methods for managing the download queue with status tracking,
|
|
priority management, and progress updates.
|
|
"""
|
|
|
|
@staticmethod
|
|
async def create(
|
|
db: AsyncSession,
|
|
series_id: int,
|
|
season: int,
|
|
episode_number: int,
|
|
priority: DownloadPriority = DownloadPriority.NORMAL,
|
|
download_url: Optional[str] = None,
|
|
file_destination: Optional[str] = None,
|
|
) -> DownloadQueueItem:
|
|
"""Add item to download queue.
|
|
|
|
Args:
|
|
db: Database session
|
|
series_id: Foreign key to AnimeSeries
|
|
season: Season number
|
|
episode_number: Episode number
|
|
priority: Download priority
|
|
download_url: Optional provider download URL
|
|
file_destination: Optional target file path
|
|
|
|
Returns:
|
|
Created DownloadQueueItem instance
|
|
"""
|
|
item = DownloadQueueItem(
|
|
series_id=series_id,
|
|
season=season,
|
|
episode_number=episode_number,
|
|
status=DownloadStatus.PENDING,
|
|
priority=priority,
|
|
download_url=download_url,
|
|
file_destination=file_destination,
|
|
)
|
|
db.add(item)
|
|
await db.flush()
|
|
await db.refresh(item)
|
|
logger.info(
|
|
f"Added to download queue: S{season:02d}E{episode_number:02d} "
|
|
f"for series_id={series_id} with priority={priority}"
|
|
)
|
|
return item
|
|
|
|
@staticmethod
|
|
async def get_by_id(
|
|
db: AsyncSession,
|
|
item_id: int,
|
|
) -> Optional[DownloadQueueItem]:
|
|
"""Get download queue item by ID.
|
|
|
|
Args:
|
|
db: Database session
|
|
item_id: Item primary key
|
|
|
|
Returns:
|
|
DownloadQueueItem instance or None if not found
|
|
"""
|
|
result = await db.execute(
|
|
select(DownloadQueueItem).where(DownloadQueueItem.id == item_id)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
@staticmethod
|
|
async def get_by_status(
|
|
db: AsyncSession,
|
|
status: DownloadStatus,
|
|
limit: Optional[int] = None,
|
|
) -> List[DownloadQueueItem]:
|
|
"""Get download queue items by status.
|
|
|
|
Args:
|
|
db: Database session
|
|
status: Download status filter
|
|
limit: Optional limit for results
|
|
|
|
Returns:
|
|
List of DownloadQueueItem instances
|
|
"""
|
|
query = select(DownloadQueueItem).where(
|
|
DownloadQueueItem.status == status
|
|
)
|
|
|
|
# Order by priority (HIGH first) then creation time
|
|
query = query.order_by(
|
|
DownloadQueueItem.priority.desc(),
|
|
DownloadQueueItem.created_at.asc(),
|
|
)
|
|
|
|
if limit:
|
|
query = query.limit(limit)
|
|
|
|
result = await db.execute(query)
|
|
return list(result.scalars().all())
|
|
|
|
@staticmethod
|
|
async def get_pending(
|
|
db: AsyncSession,
|
|
limit: Optional[int] = None,
|
|
) -> List[DownloadQueueItem]:
|
|
"""Get pending download queue items.
|
|
|
|
Args:
|
|
db: Database session
|
|
limit: Optional limit for results
|
|
|
|
Returns:
|
|
List of pending DownloadQueueItem instances ordered by priority
|
|
"""
|
|
return await DownloadQueueService.get_by_status(
|
|
db, DownloadStatus.PENDING, limit
|
|
)
|
|
|
|
@staticmethod
|
|
async def get_active(db: AsyncSession) -> List[DownloadQueueItem]:
|
|
"""Get active download queue items.
|
|
|
|
Args:
|
|
db: Database session
|
|
|
|
Returns:
|
|
List of downloading DownloadQueueItem instances
|
|
"""
|
|
return await DownloadQueueService.get_by_status(
|
|
db, DownloadStatus.DOWNLOADING
|
|
)
|
|
|
|
@staticmethod
|
|
async def get_all(
|
|
db: AsyncSession,
|
|
with_series: bool = False,
|
|
) -> List[DownloadQueueItem]:
|
|
"""Get all download queue items.
|
|
|
|
Args:
|
|
db: Database session
|
|
with_series: Whether to eagerly load series data
|
|
|
|
Returns:
|
|
List of all DownloadQueueItem instances
|
|
"""
|
|
query = select(DownloadQueueItem)
|
|
|
|
if with_series:
|
|
query = query.options(selectinload(DownloadQueueItem.series))
|
|
|
|
query = query.order_by(
|
|
DownloadQueueItem.priority.desc(),
|
|
DownloadQueueItem.created_at.asc(),
|
|
)
|
|
|
|
result = await db.execute(query)
|
|
return list(result.scalars().all())
|
|
|
|
@staticmethod
|
|
async def update_status(
|
|
db: AsyncSession,
|
|
item_id: int,
|
|
status: DownloadStatus,
|
|
error_message: Optional[str] = None,
|
|
) -> Optional[DownloadQueueItem]:
|
|
"""Update download queue item status.
|
|
|
|
Args:
|
|
db: Database session
|
|
item_id: Item primary key
|
|
status: New download status
|
|
error_message: Optional error message for failed status
|
|
|
|
Returns:
|
|
Updated DownloadQueueItem instance or None if not found
|
|
"""
|
|
item = await DownloadQueueService.get_by_id(db, item_id)
|
|
if not item:
|
|
return None
|
|
|
|
item.status = status
|
|
|
|
# Update timestamps based on status
|
|
if status == DownloadStatus.DOWNLOADING and not item.started_at:
|
|
item.started_at = datetime.now(timezone.utc)
|
|
elif status in (DownloadStatus.COMPLETED, DownloadStatus.FAILED):
|
|
item.completed_at = datetime.now(timezone.utc)
|
|
|
|
# Set error message for failed downloads
|
|
if status == DownloadStatus.FAILED and error_message:
|
|
item.error_message = error_message
|
|
item.retry_count += 1
|
|
|
|
await db.flush()
|
|
await db.refresh(item)
|
|
logger.debug(f"Updated download queue item {item_id} status to {status}")
|
|
return item
|
|
|
|
@staticmethod
|
|
async def update_progress(
|
|
db: AsyncSession,
|
|
item_id: int,
|
|
progress_percent: float,
|
|
downloaded_bytes: int,
|
|
total_bytes: Optional[int] = None,
|
|
download_speed: Optional[float] = None,
|
|
) -> Optional[DownloadQueueItem]:
|
|
"""Update download progress.
|
|
|
|
Args:
|
|
db: Database session
|
|
item_id: Item primary key
|
|
progress_percent: Progress percentage (0-100)
|
|
downloaded_bytes: Bytes downloaded
|
|
total_bytes: Optional total file size
|
|
download_speed: Optional current speed (bytes/sec)
|
|
|
|
Returns:
|
|
Updated DownloadQueueItem instance or None if not found
|
|
"""
|
|
item = await DownloadQueueService.get_by_id(db, item_id)
|
|
if not item:
|
|
return None
|
|
|
|
item.progress_percent = progress_percent
|
|
item.downloaded_bytes = downloaded_bytes
|
|
|
|
if total_bytes is not None:
|
|
item.total_bytes = total_bytes
|
|
|
|
if download_speed is not None:
|
|
item.download_speed = download_speed
|
|
|
|
await db.flush()
|
|
await db.refresh(item)
|
|
return item
|
|
|
|
@staticmethod
|
|
async def delete(db: AsyncSession, item_id: int) -> bool:
|
|
"""Delete download queue item.
|
|
|
|
Args:
|
|
db: Database session
|
|
item_id: Item primary key
|
|
|
|
Returns:
|
|
True if deleted, False if not found
|
|
"""
|
|
result = await db.execute(
|
|
delete(DownloadQueueItem).where(DownloadQueueItem.id == item_id)
|
|
)
|
|
deleted = result.rowcount > 0
|
|
if deleted:
|
|
logger.info(f"Deleted download queue item with id={item_id}")
|
|
return deleted
|
|
|
|
@staticmethod
|
|
async def clear_completed(db: AsyncSession) -> int:
|
|
"""Clear completed downloads from queue.
|
|
|
|
Args:
|
|
db: Database session
|
|
|
|
Returns:
|
|
Number of items cleared
|
|
"""
|
|
result = await db.execute(
|
|
delete(DownloadQueueItem).where(
|
|
DownloadQueueItem.status == DownloadStatus.COMPLETED
|
|
)
|
|
)
|
|
count = result.rowcount
|
|
logger.info(f"Cleared {count} completed downloads from queue")
|
|
return count
|
|
|
|
@staticmethod
|
|
async def retry_failed(
|
|
db: AsyncSession,
|
|
max_retries: int = 3,
|
|
) -> List[DownloadQueueItem]:
|
|
"""Retry failed downloads that haven't exceeded max retries.
|
|
|
|
Args:
|
|
db: Database session
|
|
max_retries: Maximum number of retry attempts
|
|
|
|
Returns:
|
|
List of items marked for retry
|
|
"""
|
|
result = await db.execute(
|
|
select(DownloadQueueItem).where(
|
|
DownloadQueueItem.status == DownloadStatus.FAILED,
|
|
DownloadQueueItem.retry_count < max_retries,
|
|
)
|
|
)
|
|
items = list(result.scalars().all())
|
|
|
|
for item in items:
|
|
item.status = DownloadStatus.PENDING
|
|
item.error_message = None
|
|
item.progress_percent = 0.0
|
|
item.downloaded_bytes = 0
|
|
item.started_at = None
|
|
item.completed_at = None
|
|
|
|
await db.flush()
|
|
logger.info(f"Marked {len(items)} failed downloads for retry")
|
|
return items
|
|
|
|
|
|
# ============================================================================
|
|
# User Session Service
|
|
# ============================================================================
|
|
|
|
|
|
class UserSessionService:
|
|
"""Service for user session CRUD operations.
|
|
|
|
Provides methods for managing user authentication sessions with JWT tokens.
|
|
"""
|
|
|
|
@staticmethod
|
|
async def create(
|
|
db: AsyncSession,
|
|
session_id: str,
|
|
token_hash: str,
|
|
expires_at: datetime,
|
|
user_id: Optional[str] = None,
|
|
ip_address: Optional[str] = None,
|
|
user_agent: Optional[str] = None,
|
|
) -> UserSession:
|
|
"""Create a new user session.
|
|
|
|
Args:
|
|
db: Database session
|
|
session_id: Unique session identifier
|
|
token_hash: Hashed JWT token
|
|
expires_at: Session expiration timestamp
|
|
user_id: Optional user identifier
|
|
ip_address: Optional client IP address
|
|
user_agent: Optional client user agent
|
|
|
|
Returns:
|
|
Created UserSession instance
|
|
"""
|
|
session = UserSession(
|
|
session_id=session_id,
|
|
token_hash=token_hash,
|
|
expires_at=expires_at,
|
|
user_id=user_id,
|
|
ip_address=ip_address,
|
|
user_agent=user_agent,
|
|
)
|
|
db.add(session)
|
|
await db.flush()
|
|
await db.refresh(session)
|
|
logger.info(f"Created user session: {session_id}")
|
|
return session
|
|
|
|
@staticmethod
|
|
async def get_by_session_id(
|
|
db: AsyncSession,
|
|
session_id: str,
|
|
) -> Optional[UserSession]:
|
|
"""Get session by session ID.
|
|
|
|
Args:
|
|
db: Database session
|
|
session_id: Unique session identifier
|
|
|
|
Returns:
|
|
UserSession instance or None if not found
|
|
"""
|
|
result = await db.execute(
|
|
select(UserSession).where(UserSession.session_id == session_id)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
@staticmethod
|
|
async def get_active_sessions(
|
|
db: AsyncSession,
|
|
user_id: Optional[str] = None,
|
|
) -> List[UserSession]:
|
|
"""Get active sessions.
|
|
|
|
Args:
|
|
db: Database session
|
|
user_id: Optional user ID filter
|
|
|
|
Returns:
|
|
List of active UserSession instances
|
|
"""
|
|
query = select(UserSession).where(
|
|
UserSession.is_active == True,
|
|
UserSession.expires_at > datetime.now(timezone.utc),
|
|
)
|
|
|
|
if user_id:
|
|
query = query.where(UserSession.user_id == user_id)
|
|
|
|
result = await db.execute(query)
|
|
return list(result.scalars().all())
|
|
|
|
@staticmethod
|
|
async def update_activity(
|
|
db: AsyncSession,
|
|
session_id: str,
|
|
) -> Optional[UserSession]:
|
|
"""Update session last activity timestamp.
|
|
|
|
Args:
|
|
db: Database session
|
|
session_id: Unique session identifier
|
|
|
|
Returns:
|
|
Updated UserSession instance or None if not found
|
|
"""
|
|
session = await UserSessionService.get_by_session_id(db, session_id)
|
|
if not session:
|
|
return None
|
|
|
|
session.last_activity = datetime.now(timezone.utc)
|
|
await db.flush()
|
|
await db.refresh(session)
|
|
return session
|
|
|
|
@staticmethod
|
|
async def revoke(db: AsyncSession, session_id: str) -> bool:
|
|
"""Revoke a session.
|
|
|
|
Args:
|
|
db: Database session
|
|
session_id: Unique session identifier
|
|
|
|
Returns:
|
|
True if revoked, False if not found
|
|
"""
|
|
session = await UserSessionService.get_by_session_id(db, session_id)
|
|
if not session:
|
|
return False
|
|
|
|
session.revoke()
|
|
await db.flush()
|
|
logger.info(f"Revoked user session: {session_id}")
|
|
return True
|
|
|
|
@staticmethod
|
|
async def cleanup_expired(db: AsyncSession) -> int:
|
|
"""Clean up expired sessions.
|
|
|
|
Args:
|
|
db: Database session
|
|
|
|
Returns:
|
|
Number of sessions deleted
|
|
"""
|
|
result = await db.execute(
|
|
delete(UserSession).where(
|
|
UserSession.expires_at < datetime.now(timezone.utc)
|
|
)
|
|
)
|
|
count = result.rowcount
|
|
logger.info(f"Cleaned up {count} expired sessions")
|
|
return count
|