"""Transaction management utilities for SQLAlchemy. This module provides transaction management utilities including decorators, context managers, and helper functions for ensuring data consistency across database operations. Components: - @transactional decorator: Wraps functions in transaction boundaries - TransactionContext: Sync context manager for explicit transaction control - atomic(): Async context manager for async operations - TransactionPropagation: Enum for transaction propagation modes Usage: @transactional async def compound_operation(session: AsyncSession, data: Model) -> Result: # Multiple write operations here # All succeed or all fail pass async with atomic(session) as tx: # Operations here async with tx.savepoint() as sp: # Nested operations with partial rollback capability pass """ from __future__ import annotations import functools import logging from contextlib import asynccontextmanager, contextmanager from enum import Enum from typing import ( Any, AsyncGenerator, Callable, Generator, Optional, ParamSpec, TypeVar, ) from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session logger = logging.getLogger(__name__) # Type variables for generic typing T = TypeVar("T") P = ParamSpec("P") class TransactionPropagation(Enum): """Transaction propagation behavior options. Defines how transactions should behave when called within an existing transaction context. Values: REQUIRED: Use existing transaction or create new one (default) REQUIRES_NEW: Always create a new transaction (suspend existing) NESTED: Create a savepoint within existing transaction """ REQUIRED = "required" REQUIRES_NEW = "requires_new" NESTED = "nested" class TransactionError(Exception): """Exception raised for transaction-related errors.""" class TransactionContext: """Synchronous context manager for explicit transaction control. Provides a clean interface for managing database transactions with automatic commit/rollback semantics and savepoint support. Attributes: session: SQLAlchemy Session instance _savepoint_count: Counter for nested savepoints Example: with TransactionContext(session) as tx: # Database operations here with tx.savepoint() as sp: # Nested operations with partial rollback pass """ def __init__(self, session: Session) -> None: """Initialize transaction context. Args: session: SQLAlchemy sync session """ self.session = session self._savepoint_count = 0 self._committed = False def __enter__(self) -> "TransactionContext": """Enter transaction context. Begins a new transaction if not already in one. Returns: Self for context manager protocol """ logger.debug("Entering transaction context") # Check if session is already in a transaction if not self.session.in_transaction(): self.session.begin() logger.debug("Started new transaction") return self def __exit__( self, exc_type: Optional[type], exc_val: Optional[BaseException], exc_tb: Optional[Any], ) -> bool: """Exit transaction context. Commits on success, rolls back on exception. Args: exc_type: Exception type if raised exc_val: Exception value if raised exc_tb: Exception traceback if raised Returns: False to propagate exceptions """ if exc_type is not None: logger.warning( "Transaction rollback due to exception: %s: %s", exc_type.__name__, exc_val, ) self.session.rollback() return False if not self._committed: self.session.commit() logger.debug("Transaction committed") self._committed = True return False @contextmanager def savepoint(self, name: Optional[str] = None) -> Generator["SavepointContext", None, None]: """Create a savepoint for partial rollback capability. Savepoints allow nested transactions where inner operations can be rolled back without affecting outer operations. Args: name: Optional savepoint name (auto-generated if not provided) Yields: SavepointContext for nested transaction control Example: with tx.savepoint() as sp: # Operations here can be rolled back independently if error_condition: sp.rollback() """ self._savepoint_count += 1 savepoint_name = name or f"sp_{self._savepoint_count}" logger.debug("Creating savepoint: %s", savepoint_name) nested = self.session.begin_nested() sp_context = SavepointContext(nested, savepoint_name) try: yield sp_context if not sp_context._rolled_back: # Commit the savepoint (release it) logger.debug("Releasing savepoint: %s", savepoint_name) except Exception as e: if not sp_context._rolled_back: logger.warning( "Rolling back savepoint %s due to exception: %s", savepoint_name, e, ) nested.rollback() raise def commit(self) -> None: """Explicitly commit the transaction. Use this for early commit within the context. """ if not self._committed: self.session.commit() self._committed = True logger.debug("Transaction explicitly committed") def rollback(self) -> None: """Explicitly rollback the transaction. Use this for early rollback within the context. """ self.session.rollback() self._committed = True # Prevent double commit logger.debug("Transaction explicitly rolled back") class SavepointContext: """Context for managing a database savepoint. Provides explicit control over savepoint commit/rollback. Attributes: _nested: SQLAlchemy nested transaction object _name: Savepoint name for logging _rolled_back: Whether rollback has been called """ def __init__(self, nested: Any, name: str) -> None: """Initialize savepoint context. Args: nested: SQLAlchemy nested transaction name: Savepoint name for logging """ self._nested = nested self._name = name self._rolled_back = False def rollback(self) -> None: """Rollback to this savepoint. Undoes all changes since the savepoint was created. """ if not self._rolled_back: self._nested.rollback() self._rolled_back = True logger.debug("Savepoint %s rolled back", self._name) def commit(self) -> None: """Commit (release) this savepoint. Makes changes since the savepoint permanent within the parent transaction. """ if not self._rolled_back: # SQLAlchemy commits nested transactions automatically # when exiting the context without rollback logger.debug("Savepoint %s committed", self._name) class AsyncTransactionContext: """Asynchronous context manager for explicit transaction control. Provides async interface for managing database transactions with automatic commit/rollback semantics and savepoint support. Attributes: session: SQLAlchemy AsyncSession instance _savepoint_count: Counter for nested savepoints Example: async with AsyncTransactionContext(session) as tx: # Database operations here async with tx.savepoint() as sp: # Nested operations with partial rollback pass """ def __init__(self, session: AsyncSession) -> None: """Initialize async transaction context. Args: session: SQLAlchemy async session """ self.session = session self._savepoint_count = 0 self._committed = False async def __aenter__(self) -> "AsyncTransactionContext": """Enter async transaction context. Begins a new transaction if not already in one. Returns: Self for context manager protocol """ logger.debug("Entering async transaction context") # Check if session is already in a transaction if not self.session.in_transaction(): await self.session.begin() logger.debug("Started new async transaction") return self async def __aexit__( self, exc_type: Optional[type], exc_val: Optional[BaseException], exc_tb: Optional[Any], ) -> bool: """Exit async transaction context. Commits on success, rolls back on exception. Args: exc_type: Exception type if raised exc_val: Exception value if raised exc_tb: Exception traceback if raised Returns: False to propagate exceptions """ if exc_type is not None: logger.warning( "Async transaction rollback due to exception: %s: %s", exc_type.__name__, exc_val, ) await self.session.rollback() return False if not self._committed: await self.session.commit() logger.debug("Async transaction committed") self._committed = True return False @asynccontextmanager async def savepoint( self, name: Optional[str] = None ) -> AsyncGenerator["AsyncSavepointContext", None]: """Create an async savepoint for partial rollback capability. Args: name: Optional savepoint name (auto-generated if not provided) Yields: AsyncSavepointContext for nested transaction control """ self._savepoint_count += 1 savepoint_name = name or f"sp_{self._savepoint_count}" logger.debug("Creating async savepoint: %s", savepoint_name) nested = await self.session.begin_nested() sp_context = AsyncSavepointContext(nested, savepoint_name, self.session) try: yield sp_context if not sp_context._rolled_back: logger.debug("Releasing async savepoint: %s", savepoint_name) except Exception as e: if not sp_context._rolled_back: logger.warning( "Rolling back async savepoint %s due to exception: %s", savepoint_name, e, ) await nested.rollback() raise async def commit(self) -> None: """Explicitly commit the async transaction.""" if not self._committed: await self.session.commit() self._committed = True logger.debug("Async transaction explicitly committed") async def rollback(self) -> None: """Explicitly rollback the async transaction.""" await self.session.rollback() self._committed = True # Prevent double commit logger.debug("Async transaction explicitly rolled back") class AsyncSavepointContext: """Async context for managing a database savepoint. Attributes: _nested: SQLAlchemy nested transaction object _name: Savepoint name for logging _session: Parent session for async operations _rolled_back: Whether rollback has been called """ def __init__( self, nested: Any, name: str, session: AsyncSession ) -> None: """Initialize async savepoint context. Args: nested: SQLAlchemy nested transaction name: Savepoint name for logging session: Parent async session """ self._nested = nested self._name = name self._session = session self._rolled_back = False async def rollback(self) -> None: """Rollback to this savepoint asynchronously.""" if not self._rolled_back: await self._nested.rollback() self._rolled_back = True logger.debug("Async savepoint %s rolled back", self._name) async def commit(self) -> None: """Commit (release) this savepoint asynchronously.""" if not self._rolled_back: logger.debug("Async savepoint %s committed", self._name) @asynccontextmanager async def atomic( session: AsyncSession, propagation: TransactionPropagation = TransactionPropagation.REQUIRED, ) -> AsyncGenerator[AsyncTransactionContext, None]: """Async context manager for atomic database operations. Provides a clean interface for wrapping database operations in a transaction boundary with automatic commit/rollback. Args: session: SQLAlchemy async session propagation: Transaction propagation behavior Yields: AsyncTransactionContext for transaction control Example: async with atomic(session) as tx: await some_operation(session) await another_operation(session) # All operations committed together or rolled back async with atomic(session) as tx: await outer_operation(session) async with tx.savepoint() as sp: await risky_operation(session) if error: await sp.rollback() # Only rollback nested ops """ logger.debug( "Starting atomic block with propagation: %s", propagation.value, ) if propagation == TransactionPropagation.NESTED: # Use savepoint for nested propagation if session.in_transaction(): nested = await session.begin_nested() sp_context = AsyncSavepointContext(nested, "atomic_nested", session) try: # Create a wrapper context for consistency wrapper = AsyncTransactionContext(session) wrapper._committed = True # Parent manages commit yield wrapper if not sp_context._rolled_back: logger.debug("Releasing nested atomic savepoint") except Exception as e: if not sp_context._rolled_back: logger.warning( "Rolling back nested atomic savepoint due to: %s", e ) await nested.rollback() raise else: # No existing transaction, start new one async with AsyncTransactionContext(session) as tx: yield tx else: # REQUIRED or REQUIRES_NEW async with AsyncTransactionContext(session) as tx: yield tx @contextmanager def atomic_sync( session: Session, propagation: TransactionPropagation = TransactionPropagation.REQUIRED, ) -> Generator[TransactionContext, None, None]: """Sync context manager for atomic database operations. Args: session: SQLAlchemy sync session propagation: Transaction propagation behavior Yields: TransactionContext for transaction control """ logger.debug( "Starting sync atomic block with propagation: %s", propagation.value, ) if propagation == TransactionPropagation.NESTED: if session.in_transaction(): nested = session.begin_nested() sp_context = SavepointContext(nested, "atomic_nested") try: wrapper = TransactionContext(session) wrapper._committed = True yield wrapper if not sp_context._rolled_back: logger.debug("Releasing nested sync atomic savepoint") except Exception as e: if not sp_context._rolled_back: logger.warning( "Rolling back nested sync savepoint due to: %s", e ) nested.rollback() raise else: with TransactionContext(session) as tx: yield tx else: with TransactionContext(session) as tx: yield tx def transactional( propagation: TransactionPropagation = TransactionPropagation.REQUIRED, session_param: str = "db", ) -> Callable[[Callable[P, T]], Callable[P, T]]: """Decorator to wrap a function in a transaction boundary. Automatically handles commit on success and rollback on exception. Works with both sync and async functions. Args: propagation: Transaction propagation behavior session_param: Name of the session parameter in the function signature Returns: Decorated function wrapped in transaction Example: @transactional() async def create_user_with_profile(db: AsyncSession, data: dict): user = await create_user(db, data['user']) profile = await create_profile(db, user.id, data['profile']) return user, profile @transactional(propagation=TransactionPropagation.NESTED) async def risky_sub_operation(db: AsyncSession, data: dict): # This can be rolled back without affecting parent transaction pass """ def decorator(func: Callable[P, T]) -> Callable[P, T]: import asyncio if asyncio.iscoroutinefunction(func): @functools.wraps(func) async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # Get session from kwargs or args session = _extract_session(func, args, kwargs, session_param) if session is None: raise TransactionError( f"Could not find session parameter '{session_param}' " f"in function {func.__name__}" ) logger.debug( "Starting transaction for %s with propagation %s", func.__name__, propagation.value, ) async with atomic(session, propagation): result = await func(*args, **kwargs) logger.debug( "Transaction completed for %s", func.__name__, ) return result return async_wrapper # type: ignore else: @functools.wraps(func) def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # Get session from kwargs or args session = _extract_session(func, args, kwargs, session_param) if session is None: raise TransactionError( f"Could not find session parameter '{session_param}' " f"in function {func.__name__}" ) logger.debug( "Starting sync transaction for %s with propagation %s", func.__name__, propagation.value, ) with atomic_sync(session, propagation): result = func(*args, **kwargs) logger.debug( "Sync transaction completed for %s", func.__name__, ) return result return sync_wrapper # type: ignore return decorator def _extract_session( func: Callable, args: tuple, kwargs: dict, session_param: str, ) -> Optional[AsyncSession | Session]: """Extract session from function arguments. Args: func: The function being called args: Positional arguments kwargs: Keyword arguments session_param: Name of the session parameter Returns: Session instance or None if not found """ import inspect # Check kwargs first if session_param in kwargs: return kwargs[session_param] # Get function signature to find positional index sig = inspect.signature(func) params = list(sig.parameters.keys()) if session_param in params: idx = params.index(session_param) # Account for 'self' parameter in methods if len(args) > idx: return args[idx] return None def is_in_transaction(session: AsyncSession | Session) -> bool: """Check if session is currently in a transaction. Args: session: SQLAlchemy session (sync or async) Returns: True if session is in an active transaction """ return session.in_transaction() def get_transaction_depth(session: AsyncSession | Session) -> int: """Get the current transaction nesting depth. Args: session: SQLAlchemy session (sync or async) Returns: Number of nested transactions (0 if not in transaction) """ # SQLAlchemy doesn't expose nesting depth directly, # but we can check transaction state if not session.in_transaction(): return 0 # Check for nested transaction if hasattr(session, '_nested_transaction') and session._nested_transaction: return 2 # At least one savepoint return 1 __all__ = [ "TransactionPropagation", "TransactionError", "TransactionContext", "AsyncTransactionContext", "SavepointContext", "AsyncSavepointContext", "atomic", "atomic_sync", "transactional", "is_in_transaction", "get_transaction_depth", ]