Aniworld/src/server/database/transaction.py
Lukas 1ba67357dc Add database transaction support with atomic operations
- Create transaction.py with @transactional decorator, atomic() context manager
- Add TransactionPropagation modes: REQUIRED, REQUIRES_NEW, NESTED
- Add savepoint support for nested transactions with partial rollback
- Update connection.py with TransactionManager, get_transactional_session
- Update service.py with bulk operations (bulk_mark_downloaded, bulk_delete)
- Wrap QueueRepository.save_item() and clear_all() in atomic transactions
- Add comprehensive tests (66 transaction tests, 90% coverage)
- All 1090 tests passing
2025-12-25 18:05:33 +01:00

716 lines
22 KiB
Python

"""Transaction management utilities for SQLAlchemy.
This module provides transaction management utilities including decorators,
context managers, and helper functions for ensuring data consistency
across database operations.
Components:
- @transactional decorator: Wraps functions in transaction boundaries
- TransactionContext: Sync context manager for explicit transaction control
- atomic(): Async context manager for async operations
- TransactionPropagation: Enum for transaction propagation modes
Usage:
@transactional
async def compound_operation(session: AsyncSession, data: Model) -> Result:
# Multiple write operations here
# All succeed or all fail
pass
async with atomic(session) as tx:
# Operations here
async with tx.savepoint() as sp:
# Nested operations with partial rollback capability
pass
"""
from __future__ import annotations
import functools
import logging
from contextlib import asynccontextmanager, contextmanager
from enum import Enum
from typing import (
Any,
AsyncGenerator,
Callable,
Generator,
Optional,
ParamSpec,
TypeVar,
)
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
logger = logging.getLogger(__name__)
# Type variables for generic typing
T = TypeVar("T")
P = ParamSpec("P")
class TransactionPropagation(Enum):
"""Transaction propagation behavior options.
Defines how transactions should behave when called within
an existing transaction context.
Values:
REQUIRED: Use existing transaction or create new one (default)
REQUIRES_NEW: Always create a new transaction (suspend existing)
NESTED: Create a savepoint within existing transaction
"""
REQUIRED = "required"
REQUIRES_NEW = "requires_new"
NESTED = "nested"
class TransactionError(Exception):
"""Exception raised for transaction-related errors."""
class TransactionContext:
"""Synchronous context manager for explicit transaction control.
Provides a clean interface for managing database transactions with
automatic commit/rollback semantics and savepoint support.
Attributes:
session: SQLAlchemy Session instance
_savepoint_count: Counter for nested savepoints
Example:
with TransactionContext(session) as tx:
# Database operations here
with tx.savepoint() as sp:
# Nested operations with partial rollback
pass
"""
def __init__(self, session: Session) -> None:
"""Initialize transaction context.
Args:
session: SQLAlchemy sync session
"""
self.session = session
self._savepoint_count = 0
self._committed = False
def __enter__(self) -> "TransactionContext":
"""Enter transaction context.
Begins a new transaction if not already in one.
Returns:
Self for context manager protocol
"""
logger.debug("Entering transaction context")
# Check if session is already in a transaction
if not self.session.in_transaction():
self.session.begin()
logger.debug("Started new transaction")
return self
def __exit__(
self,
exc_type: Optional[type],
exc_val: Optional[BaseException],
exc_tb: Optional[Any],
) -> bool:
"""Exit transaction context.
Commits on success, rolls back on exception.
Args:
exc_type: Exception type if raised
exc_val: Exception value if raised
exc_tb: Exception traceback if raised
Returns:
False to propagate exceptions
"""
if exc_type is not None:
logger.warning(
"Transaction rollback due to exception: %s: %s",
exc_type.__name__,
exc_val,
)
self.session.rollback()
return False
if not self._committed:
self.session.commit()
logger.debug("Transaction committed")
self._committed = True
return False
@contextmanager
def savepoint(self, name: Optional[str] = None) -> Generator["SavepointContext", None, None]:
"""Create a savepoint for partial rollback capability.
Savepoints allow nested transactions where inner operations
can be rolled back without affecting outer operations.
Args:
name: Optional savepoint name (auto-generated if not provided)
Yields:
SavepointContext for nested transaction control
Example:
with tx.savepoint() as sp:
# Operations here can be rolled back independently
if error_condition:
sp.rollback()
"""
self._savepoint_count += 1
savepoint_name = name or f"sp_{self._savepoint_count}"
logger.debug("Creating savepoint: %s", savepoint_name)
nested = self.session.begin_nested()
sp_context = SavepointContext(nested, savepoint_name)
try:
yield sp_context
if not sp_context._rolled_back:
# Commit the savepoint (release it)
logger.debug("Releasing savepoint: %s", savepoint_name)
except Exception as e:
if not sp_context._rolled_back:
logger.warning(
"Rolling back savepoint %s due to exception: %s",
savepoint_name,
e,
)
nested.rollback()
raise
def commit(self) -> None:
"""Explicitly commit the transaction.
Use this for early commit within the context.
"""
if not self._committed:
self.session.commit()
self._committed = True
logger.debug("Transaction explicitly committed")
def rollback(self) -> None:
"""Explicitly rollback the transaction.
Use this for early rollback within the context.
"""
self.session.rollback()
self._committed = True # Prevent double commit
logger.debug("Transaction explicitly rolled back")
class SavepointContext:
"""Context for managing a database savepoint.
Provides explicit control over savepoint commit/rollback.
Attributes:
_nested: SQLAlchemy nested transaction object
_name: Savepoint name for logging
_rolled_back: Whether rollback has been called
"""
def __init__(self, nested: Any, name: str) -> None:
"""Initialize savepoint context.
Args:
nested: SQLAlchemy nested transaction
name: Savepoint name for logging
"""
self._nested = nested
self._name = name
self._rolled_back = False
def rollback(self) -> None:
"""Rollback to this savepoint.
Undoes all changes since the savepoint was created.
"""
if not self._rolled_back:
self._nested.rollback()
self._rolled_back = True
logger.debug("Savepoint %s rolled back", self._name)
def commit(self) -> None:
"""Commit (release) this savepoint.
Makes changes since the savepoint permanent within
the parent transaction.
"""
if not self._rolled_back:
# SQLAlchemy commits nested transactions automatically
# when exiting the context without rollback
logger.debug("Savepoint %s committed", self._name)
class AsyncTransactionContext:
"""Asynchronous context manager for explicit transaction control.
Provides async interface for managing database transactions with
automatic commit/rollback semantics and savepoint support.
Attributes:
session: SQLAlchemy AsyncSession instance
_savepoint_count: Counter for nested savepoints
Example:
async with AsyncTransactionContext(session) as tx:
# Database operations here
async with tx.savepoint() as sp:
# Nested operations with partial rollback
pass
"""
def __init__(self, session: AsyncSession) -> None:
"""Initialize async transaction context.
Args:
session: SQLAlchemy async session
"""
self.session = session
self._savepoint_count = 0
self._committed = False
async def __aenter__(self) -> "AsyncTransactionContext":
"""Enter async transaction context.
Begins a new transaction if not already in one.
Returns:
Self for context manager protocol
"""
logger.debug("Entering async transaction context")
# Check if session is already in a transaction
if not self.session.in_transaction():
await self.session.begin()
logger.debug("Started new async transaction")
return self
async def __aexit__(
self,
exc_type: Optional[type],
exc_val: Optional[BaseException],
exc_tb: Optional[Any],
) -> bool:
"""Exit async transaction context.
Commits on success, rolls back on exception.
Args:
exc_type: Exception type if raised
exc_val: Exception value if raised
exc_tb: Exception traceback if raised
Returns:
False to propagate exceptions
"""
if exc_type is not None:
logger.warning(
"Async transaction rollback due to exception: %s: %s",
exc_type.__name__,
exc_val,
)
await self.session.rollback()
return False
if not self._committed:
await self.session.commit()
logger.debug("Async transaction committed")
self._committed = True
return False
@asynccontextmanager
async def savepoint(
self, name: Optional[str] = None
) -> AsyncGenerator["AsyncSavepointContext", None]:
"""Create an async savepoint for partial rollback capability.
Args:
name: Optional savepoint name (auto-generated if not provided)
Yields:
AsyncSavepointContext for nested transaction control
"""
self._savepoint_count += 1
savepoint_name = name or f"sp_{self._savepoint_count}"
logger.debug("Creating async savepoint: %s", savepoint_name)
nested = await self.session.begin_nested()
sp_context = AsyncSavepointContext(nested, savepoint_name, self.session)
try:
yield sp_context
if not sp_context._rolled_back:
logger.debug("Releasing async savepoint: %s", savepoint_name)
except Exception as e:
if not sp_context._rolled_back:
logger.warning(
"Rolling back async savepoint %s due to exception: %s",
savepoint_name,
e,
)
await nested.rollback()
raise
async def commit(self) -> None:
"""Explicitly commit the async transaction."""
if not self._committed:
await self.session.commit()
self._committed = True
logger.debug("Async transaction explicitly committed")
async def rollback(self) -> None:
"""Explicitly rollback the async transaction."""
await self.session.rollback()
self._committed = True # Prevent double commit
logger.debug("Async transaction explicitly rolled back")
class AsyncSavepointContext:
"""Async context for managing a database savepoint.
Attributes:
_nested: SQLAlchemy nested transaction object
_name: Savepoint name for logging
_session: Parent session for async operations
_rolled_back: Whether rollback has been called
"""
def __init__(
self, nested: Any, name: str, session: AsyncSession
) -> None:
"""Initialize async savepoint context.
Args:
nested: SQLAlchemy nested transaction
name: Savepoint name for logging
session: Parent async session
"""
self._nested = nested
self._name = name
self._session = session
self._rolled_back = False
async def rollback(self) -> None:
"""Rollback to this savepoint asynchronously."""
if not self._rolled_back:
await self._nested.rollback()
self._rolled_back = True
logger.debug("Async savepoint %s rolled back", self._name)
async def commit(self) -> None:
"""Commit (release) this savepoint asynchronously."""
if not self._rolled_back:
logger.debug("Async savepoint %s committed", self._name)
@asynccontextmanager
async def atomic(
session: AsyncSession,
propagation: TransactionPropagation = TransactionPropagation.REQUIRED,
) -> AsyncGenerator[AsyncTransactionContext, None]:
"""Async context manager for atomic database operations.
Provides a clean interface for wrapping database operations in
a transaction boundary with automatic commit/rollback.
Args:
session: SQLAlchemy async session
propagation: Transaction propagation behavior
Yields:
AsyncTransactionContext for transaction control
Example:
async with atomic(session) as tx:
await some_operation(session)
await another_operation(session)
# All operations committed together or rolled back
async with atomic(session) as tx:
await outer_operation(session)
async with tx.savepoint() as sp:
await risky_operation(session)
if error:
await sp.rollback() # Only rollback nested ops
"""
logger.debug(
"Starting atomic block with propagation: %s",
propagation.value,
)
if propagation == TransactionPropagation.NESTED:
# Use savepoint for nested propagation
if session.in_transaction():
nested = await session.begin_nested()
sp_context = AsyncSavepointContext(nested, "atomic_nested", session)
try:
# Create a wrapper context for consistency
wrapper = AsyncTransactionContext(session)
wrapper._committed = True # Parent manages commit
yield wrapper
if not sp_context._rolled_back:
logger.debug("Releasing nested atomic savepoint")
except Exception as e:
if not sp_context._rolled_back:
logger.warning(
"Rolling back nested atomic savepoint due to: %s", e
)
await nested.rollback()
raise
else:
# No existing transaction, start new one
async with AsyncTransactionContext(session) as tx:
yield tx
else:
# REQUIRED or REQUIRES_NEW
async with AsyncTransactionContext(session) as tx:
yield tx
@contextmanager
def atomic_sync(
session: Session,
propagation: TransactionPropagation = TransactionPropagation.REQUIRED,
) -> Generator[TransactionContext, None, None]:
"""Sync context manager for atomic database operations.
Args:
session: SQLAlchemy sync session
propagation: Transaction propagation behavior
Yields:
TransactionContext for transaction control
"""
logger.debug(
"Starting sync atomic block with propagation: %s",
propagation.value,
)
if propagation == TransactionPropagation.NESTED:
if session.in_transaction():
nested = session.begin_nested()
sp_context = SavepointContext(nested, "atomic_nested")
try:
wrapper = TransactionContext(session)
wrapper._committed = True
yield wrapper
if not sp_context._rolled_back:
logger.debug("Releasing nested sync atomic savepoint")
except Exception as e:
if not sp_context._rolled_back:
logger.warning(
"Rolling back nested sync savepoint due to: %s", e
)
nested.rollback()
raise
else:
with TransactionContext(session) as tx:
yield tx
else:
with TransactionContext(session) as tx:
yield tx
def transactional(
propagation: TransactionPropagation = TransactionPropagation.REQUIRED,
session_param: str = "db",
) -> Callable[[Callable[P, T]], Callable[P, T]]:
"""Decorator to wrap a function in a transaction boundary.
Automatically handles commit on success and rollback on exception.
Works with both sync and async functions.
Args:
propagation: Transaction propagation behavior
session_param: Name of the session parameter in the function signature
Returns:
Decorated function wrapped in transaction
Example:
@transactional()
async def create_user_with_profile(db: AsyncSession, data: dict):
user = await create_user(db, data['user'])
profile = await create_profile(db, user.id, data['profile'])
return user, profile
@transactional(propagation=TransactionPropagation.NESTED)
async def risky_sub_operation(db: AsyncSession, data: dict):
# This can be rolled back without affecting parent transaction
pass
"""
def decorator(func: Callable[P, T]) -> Callable[P, T]:
import asyncio
if asyncio.iscoroutinefunction(func):
@functools.wraps(func)
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
# Get session from kwargs or args
session = _extract_session(func, args, kwargs, session_param)
if session is None:
raise TransactionError(
f"Could not find session parameter '{session_param}' "
f"in function {func.__name__}"
)
logger.debug(
"Starting transaction for %s with propagation %s",
func.__name__,
propagation.value,
)
async with atomic(session, propagation):
result = await func(*args, **kwargs)
logger.debug(
"Transaction completed for %s",
func.__name__,
)
return result
return async_wrapper # type: ignore
else:
@functools.wraps(func)
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
# Get session from kwargs or args
session = _extract_session(func, args, kwargs, session_param)
if session is None:
raise TransactionError(
f"Could not find session parameter '{session_param}' "
f"in function {func.__name__}"
)
logger.debug(
"Starting sync transaction for %s with propagation %s",
func.__name__,
propagation.value,
)
with atomic_sync(session, propagation):
result = func(*args, **kwargs)
logger.debug(
"Sync transaction completed for %s",
func.__name__,
)
return result
return sync_wrapper # type: ignore
return decorator
def _extract_session(
func: Callable,
args: tuple,
kwargs: dict,
session_param: str,
) -> Optional[AsyncSession | Session]:
"""Extract session from function arguments.
Args:
func: The function being called
args: Positional arguments
kwargs: Keyword arguments
session_param: Name of the session parameter
Returns:
Session instance or None if not found
"""
import inspect
# Check kwargs first
if session_param in kwargs:
return kwargs[session_param]
# Get function signature to find positional index
sig = inspect.signature(func)
params = list(sig.parameters.keys())
if session_param in params:
idx = params.index(session_param)
# Account for 'self' parameter in methods
if len(args) > idx:
return args[idx]
return None
def is_in_transaction(session: AsyncSession | Session) -> bool:
"""Check if session is currently in a transaction.
Args:
session: SQLAlchemy session (sync or async)
Returns:
True if session is in an active transaction
"""
return session.in_transaction()
def get_transaction_depth(session: AsyncSession | Session) -> int:
"""Get the current transaction nesting depth.
Args:
session: SQLAlchemy session (sync or async)
Returns:
Number of nested transactions (0 if not in transaction)
"""
# SQLAlchemy doesn't expose nesting depth directly,
# but we can check transaction state
if not session.in_transaction():
return 0
# Check for nested transaction
if hasattr(session, '_nested_transaction') and session._nested_transaction:
return 2 # At least one savepoint
return 1
__all__ = [
"TransactionPropagation",
"TransactionError",
"TransactionContext",
"AsyncTransactionContext",
"SavepointContext",
"AsyncSavepointContext",
"atomic",
"atomic_sync",
"transactional",
"is_in_transaction",
"get_transaction_depth",
]