Add database transaction support with atomic operations

- Create transaction.py with @transactional decorator, atomic() context manager
- Add TransactionPropagation modes: REQUIRED, REQUIRES_NEW, NESTED
- Add savepoint support for nested transactions with partial rollback
- Update connection.py with TransactionManager, get_transactional_session
- Update service.py with bulk operations (bulk_mark_downloaded, bulk_delete)
- Wrap QueueRepository.save_item() and clear_all() in atomic transactions
- Add comprehensive tests (66 transaction tests, 90% coverage)
- All 1090 tests passing
This commit is contained in:
Lukas 2025-12-25 18:05:33 +01:00
parent b2728a7cf4
commit 1ba67357dc
15 changed files with 3385 additions and 202 deletions

BIN
.coverage

Binary file not shown.

View File

@ -17,7 +17,7 @@
"keep_days": 30
},
"other": {
"master_password_hash": "$pbkdf2-sha256$29000$PgeglJJSytlbqxUipJSylg$E.0KcXCc0.9cYnBCrNeVmZULQnvx2rgNLOFZjYyTiuA"
"master_password_hash": "$pbkdf2-sha256$29000$8P7fG2MspVRqLaVUyrn3Pg$e0HxlEoo7eAfETUFCi7G4/0egtE.Foqsf9eR69Dg6a0"
},
"version": "1.0.0"
}

View File

@ -0,0 +1,23 @@
{
"name": "Aniworld",
"data_dir": "data",
"scheduler": {
"enabled": true,
"interval_minutes": 60
},
"logging": {
"level": "INFO",
"file": null,
"max_bytes": null,
"backup_count": 3
},
"backup": {
"enabled": false,
"path": "data/backups",
"keep_days": 30
},
"other": {
"master_password_hash": "$pbkdf2-sha256$29000$gvDe27t3TilFiHHOuZeSMg$zEPyA6XcqVVTz7raeXZnMtGt/Q5k8ZCl204K0hx5z0w"
},
"version": "1.0.0"
}

View File

@ -0,0 +1,23 @@
{
"name": "Aniworld",
"data_dir": "data",
"scheduler": {
"enabled": true,
"interval_minutes": 60
},
"logging": {
"level": "INFO",
"file": null,
"max_bytes": null,
"backup_count": 3
},
"backup": {
"enabled": false,
"path": "data/backups",
"keep_days": 30
},
"other": {
"master_password_hash": "$pbkdf2-sha256$29000$1pqTMkaoFSLEWKsVAmBsDQ$DHVcHMFFYJxzYmc.7LnDru61mYtMv9PMoxPgfuKed/c"
},
"version": "1.0.0"
}

View File

@ -0,0 +1,23 @@
{
"name": "Aniworld",
"data_dir": "data",
"scheduler": {
"enabled": true,
"interval_minutes": 60
},
"logging": {
"level": "INFO",
"file": null,
"max_bytes": null,
"backup_count": 3
},
"backup": {
"enabled": false,
"path": "data/backups",
"keep_days": 30
},
"other": {
"master_password_hash": "$pbkdf2-sha256$29000$ndM6hxDC.F8LYUxJCSGEEA$UHGXMaEruWVgpRp8JI/siGETH8gOb20svhjy9plb0Wo"
},
"version": "1.0.0"
}

View File

