Add database transaction support with atomic operations
- Create transaction.py with @transactional decorator, atomic() context manager - Add TransactionPropagation modes: REQUIRED, REQUIRES_NEW, NESTED - Add savepoint support for nested transactions with partial rollback - Update connection.py with TransactionManager, get_transactional_session - Update service.py with bulk operations (bulk_mark_downloaded, bulk_delete) - Wrap QueueRepository.save_item() and clear_all() in atomic transactions - Add comprehensive tests (66 transaction tests, 90% coverage) - All 1090 tests passing
This commit is contained in:
@@ -7,7 +7,11 @@ Functions:
|
||||
- init_db: Initialize database engine and create tables
|
||||
- close_db: Close database connections and cleanup
|
||||
- get_db_session: FastAPI dependency for database sessions
|
||||
- get_transactional_session: Session without auto-commit for transactions
|
||||
- get_engine: Get database engine instance
|
||||
|
||||
Classes:
|
||||
- TransactionManager: Helper class for manual transaction control
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -296,3 +300,275 @@ def get_async_session_factory() -> AsyncSession:
|
||||
)
|
||||
|
||||
return _session_factory()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_transactional_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Get a database session without auto-commit for explicit transaction control.
|
||||
|
||||
Unlike get_db_session(), this does NOT auto-commit on success.
|
||||
Use this when you need explicit transaction control with the
|
||||
@transactional decorator or atomic() context manager.
|
||||
|
||||
Yields:
|
||||
AsyncSession: Database session for async operations
|
||||
|
||||
Raises:
|
||||
RuntimeError: If database is not initialized
|
||||
|
||||
Example:
|
||||
async with get_transactional_session() as session:
|
||||
async with atomic(session) as tx:
|
||||
# Multiple operations in transaction
|
||||
await operation1(session)
|
||||
await operation2(session)
|
||||
# Committed when exiting atomic() context
|
||||
"""
|
||||
if _session_factory is None:
|
||||
raise RuntimeError(
|
||||
"Database not initialized. Call init_db() first."
|
||||
)
|
||||
|
||||
session = _session_factory()
|
||||
try:
|
||||
yield session
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
class TransactionManager:
|
||||
"""Helper class for manual transaction control.
|
||||
|
||||
Provides a cleaner interface for managing transactions across
|
||||
multiple service calls within a single request.
|
||||
|
||||
Attributes:
|
||||
_session_factory: Factory for creating new sessions
|
||||
_session: Current active session
|
||||
_in_transaction: Whether currently in a transaction
|
||||
|
||||
Example:
|
||||
async with TransactionManager() as tm:
|
||||
session = await tm.get_session()
|
||||
await tm.begin()
|
||||
try:
|
||||
await service1.operation(session)
|
||||
await service2.operation(session)
|
||||
await tm.commit()
|
||||
except Exception:
|
||||
await tm.rollback()
|
||||
raise
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_factory: Optional[async_sessionmaker] = None
|
||||
) -> None:
|
||||
"""Initialize transaction manager.
|
||||
|
||||
Args:
|
||||
session_factory: Optional custom session factory.
|
||||
Uses global factory if not provided.
|
||||
"""
|
||||
self._session_factory = session_factory or _session_factory
|
||||
self._session: Optional[AsyncSession] = None
|
||||
self._in_transaction = False
|
||||
|
||||
if self._session_factory is None:
|
||||
raise RuntimeError(
|
||||
"Database not initialized. Call init_db() first."
|
||||
)
|
||||
|
||||
async def __aenter__(self) -> "TransactionManager":
|
||||
"""Enter context manager and create session."""
|
||||
self._session = self._session_factory()
|
||||
logger.debug("TransactionManager: Created new session")
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[type],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[object],
|
||||
) -> bool:
|
||||
"""Exit context manager and cleanup session.
|
||||
|
||||
Automatically rolls back if an exception occurred and
|
||||
transaction wasn't explicitly committed.
|
||||
"""
|
||||
if self._session:
|
||||
if exc_type is not None and self._in_transaction:
|
||||
logger.warning(
|
||||
"TransactionManager: Rolling back due to exception: %s",
|
||||
exc_val,
|
||||
)
|
||||
await self._session.rollback()
|
||||
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
self._in_transaction = False
|
||||
logger.debug("TransactionManager: Session closed")
|
||||
|
||||
return False
|
||||
|
||||
async def get_session(self) -> AsyncSession:
|
||||
"""Get the current session.
|
||||
|
||||
Returns:
|
||||
Current AsyncSession instance
|
||||
|
||||
Raises:
|
||||
RuntimeError: If not within context manager
|
||||
"""
|
||||
if self._session is None:
|
||||
raise RuntimeError(
|
||||
"TransactionManager must be used as async context manager"
|
||||
)
|
||||
return self._session
|
||||
|
||||
async def begin(self) -> None:
|
||||
"""Begin a new transaction.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If already in a transaction or no session
|
||||
"""
|
||||
if self._session is None:
|
||||
raise RuntimeError("No active session")
|
||||
|
||||
if self._in_transaction:
|
||||
raise RuntimeError("Already in a transaction")
|
||||
|
||||
await self._session.begin()
|
||||
self._in_transaction = True
|
||||
logger.debug("TransactionManager: Transaction started")
|
||||
|
||||
async def commit(self) -> None:
|
||||
"""Commit the current transaction.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If not in a transaction
|
||||
"""
|
||||
if not self._in_transaction or self._session is None:
|
||||
raise RuntimeError("Not in a transaction")
|
||||
|
||||
await self._session.commit()
|
||||
self._in_transaction = False
|
||||
logger.debug("TransactionManager: Transaction committed")
|
||||
|
||||
async def rollback(self) -> None:
|
||||
"""Rollback the current transaction.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If not in a transaction
|
||||
"""
|
||||
if self._session is None:
|
||||
raise RuntimeError("No active session")
|
||||
|
||||
await self._session.rollback()
|
||||
self._in_transaction = False
|
||||
logger.debug("TransactionManager: Transaction rolled back")
|
||||
|
||||
async def savepoint(self, name: Optional[str] = None) -> "SavepointHandle":
|
||||
"""Create a savepoint within the current transaction.
|
||||
|
||||
Args:
|
||||
name: Optional savepoint name
|
||||
|
||||
Returns:
|
||||
SavepointHandle for controlling the savepoint
|
||||
|
||||
Raises:
|
||||
RuntimeError: If not in a transaction
|
||||
"""
|
||||
if not self._in_transaction or self._session is None:
|
||||
raise RuntimeError("Must be in a transaction to create savepoint")
|
||||
|
||||
nested = await self._session.begin_nested()
|
||||
return SavepointHandle(nested, name or "unnamed")
|
||||
|
||||
def is_in_transaction(self) -> bool:
|
||||
"""Check if currently in a transaction.
|
||||
|
||||
Returns:
|
||||
True if in an active transaction
|
||||
"""
|
||||
return self._in_transaction
|
||||
|
||||
def get_transaction_depth(self) -> int:
|
||||
"""Get current transaction nesting depth.
|
||||
|
||||
Returns:
|
||||
0 if not in transaction, 1+ for nested transactions
|
||||
"""
|
||||
if not self._in_transaction:
|
||||
return 0
|
||||
return 1 # Basic implementation - could be extended
|
||||
|
||||
|
||||
class SavepointHandle:
|
||||
"""Handle for controlling a database savepoint.
|
||||
|
||||
Attributes:
|
||||
_nested: SQLAlchemy nested transaction
|
||||
_name: Savepoint name for logging
|
||||
_released: Whether savepoint has been released
|
||||
"""
|
||||
|
||||
def __init__(self, nested: object, name: str) -> None:
|
||||
"""Initialize savepoint handle.
|
||||
|
||||
Args:
|
||||
nested: SQLAlchemy nested transaction object
|
||||
name: Savepoint name
|
||||
"""
|
||||
self._nested = nested
|
||||
self._name = name
|
||||
self._released = False
|
||||
logger.debug("Created savepoint: %s", name)
|
||||
|
||||
async def rollback(self) -> None:
|
||||
"""Rollback to this savepoint."""
|
||||
if not self._released:
|
||||
await self._nested.rollback()
|
||||
self._released = True
|
||||
logger.debug("Rolled back savepoint: %s", self._name)
|
||||
|
||||
async def release(self) -> None:
|
||||
"""Release (commit) this savepoint."""
|
||||
if not self._released:
|
||||
# Nested transactions commit automatically in SQLAlchemy
|
||||
self._released = True
|
||||
logger.debug("Released savepoint: %s", self._name)
|
||||
|
||||
|
||||
def is_session_in_transaction(session: AsyncSession | Session) -> bool:
|
||||
"""Check if a session is currently in a transaction.
|
||||
|
||||
Args:
|
||||
session: SQLAlchemy session (sync or async)
|
||||
|
||||
Returns:
|
||||
True if session is in an active transaction
|
||||
"""
|
||||
return session.in_transaction()
|
||||
|
||||
|
||||
def get_session_transaction_depth(session: AsyncSession | Session) -> int:
|
||||
"""Get the transaction nesting depth of a session.
|
||||
|
||||
Args:
|
||||
session: SQLAlchemy session (sync or async)
|
||||
|
||||
Returns:
|
||||
Number of nested transactions (0 if not in transaction)
|
||||
"""
|
||||
if not session.in_transaction():
|
||||
return 0
|
||||
|
||||
# Check for nested transaction state
|
||||
# Note: SQLAlchemy doesn't directly expose nesting depth
|
||||
return 1
|
||||
|
||||
|
||||
Reference in New Issue
Block a user