- Create transaction.py with @transactional decorator, atomic() context manager - Add TransactionPropagation modes: REQUIRED, REQUIRES_NEW, NESTED - Add savepoint support for nested transactions with partial rollback - Update connection.py with TransactionManager, get_transactional_session - Update service.py with bulk operations (bulk_mark_downloaded, bulk_delete) - Wrap QueueRepository.save_item() and clear_all() in atomic transactions - Add comprehensive tests (66 transaction tests, 90% coverage) - All 1090 tests passing
716 lines
22 KiB
Python
716 lines
22 KiB
Python
"""Transaction management utilities for SQLAlchemy.
|
|
|
|
This module provides transaction management utilities including decorators,
|
|
context managers, and helper functions for ensuring data consistency
|
|
across database operations.
|
|
|
|
Components:
|
|
- @transactional decorator: Wraps functions in transaction boundaries
|
|
- TransactionContext: Sync context manager for explicit transaction control
|
|
- atomic(): Async context manager for async operations
|
|
- TransactionPropagation: Enum for transaction propagation modes
|
|
|
|
Usage:
|
|
@transactional
|
|
async def compound_operation(session: AsyncSession, data: Model) -> Result:
|
|
# Multiple write operations here
|
|
# All succeed or all fail
|
|
pass
|
|
|
|
async with atomic(session) as tx:
|
|
# Operations here
|
|
async with tx.savepoint() as sp:
|
|
# Nested operations with partial rollback capability
|
|
pass
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import functools
|
|
import logging
|
|
from contextlib import asynccontextmanager, contextmanager
|
|
from enum import Enum
|
|
from typing import (
|
|
Any,
|
|
AsyncGenerator,
|
|
Callable,
|
|
Generator,
|
|
Optional,
|
|
ParamSpec,
|
|
TypeVar,
|
|
)
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.orm import Session
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Type variables for generic typing
|
|
T = TypeVar("T")
|
|
P = ParamSpec("P")
|
|
|
|
|
|
class TransactionPropagation(Enum):
|
|
"""Transaction propagation behavior options.
|
|
|
|
Defines how transactions should behave when called within
|
|
an existing transaction context.
|
|
|
|
Values:
|
|
REQUIRED: Use existing transaction or create new one (default)
|
|
REQUIRES_NEW: Always create a new transaction (suspend existing)
|
|
NESTED: Create a savepoint within existing transaction
|
|
"""
|
|
|
|
REQUIRED = "required"
|
|
REQUIRES_NEW = "requires_new"
|
|
NESTED = "nested"
|
|
|
|
|
|
class TransactionError(Exception):
|
|
"""Exception raised for transaction-related errors."""
|
|
|
|
|
|
class TransactionContext:
|
|
"""Synchronous context manager for explicit transaction control.
|
|
|
|
Provides a clean interface for managing database transactions with
|
|
automatic commit/rollback semantics and savepoint support.
|
|
|
|
Attributes:
|
|
session: SQLAlchemy Session instance
|
|
_savepoint_count: Counter for nested savepoints
|
|
|
|
Example:
|
|
with TransactionContext(session) as tx:
|
|
# Database operations here
|
|
with tx.savepoint() as sp:
|
|
# Nested operations with partial rollback
|
|
pass
|
|
"""
|
|
|
|
def __init__(self, session: Session) -> None:
|
|
"""Initialize transaction context.
|
|
|
|
Args:
|
|
session: SQLAlchemy sync session
|
|
"""
|
|
self.session = session
|
|
self._savepoint_count = 0
|
|
self._committed = False
|
|
|
|
def __enter__(self) -> "TransactionContext":
|
|
"""Enter transaction context.
|
|
|
|
Begins a new transaction if not already in one.
|
|
|
|
Returns:
|
|
Self for context manager protocol
|
|
"""
|
|
logger.debug("Entering transaction context")
|
|
|
|
# Check if session is already in a transaction
|
|
if not self.session.in_transaction():
|
|
self.session.begin()
|
|
logger.debug("Started new transaction")
|
|
|
|
return self
|
|
|
|
def __exit__(
|
|
self,
|
|
exc_type: Optional[type],
|
|
exc_val: Optional[BaseException],
|
|
exc_tb: Optional[Any],
|
|
) -> bool:
|
|
"""Exit transaction context.
|
|
|
|
Commits on success, rolls back on exception.
|
|
|
|
Args:
|
|
exc_type: Exception type if raised
|
|
exc_val: Exception value if raised
|
|
exc_tb: Exception traceback if raised
|
|
|
|
Returns:
|
|
False to propagate exceptions
|
|
"""
|
|
if exc_type is not None:
|
|
logger.warning(
|
|
"Transaction rollback due to exception: %s: %s",
|
|
exc_type.__name__,
|
|
exc_val,
|
|
)
|
|
self.session.rollback()
|
|
return False
|
|
|
|
if not self._committed:
|
|
self.session.commit()
|
|
logger.debug("Transaction committed")
|
|
self._committed = True
|
|
|
|
return False
|
|
|
|
@contextmanager
|
|
def savepoint(self, name: Optional[str] = None) -> Generator["SavepointContext", None, None]:
|
|
"""Create a savepoint for partial rollback capability.
|
|
|
|
Savepoints allow nested transactions where inner operations
|
|
can be rolled back without affecting outer operations.
|
|
|
|
Args:
|
|
name: Optional savepoint name (auto-generated if not provided)
|
|
|
|
Yields:
|
|
SavepointContext for nested transaction control
|
|
|
|
Example:
|
|
with tx.savepoint() as sp:
|
|
# Operations here can be rolled back independently
|
|
if error_condition:
|
|
sp.rollback()
|
|
"""
|
|
self._savepoint_count += 1
|
|
savepoint_name = name or f"sp_{self._savepoint_count}"
|
|
|
|
logger.debug("Creating savepoint: %s", savepoint_name)
|
|
nested = self.session.begin_nested()
|
|
|
|
sp_context = SavepointContext(nested, savepoint_name)
|
|
|
|
try:
|
|
yield sp_context
|
|
|
|
if not sp_context._rolled_back:
|
|
# Commit the savepoint (release it)
|
|
logger.debug("Releasing savepoint: %s", savepoint_name)
|
|
|
|
except Exception as e:
|
|
if not sp_context._rolled_back:
|
|
logger.warning(
|
|
"Rolling back savepoint %s due to exception: %s",
|
|
savepoint_name,
|
|
e,
|
|
)
|
|
nested.rollback()
|
|
raise
|
|
|
|
def commit(self) -> None:
|
|
"""Explicitly commit the transaction.
|
|
|
|
Use this for early commit within the context.
|
|
"""
|
|
if not self._committed:
|
|
self.session.commit()
|
|
self._committed = True
|
|
logger.debug("Transaction explicitly committed")
|
|
|
|
def rollback(self) -> None:
|
|
"""Explicitly rollback the transaction.
|
|
|
|
Use this for early rollback within the context.
|
|
"""
|
|
self.session.rollback()
|
|
self._committed = True # Prevent double commit
|
|
logger.debug("Transaction explicitly rolled back")
|
|
|
|
|
|
class SavepointContext:
|
|
"""Context for managing a database savepoint.
|
|
|
|
Provides explicit control over savepoint commit/rollback.
|
|
|
|
Attributes:
|
|
_nested: SQLAlchemy nested transaction object
|
|
_name: Savepoint name for logging
|
|
_rolled_back: Whether rollback has been called
|
|
"""
|
|
|
|
def __init__(self, nested: Any, name: str) -> None:
|
|
"""Initialize savepoint context.
|
|
|
|
Args:
|
|
nested: SQLAlchemy nested transaction
|
|
name: Savepoint name for logging
|
|
"""
|
|
self._nested = nested
|
|
self._name = name
|
|
self._rolled_back = False
|
|
|
|
def rollback(self) -> None:
|
|
"""Rollback to this savepoint.
|
|
|
|
Undoes all changes since the savepoint was created.
|
|
"""
|
|
if not self._rolled_back:
|
|
self._nested.rollback()
|
|
self._rolled_back = True
|
|
logger.debug("Savepoint %s rolled back", self._name)
|
|
|
|
def commit(self) -> None:
|
|
"""Commit (release) this savepoint.
|
|
|
|
Makes changes since the savepoint permanent within
|
|
the parent transaction.
|
|
"""
|
|
if not self._rolled_back:
|
|
# SQLAlchemy commits nested transactions automatically
|
|
# when exiting the context without rollback
|
|
logger.debug("Savepoint %s committed", self._name)
|
|
|
|
|
|
class AsyncTransactionContext:
|
|
"""Asynchronous context manager for explicit transaction control.
|
|
|
|
Provides async interface for managing database transactions with
|
|
automatic commit/rollback semantics and savepoint support.
|
|
|
|
Attributes:
|
|
session: SQLAlchemy AsyncSession instance
|
|
_savepoint_count: Counter for nested savepoints
|
|
|
|
Example:
|
|
async with AsyncTransactionContext(session) as tx:
|
|
# Database operations here
|
|
async with tx.savepoint() as sp:
|
|
# Nested operations with partial rollback
|
|
pass
|
|
"""
|
|
|
|
def __init__(self, session: AsyncSession) -> None:
|
|
"""Initialize async transaction context.
|
|
|
|
Args:
|
|
session: SQLAlchemy async session
|
|
"""
|
|
self.session = session
|
|
self._savepoint_count = 0
|
|
self._committed = False
|
|
|
|
async def __aenter__(self) -> "AsyncTransactionContext":
|
|
"""Enter async transaction context.
|
|
|
|
Begins a new transaction if not already in one.
|
|
|
|
Returns:
|
|
Self for context manager protocol
|
|
"""
|
|
logger.debug("Entering async transaction context")
|
|
|
|
# Check if session is already in a transaction
|
|
if not self.session.in_transaction():
|
|
await self.session.begin()
|
|
logger.debug("Started new async transaction")
|
|
|
|
return self
|
|
|
|
async def __aexit__(
|
|
self,
|
|
exc_type: Optional[type],
|
|
exc_val: Optional[BaseException],
|
|
exc_tb: Optional[Any],
|
|
) -> bool:
|
|
"""Exit async transaction context.
|
|
|
|
Commits on success, rolls back on exception.
|
|
|
|
Args:
|
|
exc_type: Exception type if raised
|
|
exc_val: Exception value if raised
|
|
exc_tb: Exception traceback if raised
|
|
|
|
Returns:
|
|
False to propagate exceptions
|
|
"""
|
|
if exc_type is not None:
|
|
logger.warning(
|
|
"Async transaction rollback due to exception: %s: %s",
|
|
exc_type.__name__,
|
|
exc_val,
|
|
)
|
|
await self.session.rollback()
|
|
return False
|
|
|
|
if not self._committed:
|
|
await self.session.commit()
|
|
logger.debug("Async transaction committed")
|
|
self._committed = True
|
|
|
|
return False
|
|
|
|
@asynccontextmanager
|
|
async def savepoint(
|
|
self, name: Optional[str] = None
|
|
) -> AsyncGenerator["AsyncSavepointContext", None]:
|
|
"""Create an async savepoint for partial rollback capability.
|
|
|
|
Args:
|
|
name: Optional savepoint name (auto-generated if not provided)
|
|
|
|
Yields:
|
|
AsyncSavepointContext for nested transaction control
|
|
"""
|
|
self._savepoint_count += 1
|
|
savepoint_name = name or f"sp_{self._savepoint_count}"
|
|
|
|
logger.debug("Creating async savepoint: %s", savepoint_name)
|
|
nested = await self.session.begin_nested()
|
|
|
|
sp_context = AsyncSavepointContext(nested, savepoint_name, self.session)
|
|
|
|
try:
|
|
yield sp_context
|
|
|
|
if not sp_context._rolled_back:
|
|
logger.debug("Releasing async savepoint: %s", savepoint_name)
|
|
|
|
except Exception as e:
|
|
if not sp_context._rolled_back:
|
|
logger.warning(
|
|
"Rolling back async savepoint %s due to exception: %s",
|
|
savepoint_name,
|
|
e,
|
|
)
|
|
await nested.rollback()
|
|
raise
|
|
|
|
async def commit(self) -> None:
|
|
"""Explicitly commit the async transaction."""
|
|
if not self._committed:
|
|
await self.session.commit()
|
|
self._committed = True
|
|
logger.debug("Async transaction explicitly committed")
|
|
|
|
async def rollback(self) -> None:
|
|
"""Explicitly rollback the async transaction."""
|
|
await self.session.rollback()
|
|
self._committed = True # Prevent double commit
|
|
logger.debug("Async transaction explicitly rolled back")
|
|
|
|
|
|
class AsyncSavepointContext:
|
|
"""Async context for managing a database savepoint.
|
|
|
|
Attributes:
|
|
_nested: SQLAlchemy nested transaction object
|
|
_name: Savepoint name for logging
|
|
_session: Parent session for async operations
|
|
_rolled_back: Whether rollback has been called
|
|
"""
|
|
|
|
def __init__(
|
|
self, nested: Any, name: str, session: AsyncSession
|
|
) -> None:
|
|
"""Initialize async savepoint context.
|
|
|
|
Args:
|
|
nested: SQLAlchemy nested transaction
|
|
name: Savepoint name for logging
|
|
session: Parent async session
|
|
"""
|
|
self._nested = nested
|
|
self._name = name
|
|
self._session = session
|
|
self._rolled_back = False
|
|
|
|
async def rollback(self) -> None:
|
|
"""Rollback to this savepoint asynchronously."""
|
|
if not self._rolled_back:
|
|
await self._nested.rollback()
|
|
self._rolled_back = True
|
|
logger.debug("Async savepoint %s rolled back", self._name)
|
|
|
|
async def commit(self) -> None:
|
|
"""Commit (release) this savepoint asynchronously."""
|
|
if not self._rolled_back:
|
|
logger.debug("Async savepoint %s committed", self._name)
|
|
|
|
|
|
@asynccontextmanager
|
|
async def atomic(
|
|
session: AsyncSession,
|
|
propagation: TransactionPropagation = TransactionPropagation.REQUIRED,
|
|
) -> AsyncGenerator[AsyncTransactionContext, None]:
|
|
"""Async context manager for atomic database operations.
|
|
|
|
Provides a clean interface for wrapping database operations in
|
|
a transaction boundary with automatic commit/rollback.
|
|
|
|
Args:
|
|
session: SQLAlchemy async session
|
|
propagation: Transaction propagation behavior
|
|
|
|
Yields:
|
|
AsyncTransactionContext for transaction control
|
|
|
|
Example:
|
|
async with atomic(session) as tx:
|
|
await some_operation(session)
|
|
await another_operation(session)
|
|
# All operations committed together or rolled back
|
|
|
|
async with atomic(session) as tx:
|
|
await outer_operation(session)
|
|
async with tx.savepoint() as sp:
|
|
await risky_operation(session)
|
|
if error:
|
|
await sp.rollback() # Only rollback nested ops
|
|
"""
|
|
logger.debug(
|
|
"Starting atomic block with propagation: %s",
|
|
propagation.value,
|
|
)
|
|
|
|
if propagation == TransactionPropagation.NESTED:
|
|
# Use savepoint for nested propagation
|
|
if session.in_transaction():
|
|
nested = await session.begin_nested()
|
|
sp_context = AsyncSavepointContext(nested, "atomic_nested", session)
|
|
|
|
try:
|
|
# Create a wrapper context for consistency
|
|
wrapper = AsyncTransactionContext(session)
|
|
wrapper._committed = True # Parent manages commit
|
|
yield wrapper
|
|
|
|
if not sp_context._rolled_back:
|
|
logger.debug("Releasing nested atomic savepoint")
|
|
|
|
except Exception as e:
|
|
if not sp_context._rolled_back:
|
|
logger.warning(
|
|
"Rolling back nested atomic savepoint due to: %s", e
|
|
)
|
|
await nested.rollback()
|
|
raise
|
|
else:
|
|
# No existing transaction, start new one
|
|
async with AsyncTransactionContext(session) as tx:
|
|
yield tx
|
|
else:
|
|
# REQUIRED or REQUIRES_NEW
|
|
async with AsyncTransactionContext(session) as tx:
|
|
yield tx
|
|
|
|
|
|
@contextmanager
|
|
def atomic_sync(
|
|
session: Session,
|
|
propagation: TransactionPropagation = TransactionPropagation.REQUIRED,
|
|
) -> Generator[TransactionContext, None, None]:
|
|
"""Sync context manager for atomic database operations.
|
|
|
|
Args:
|
|
session: SQLAlchemy sync session
|
|
propagation: Transaction propagation behavior
|
|
|
|
Yields:
|
|
TransactionContext for transaction control
|
|
"""
|
|
logger.debug(
|
|
"Starting sync atomic block with propagation: %s",
|
|
propagation.value,
|
|
)
|
|
|
|
if propagation == TransactionPropagation.NESTED:
|
|
if session.in_transaction():
|
|
nested = session.begin_nested()
|
|
sp_context = SavepointContext(nested, "atomic_nested")
|
|
|
|
try:
|
|
wrapper = TransactionContext(session)
|
|
wrapper._committed = True
|
|
yield wrapper
|
|
|
|
if not sp_context._rolled_back:
|
|
logger.debug("Releasing nested sync atomic savepoint")
|
|
|
|
except Exception as e:
|
|
if not sp_context._rolled_back:
|
|
logger.warning(
|
|
"Rolling back nested sync savepoint due to: %s", e
|
|
)
|
|
nested.rollback()
|
|
raise
|
|
else:
|
|
with TransactionContext(session) as tx:
|
|
yield tx
|
|
else:
|
|
with TransactionContext(session) as tx:
|
|
yield tx
|
|
|
|
|
|
def transactional(
|
|
propagation: TransactionPropagation = TransactionPropagation.REQUIRED,
|
|
session_param: str = "db",
|
|
) -> Callable[[Callable[P, T]], Callable[P, T]]:
|
|
"""Decorator to wrap a function in a transaction boundary.
|
|
|
|
Automatically handles commit on success and rollback on exception.
|
|
Works with both sync and async functions.
|
|
|
|
Args:
|
|
propagation: Transaction propagation behavior
|
|
session_param: Name of the session parameter in the function signature
|
|
|
|
Returns:
|
|
Decorated function wrapped in transaction
|
|
|
|
Example:
|
|
@transactional()
|
|
async def create_user_with_profile(db: AsyncSession, data: dict):
|
|
user = await create_user(db, data['user'])
|
|
profile = await create_profile(db, user.id, data['profile'])
|
|
return user, profile
|
|
|
|
@transactional(propagation=TransactionPropagation.NESTED)
|
|
async def risky_sub_operation(db: AsyncSession, data: dict):
|
|
# This can be rolled back without affecting parent transaction
|
|
pass
|
|
"""
|
|
def decorator(func: Callable[P, T]) -> Callable[P, T]:
|
|
import asyncio
|
|
|
|
if asyncio.iscoroutinefunction(func):
|
|
@functools.wraps(func)
|
|
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
|
# Get session from kwargs or args
|
|
session = _extract_session(func, args, kwargs, session_param)
|
|
|
|
if session is None:
|
|
raise TransactionError(
|
|
f"Could not find session parameter '{session_param}' "
|
|
f"in function {func.__name__}"
|
|
)
|
|
|
|
logger.debug(
|
|
"Starting transaction for %s with propagation %s",
|
|
func.__name__,
|
|
propagation.value,
|
|
)
|
|
|
|
async with atomic(session, propagation):
|
|
result = await func(*args, **kwargs)
|
|
|
|
logger.debug(
|
|
"Transaction completed for %s",
|
|
func.__name__,
|
|
)
|
|
|
|
return result
|
|
|
|
return async_wrapper # type: ignore
|
|
else:
|
|
@functools.wraps(func)
|
|
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
|
# Get session from kwargs or args
|
|
session = _extract_session(func, args, kwargs, session_param)
|
|
|
|
if session is None:
|
|
raise TransactionError(
|
|
f"Could not find session parameter '{session_param}' "
|
|
f"in function {func.__name__}"
|
|
)
|
|
|
|
logger.debug(
|
|
"Starting sync transaction for %s with propagation %s",
|
|
func.__name__,
|
|
propagation.value,
|
|
)
|
|
|
|
with atomic_sync(session, propagation):
|
|
result = func(*args, **kwargs)
|
|
|
|
logger.debug(
|
|
"Sync transaction completed for %s",
|
|
func.__name__,
|
|
)
|
|
|
|
return result
|
|
|
|
return sync_wrapper # type: ignore
|
|
|
|
return decorator
|
|
|
|
|
|
def _extract_session(
|
|
func: Callable,
|
|
args: tuple,
|
|
kwargs: dict,
|
|
session_param: str,
|
|
) -> Optional[AsyncSession | Session]:
|
|
"""Extract session from function arguments.
|
|
|
|
Args:
|
|
func: The function being called
|
|
args: Positional arguments
|
|
kwargs: Keyword arguments
|
|
session_param: Name of the session parameter
|
|
|
|
Returns:
|
|
Session instance or None if not found
|
|
"""
|
|
import inspect
|
|
|
|
# Check kwargs first
|
|
if session_param in kwargs:
|
|
return kwargs[session_param]
|
|
|
|
# Get function signature to find positional index
|
|
sig = inspect.signature(func)
|
|
params = list(sig.parameters.keys())
|
|
|
|
if session_param in params:
|
|
idx = params.index(session_param)
|
|
# Account for 'self' parameter in methods
|
|
if len(args) > idx:
|
|
return args[idx]
|
|
|
|
return None
|
|
|
|
|
|
def is_in_transaction(session: AsyncSession | Session) -> bool:
|
|
"""Check if session is currently in a transaction.
|
|
|
|
Args:
|
|
session: SQLAlchemy session (sync or async)
|
|
|
|
Returns:
|
|
True if session is in an active transaction
|
|
"""
|
|
return session.in_transaction()
|
|
|
|
|
|
def get_transaction_depth(session: AsyncSession | Session) -> int:
|
|
"""Get the current transaction nesting depth.
|
|
|
|
Args:
|
|
session: SQLAlchemy session (sync or async)
|
|
|
|
Returns:
|
|
Number of nested transactions (0 if not in transaction)
|
|
"""
|
|
# SQLAlchemy doesn't expose nesting depth directly,
|
|
# but we can check transaction state
|
|
if not session.in_transaction():
|
|
return 0
|
|
|
|
# Check for nested transaction
|
|
if hasattr(session, '_nested_transaction') and session._nested_transaction:
|
|
return 2 # At least one savepoint
|
|
|
|
return 1
|
|
|
|
|
|
__all__ = [
|
|
"TransactionPropagation",
|
|
"TransactionError",
|
|
"TransactionContext",
|
|
"AsyncTransactionContext",
|
|
"SavepointContext",
|
|
"AsyncSavepointContext",
|
|
"atomic",
|
|
"atomic_sync",
|
|
"transactional",
|
|
"is_in_transaction",
|
|
"get_transaction_depth",
|
|
]
|