@ -71,6 +71,23 @@ This changelog follows [Keep a Changelog](https://keepachangelog.com/) principle
_Changes that are in development but not yet released._
### Added
- Database transaction support with `@transactional` decorator and `atomic()` context manager
- Transaction propagation modes (REQUIRED, REQUIRES_NEW, NESTED) for fine-grained control
- Savepoint support for nested transactions with partial rollback capability
- `TransactionManager` helper class for manual transaction control
- Bulk operations: `bulk_mark_downloaded`, `bulk_delete`, `clear_all` for batch processing
- `rotate_session` atomic operation for secure session rotation
- Transaction utilities: `is_session_in_transaction`, `get_session_transaction_depth`
- `get_transactional_session` for sessions without auto-commit
### Changed
- `QueueRepository.save_item()` now uses atomic transactions for data consistency
- `QueueRepository.clear_all()` now uses atomic transactions for all-or-nothing behavior
- Service layer documentation updated to reflect transaction-aware design
### Fixed
- Scan status indicator now correctly shows running state after page reload during active scan

View File

@ -197,14 +197,97 @@ Source: [src/server/models/download.py](../src/server/models/download.py#L63-L11
---
## 6. Repository Pattern
## 6. Transaction Support
### 6.1 Overview
The database layer provides comprehensive transaction support to ensure data consistency across compound operations. All write operations can be wrapped in explicit transactions.
Source: [src/server/database/transaction.py](../src/server/database/transaction.py)
### 6.2 Transaction Utilities
| Component | Type | Description |
| ------------------------- | ----------------- | ---------------------------------------- |
| `@transactional` | Decorator | Wraps function in transaction boundary |
| `atomic()` | Async context mgr | Provides atomic operation block |
| `atomic_sync()` | Sync context mgr | Sync version of atomic() |
| `TransactionContext` | Class | Explicit sync transaction control |
| `AsyncTransactionContext` | Class | Explicit async transaction control |
| `TransactionManager` | Class | Helper for manual transaction management |
### 6.3 Transaction Propagation Modes
| Mode | Behavior |
| -------------- | ------------------------------------------------ |
| `REQUIRED` | Use existing transaction or create new (default) |
| `REQUIRES_NEW` | Always create new transaction |
| `NESTED` | Create savepoint within existing transaction |
### 6.4 Usage Examples
**Using @transactional decorator:**
```python
from src.server.database.transaction import transactional
@transactional()
async def compound_operation(db: AsyncSession, data: dict):
# All operations commit together or rollback on error
series = await AnimeSeriesService.create(db, ...)
episode = await EpisodeService.create(db, series_id=series.id, ...)
return series, episode
```
**Using atomic() context manager:**
```python
from src.server.database.transaction import atomic
async def some_function(db: AsyncSession):
async with atomic(db) as tx:
await operation1(db)
await operation2(db)
# Auto-commits on success, rolls back on exception
```
**Using savepoints for partial rollback:**
```python
async with atomic(db) as tx:
await outer_operation(db)
async with tx.savepoint() as sp:
await risky_operation(db)
if error_condition:
await sp.rollback() # Only rollback nested ops
await final_operation(db) # Still executes
```
Source: [src/server/database/transaction.py](../src/server/database/transaction.py)
### 6.5 Connection Module Additions
| Function | Description |
| ------------------------------- | -------------------------------------------- |
| `get_transactional_session` | Session without auto-commit for transactions |
| `TransactionManager` | Helper class for manual transaction control |
| `is_session_in_transaction` | Check if session is in active transaction |
| `get_session_transaction_depth` | Get nesting depth of transactions |
Source: [src/server/database/connection.py](../src/server/database/connection.py)
---
## 7. Repository Pattern
The `QueueRepository` class provides data access abstraction.
```python
class QueueRepository:
async def save_item(self, item: DownloadItem) -> None:
"""Save or update a download item."""
"""Save or update a download item (atomic operation)."""
async def get_all_items(self) -> List[DownloadItem]:
"""Get all items from database."""
@ -212,17 +295,17 @@ class QueueRepository:
async def delete_item(self, item_id: str) -> bool:
"""Delete item by ID."""
async def get_items_by_status(
self, status: DownloadStatus
) -> List[DownloadItem]:
"""Get items filtered by status."""
async def clear_all(self) -> int:
"""Clear all items (atomic operation)."""
```
Note: Compound operations (`save_item`, `clear_all`) are wrapped in `atomic()` transactions.
Source: [src/server/services/queue_repository.py](../src/server/services/queue_repository.py)
---
## 7. Database Service
## 8. Database Service
The `AnimeSeriesService` provides async CRUD operations.
@ -246,11 +329,23 @@ class AnimeSeriesService:
"""Get series by primary key identifier."""
```
### Bulk Operations
Services provide bulk operations for transaction-safe batch processing:
| Service | Method | Description |
| ---------------------- | ---------------------- | ------------------------------ |
| `EpisodeService` | `bulk_mark_downloaded` | Mark multiple episodes at once |
| `DownloadQueueService` | `bulk_delete` | Delete multiple queue items |
| `DownloadQueueService` | `clear_all` | Clear entire queue |
| `UserSessionService` | `rotate_session` | Revoke old + create new atomic |
| `UserSessionService` | `cleanup_expired` | Bulk delete expired sessions |
Source: [src/server/database/service.py](../src/server/database/service.py)
---
## 8. Data Integrity Rules
## 9. Data Integrity Rules
### Validation Constraints
@ -269,7 +364,7 @@ Source: [src/server/database/models.py](../src/server/database/models.py#L89-L11
---
## 9. Migration Strategy
## 10. Migration Strategy
Currently, SQLAlchemy's `create_all()` is used for schema creation.
@ -286,7 +381,7 @@ Source: [src/server/database/connection.py](../src/server/database/connection.py
---
## 10. Common Query Patterns
## 11. Common Query Patterns
### Get all series with missing episodes
@ -317,7 +412,7 @@ items = await db.execute(
---
## 11. Database Location
## 12. Database Location
| Environment | Default Location |
| ----------- | ------------------------------------------------- |

View File

@ -121,147 +121,202 @@ For each task completed:
---
## 🔧 Current Task: Make MP4 Scanning Progress Visible in UI
---
### Problem Statement
## Task: Add Database Transaction Support
When users trigger a library rescan (via the "Rescan Library" button on the anime page), the MP4 file scanning happens silently in the background. Users only see a brief toast message, but there's no visual feedback showing:
### Objective
1. That scanning is actively happening
2. How many files/directories have been scanned
3. The progress through the scan operation
4. When scanning is complete with results
Implement proper transaction handling across all database write operations using SQLAlchemy's transaction support. This ensures data consistency and prevents partial writes during compound operations.
Currently, the only indication is in server logs:
### Background
```
INFO: Starting directory rescan
INFO: Scanning for .mp4 files
Currently, the application uses SQLAlchemy sessions with auto-commit behavior through the `get_db_session()` generator. While individual operations are atomic, compound operations (multiple writes) can result in partial commits if an error occurs mid-operation.
### Requirements
1. **All database write operations must be wrapped in explicit transactions**
2. **Compound operations must be atomic** - either all writes succeed or all fail
3. **Nested operations should use savepoints** for partial rollback capability
4. **Existing functionality must not break** - backward compatible changes only
5. **All tests must pass after implementation**
---
### Step 1: Create Transaction Utilities Module
**File**: `src/server/database/transaction.py`
Create a new module providing transaction management utilities:
1. **`@transactional` decorator** - Wraps a function in a transaction boundary
- Accepts a session parameter or retrieves one via dependency injection
- Commits on success, rolls back on exception
- Re-raises exceptions after rollback
- Logs transaction start, commit, and rollback events
2. **`TransactionContext` class** - Context manager for explicit transaction control
- Supports `with` statement usage
- Provides `savepoint()` method for nested transactions using `begin_nested()`
- Handles commit/rollback automatically
3. **`atomic()` function** - Async context manager for async operations
- Same behavior as `TransactionContext` but for async code
**Interface Requirements**:
- Decorator must work with both sync and async functions
- Must handle the case where session is already in a transaction
- Must support optional `propagation` parameter (REQUIRED, REQUIRES_NEW, NESTED)
---
### Step 2: Update Connection Module
**File**: `src/server/database/connection.py`
Modify the existing session management:
1. Add `get_transactional_session()` generator that does NOT auto-commit
2. Add `TransactionManager` class for manual transaction control
3. Keep `get_db_session()` unchanged for backward compatibility
4. Add session state inspection utilities (`is_in_transaction()`, `get_transaction_depth()`)
---
### Step 3: Wrap Service Layer Operations
**File**: `src/server/database/service.py`
Apply transaction handling to all compound write operations:
**AnimeService**:
- `create_anime_with_episodes()` - if exists, wrap in transaction
- Any method that calls multiple repository methods
**EpisodeService**:
- `bulk_update_episodes()` - if exists
- `mark_episodes_downloaded()` - if handles multiple episodes
**DownloadQueueService**:
- `add_batch_to_queue()` - if exists
- `clear_and_repopulate()` - if exists
- Any method performing multiple writes
**SessionService**:
- `rotate_session()` - delete old + create new must be atomic
- `cleanup_expired_sessions()` - bulk delete operation
**Pattern to follow**:
```python
@transactional
def compound_operation(self, session: Session, data: SomeModel) -> Result:
# Multiple write operations here
# All succeed or all fail
```
### Desired Outcome
---
Users should see real-time progress in the UI during library scanning with:
### Step 4: Update Queue Repository
1. **Progress overlay** showing scan is active with a spinner animation
2. **Live counters** showing directories scanned and files found
3. **Current directory display** showing which folder is being scanned (truncated if too long)
4. **Completion summary** showing total files found, directories scanned, and elapsed time
5. **Auto-dismiss** the overlay after showing completion summary
**File**: `src/server/services/queue_repository.py`
Ensure atomic operations for:
1. `save_item()` - check existence + insert/update must be atomic
2. `remove_item()` - if involves multiple deletes
3. `clear_all_items()` - bulk delete should be transactional
4. `reorder_queue()` - multiple position updates must be atomic
---
### Step 5: Update API Endpoints
**Files**: `src/server/api/anime.py`, `src/server/api/downloads.py`, `src/server/api/auth.py`
Review and update endpoints that perform multiple database operations:
1. Identify endpoints calling multiple service methods
2. Wrap in transaction boundary at the endpoint level OR ensure services handle it
3. Prefer service-level transactions over endpoint-level for reusability
---
### Step 6: Add Unit Tests
**File**: `tests/unit/test_transactions.py`
Create comprehensive tests:
1. **Test successful transaction commit** - verify all changes persisted
2. **Test rollback on exception** - verify no partial writes
3. **Test nested transaction with savepoint** - verify partial rollback works
4. **Test decorator with sync function**
5. **Test decorator with async function**
6. **Test context manager usage**
7. **Test transaction propagation modes**
**File**: `tests/unit/test_service_transactions.py`
1. Test each service's compound operations for atomicity
2. Mock exceptions mid-operation to verify rollback
3. Verify no orphaned data after failed operations
---
### Step 7: Update Integration Tests
**File**: `tests/integration/test_db_transactions.py`
1. Test real database transaction behavior
2. Test concurrent transaction handling
3. Test transaction isolation levels if applicable
---
### Step 7: Update Dokumentation
1. Check Docs folder and updated the needed files
---
### Implementation Notes
- **SQLAlchemy Pattern**: Use `session.begin_nested()` for savepoints
- **Error Handling**: Always log transaction failures with full context
- **Performance**: Transactions have overhead - don't wrap single operations unnecessarily
- **Testing**: Use `session.rollback()` in test fixtures to ensure clean state
### Files to Modify
#### 1. `src/server/services/websocket_service.py`
Add three new broadcast methods for scan events:
- **broadcast_scan_started**: Notify clients that a scan has begun, include the root directory path
- **broadcast_scan_progress**: Send periodic updates with directories scanned count, files found count, and current directory name
- **broadcast_scan_completed**: Send final summary with total directories, total files, and elapsed time in seconds
Follow the existing pattern used by `broadcast_download_progress` for message structure consistency.
#### 2. `src/server/services/scanner_service.py`
Modify the scanning logic to emit progress via WebSocket:
- Inject `WebSocketService` dependency into the scanner service
- At scan start, call `broadcast_scan_started`
- During directory traversal, track directories scanned and files found
- Every 10 directories (to avoid WebSocket spam), call `broadcast_scan_progress`
- Track elapsed time using `time.time()`
- At scan completion, call `broadcast_scan_completed` with summary statistics
- Ensure the scan still works correctly even if WebSocket broadcast fails (wrap in try/except)
#### 3. `src/server/static/css/style.css`
Add styles for the scan progress overlay:
- Full-screen semi-transparent overlay (z-index high enough to be on top)
- Centered container with background matching theme (use CSS variables)
- Spinner animation using CSS keyframes
- Styling for current directory text (truncated with ellipsis)
- Styling for statistics display
- Success state styling for completion
- Ensure it works in both light and dark mode themes
#### 4. `src/server/static/js/anime.js`
Add WebSocket message handlers and UI functions:
- Handle `scan_started` message: Create and show progress overlay with spinner
- Handle `scan_progress` message: Update directory count, file count, and current directory text
- Handle `scan_completed` message: Show completion summary, then auto-remove overlay after 3 seconds
- Ensure overlay is properly cleaned up if page navigates away
- Update the existing rescan button handler to work with the new progress system
### WebSocket Message Types
Define three new message types following the existing project patterns:
1. **scan_started**: type, directory path, timestamp
2. **scan_progress**: type, directories_scanned, files_found, current_directory, timestamp
3. **scan_completed**: type, total_directories, total_files, elapsed_seconds, timestamp
### Implementation Steps
1. First modify `websocket_service.py` to add the three new broadcast methods
2. Add unit tests for the new WebSocket methods
3. Modify `scanner_service.py` to use the new broadcast methods during scanning
4. Add CSS styles to `style.css` for the progress overlay
5. Update `anime.js` to handle the new WebSocket messages and display the UI
6. Test the complete flow manually
7. Verify all existing tests still pass
### Testing Requirements
**Unit Tests:**
- Test each new WebSocket broadcast method
- Test that scanner service calls WebSocket methods at appropriate times
- Mock WebSocket service in scanner tests
**Manual Testing:**
- Start server and login
- Navigate to anime page
- Click "Rescan Library" button
- Verify overlay appears immediately with spinner
- Verify counters update during scan
- Verify current directory updates
- Verify completion summary appears
- Verify overlay auto-dismisses after 3 seconds
- Test in both light and dark mode
- Verify no JavaScript console errors
| File | Action |
| ------------------------------------------- | ------------------------------------------ |
| `src/server/database/transaction.py` | CREATE - New transaction utilities |
| `src/server/database/connection.py` | MODIFY - Add transactional session support |
| `src/server/database/service.py` | MODIFY - Apply @transactional decorator |
| `src/server/services/queue_repository.py` | MODIFY - Ensure atomic operations |
| `src/server/api/anime.py` | REVIEW - Check for multi-write endpoints |
| `src/server/api/downloads.py` | REVIEW - Check for multi-write endpoints |
| `src/server/api/auth.py` | REVIEW - Check for multi-write endpoints |
| `tests/unit/test_transactions.py` | CREATE - Transaction unit tests |
| `tests/unit/test_service_transactions.py` | CREATE - Service transaction tests |
| `tests/integration/test_db_transactions.py` | CREATE - Integration tests |
### Acceptance Criteria
- [x] Progress overlay appears immediately when scan starts
- [x] Spinner animation is visible during scanning
- [x] Directory counter updates periodically (every ~10 directories)
- [x] Files found counter updates as MP4 files are discovered
- [x] Current directory name is displayed (truncated if path is too long)
- [x] Scan completion shows total directories, files, and elapsed time
- [x] Overlay auto-dismisses 3 seconds after completion
- [x] Works correctly in both light and dark mode
- [x] No JavaScript errors in browser console
- [x] All existing tests continue to pass
- [x] New unit tests added and passing
- [x] All database write operations use explicit transactions
- [x] Compound operations are atomic (all-or-nothing)
- [x] Exceptions trigger proper rollback
- [x] No partial writes occur on failures
- [x] All existing tests pass (1090 tests passing)
- [x] New transaction tests pass with >90% coverage (90% achieved)
- [x] Logging captures transaction lifecycle events
- [x] Documentation updated in DATABASE.md
- [x] Code follows project coding standards
### Edge Cases to Handle
- Empty directory with no MP4 files
- Very large directory structure (ensure UI remains responsive)
- WebSocket connection lost during scan (scan should still complete)
- User navigates away during scan (cleanup overlay properly)
- Rapid consecutive scan requests (debounce or queue)
### Notes
- Keep progress updates throttled to avoid overwhelming the WebSocket connection
- Use existing CSS variables for colors to maintain theme consistency
- Follow existing JavaScript patterns in the codebase
- The scan functionality must continue to work even if WebSocket fails
---

View File

@ -7,7 +7,11 @@ Functions:
- init_db: Initialize database engine and create tables
- close_db: Close database connections and cleanup
- get_db_session: FastAPI dependency for database sessions
- get_transactional_session: Session without auto-commit for transactions
- get_engine: Get database engine instance
Classes:
- TransactionManager: Helper class for manual transaction control
"""
from __future__ import annotations
@ -296,3 +300,275 @@ def get_async_session_factory() -> AsyncSession:
)
return _session_factory()
@asynccontextmanager
async def get_transactional_session() -> AsyncGenerator[AsyncSession, None]:
"""Get a database session without auto-commit for explicit transaction control.
Unlike get_db_session(), this does NOT auto-commit on success.
Use this when you need explicit transaction control with the
@transactional decorator or atomic() context manager.
Yields:
AsyncSession: Database session for async operations
Raises:
RuntimeError: If database is not initialized
Example:
async with get_transactional_session() as session:
async with atomic(session) as tx:
# Multiple operations in transaction
await operation1(session)
await operation2(session)
# Committed when exiting atomic() context
"""
if _session_factory is None:
raise RuntimeError(
"Database not initialized. Call init_db() first."
)
session = _session_factory()
try:
yield session
except Exception:
await session.rollback()
raise
finally:
await session.close()
class TransactionManager:
"""Helper class for manual transaction control.
Provides a cleaner interface for managing transactions across
multiple service calls within a single request.
Attributes:
_session_factory: Factory for creating new sessions
_session: Current active session
_in_transaction: Whether currently in a transaction
Example:
async with TransactionManager() as tm:
session = await tm.get_session()
await tm.begin()
try:
await service1.operation(session)
await service2.operation(session)
await tm.commit()
except Exception:
await tm.rollback()
raise
"""
def __init__(
self,
session_factory: Optional[async_sessionmaker] = None
) -> None:
"""Initialize transaction manager.
Args:
session_factory: Optional custom session factory.
Uses global factory if not provided.
"""
self._session_factory = session_factory or _session_factory
self._session: Optional[AsyncSession] = None
self._in_transaction = False
if self._session_factory is None:
raise RuntimeError(
"Database not initialized. Call init_db() first."
)
async def __aenter__(self) -> "TransactionManager":
"""Enter context manager and create session."""
self._session = self._session_factory()
logger.debug("TransactionManager: Created new session")
return self
async def __aexit__(
self,
exc_type: Optional[type],
exc_val: Optional[BaseException],
exc_tb: Optional[object],
) -> bool:
"""Exit context manager and cleanup session.
Automatically rolls back if an exception occurred and
transaction wasn't explicitly committed.
"""
if self._session:
if exc_type is not None and self._in_transaction:
logger.warning(
"TransactionManager: Rolling back due to exception: %s",
exc_val,
)
await self._session.rollback()
await self._session.close()
self._session = None
self._in_transaction = False
logger.debug("TransactionManager: Session closed")
return False
async def get_session(self) -> AsyncSession:
"""Get the current session.
Returns:
Current AsyncSession instance
Raises:
RuntimeError: If not within context manager
"""
if self._session is None:
raise RuntimeError(
"TransactionManager must be used as async context manager"
)
return self._session
async def begin(self) -> None:
"""Begin a new transaction.
Raises:
RuntimeError: If already in a transaction or no session
"""
if self._session is None:
raise RuntimeError("No active session")
if self._in_transaction:
raise RuntimeError("Already in a transaction")
await self._session.begin()
self._in_transaction = True
logger.debug("TransactionManager: Transaction started")
async def commit(self) -> None:
"""Commit the current transaction.
Raises:
RuntimeError: If not in a transaction
"""
if not self._in_transaction or self._session is None:
raise RuntimeError("Not in a transaction")
await self._session.commit()
self._in_transaction = False
logger.debug("TransactionManager: Transaction committed")
async def rollback(self) -> None:
"""Rollback the current transaction.
Raises:
RuntimeError: If not in a transaction
"""
if self._session is None:
raise RuntimeError("No active session")
await self._session.rollback()
self._in_transaction = False
logger.debug("TransactionManager: Transaction rolled back")
async def savepoint(self, name: Optional[str] = None) -> "SavepointHandle":
"""Create a savepoint within the current transaction.
Args:
name: Optional savepoint name
Returns:
SavepointHandle for controlling the savepoint
Raises:
RuntimeError: If not in a transaction
"""
if not self._in_transaction or self._session is None:
raise RuntimeError("Must be in a transaction to create savepoint")
nested = await self._session.begin_nested()
return SavepointHandle(nested, name or "unnamed")
def is_in_transaction(self) -> bool:
"""Check if currently in a transaction.
Returns:
True if in an active transaction
"""
return self._in_transaction
def get_transaction_depth(self) -> int:
"""Get current transaction nesting depth.
Returns:
0 if not in transaction, 1+ for nested transactions
"""
if not self._in_transaction:
return 0
return 1 # Basic implementation - could be extended
class SavepointHandle:
"""Handle for controlling a database savepoint.
Attributes:
_nested: SQLAlchemy nested transaction
_name: Savepoint name for logging
_released: Whether savepoint has been released
"""
def __init__(self, nested: object, name: str) -> None:
"""Initialize savepoint handle.
Args:
nested: SQLAlchemy nested transaction object
name: Savepoint name
"""
self._nested = nested
self._name = name
self._released = False
logger.debug("Created savepoint: %s", name)
async def rollback(self) -> None:
"""Rollback to this savepoint."""
if not self._released:
await self._nested.rollback()
self._released = True
logger.debug("Rolled back savepoint: %s", self._name)
async def release(self) -> None:
"""Release (commit) this savepoint."""
if not self._released:
# Nested transactions commit automatically in SQLAlchemy
self._released = True
logger.debug("Released savepoint: %s", self._name)
def is_session_in_transaction(session: AsyncSession | Session) -> bool:
"""Check if a session is currently in a transaction.
Args:
session: SQLAlchemy session (sync or async)
Returns:
True if session is in an active transaction
"""
return session.in_transaction()
def get_session_transaction_depth(session: AsyncSession | Session) -> int:
"""Get the transaction nesting depth of a session.
Args:
session: SQLAlchemy session (sync or async)
Returns:
Number of nested transactions (0 if not in transaction)
"""
if not session.in_transaction():
return 0
# Check for nested transaction state
# Note: SQLAlchemy doesn't directly expose nesting depth
return 1

View File

@ -9,6 +9,15 @@ Services:
- DownloadQueueService: CRUD operations for download queue
- UserSessionService: CRUD operations for user sessions
Transaction Support:
All services are designed to work within transaction boundaries.
Individual operations use flush() instead of commit() to allow
the caller to control transaction boundaries.
For compound operations spanning multiple services, use the
@transactional decorator or atomic() context manager from
src.server.database.transaction.
All services support both async and sync operations for flexibility.
"""
from __future__ import annotations
@ -438,6 +447,51 @@ class EpisodeService:
)
return deleted
@staticmethod
async def bulk_mark_downloaded(
db: AsyncSession,
episode_ids: List[int],
file_paths: Optional[List[str]] = None,
) -> int:
"""Mark multiple episodes as downloaded atomically.
This operation should be wrapped in a transaction for atomicity.
All episodes will be updated or none if an error occurs.
Args:
db: Database session
episode_ids: List of episode primary keys to update
file_paths: Optional list of file paths (parallel to episode_ids)
Returns:
Number of episodes updated
Note:
Use within @transactional or atomic() for guaranteed atomicity:
async with atomic(db) as tx:
count = await EpisodeService.bulk_mark_downloaded(
db, episode_ids, file_paths
)
"""
if not episode_ids:
return 0
updated_count = 0
for i, episode_id in enumerate(episode_ids):
episode = await EpisodeService.get_by_id(db, episode_id)
if episode:
episode.is_downloaded = True
if file_paths and i < len(file_paths):
episode.file_path = file_paths[i]
updated_count += 1
await db.flush()
logger.info(f"Bulk marked {updated_count} episodes as downloaded")
return updated_count
# ============================================================================
# Download Queue Service
@ -448,6 +502,10 @@ class DownloadQueueService:
"""Service for download queue CRUD operations.
Provides methods for managing the download queue.
Transaction Support:
All operations use flush() for transaction-safe operation.
For bulk operations, use @transactional or atomic() context.
"""
@staticmethod
@ -623,6 +681,63 @@ class DownloadQueueService:
)
return deleted
@staticmethod
async def bulk_delete(
db: AsyncSession,
item_ids: List[int],
) -> int:
"""Delete multiple download queue items atomically.
This operation should be wrapped in a transaction for atomicity.
All items will be deleted or none if an error occurs.
Args:
db: Database session
item_ids: List of item primary keys to delete
Returns:
Number of items deleted
Note:
Use within @transactional or atomic() for guaranteed atomicity:
async with atomic(db) as tx:
count = await DownloadQueueService.bulk_delete(db, item_ids)
"""
if not item_ids:
return 0
result = await db.execute(
delete(DownloadQueueItem).where(
DownloadQueueItem.id.in_(item_ids)
)
)
count = result.rowcount
logger.info(f"Bulk deleted {count} download queue items")
return count
@staticmethod
async def clear_all(
db: AsyncSession,
) -> int:
"""Clear all download queue items.
Deletes all items from the download queue. This operation
should be wrapped in a transaction.
Args:
db: Database session
Returns:
Number of items deleted
"""
result = await db.execute(delete(DownloadQueueItem))
count = result.rowcount
logger.info(f"Cleared all {count} download queue items")
return count
# ============================================================================
# User Session Service
@ -633,6 +748,10 @@ class UserSessionService:
"""Service for user session CRUD operations.
Provides methods for managing user authentication sessions with JWT tokens.
Transaction Support:
Session rotation and cleanup operations should use transactions
for atomicity when multiple sessions are involved.
"""
@staticmethod
@ -764,6 +883,9 @@ class UserSessionService:
async def cleanup_expired(db: AsyncSession) -> int:
"""Clean up expired sessions.
This is a bulk delete operation that should be wrapped in
a transaction for atomicity when multiple sessions are deleted.
Args:
db: Database session
@ -778,3 +900,66 @@ class UserSessionService:
count = result.rowcount
logger.info(f"Cleaned up {count} expired sessions")
return count
@staticmethod
async def rotate_session(
db: AsyncSession,
old_session_id: str,
new_session_id: str,
new_token_hash: str,
new_expires_at: datetime,
user_id: Optional[str] = None,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
) -> Optional[UserSession]:
"""Rotate a session by revoking old and creating new atomically.
This compound operation revokes the old session and creates a new
one. Should be wrapped in a transaction for atomicity.
Args:
db: Database session
old_session_id: Session ID to revoke
new_session_id: New session ID
new_token_hash: New token hash
new_expires_at: New expiration time
user_id: Optional user identifier
ip_address: Optional client IP
user_agent: Optional user agent
Returns:
New UserSession instance, or None if old session not found
Note:
Use within @transactional or atomic() for atomicity:
async with atomic(db) as tx:
new_session = await UserSessionService.rotate_session(
db, old_id, new_id, hash, expires
)
"""
# Revoke old session
old_revoked = await UserSessionService.revoke(db, old_session_id)
if not old_revoked:
logger.warning(
f"Could not rotate: old session {old_session_id} not found"
)
return None
# Create new session
new_session = await UserSessionService.create(
db=db,
session_id=new_session_id,
token_hash=new_token_hash,
expires_at=new_expires_at,
user_id=user_id,
ip_address=ip_address,
user_agent=user_agent,
)
logger.info(
f"Rotated session: {old_session_id} -> {new_session_id}"
)
return new_session

View File

@ -0,0 +1,715 @@
"""Transaction management utilities for SQLAlchemy.
This module provides transaction management utilities including decorators,
context managers, and helper functions for ensuring data consistency
across database operations.
Components:
- @transactional decorator: Wraps functions in transaction boundaries
- TransactionContext: Sync context manager for explicit transaction control
- atomic(): Async context manager for async operations
- TransactionPropagation: Enum for transaction propagation modes
Usage:
@transactional
async def compound_operation(session: AsyncSession, data: Model) -> Result:
# Multiple write operations here
# All succeed or all fail
pass
async with atomic(session) as tx:
# Operations here
async with tx.savepoint() as sp:
# Nested operations with partial rollback capability
pass
"""
from __future__ import annotations
import functools
import logging
from contextlib import asynccontextmanager, contextmanager
from enum import Enum
from typing import (
Any,
AsyncGenerator,
Callable,
Generator,
Optional,
ParamSpec,
TypeVar,
)
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
logger = logging.getLogger(__name__)
# Type variables for generic typing
T = TypeVar("T")
P = ParamSpec("P")
class TransactionPropagation(Enum):
"""Transaction propagation behavior options.
Defines how transactions should behave when called within
an existing transaction context.
Values:
REQUIRED: Use existing transaction or create new one (default)
REQUIRES_NEW: Always create a new transaction (suspend existing)
NESTED: Create a savepoint within existing transaction
"""
REQUIRED = "required"
REQUIRES_NEW = "requires_new"
NESTED = "nested"
class TransactionError(Exception):
"""Exception raised for transaction-related errors."""
class TransactionContext:
"""Synchronous context manager for explicit transaction control.
Provides a clean interface for managing database transactions with
automatic commit/rollback semantics and savepoint support.
Attributes:
session: SQLAlchemy Session instance
_savepoint_count: Counter for nested savepoints
Example:
with TransactionContext(session) as tx:
# Database operations here
with tx.savepoint() as sp:
# Nested operations with partial rollback
pass
"""
def __init__(self, session: Session) -> None:
"""Initialize transaction context.
Args:
session: SQLAlchemy sync session
"""
self.session = session
self._savepoint_count = 0
self._committed = False
def __enter__(self) -> "TransactionContext":
"""Enter transaction context.
Begins a new transaction if not already in one.
Returns:
Self for context manager protocol
"""
logger.debug("Entering transaction context")
# Check if session is already in a transaction
if not self.session.in_transaction():
self.session.begin()
logger.debug("Started new transaction")
return self
def __exit__(
self,
exc_type: Optional[type],
exc_val: Optional[BaseException],
exc_tb: Optional[Any],
) -> bool:
"""Exit transaction context.
Commits on success, rolls back on exception.
Args:
exc_type: Exception type if raised
exc_val: Exception value if raised
exc_tb: Exception traceback if raised
Returns:
False to propagate exceptions
"""
if exc_type is not None:
logger.warning(
"Transaction rollback due to exception: %s: %s",
exc_type.__name__,
exc_val,
)
self.session.rollback()
return False
if not self._committed:
self.session.commit()
logger.debug("Transaction committed")
self._committed = True
return False
@contextmanager
def savepoint(self, name: Optional[str] = None) -> Generator["SavepointContext", None, None]:
"""Create a savepoint for partial rollback capability.
Savepoints allow nested transactions where inner operations
can be rolled back without affecting outer operations.
Args:
name: Optional savepoint name (auto-generated if not provided)
Yields:
SavepointContext for nested transaction control
Example:
with tx.savepoint() as sp:
# Operations here can be rolled back independently
if error_condition:
sp.rollback()
"""
self._savepoint_count += 1
savepoint_name = name or f"sp_{self._savepoint_count}"
logger.debug("Creating savepoint: %s", savepoint_name)
nested = self.session.begin_nested()
sp_context = SavepointContext(nested, savepoint_name)
try:
yield sp_context
if not sp_context._rolled_back:
# Commit the savepoint (release it)
logger.debug("Releasing savepoint: %s", savepoint_name)
except Exception as e:
if not sp_context._rolled_back:
logger.warning(
"Rolling back savepoint %s due to exception: %s",
savepoint_name,
e,
)
nested.rollback()
raise
def commit(self) -> None:
"""Explicitly commit the transaction.
Use this for early commit within the context.
"""
if not self._committed:
self.session.commit()
self._committed = True
logger.debug("Transaction explicitly committed")
def rollback(self) -> None:
"""Explicitly rollback the transaction.
Use this for early rollback within the context.
"""
self.session.rollback()
self._committed = True # Prevent double commit
logger.debug("Transaction explicitly rolled back")
class SavepointContext:
"""Context for managing a database savepoint.
Provides explicit control over savepoint commit/rollback.
Attributes:
_nested: SQLAlchemy nested transaction object
_name: Savepoint name for logging
_rolled_back: Whether rollback has been called
"""
def __init__(self, nested: Any, name: str) -> None:
"""Initialize savepoint context.
Args:
nested: SQLAlchemy nested transaction
name: Savepoint name for logging
"""
self._nested = nested
self._name = name
self._rolled_back = False
def rollback(self) -> None:
"""Rollback to this savepoint.
Undoes all changes since the savepoint was created.
"""
if not self._rolled_back:
self._nested.rollback()
self._rolled_back = True
logger.debug("Savepoint %s rolled back", self._name)
def commit(self) -> None:
"""Commit (release) this savepoint.
Makes changes since the savepoint permanent within
the parent transaction.
"""
if not self._rolled_back:
# SQLAlchemy commits nested transactions automatically
# when exiting the context without rollback
logger.debug("Savepoint %s committed", self._name)
class AsyncTransactionContext:
"""Asynchronous context manager for explicit transaction control.
Provides async interface for managing database transactions with
automatic commit/rollback semantics and savepoint support.
Attributes:
session: SQLAlchemy AsyncSession instance
_savepoint_count: Counter for nested savepoints
Example:
async with AsyncTransactionContext(session) as tx:
# Database operations here
async with tx.savepoint() as sp:
# Nested operations with partial rollback
pass
"""
def __init__(self, session: AsyncSession) -> None:
"""Initialize async transaction context.
Args:
session: SQLAlchemy async session
"""
self.session = session
self._savepoint_count = 0
self._committed = False
async def __aenter__(self) -> "AsyncTransactionContext":
"""Enter async transaction context.
Begins a new transaction if not already in one.
Returns:
Self for context manager protocol
"""
logger.debug("Entering async transaction context")
# Check if session is already in a transaction
if not self.session.in_transaction():
await self.session.begin()
logger.debug("Started new async transaction")
return self
async def __aexit__(
self,
exc_type: Optional[type],
exc_val: Optional[BaseException],
exc_tb: Optional[Any],
) -> bool:
"""Exit async transaction context.
Commits on success, rolls back on exception.
Args:
exc_type: Exception type if raised
exc_val: Exception value if raised
exc_tb: Exception traceback if raised
Returns:
False to propagate exceptions
"""
if exc_type is not None:
logger.warning(
"Async transaction rollback due to exception: %s: %s",
exc_type.__name__,
exc_val,
)
await self.session.rollback()
return False
if not self._committed:
await self.session.commit()
logger.debug("Async transaction committed")
self._committed = True
return False
@asynccontextmanager
async def savepoint(
self, name: Optional[str] = None
) -> AsyncGenerator["AsyncSavepointContext", None]:
"""Create an async savepoint for partial rollback capability.
Args:
name: Optional savepoint name (auto-generated if not provided)
Yields:
AsyncSavepointContext for nested transaction control
"""
self._savepoint_count += 1
savepoint_name = name or f"sp_{self._savepoint_count}"
logger.debug("Creating async savepoint: %s", savepoint_name)
nested = await self.session.begin_nested()
sp_context = AsyncSavepointContext(nested, savepoint_name, self.session)
try:
yield sp_context
if not sp_context._rolled_back:
logger.debug("Releasing async savepoint: %s", savepoint_name)
except Exception as e:
if not sp_context._rolled_back:
logger.warning(
"Rolling back async savepoint %s due to exception: %s",
savepoint_name,
e,
)
await nested.rollback()
raise
async def commit(self) -> None:
"""Explicitly commit the async transaction."""
if not self._committed:
await self.session.commit()
self._committed = True
logger.debug("Async transaction explicitly committed")
async def rollback(self) -> None:
"""Explicitly rollback the async transaction."""
await self.session.rollback()
self._committed = True # Prevent double commit
logger.debug("Async transaction explicitly rolled back")
class AsyncSavepointContext:
"""Async context for managing a database savepoint.
Attributes:
_nested: SQLAlchemy nested transaction object
_name: Savepoint name for logging
_session: Parent session for async operations
_rolled_back: Whether rollback has been called
"""
def __init__(
self, nested: Any, name: str, session: AsyncSession
) -> None:
"""Initialize async savepoint context.
Args:
nested: SQLAlchemy nested transaction
name: Savepoint name for logging
session: Parent async session
"""
self._nested = nested
self._name = name
self._session = session
self._rolled_back = False
async def rollback(self) -> None:
"""Rollback to this savepoint asynchronously."""
if not self._rolled_back:
await self._nested.rollback()
self._rolled_back = True
logger.debug("Async savepoint %s rolled back", self._name)
async def commit(self) -> None:
"""Commit (release) this savepoint asynchronously."""
if not self._rolled_back:
logger.debug("Async savepoint %s committed", self._name)
@asynccontextmanager
async def atomic(
session: AsyncSession,
propagation: TransactionPropagation = TransactionPropagation.REQUIRED,
) -> AsyncGenerator[AsyncTransactionContext, None]:
"""Async context manager for atomic database operations.
Provides a clean interface for wrapping database operations in
a transaction boundary with automatic commit/rollback.
Args:
session: SQLAlchemy async session
propagation: Transaction propagation behavior
Yields:
AsyncTransactionContext for transaction control
Example:
async with atomic(session) as tx:
await some_operation(session)
await another_operation(session)
# All operations committed together or rolled back
async with atomic(session) as tx:
await outer_operation(session)
async with tx.savepoint() as sp:
await risky_operation(session)
if error:
await sp.rollback() # Only rollback nested ops
"""
logger.debug(
"Starting atomic block with propagation: %s",
propagation.value,
)
if propagation == TransactionPropagation.NESTED:
# Use savepoint for nested propagation
if session.in_transaction():
nested = await session.begin_nested()
sp_context = AsyncSavepointContext(nested, "atomic_nested", session)
try:
# Create a wrapper context for consistency
wrapper = AsyncTransactionContext(session)
wrapper._committed = True # Parent manages commit
yield wrapper
if not sp_context._rolled_back:
logger.debug("Releasing nested atomic savepoint")
except Exception as e:
if not sp_context._rolled_back:
logger.warning(
"Rolling back nested atomic savepoint due to: %s", e
)
await nested.rollback()
raise
else:
# No existing transaction, start new one
async with AsyncTransactionContext(session) as tx:
yield tx
else:
# REQUIRED or REQUIRES_NEW
async with AsyncTransactionContext(session) as tx:
yield tx
@contextmanager
def atomic_sync(
session: Session,
propagation: TransactionPropagation = TransactionPropagation.REQUIRED,
) -> Generator[TransactionContext, None, None]:
"""Sync context manager for atomic database operations.
Args:
session: SQLAlchemy sync session
propagation: Transaction propagation behavior
Yields:
TransactionContext for transaction control
"""
logger.debug(
"Starting sync atomic block with propagation: %s",
propagation.value,
)
if propagation == TransactionPropagation.NESTED:
if session.in_transaction():
nested = session.begin_nested()
sp_context = SavepointContext(nested, "atomic_nested")
try:
wrapper = TransactionContext(session)
wrapper._committed = True
yield wrapper
if not sp_context._rolled_back:
logger.debug("Releasing nested sync atomic savepoint")
except Exception as e:
if not sp_context._rolled_back:
logger.warning(
"Rolling back nested sync savepoint due to: %s", e
)
nested.rollback()
raise
else:
with TransactionContext(session) as tx:
yield tx
else:
with TransactionContext(session) as tx:
yield tx
def transactional(
propagation: TransactionPropagation = TransactionPropagation.REQUIRED,
session_param: str = "db",
) -> Callable[[Callable[P, T]], Callable[P, T]]:
"""Decorator to wrap a function in a transaction boundary.
Automatically handles commit on success and rollback on exception.
Works with both sync and async functions.
Args:
propagation: Transaction propagation behavior
session_param: Name of the session parameter in the function signature
Returns:
Decorated function wrapped in transaction
Example:
@transactional()
async def create_user_with_profile(db: AsyncSession, data: dict):
user = await create_user(db, data['user'])
profile = await create_profile(db, user.id, data['profile'])
return user, profile
@transactional(propagation=TransactionPropagation.NESTED)
async def risky_sub_operation(db: AsyncSession, data: dict):
# This can be rolled back without affecting parent transaction
pass
"""
def decorator(func: Callable[P, T]) -> Callable[P, T]:
import asyncio
if asyncio.iscoroutinefunction(func):
@functools.wraps(func)
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
# Get session from kwargs or args
session = _extract_session(func, args, kwargs, session_param)
if session is None:
raise TransactionError(
f"Could not find session parameter '{session_param}' "
f"in function {func.__name__}"
)
logger.debug(
"Starting transaction for %s with propagation %s",
func.__name__,
propagation.value,
)
async with atomic(session, propagation):
result = await func(*args, **kwargs)
logger.debug(
"Transaction completed for %s",
func.__name__,
)
return result
return async_wrapper # type: ignore
else:
@functools.wraps(func)
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
# Get session from kwargs or args
session = _extract_session(func, args, kwargs, session_param)
if session is None:
raise TransactionError(
f"Could not find session parameter '{session_param}' "
f"in function {func.__name__}"
)
logger.debug(
"Starting sync transaction for %s with propagation %s",
func.__name__,
propagation.value,
)
with atomic_sync(session, propagation):
result = func(*args, **kwargs)
logger.debug(
"Sync transaction completed for %s",
func.__name__,
)
return result
return sync_wrapper # type: ignore
return decorator
def _extract_session(
func: Callable,
args: tuple,
kwargs: dict,
session_param: str,
) -> Optional[AsyncSession | Session]:
"""Extract session from function arguments.
Args:
func: The function being called
args: Positional arguments
kwargs: Keyword arguments
session_param: Name of the session parameter
Returns:
Session instance or None if not found
"""
import inspect
# Check kwargs first
if session_param in kwargs:
return kwargs[session_param]
# Get function signature to find positional index
sig = inspect.signature(func)
params = list(sig.parameters.keys())
if session_param in params:
idx = params.index(session_param)
# Account for 'self' parameter in methods
if len(args) > idx:
return args[idx]
return None
def is_in_transaction(session: AsyncSession | Session) -> bool:
"""Check if session is currently in a transaction.
Args:
session: SQLAlchemy session (sync or async)
Returns:
True if session is in an active transaction
"""
return session.in_transaction()
def get_transaction_depth(session: AsyncSession | Session) -> int:
"""Get the current transaction nesting depth.
Args:
session: SQLAlchemy session (sync or async)
Returns:
Number of nested transactions (0 if not in transaction)
"""
# SQLAlchemy doesn't expose nesting depth directly,
# but we can check transaction state
if not session.in_transaction():
return 0
# Check for nested transaction
if hasattr(session, '_nested_transaction') and session._nested_transaction:
return 2 # At least one savepoint
return 1
__all__ = [
"TransactionPropagation",
"TransactionError",
"TransactionContext",
"AsyncTransactionContext",
"SavepointContext",
"AsyncSavepointContext",
"atomic",
"atomic_sync",
"transactional",
"is_in_transaction",
"get_transaction_depth",
]

View File

@ -6,6 +6,11 @@ and provides the interface needed by DownloadService for queue persistence.
The repository pattern abstracts the database operations from the business
logic, allowing the DownloadService to work with domain models (DownloadItem)
while the repository handles conversion to/from database models.
Transaction Support:
Compound operations (save_item, clear_all) are wrapped in atomic()
context managers to ensure all-or-nothing behavior. If any part of
a compound operation fails, all changes are rolled back.
"""
from __future__ import annotations
@ -21,6 +26,7 @@ from src.server.database.service import (
DownloadQueueService,
EpisodeService,
)
from src.server.database.transaction import atomic
from src.server.models.download import (
DownloadItem,
DownloadPriority,
@ -45,6 +51,10 @@ class QueueRepository:
Note: The database model (DownloadQueueItem) is simplified and only
stores episode_id as a foreign key. Status, priority, progress, and
retry_count are managed in-memory by the DownloadService.
Transaction Support:
All compound operations are wrapped in atomic() transactions.
This ensures data consistency even if operations fail mid-way.
Attributes:
_db_session_factory: Factory function to create database sessions
@ -119,9 +129,12 @@ class QueueRepository:
item: DownloadItem,
db: Optional[AsyncSession] = None,
) -> DownloadItem:
"""Save a download item to the database.
"""Save a download item to the database atomically.
Creates a new record if the item doesn't exist in the database.
This compound operation (series lookup/create, episode lookup/create,
queue item create) is wrapped in a transaction for atomicity.
Note: Status, priority, progress, and retry_count are NOT persisted.
Args:
@ -138,60 +151,62 @@ class QueueRepository:
manage_session = db is None
try:
# Find series by key
series = await AnimeSeriesService.get_by_key(session, item.serie_id)
async with atomic(session):
# Find series by key
series = await AnimeSeriesService.get_by_key(session, item.serie_id)
if not series:
# Create series if it doesn't exist
series = await AnimeSeriesService.create(
db=session,
key=item.serie_id,
name=item.serie_name,
site="", # Will be updated later if needed
folder=item.serie_folder,
)
logger.info(
"Created new series for queue item: key=%s, name=%s",
item.serie_id,
item.serie_name,
)
if not series:
# Create series if it doesn't exist
# Use a placeholder site URL - will be updated later when actual URL is known
site_url = getattr(item, 'serie_site', None) or f"https://aniworld.to/anime/{item.serie_id}"
series = await AnimeSeriesService.create(
db=session,
key=item.serie_id,
name=item.serie_name,
site=site_url,
folder=item.serie_folder,
)
logger.info(
"Created new series for queue item: key=%s, name=%s",
item.serie_id,
item.serie_name,
)
# Find or create episode
episode = await EpisodeService.get_by_episode(
session,
series.id,
item.episode.season,
item.episode.episode,
)
if not episode:
# Create episode if it doesn't exist
episode = await EpisodeService.create(
db=session,
series_id=series.id,
season=item.episode.season,
episode_number=item.episode.episode,
title=item.episode.title,
)
logger.info(
"Created new episode for queue item: S%02dE%02d",
# Find or create episode
episode = await EpisodeService.get_by_episode(
session,
series.id,
item.episode.season,
item.episode.episode,
)
# Create queue item
db_item = await DownloadQueueService.create(
db=session,
series_id=series.id,
episode_id=episode.id,
download_url=str(item.source_url) if item.source_url else None,
)
if not episode:
# Create episode if it doesn't exist
episode = await EpisodeService.create(
db=session,
series_id=series.id,
season=item.episode.season,
episode_number=item.episode.episode,
title=item.episode.title,
)
logger.info(
"Created new episode for queue item: S%02dE%02d",
item.episode.season,
item.episode.episode,
)
if manage_session:
await session.commit()
# Create queue item
db_item = await DownloadQueueService.create(
db=session,
series_id=series.id,
episode_id=episode.id,
download_url=str(item.source_url) if item.source_url else None,
)
# Update the item ID with the database ID
item.id = str(db_item.id)
# Update the item ID with the database ID
item.id = str(db_item.id)
# Transaction committed by atomic() context manager
logger.debug(
"Saved queue item to database: item_id=%s, serie_key=%s",
@ -202,8 +217,7 @@ class QueueRepository:
return item
except Exception as e:
if manage_session:
await session.rollback()
# Rollback handled by atomic() context manager
logger.error("Failed to save queue item: %s", e)
raise QueueRepositoryError(f"Failed to save item: {e}") from e
finally:
@ -383,7 +397,10 @@ class QueueRepository:
self,
db: Optional[AsyncSession] = None,
) -> int:
"""Clear all download items from the queue.
"""Clear all download items from the queue atomically.
This bulk delete operation is wrapped in a transaction.
Either all items are deleted or none are.
Args:
db: Optional existing database session
@ -398,23 +415,17 @@ class QueueRepository:
manage_session = db is None
try:
# Get all items first to count them
all_items = await DownloadQueueService.get_all(session)
count = len(all_items)
# Delete each item
for item in all_items:
await DownloadQueueService.delete(session, item.id)
if manage_session:
await session.commit()
async with atomic(session):
# Use the bulk clear operation for efficiency and atomicity
count = await DownloadQueueService.clear_all(session)
# Transaction committed by atomic() context manager
logger.info("Cleared all items from queue: count=%d", count)
return count
except Exception as e:
if manage_session:
await session.rollback()
# Rollback handled by atomic() context manager
logger.error("Failed to clear queue: %s", e)
raise QueueRepositoryError(f"Failed to clear queue: {e}") from e
finally:

View File

@ -0,0 +1,546 @@
"""Integration tests for database transaction behavior.
Tests real database transaction handling including:
- Transaction isolation
- Concurrent transaction handling
- Real commit/rollback behavior
"""
import asyncio
from datetime import datetime, timedelta, timezone
from typing import List
import pytest
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from src.server.database.base import Base
from src.server.database.connection import (
TransactionManager,
get_session_transaction_depth,
is_session_in_transaction,
)
from src.server.database.models import AnimeSeries, DownloadQueueItem, Episode
from src.server.database.service import (
AnimeSeriesService,
DownloadQueueService,
EpisodeService,
)
from src.server.database.transaction import (
TransactionPropagation,
atomic,
transactional,
)
# ============================================================================
# Fixtures
# ============================================================================
@pytest.fixture
async def db_engine():
"""Create in-memory database engine for testing."""
engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
echo=False,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest.fixture
async def session_factory(db_engine):
"""Create session factory for testing."""
from sqlalchemy.ext.asyncio import async_sessionmaker
return async_sessionmaker(
db_engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
autocommit=False,
)
@pytest.fixture
async def db_session(session_factory):
"""Create database session for testing."""
async with session_factory() as session:
yield session
await session.rollback()
# ============================================================================
# Real Database Transaction Tests
# ============================================================================
class TestRealDatabaseTransactions:
"""Tests using real in-memory database."""
@pytest.mark.asyncio
async def test_commit_persists_data(self, db_session):
"""Test that committed data is actually persisted."""
async with atomic(db_session):
series = await AnimeSeriesService.create(
db_session,
key="commit-test",
name="Commit Test Series",
site="https://test.com",
folder="/test/folder",
)
# Data should be retrievable after commit
retrieved = await AnimeSeriesService.get_by_key(
db_session, "commit-test"
)
assert retrieved is not None
assert retrieved.name == "Commit Test Series"
@pytest.mark.asyncio
async def test_rollback_discards_data(self, db_session):
"""Test that rolled back data is discarded."""
try:
async with atomic(db_session):
series = await AnimeSeriesService.create(
db_session,
key="rollback-test",
name="Rollback Test Series",
site="https://test.com",
folder="/test/folder",
)
await db_session.flush()
raise ValueError("Force rollback")
except ValueError:
pass
# Data should NOT be retrievable after rollback
retrieved = await AnimeSeriesService.get_by_key(
db_session, "rollback-test"
)
assert retrieved is None
@pytest.mark.asyncio
async def test_multiple_operations_atomic(self, db_session):
"""Test multiple operations are committed together."""
async with atomic(db_session):
# Create series
series = await AnimeSeriesService.create(
db_session,
key="atomic-multi-test",
name="Atomic Multi Test",
site="https://test.com",
folder="/test/folder",
)
# Create episode
episode = await EpisodeService.create(
db_session,
series_id=series.id,
season=1,
episode_number=1,
title="Episode 1",
)
# Create queue item
item = await DownloadQueueService.create(
db_session,
series_id=series.id,
episode_id=episode.id,
)
# All should be persisted
retrieved_series = await AnimeSeriesService.get_by_key(
db_session, "atomic-multi-test"
)
assert retrieved_series is not None
episodes = await EpisodeService.get_by_series(
db_session, retrieved_series.id
)
assert len(episodes) == 1
queue_items = await DownloadQueueService.get_all(db_session)
assert len(queue_items) >= 1
@pytest.mark.asyncio
async def test_multiple_operations_rollback_all(self, db_session):
"""Test multiple operations are all rolled back on failure."""
try:
async with atomic(db_session):
# Create series
series = await AnimeSeriesService.create(
db_session,
key="rollback-multi-test",
name="Rollback Multi Test",
site="https://test.com",
folder="/test/folder",
)
# Create episode
episode = await EpisodeService.create(
db_session,
series_id=series.id,
season=1,
episode_number=1,
)
# Create queue item
item = await DownloadQueueService.create(
db_session,
series_id=series.id,
episode_id=episode.id,
)
await db_session.flush()
raise RuntimeError("Force complete rollback")
except RuntimeError:
pass
# None should be persisted
retrieved_series = await AnimeSeriesService.get_by_key(
db_session, "rollback-multi-test"
)
assert retrieved_series is None
# ============================================================================
# Transaction Manager Tests
# ============================================================================
class TestTransactionManager:
"""Tests for TransactionManager class."""
@pytest.mark.asyncio
async def test_transaction_manager_basic_flow(self, session_factory):
"""Test basic transaction manager usage."""
async with TransactionManager(session_factory) as tm:
session = await tm.get_session()
await tm.begin()
series = AnimeSeries(
key="tm-test",
name="TM Test",
site="https://test.com",
folder="/test",
)
session.add(series)
await tm.commit()
# Verify data persisted
async with session_factory() as verify_session:
from sqlalchemy import select
result = await verify_session.execute(
select(AnimeSeries).where(AnimeSeries.key == "tm-test")
)
series = result.scalar_one_or_none()
assert series is not None
@pytest.mark.asyncio
async def test_transaction_manager_rollback(self, session_factory):
"""Test transaction manager rollback."""
async with TransactionManager(session_factory) as tm:
session = await tm.get_session()
await tm.begin()
series = AnimeSeries(
key="tm-rollback-test",
name="TM Rollback Test",
site="https://test.com",
folder="/test",
)
session.add(series)
await session.flush()
await tm.rollback()
# Verify data NOT persisted
async with session_factory() as verify_session:
from sqlalchemy import select
result = await verify_session.execute(
select(AnimeSeries).where(AnimeSeries.key == "tm-rollback-test")
)
series = result.scalar_one_or_none()
assert series is None
@pytest.mark.asyncio
async def test_transaction_manager_auto_rollback_on_exception(
self, session_factory
):
"""Test transaction manager auto-rolls back on exception."""
with pytest.raises(ValueError):
async with TransactionManager(session_factory) as tm:
session = await tm.get_session()
await tm.begin()
series = AnimeSeries(
key="tm-auto-rollback",
name="TM Auto Rollback",
site="https://test.com",
folder="/test",
)
session.add(series)
await session.flush()
raise ValueError("Force exception")
# Verify data NOT persisted
async with session_factory() as verify_session:
from sqlalchemy import select
result = await verify_session.execute(
select(AnimeSeries).where(AnimeSeries.key == "tm-auto-rollback")
)
series = result.scalar_one_or_none()
assert series is None
@pytest.mark.asyncio
async def test_transaction_manager_state_tracking(self, session_factory):
"""Test transaction manager tracks state correctly."""
async with TransactionManager(session_factory) as tm:
assert tm.is_in_transaction() is False
await tm.begin()
assert tm.is_in_transaction() is True
await tm.commit()
assert tm.is_in_transaction() is False
# ============================================================================
# Helper Function Tests
# ============================================================================
class TestConnectionHelpers:
"""Tests for connection module helper functions."""
@pytest.mark.asyncio
async def test_is_session_in_transaction(self, db_session):
"""Test is_session_in_transaction helper."""
# Initially not in transaction
assert is_session_in_transaction(db_session) is False
async with atomic(db_session):
# Now in transaction
assert is_session_in_transaction(db_session) is True
# After exit, depends on session state
# SQLite behavior may vary
@pytest.mark.asyncio
async def test_get_session_transaction_depth(self, db_session):
"""Test get_session_transaction_depth helper."""
depth = get_session_transaction_depth(db_session)
assert depth >= 0
# ============================================================================
# @transactional Decorator Integration Tests
# ============================================================================
class TestTransactionalDecoratorIntegration:
"""Integration tests for @transactional decorator."""
@pytest.mark.asyncio
async def test_decorated_function_commits(self, db_session):
"""Test decorated function commits on success."""
@transactional()
async def create_series_decorated(db: AsyncSession):
return await AnimeSeriesService.create(
db,
key="decorated-test",
name="Decorated Test",
site="https://test.com",
folder="/test",
)
series = await create_series_decorated(db=db_session)
# Verify committed
retrieved = await AnimeSeriesService.get_by_key(
db_session, "decorated-test"
)
assert retrieved is not None
@pytest.mark.asyncio
async def test_decorated_function_rollback(self, db_session):
"""Test decorated function rolls back on error."""
@transactional()
async def create_then_fail(db: AsyncSession):
await AnimeSeriesService.create(
db,
key="decorated-rollback",
name="Decorated Rollback",
site="https://test.com",
folder="/test",
)
raise ValueError("Force failure")
with pytest.raises(ValueError):
await create_then_fail(db=db_session)
# Verify NOT committed
retrieved = await AnimeSeriesService.get_by_key(
db_session, "decorated-rollback"
)
assert retrieved is None
@pytest.mark.asyncio
async def test_nested_decorated_functions(self, db_session):
"""Test nested decorated functions work correctly."""
@transactional(propagation=TransactionPropagation.NESTED)
async def inner_operation(db: AsyncSession, series_id: int):
return await EpisodeService.create(
db,
series_id=series_id,
season=1,
episode_number=1,
)
@transactional()
async def outer_operation(db: AsyncSession):
series = await AnimeSeriesService.create(
db,
key="nested-decorated",
name="Nested Decorated",
site="https://test.com",
folder="/test",
)
episode = await inner_operation(db=db, series_id=series.id)
return series, episode
series, episode = await outer_operation(db=db_session)
# Both should be committed
assert series is not None
assert episode is not None
# ============================================================================
# Concurrent Transaction Tests
# ============================================================================
class TestConcurrentTransactions:
"""Tests for concurrent transaction handling."""
@pytest.mark.asyncio
async def test_concurrent_writes_different_keys(self, session_factory):
"""Test concurrent writes to different records."""
async def create_series(key: str):
async with session_factory() as session:
async with atomic(session):
await AnimeSeriesService.create(
session,
key=key,
name=f"Series {key}",
site="https://test.com",
folder=f"/test/{key}",
)
# Run concurrent creates
await asyncio.gather(
create_series("concurrent-1"),
create_series("concurrent-2"),
create_series("concurrent-3"),
)
# Verify all created
async with session_factory() as verify_session:
for i in range(1, 4):
series = await AnimeSeriesService.get_by_key(
verify_session, f"concurrent-{i}"
)
assert series is not None
# ============================================================================
# Queue Repository Transaction Tests
# ============================================================================
class TestQueueRepositoryTransactions:
"""Integration tests for QueueRepository transaction handling."""
@pytest.mark.asyncio
async def test_save_item_atomic(self, session_factory):
"""Test save_item creates series, episode, and queue item atomically."""
from src.server.models.download import (
DownloadItem,
DownloadStatus,
EpisodeIdentifier,
)
from src.server.services.queue_repository import QueueRepository
repo = QueueRepository(session_factory)
item = DownloadItem(
id="temp-id",
serie_id="repo-atomic-test",
serie_folder="/test/folder",
serie_name="Repo Atomic Test",
episode=EpisodeIdentifier(season=1, episode=1),
status=DownloadStatus.PENDING,
)
saved_item = await repo.save_item(item)
assert saved_item.id != "temp-id" # Should have DB ID
# Verify all entities created
async with session_factory() as verify_session:
series = await AnimeSeriesService.get_by_key(
verify_session, "repo-atomic-test"
)
assert series is not None
episodes = await EpisodeService.get_by_series(
verify_session, series.id
)
assert len(episodes) == 1
queue_items = await DownloadQueueService.get_all(verify_session)
assert len(queue_items) >= 1
@pytest.mark.asyncio
async def test_clear_all_atomic(self, session_factory):
"""Test clear_all removes all items atomically."""
from src.server.models.download import (
DownloadItem,
DownloadStatus,
EpisodeIdentifier,
)
from src.server.services.queue_repository import QueueRepository
repo = QueueRepository(session_factory)
# Add some items
for i in range(3):
item = DownloadItem(
id=f"clear-{i}",
serie_id=f"clear-series-{i}",
serie_folder=f"/test/folder/{i}",
serie_name=f"Clear Series {i}",
episode=EpisodeIdentifier(season=1, episode=1),
status=DownloadStatus.PENDING,
)
await repo.save_item(item)
# Clear all
count = await repo.clear_all()
assert count == 3
# Verify all cleared
async with session_factory() as verify_session:
queue_items = await DownloadQueueService.get_all(verify_session)
assert len(queue_items) == 0

View File

@ -0,0 +1,546 @@
"""Unit tests for service layer transaction behavior.
Tests that service operations correctly handle transactions,
especially compound operations that require atomicity.
"""
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from src.server.database.base import Base
from src.server.database.models import (
AnimeSeries,
DownloadQueueItem,
Episode,
UserSession,
)
from src.server.database.service import (
AnimeSeriesService,
DownloadQueueService,
EpisodeService,
UserSessionService,
)
from src.server.database.transaction import atomic
# ============================================================================
# Fixtures
# ============================================================================
@pytest.fixture
async def db_engine():
"""Create in-memory database engine for testing."""
engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
echo=False,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest.fixture
async def db_session(db_engine):
"""Create database session for testing."""
from sqlalchemy.ext.asyncio import async_sessionmaker
async_session = async_sessionmaker(
db_engine,
class_=AsyncSession,
expire_on_commit=False,
)
async with async_session() as session:
yield session
await session.rollback()
# ============================================================================
# AnimeSeriesService Transaction Tests
# ============================================================================
class TestAnimeSeriesServiceTransactions:
"""Tests for AnimeSeriesService transaction behavior."""
@pytest.mark.asyncio
async def test_create_uses_flush_not_commit(self, db_session):
"""Test create uses flush for transaction compatibility."""
series = await AnimeSeriesService.create(
db_session,
key="test-key",
name="Test Series",
site="https://test.com",
folder="/test/folder",
)
# Series should exist in session
assert series.id is not None
# But not committed yet (we're in an uncommitted transaction)
# We can verify by checking the session's uncommitted state
assert series in db_session
@pytest.mark.asyncio
async def test_update_uses_flush_not_commit(self, db_session):
"""Test update uses flush for transaction compatibility."""
# Create series
series = await AnimeSeriesService.create(
db_session,
key="update-test",
name="Original Name",
site="https://test.com",
folder="/test/folder",
)
# Update series
updated = await AnimeSeriesService.update(
db_session,
series.id,
name="Updated Name",
)
assert updated.name == "Updated Name"
assert updated in db_session
# ============================================================================
# EpisodeService Transaction Tests
# ============================================================================
class TestEpisodeServiceTransactions:
"""Tests for EpisodeService transaction behavior."""
@pytest.mark.asyncio
async def test_bulk_mark_downloaded_atomicity(self, db_session):
"""Test bulk_mark_downloaded updates all or none."""
# Create series and episodes
series = await AnimeSeriesService.create(
db_session,
key="bulk-test-series",
name="Bulk Test",
site="https://test.com",
folder="/test/folder",
)
episodes = []
for i in range(1, 4):
ep = await EpisodeService.create(
db_session,
series_id=series.id,
season=1,
episode_number=i,
title=f"Episode {i}",
)
episodes.append(ep)
episode_ids = [ep.id for ep in episodes]
file_paths = [f"/path/ep{i}.mp4" for i in range(1, 4)]
# Bulk update within atomic context
async with atomic(db_session):
count = await EpisodeService.bulk_mark_downloaded(
db_session,
episode_ids,
file_paths,
)
assert count == 3
# Verify all episodes were marked
for i, ep_id in enumerate(episode_ids):
episode = await EpisodeService.get_by_id(db_session, ep_id)
assert episode.is_downloaded is True
assert episode.file_path == file_paths[i]
@pytest.mark.asyncio
async def test_bulk_mark_downloaded_empty_list(self, db_session):
"""Test bulk_mark_downloaded handles empty list."""
count = await EpisodeService.bulk_mark_downloaded(
db_session,
episode_ids=[],
)
assert count == 0
@pytest.mark.asyncio
async def test_delete_by_series_and_episode_transaction(self, db_session):
"""Test delete_by_series_and_episode in transaction."""
# Create series and episode
series = await AnimeSeriesService.create(
db_session,
key="delete-test-series",
name="Delete Test",
site="https://test.com",
folder="/test/folder",
)
await EpisodeService.create(
db_session,
series_id=series.id,
season=1,
episode_number=1,
title="Episode 1",
)
await db_session.commit()
# Delete episode within transaction
async with atomic(db_session):
deleted = await EpisodeService.delete_by_series_and_episode(
db_session,
series_key="delete-test-series",
season=1,
episode_number=1,
)
assert deleted is True
# Verify episode is gone
episode = await EpisodeService.get_by_episode(
db_session,
series.id,
season=1,
episode_number=1,
)
assert episode is None
# ============================================================================
# DownloadQueueService Transaction Tests
# ============================================================================
class TestDownloadQueueServiceTransactions:
"""Tests for DownloadQueueService transaction behavior."""
@pytest.mark.asyncio
async def test_bulk_delete_atomicity(self, db_session):
"""Test bulk_delete removes all or none."""
# Create series and episodes
series = await AnimeSeriesService.create(
db_session,
key="queue-bulk-test",
name="Queue Bulk Test",
site="https://test.com",
folder="/test/folder",
)
item_ids = []
for i in range(1, 4):
episode = await EpisodeService.create(
db_session,
series_id=series.id,
season=1,
episode_number=i,
)
item = await DownloadQueueService.create(
db_session,
series_id=series.id,
episode_id=episode.id,
)
item_ids.append(item.id)
# Bulk delete within atomic context
async with atomic(db_session):
count = await DownloadQueueService.bulk_delete(
db_session,
item_ids,
)
assert count == 3
# Verify all items deleted
all_items = await DownloadQueueService.get_all(db_session)
assert len(all_items) == 0
@pytest.mark.asyncio
async def test_bulk_delete_empty_list(self, db_session):
"""Test bulk_delete handles empty list."""
count = await DownloadQueueService.bulk_delete(
db_session,
item_ids=[],
)
assert count == 0
@pytest.mark.asyncio
async def test_clear_all_atomicity(self, db_session):
"""Test clear_all removes all items atomically."""
# Create series and queue items
series = await AnimeSeriesService.create(
db_session,
key="clear-all-test",
name="Clear All Test",
site="https://test.com",
folder="/test/folder",
)
for i in range(1, 4):
episode = await EpisodeService.create(
db_session,
series_id=series.id,
season=1,
episode_number=i,
)
await DownloadQueueService.create(
db_session,
series_id=series.id,
episode_id=episode.id,
)
# Clear all within atomic context
async with atomic(db_session):
count = await DownloadQueueService.clear_all(db_session)
assert count == 3
# Verify all items cleared
all_items = await DownloadQueueService.get_all(db_session)
assert len(all_items) == 0
# ============================================================================
# UserSessionService Transaction Tests
# ============================================================================
class TestUserSessionServiceTransactions:
"""Tests for UserSessionService transaction behavior."""
@pytest.mark.asyncio
async def test_rotate_session_atomicity(self, db_session):
"""Test rotate_session is atomic (revoke + create)."""
# Create old session
old_session = await UserSessionService.create(
db_session,
session_id="old-session-123",
token_hash="old_hash",
expires_at=datetime.now(timezone.utc) + timedelta(hours=1),
)
await db_session.commit()
# Rotate session within atomic context
async with atomic(db_session):
new_session = await UserSessionService.rotate_session(
db_session,
old_session_id="old-session-123",
new_session_id="new-session-456",
new_token_hash="new_hash",
new_expires_at=datetime.now(timezone.utc) + timedelta(hours=2),
)
assert new_session is not None
assert new_session.session_id == "new-session-456"
# Verify old session is revoked
old = await UserSessionService.get_by_session_id(
db_session, "old-session-123"
)
assert old.is_active is False
@pytest.mark.asyncio
async def test_rotate_session_old_not_found(self, db_session):
"""Test rotate_session returns None if old session not found."""
result = await UserSessionService.rotate_session(
db_session,
old_session_id="nonexistent-session",
new_session_id="new-session",
new_token_hash="hash",
new_expires_at=datetime.now(timezone.utc) + timedelta(hours=1),
)
assert result is None
@pytest.mark.asyncio
async def test_cleanup_expired_bulk_delete(self, db_session):
"""Test cleanup_expired removes all expired sessions."""
# Create expired sessions
for i in range(3):
await UserSessionService.create(
db_session,
session_id=f"expired-{i}",
token_hash=f"hash-{i}",
expires_at=datetime.now(timezone.utc) - timedelta(hours=1),
)
# Create active session
await UserSessionService.create(
db_session,
session_id="active-session",
token_hash="active_hash",
expires_at=datetime.now(timezone.utc) + timedelta(hours=1),
)
await db_session.commit()
# Cleanup expired within atomic context
async with atomic(db_session):
count = await UserSessionService.cleanup_expired(db_session)
assert count == 3
# Verify active session still exists
active = await UserSessionService.get_by_session_id(
db_session, "active-session"
)
assert active is not None
# ============================================================================
# Compound Operation Rollback Tests
# ============================================================================
class TestCompoundOperationRollback:
"""Tests for rollback behavior in compound operations."""
@pytest.mark.asyncio
async def test_rollback_on_partial_failure(self, db_session):
"""Test rollback when compound operation fails mid-way."""
# Create initial series
series = await AnimeSeriesService.create(
db_session,
key="rollback-test-series",
name="Rollback Test",
site="https://test.com",
folder="/test/folder",
)
await db_session.commit()
# Store the id before starting the transaction to avoid expired state access
series_id = series.id
try:
async with atomic(db_session):
# Create episode
episode = await EpisodeService.create(
db_session,
series_id=series_id,
season=1,
episode_number=1,
)
# Force flush to persist episode in transaction
await db_session.flush()
# Simulate failure mid-operation
raise ValueError("Simulated failure")
except ValueError:
pass
# Verify episode was NOT persisted
episode = await EpisodeService.get_by_episode(
db_session,
series_id,
season=1,
episode_number=1,
)
assert episode is None
@pytest.mark.asyncio
async def test_no_orphan_data_on_failure(self, db_session):
"""Test no orphaned data when multi-service operation fails."""
try:
async with atomic(db_session):
# Create series
series = await AnimeSeriesService.create(
db_session,
key="orphan-test-series",
name="Orphan Test",
site="https://test.com",
folder="/test/folder",
)
# Create episode
episode = await EpisodeService.create(
db_session,
series_id=series.id,
season=1,
episode_number=1,
)
# Create queue item
item = await DownloadQueueService.create(
db_session,
series_id=series.id,
episode_id=episode.id,
)
await db_session.flush()
# Fail after all creates
raise RuntimeError("Critical failure")
except RuntimeError:
pass
# Verify nothing was persisted
all_series = await AnimeSeriesService.get_all(db_session)
series_keys = [s.key for s in all_series]
assert "orphan-test-series" not in series_keys
# ============================================================================
# Nested Transaction Tests
# ============================================================================
class TestNestedTransactions:
"""Tests for nested transaction (savepoint) behavior."""
@pytest.mark.asyncio
async def test_savepoint_partial_rollback(self, db_session):
"""Test savepoint allows partial rollback."""
# Create series
series = await AnimeSeriesService.create(
db_session,
key="savepoint-test",
name="Savepoint Test",
site="https://test.com",
folder="/test/folder",
)
async with atomic(db_session) as tx:
# Create first episode (should persist)
await EpisodeService.create(
db_session,
series_id=series.id,
season=1,
episode_number=1,
)
# Nested transaction for second episode
async with tx.savepoint() as sp:
await EpisodeService.create(
db_session,
series_id=series.id,
season=1,
episode_number=2,
)
# Rollback only the savepoint
await sp.rollback()
# Create third episode (should persist)
await EpisodeService.create(
db_session,
series_id=series.id,
season=1,
episode_number=3,
)
# Verify first and third episodes exist, second doesn't
episodes = await EpisodeService.get_by_series(db_session, series.id)
episode_numbers = [ep.episode_number for ep in episodes]
assert 1 in episode_numbers
assert 2 not in episode_numbers # Rolled back
assert 3 in episode_numbers

View File

@ -0,0 +1,668 @@
"""Unit tests for database transaction utilities.
Tests the transaction management utilities including decorators,
context managers, and helper functions.
"""
import asyncio
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import Session, sessionmaker
from src.server.database.base import Base
from src.server.database.transaction import (
AsyncTransactionContext,
TransactionContext,
TransactionError,
TransactionPropagation,
atomic,
atomic_sync,
is_in_transaction,
transactional,
)
# ============================================================================
# Fixtures
# ============================================================================
@pytest.fixture
async def async_engine():
"""Create in-memory async database engine for testing."""
engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
echo=False,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest.fixture
async def async_session(async_engine):
"""Create async database session for testing."""
from sqlalchemy.ext.asyncio import async_sessionmaker
async_session_factory = async_sessionmaker(
async_engine,
class_=AsyncSession,
expire_on_commit=False,
)
async with async_session_factory() as session:
yield session
await session.rollback()
# ============================================================================
# TransactionContext Tests (Sync)
# ============================================================================
class TestTransactionContext:
"""Tests for synchronous TransactionContext."""
def test_context_manager_protocol(self):
"""Test context manager enters and exits properly."""
mock_session = MagicMock(spec=Session)
mock_session.in_transaction.return_value = False
with TransactionContext(mock_session) as ctx:
assert ctx.session == mock_session
mock_session.begin.assert_called_once()
mock_session.commit.assert_called_once()
def test_rollback_on_exception(self):
"""Test rollback is called when exception occurs."""
mock_session = MagicMock(spec=Session)
mock_session.in_transaction.return_value = False
with pytest.raises(ValueError):
with TransactionContext(mock_session):
raise ValueError("Test error")
mock_session.rollback.assert_called_once()
mock_session.commit.assert_not_called()
def test_no_begin_if_already_in_transaction(self):
"""Test no new transaction started if already in one."""
mock_session = MagicMock(spec=Session)
mock_session.in_transaction.return_value = True
with TransactionContext(mock_session):
pass
mock_session.begin.assert_not_called()
def test_explicit_commit(self):
"""Test explicit commit within context."""
mock_session = MagicMock(spec=Session)
mock_session.in_transaction.return_value = False
with TransactionContext(mock_session) as ctx:
ctx.commit()
mock_session.commit.assert_called_once()
# Should not commit again on exit
assert mock_session.commit.call_count == 1
def test_explicit_rollback(self):
"""Test explicit rollback within context."""
mock_session = MagicMock(spec=Session)
mock_session.in_transaction.return_value = False
with TransactionContext(mock_session) as ctx:
ctx.rollback()
mock_session.rollback.assert_called_once()
# Should not commit after explicit rollback
mock_session.commit.assert_not_called()
# ============================================================================
# AsyncTransactionContext Tests
# ============================================================================
class TestAsyncTransactionContext:
"""Tests for asynchronous AsyncTransactionContext."""
@pytest.mark.asyncio
async def test_async_context_manager_protocol(self):
"""Test async context manager enters and exits properly."""
mock_session = AsyncMock(spec=AsyncSession)
mock_session.in_transaction.return_value = False
mock_session.begin = AsyncMock()
mock_session.commit = AsyncMock()
mock_session.rollback = AsyncMock()
async with AsyncTransactionContext(mock_session) as ctx:
assert ctx.session == mock_session
mock_session.begin.assert_called_once()
mock_session.commit.assert_called_once()
@pytest.mark.asyncio
async def test_async_rollback_on_exception(self):
"""Test async rollback is called when exception occurs."""
mock_session = AsyncMock(spec=AsyncSession)
mock_session.in_transaction.return_value = False
mock_session.begin = AsyncMock()
mock_session.commit = AsyncMock()
mock_session.rollback = AsyncMock()
with pytest.raises(ValueError):
async with AsyncTransactionContext(mock_session):
raise ValueError("Test error")
mock_session.rollback.assert_called_once()
mock_session.commit.assert_not_called()
@pytest.mark.asyncio
async def test_async_explicit_commit(self):
"""Test async explicit commit within context."""
mock_session = AsyncMock(spec=AsyncSession)
mock_session.in_transaction.return_value = False
mock_session.begin = AsyncMock()
mock_session.commit = AsyncMock()
async with AsyncTransactionContext(mock_session) as ctx:
await ctx.commit()
mock_session.commit.assert_called_once()
# Should not commit again on exit
assert mock_session.commit.call_count == 1
@pytest.mark.asyncio
async def test_async_explicit_rollback(self):
"""Test async explicit rollback within context."""
mock_session = AsyncMock(spec=AsyncSession)
mock_session.in_transaction.return_value = False
mock_session.begin = AsyncMock()
mock_session.commit = AsyncMock()
mock_session.rollback = AsyncMock()
async with AsyncTransactionContext(mock_session) as ctx:
await ctx.rollback()
mock_session.rollback.assert_called_once()
# Should not commit after explicit rollback
mock_session.commit.assert_not_called()
# ============================================================================
# atomic() Context Manager Tests
# ============================================================================
class TestAtomicContextManager:
"""Tests for atomic() async context manager."""
@pytest.mark.asyncio
async def test_atomic_commits_on_success(self):
"""Test atomic commits transaction on success."""
mock_session = AsyncMock(spec=AsyncSession)
mock_session.in_transaction.return_value = False
mock_session.begin = AsyncMock()
mock_session.commit = AsyncMock()
mock_session.rollback = AsyncMock()
async with atomic(mock_session) as tx:
pass
mock_session.commit.assert_called_once()
@pytest.mark.asyncio
async def test_atomic_rollback_on_failure(self):
"""Test atomic rolls back transaction on failure."""
mock_session = AsyncMock(spec=AsyncSession)
mock_session.in_transaction.return_value = False
mock_session.begin = AsyncMock()
mock_session.commit = AsyncMock()
mock_session.rollback = AsyncMock()
with pytest.raises(RuntimeError):
async with atomic(mock_session):
raise RuntimeError("Operation failed")
mock_session.rollback.assert_called_once()
@pytest.mark.asyncio
async def test_atomic_nested_propagation(self):
"""Test atomic with NESTED propagation creates savepoint."""
mock_session = AsyncMock(spec=AsyncSession)
mock_session.in_transaction.return_value = True
mock_nested = AsyncMock()
mock_session.begin_nested = AsyncMock(return_value=mock_nested)
async with atomic(
mock_session, propagation=TransactionPropagation.NESTED
):
pass
mock_session.begin_nested.assert_called_once()
@pytest.mark.asyncio
async def test_atomic_required_propagation_default(self):
"""Test atomic uses REQUIRED propagation by default."""
mock_session = AsyncMock(spec=AsyncSession)
mock_session.in_transaction.return_value = False
mock_session.begin = AsyncMock()
mock_session.commit = AsyncMock()
async with atomic(mock_session) as tx:
# Should start new transaction
mock_session.begin.assert_called_once()
# ============================================================================
# @transactional Decorator Tests
# ============================================================================
class TestTransactionalDecorator:
"""Tests for @transactional decorator."""
@pytest.mark.asyncio
async def test_async_function_wrapped(self):
"""Test async function is wrapped in transaction."""
mock_session = AsyncMock(spec=AsyncSession)
mock_session.in_transaction.return_value = False
mock_session.begin = AsyncMock()
mock_session.commit = AsyncMock()
mock_session.rollback = AsyncMock()
@transactional()
async def sample_operation(db: AsyncSession):
return "result"
result = await sample_operation(db=mock_session)
assert result == "result"
mock_session.commit.assert_called_once()
@pytest.mark.asyncio
async def test_async_rollback_on_error(self):
"""Test async function rollback on error."""
mock_session = AsyncMock(spec=AsyncSession)
mock_session.in_transaction.return_value = False
mock_session.begin = AsyncMock()
mock_session.commit = AsyncMock()
mock_session.rollback = AsyncMock()
@transactional()
async def failing_operation(db: AsyncSession):
raise ValueError("Operation failed")
with pytest.raises(ValueError):
await failing_operation(db=mock_session)
mock_session.rollback.assert_called_once()
@pytest.mark.asyncio
async def test_custom_session_param_name(self):
"""Test decorator with custom session parameter name."""
mock_session = AsyncMock(spec=AsyncSession)
mock_session.in_transaction.return_value = False
mock_session.begin = AsyncMock()
mock_session.commit = AsyncMock()
@transactional(session_param="session")
async def operation_with_session(session: AsyncSession):
return "done"
result = await operation_with_session(session=mock_session)
assert result == "done"
mock_session.commit.assert_called_once()
@pytest.mark.asyncio
async def test_missing_session_raises_error(self):
"""Test error raised when session parameter not found."""
@transactional()
async def operation_no_session(data: dict):
return data
with pytest.raises(TransactionError):
await operation_no_session(data={"key": "value"})
@pytest.mark.asyncio
async def test_propagation_passed_to_atomic(self):
"""Test propagation mode is passed to atomic."""
mock_session = AsyncMock(spec=AsyncSession)
mock_session.in_transaction.return_value = True
mock_nested = AsyncMock()
mock_session.begin_nested = AsyncMock(return_value=mock_nested)
@transactional(propagation=TransactionPropagation.NESTED)
async def nested_operation(db: AsyncSession):
return "nested"
result = await nested_operation(db=mock_session)
assert result == "nested"
mock_session.begin_nested.assert_called_once()
# ============================================================================
# Sync transactional decorator Tests
# ============================================================================
class TestSyncTransactionalDecorator:
"""Tests for @transactional decorator with sync functions."""
def test_sync_function_wrapped(self):
"""Test sync function is wrapped in transaction."""
mock_session = MagicMock(spec=Session)
mock_session.in_transaction.return_value = False
@transactional()
def sample_sync_operation(db: Session):
return "sync_result"
result = sample_sync_operation(db=mock_session)
assert result == "sync_result"
mock_session.commit.assert_called_once()
def test_sync_rollback_on_error(self):
"""Test sync function rollback on error."""
mock_session = MagicMock(spec=Session)
mock_session.in_transaction.return_value = False
@transactional()
def failing_sync_operation(db: Session):
raise ValueError("Sync operation failed")
with pytest.raises(ValueError):
failing_sync_operation(db=mock_session)
mock_session.rollback.assert_called_once()
# ============================================================================
# Helper Function Tests
# ============================================================================
class TestHelperFunctions:
"""Tests for transaction helper functions."""
def test_is_in_transaction_true(self):
"""Test is_in_transaction returns True when in transaction."""
mock_session = MagicMock()
mock_session.in_transaction.return_value = True
assert is_in_transaction(mock_session) is True
def test_is_in_transaction_false(self):
"""Test is_in_transaction returns False when not in transaction."""
mock_session = MagicMock()
mock_session.in_transaction.return_value = False
assert is_in_transaction(mock_session) is False
# ============================================================================
# Integration Tests with Real Database
# ============================================================================
class TestTransactionIntegration:
"""Integration tests using real in-memory database."""
@pytest.mark.asyncio
async def test_real_transaction_commit(self, async_session):
"""Test actual transaction commit with real session."""
from src.server.database.models import AnimeSeries
async with atomic(async_session):
series = AnimeSeries(
key="test-series",
name="Test Series",
site="https://test.com",
folder="/test/folder",
)
async_session.add(series)
# Verify data persisted
from sqlalchemy import select
result = await async_session.execute(
select(AnimeSeries).where(AnimeSeries.key == "test-series")
)
saved_series = result.scalar_one_or_none()
assert saved_series is not None
assert saved_series.name == "Test Series"
@pytest.mark.asyncio
async def test_real_transaction_rollback(self, async_session):
"""Test actual transaction rollback with real session."""
from src.server.database.models import AnimeSeries
try:
async with atomic(async_session):
series = AnimeSeries(
key="rollback-series",
name="Rollback Series",
site="https://test.com",
folder="/test/folder",
)
async_session.add(series)
await async_session.flush()
# Force rollback
raise ValueError("Simulated error")
except ValueError:
pass
# Verify data was NOT persisted
from sqlalchemy import select
result = await async_session.execute(
select(AnimeSeries).where(AnimeSeries.key == "rollback-series")
)
saved_series = result.scalar_one_or_none()
assert saved_series is None
# ============================================================================
# TransactionPropagation Tests
# ============================================================================
class TestTransactionPropagation:
"""Tests for transaction propagation modes."""
def test_propagation_enum_values(self):
"""Test propagation enum has correct values."""
assert TransactionPropagation.REQUIRED.value == "required"
assert TransactionPropagation.REQUIRES_NEW.value == "requires_new"
assert TransactionPropagation.NESTED.value == "nested"
# ============================================================================
# Additional Coverage Tests
# ============================================================================
class TestSyncSavepointCoverage:
"""Additional tests for sync savepoint coverage."""
def test_savepoint_exception_rolls_back(self):
"""Test savepoint rollback when exception occurs within savepoint."""
mock_session = MagicMock(spec=Session)
mock_session.in_transaction.return_value = False
mock_nested = MagicMock()
mock_session.begin_nested.return_value = mock_nested
with TransactionContext(mock_session) as ctx:
with pytest.raises(ValueError):
with ctx.savepoint() as sp:
raise ValueError("Error in savepoint")
# Nested transaction should have been rolled back
mock_nested.rollback.assert_called_once()
def test_savepoint_commit_explicit(self):
"""Test explicit commit on savepoint."""
mock_session = MagicMock(spec=Session)
mock_session.in_transaction.return_value = False
mock_nested = MagicMock()
mock_session.begin_nested.return_value = mock_nested
with TransactionContext(mock_session) as ctx:
with ctx.savepoint() as sp:
sp.commit()
# Commit should just log, SQLAlchemy handles actual commit
class TestAsyncSavepointCoverage:
"""Additional tests for async savepoint coverage."""
@pytest.mark.asyncio
async def test_async_savepoint_exception_rolls_back(self):
"""Test async savepoint rollback when exception occurs."""
mock_session = AsyncMock(spec=AsyncSession)
mock_session.in_transaction.return_value = False
mock_session.begin = AsyncMock()
mock_session.commit = AsyncMock()
mock_session.rollback = AsyncMock()
mock_nested = AsyncMock()
mock_nested.rollback = AsyncMock()
mock_session.begin_nested = AsyncMock(return_value=mock_nested)
async with AsyncTransactionContext(mock_session) as ctx:
with pytest.raises(ValueError):
async with ctx.savepoint() as sp:
raise ValueError("Error in async savepoint")
# Nested transaction should have been rolled back
mock_nested.rollback.assert_called_once()
@pytest.mark.asyncio
async def test_async_savepoint_commit_explicit(self):
"""Test explicit commit on async savepoint."""
mock_session = AsyncMock(spec=AsyncSession)
mock_session.in_transaction.return_value = False
mock_session.begin = AsyncMock()
mock_session.commit = AsyncMock()
mock_nested = AsyncMock()
mock_session.begin_nested = AsyncMock(return_value=mock_nested)
async with AsyncTransactionContext(mock_session) as ctx:
async with ctx.savepoint() as sp:
await sp.commit()
# Commit should just log, SQLAlchemy handles actual commit
class TestAtomicNestedPropagationNoTransaction:
"""Tests for NESTED propagation when not in transaction."""
@pytest.mark.asyncio
async def test_async_nested_starts_new_when_not_in_transaction(self):
"""Test NESTED propagation starts new transaction when none exists."""
mock_session = AsyncMock(spec=AsyncSession)
mock_session.in_transaction.return_value = False
mock_session.begin = AsyncMock()
mock_session.commit = AsyncMock()
mock_session.rollback = AsyncMock()
async with atomic(mock_session, TransactionPropagation.NESTED) as tx:
# Should start new transaction since none exists
pass
mock_session.begin.assert_called_once()
mock_session.commit.assert_called_once()
def test_sync_nested_starts_new_when_not_in_transaction(self):
"""Test sync NESTED propagation starts new transaction when none exists."""
mock_session = MagicMock(spec=Session)
mock_session.in_transaction.return_value = False
with atomic_sync(mock_session, TransactionPropagation.NESTED) as tx:
pass
mock_session.begin.assert_called_once()
mock_session.commit.assert_called_once()
class TestGetTransactionDepth:
"""Tests for get_transaction_depth helper."""
def test_depth_zero_when_not_in_transaction(self):
"""Test depth is 0 when not in transaction."""
from src.server.database.transaction import get_transaction_depth
mock_session = MagicMock(spec=Session)
mock_session.in_transaction.return_value = False
depth = get_transaction_depth(mock_session)
assert depth == 0
def test_depth_one_in_transaction(self):
"""Test depth is 1 in basic transaction."""
from src.server.database.transaction import get_transaction_depth
mock_session = MagicMock(spec=Session)
mock_session.in_transaction.return_value = True
mock_session._nested_transaction = None
depth = get_transaction_depth(mock_session)
assert depth == 1
def test_depth_two_with_nested_transaction(self):
"""Test depth is 2 with nested transaction."""
from src.server.database.transaction import get_transaction_depth
mock_session = MagicMock(spec=Session)
mock_session.in_transaction.return_value = True
mock_session._nested_transaction = MagicMock() # Has nested
depth = get_transaction_depth(mock_session)
assert depth == 2
class TestTransactionalDecoratorPositionalArgs:
"""Tests for transactional decorator with positional arguments."""
@pytest.mark.asyncio
async def test_session_from_positional_arg(self):
"""Test decorator extracts session from positional argument."""
mock_session = AsyncMock(spec=AsyncSession)
mock_session.in_transaction.return_value = False
mock_session.begin = AsyncMock()
mock_session.commit = AsyncMock()
mock_session.rollback = AsyncMock()
@transactional()
async def operation(db: AsyncSession, data: str):
return f"processed: {data}"
# Pass session as positional argument
result = await operation(mock_session, "test")
assert result == "processed: test"
mock_session.commit.assert_called_once()
def test_sync_session_from_positional_arg(self):
"""Test sync decorator extracts session from positional argument."""
mock_session = MagicMock(spec=Session)
mock_session.in_transaction.return_value = False
@transactional()
def operation(db: Session, data: str):
return f"processed: {data}"
result = operation(mock_session, "test")
assert result == "processed: test"
mock_session.commit.assert_called_once()