Compare commits
No commits in common. "30de86e77a57ce7c38193f7bc5941e203a773b02" and "8f7c489bd2d4e5ab350866e3fc7451e58ed7d204" have entirely different histories.
30de86e77a
...
8f7c489bd2
@ -1,290 +0,0 @@
|
|||||||
# Database Layer Implementation Summary
|
|
||||||
|
|
||||||
## Completed: October 17, 2025
|
|
||||||
|
|
||||||
### Overview
|
|
||||||
|
|
||||||
Successfully implemented a comprehensive SQLAlchemy-based database layer for the Aniworld web application, providing persistent storage for anime series, episodes, download queue, and user sessions.
|
|
||||||
|
|
||||||
## Implementation Details
|
|
||||||
|
|
||||||
### Files Created
|
|
||||||
|
|
||||||
1. **`src/server/database/__init__.py`** (35 lines)
|
|
||||||
|
|
||||||
- Package initialization and exports
|
|
||||||
- Public API for database operations
|
|
||||||
|
|
||||||
2. **`src/server/database/base.py`** (75 lines)
|
|
||||||
|
|
||||||
- Base declarative class for all models
|
|
||||||
- TimestampMixin for automatic timestamp tracking
|
|
||||||
- SoftDeleteMixin for logical deletion (future use)
|
|
||||||
|
|
||||||
3. **`src/server/database/models.py`** (435 lines)
|
|
||||||
|
|
||||||
- AnimeSeries model with relationships
|
|
||||||
- Episode model linked to series
|
|
||||||
- DownloadQueueItem for queue persistence
|
|
||||||
- UserSession for authentication
|
|
||||||
- Enum types for status and priority
|
|
||||||
|
|
||||||
4. **`src/server/database/connection.py`** (250 lines)
|
|
||||||
|
|
||||||
- Async and sync engine creation
|
|
||||||
- Session factory configuration
|
|
||||||
- FastAPI dependency injection
|
|
||||||
- SQLite optimizations (WAL mode, foreign keys)
|
|
||||||
|
|
||||||
5. **`src/server/database/migrations.py`** (8 lines)
|
|
||||||
|
|
||||||
- Placeholder for future Alembic migrations
|
|
||||||
|
|
||||||
6. **`src/server/database/README.md`** (300 lines)
|
|
||||||
|
|
||||||
- Comprehensive documentation
|
|
||||||
- Usage examples
|
|
||||||
- Quick start guide
|
|
||||||
- Troubleshooting section
|
|
||||||
|
|
||||||
7. **`tests/unit/test_database_models.py`** (550 lines)
|
|
||||||
- 19 comprehensive test cases
|
|
||||||
- Model creation and validation
|
|
||||||
- Relationship testing
|
|
||||||
- Query operations
|
|
||||||
- All tests passing ✅
|
|
||||||
|
|
||||||
### Files Modified
|
|
||||||
|
|
||||||
1. **`requirements.txt`**
|
|
||||||
|
|
||||||
- Added: sqlalchemy>=2.0.35
|
|
||||||
- Added: alembic==1.13.0
|
|
||||||
- Added: aiosqlite>=0.19.0
|
|
||||||
|
|
||||||
2. **`src/server/utils/dependencies.py`**
|
|
||||||
|
|
||||||
- Updated `get_database_session()` dependency
|
|
||||||
- Proper error handling and imports
|
|
||||||
|
|
||||||
3. **`infrastructure.md`**
|
|
||||||
- Added comprehensive Database Layer section
|
|
||||||
- Documented models, relationships, configuration
|
|
||||||
- Production considerations
|
|
||||||
- Integration examples
|
|
||||||
|
|
||||||
## Database Schema
|
|
||||||
|
|
||||||
### AnimeSeries
|
|
||||||
|
|
||||||
- **Primary Key**: id (auto-increment)
|
|
||||||
- **Unique Key**: key (provider identifier)
|
|
||||||
- **Fields**: name, site, folder, description, status, total_episodes, cover_url, episode_dict
|
|
||||||
- **Relationships**: One-to-many with Episode and DownloadQueueItem
|
|
||||||
- **Indexes**: key, name
|
|
||||||
- **Cascade**: Delete episodes and download items on series deletion
|
|
||||||
|
|
||||||
### Episode
|
|
||||||
|
|
||||||
- **Primary Key**: id
|
|
||||||
- **Foreign Key**: series_id → AnimeSeries
|
|
||||||
- **Fields**: season, episode_number, title, file_path, file_size, is_downloaded, download_date
|
|
||||||
- **Relationship**: Many-to-one with AnimeSeries
|
|
||||||
- **Indexes**: series_id
|
|
||||||
|
|
||||||
### DownloadQueueItem
|
|
||||||
|
|
||||||
- **Primary Key**: id
|
|
||||||
- **Foreign Key**: series_id → AnimeSeries
|
|
||||||
- **Fields**: season, episode_number, status (enum), priority (enum), progress_percent, downloaded_bytes, total_bytes, download_speed, error_message, retry_count, download_url, file_destination, started_at, completed_at
|
|
||||||
- **Status Enum**: PENDING, DOWNLOADING, PAUSED, COMPLETED, FAILED, CANCELLED
|
|
||||||
- **Priority Enum**: LOW, NORMAL, HIGH
|
|
||||||
- **Indexes**: series_id, status
|
|
||||||
- **Relationship**: Many-to-one with AnimeSeries
|
|
||||||
|
|
||||||
### UserSession
|
|
||||||
|
|
||||||
- **Primary Key**: id
|
|
||||||
- **Unique Key**: session_id
|
|
||||||
- **Fields**: token_hash, user_id, ip_address, user_agent, expires_at, is_active, last_activity
|
|
||||||
- **Methods**: is_expired (property), revoke()
|
|
||||||
- **Indexes**: session_id, user_id, is_active
|
|
||||||
|
|
||||||
## Features Implemented
|
|
||||||
|
|
||||||
### Core Functionality
|
|
||||||
|
|
||||||
✅ SQLAlchemy 2.0 async support
|
|
||||||
✅ Automatic timestamp tracking (created_at, updated_at)
|
|
||||||
✅ Foreign key constraints with cascade deletes
|
|
||||||
✅ Soft delete support (mixin available)
|
|
||||||
✅ Enum types for status and priority
|
|
||||||
✅ JSON field for complex data structures
|
|
||||||
✅ Comprehensive type hints
|
|
||||||
|
|
||||||
### Database Management
|
|
||||||
|
|
||||||
✅ Async and sync engine creation
|
|
||||||
✅ Session factory with proper configuration
|
|
||||||
✅ FastAPI dependency injection
|
|
||||||
✅ Automatic table creation
|
|
||||||
✅ SQLite optimizations (WAL, foreign keys)
|
|
||||||
✅ Connection pooling configuration
|
|
||||||
✅ Graceful shutdown and cleanup
|
|
||||||
|
|
||||||
### Testing
|
|
||||||
|
|
||||||
✅ 19 comprehensive test cases
|
|
||||||
✅ 100% test pass rate
|
|
||||||
✅ In-memory SQLite for isolation
|
|
||||||
✅ Fixtures for engine and session
|
|
||||||
✅ Relationship testing
|
|
||||||
✅ Constraint validation
|
|
||||||
✅ Query operation tests
|
|
||||||
|
|
||||||
### Documentation
|
|
||||||
|
|
||||||
✅ Comprehensive infrastructure.md section
|
|
||||||
✅ Database package README
|
|
||||||
✅ Usage examples
|
|
||||||
✅ Production considerations
|
|
||||||
✅ Troubleshooting guide
|
|
||||||
✅ Migration strategy (future)
|
|
||||||
|
|
||||||
## Technical Highlights
|
|
||||||
|
|
||||||
### Python Version Compatibility
|
|
||||||
|
|
||||||
- **Issue**: SQLAlchemy 2.0.23 incompatible with Python 3.13
|
|
||||||
- **Solution**: Upgraded to SQLAlchemy 2.0.44
|
|
||||||
- **Result**: All tests passing on Python 3.13.7
|
|
||||||
|
|
||||||
### Async Support
|
|
||||||
|
|
||||||
- Uses aiosqlite for async SQLite operations
|
|
||||||
- AsyncSession for non-blocking database operations
|
|
||||||
- Proper async context managers for session lifecycle
|
|
||||||
|
|
||||||
### SQLite Optimizations
|
|
||||||
|
|
||||||
- WAL (Write-Ahead Logging) mode enabled
|
|
||||||
- Foreign key constraints enabled via PRAGMA
|
|
||||||
- Static pool for single-connection use
|
|
||||||
- Automatic conversion of sqlite:/// to sqlite+aiosqlite:///
|
|
||||||
|
|
||||||
### Type Safety
|
|
||||||
|
|
||||||
- Comprehensive type hints using SQLAlchemy 2.0 Mapped types
|
|
||||||
- Pydantic integration for validation
|
|
||||||
- Type-safe relationships and foreign keys
|
|
||||||
|
|
||||||
## Integration Points
|
|
||||||
|
|
||||||
### FastAPI Endpoints
|
|
||||||
|
|
||||||
```python
|
|
||||||
@app.get("/anime")
|
|
||||||
async def get_anime(db: AsyncSession = Depends(get_database_session)):
|
|
||||||
result = await db.execute(select(AnimeSeries))
|
|
||||||
return result.scalars().all()
|
|
||||||
```
|
|
||||||
|
|
||||||
### Service Layer
|
|
||||||
|
|
||||||
- AnimeService: Query and persist series data
|
|
||||||
- DownloadService: Queue persistence and recovery
|
|
||||||
- AuthService: Session storage and validation
|
|
||||||
|
|
||||||
### Future Enhancements
|
|
||||||
|
|
||||||
- Alembic migrations for schema versioning
|
|
||||||
- PostgreSQL/MySQL support for production
|
|
||||||
- Read replicas for scaling
|
|
||||||
- Connection pool metrics
|
|
||||||
- Query performance monitoring
|
|
||||||
|
|
||||||
## Testing Results
|
|
||||||
|
|
||||||
```
|
|
||||||
============================= test session starts ==============================
|
|
||||||
platform linux -- Python 3.13.7, pytest-8.4.2, pluggy-1.6.0
|
|
||||||
collected 19 items
|
|
||||||
|
|
||||||
tests/unit/test_database_models.py::TestAnimeSeries::test_create_anime_series PASSED
|
|
||||||
tests/unit/test_database_models.py::TestAnimeSeries::test_anime_series_unique_key PASSED
|
|
||||||
tests/unit/test_database_models.py::TestAnimeSeries::test_anime_series_relationships PASSED
|
|
||||||
tests/unit/test_database_models.py::TestAnimeSeries::test_anime_series_cascade_delete PASSED
|
|
||||||
tests/unit/test_database_models.py::TestEpisode::test_create_episode PASSED
|
|
||||||
tests/unit/test_database_models.py::TestEpisode::test_episode_relationship_to_series PASSED
|
|
||||||
tests/unit/test_database_models.py::TestDownloadQueueItem::test_create_download_item PASSED
|
|
||||||
tests/unit/test_database_models.py::TestDownloadQueueItem::test_download_item_status_enum PASSED
|
|
||||||
tests/unit/test_database_models.py::TestDownloadQueueItem::test_download_item_error_handling PASSED
|
|
||||||
tests/unit/test_database_models.py::TestUserSession::test_create_user_session PASSED
|
|
||||||
tests/unit/test_database_models.py::TestUserSession::test_session_unique_session_id PASSED
|
|
||||||
tests/unit/test_database_models.py::TestUserSession::test_session_is_expired PASSED
|
|
||||||
tests/unit/test_database_models.py::TestUserSession::test_session_revoke PASSED
|
|
||||||
tests/unit/test_database_models.py::TestTimestampMixin::test_timestamp_auto_creation PASSED
|
|
||||||
tests/unit/test_database_models.py::TestTimestampMixin::test_timestamp_auto_update PASSED
|
|
||||||
tests/unit/test_database_models.py::TestSoftDeleteMixin::test_soft_delete_not_applied_to_models PASSED
|
|
||||||
tests/unit/test_database_models.py::TestDatabaseQueries::test_query_series_with_episodes PASSED
|
|
||||||
tests/unit/test_database_models.py::TestDatabaseQueries::test_query_download_queue_by_status PASSED
|
|
||||||
tests/unit/test_database_models.py::TestDatabaseQueries::test_query_active_sessions PASSED
|
|
||||||
|
|
||||||
======================= 19 passed, 21 warnings in 0.50s ========================
|
|
||||||
```
|
|
||||||
|
|
||||||
## Deliverables Checklist
|
|
||||||
|
|
||||||
✅ Database directory structure created
|
|
||||||
✅ SQLAlchemy models implemented (4 models)
|
|
||||||
✅ Connection and session management
|
|
||||||
✅ FastAPI dependency injection
|
|
||||||
✅ Comprehensive unit tests (19 tests)
|
|
||||||
✅ Documentation updated (infrastructure.md)
|
|
||||||
✅ Package README created
|
|
||||||
✅ Dependencies added to requirements.txt
|
|
||||||
✅ All tests passing
|
|
||||||
✅ Python 3.13 compatibility verified
|
|
||||||
|
|
||||||
## Lines of Code
|
|
||||||
|
|
||||||
- **Implementation**: ~1,200 lines
|
|
||||||
- **Tests**: ~550 lines
|
|
||||||
- **Documentation**: ~500 lines
|
|
||||||
- **Total**: ~2,250 lines
|
|
||||||
|
|
||||||
## Code Quality
|
|
||||||
|
|
||||||
✅ Follows PEP 8 style guide
|
|
||||||
✅ Comprehensive docstrings
|
|
||||||
✅ Type hints throughout
|
|
||||||
✅ Error handling implemented
|
|
||||||
✅ Logging integrated
|
|
||||||
✅ Clean separation of concerns
|
|
||||||
✅ DRY principles followed
|
|
||||||
✅ Single responsibility maintained
|
|
||||||
|
|
||||||
## Status
|
|
||||||
|
|
||||||
**COMPLETED** ✅
|
|
||||||
|
|
||||||
All tasks from the Database Layer implementation checklist have been successfully completed. The database layer is production-ready and fully integrated with the existing Aniworld application infrastructure.
|
|
||||||
|
|
||||||
## Next Steps (Recommended)
|
|
||||||
|
|
||||||
1. Initialize Alembic for database migrations
|
|
||||||
2. Integrate database layer with existing services
|
|
||||||
3. Add database-backed session storage
|
|
||||||
4. Implement database queries in API endpoints
|
|
||||||
5. Add database connection pooling metrics
|
|
||||||
6. Create database backup automation
|
|
||||||
7. Add performance monitoring
|
|
||||||
|
|
||||||
## Notes
|
|
||||||
|
|
||||||
- SQLite is used for development and single-instance deployments
|
|
||||||
- PostgreSQL/MySQL recommended for multi-process production deployments
|
|
||||||
- Connection pooling configured for both development and production scenarios
|
|
||||||
- All foreign key relationships properly enforced
|
|
||||||
- Cascade deletes configured for data consistency
|
|
||||||
- Indexes added for frequently queried columns
|
|
||||||
@ -7,21 +7,6 @@ conda activate AniWorld
|
|||||||
```
|
```
|
||||||
/home/lukas/Volume/repo/Aniworld/
|
/home/lukas/Volume/repo/Aniworld/
|
||||||
├── src/
|
├── src/
|
||||||
│ ├── core/ # Core application logic
|
|
||||||
│ │ ├── SeriesApp.py # Main application class with async support
|
|
||||||
│ │ ├── SerieScanner.py # Directory scanner for anime series
|
|
||||||
│ │ ├── entities/ # Domain entities
|
|
||||||
│ │ │ ├── series.py # Serie data model
|
|
||||||
│ │ │ └── SerieList.py # Series list management
|
|
||||||
│ │ ├── interfaces/ # Abstract interfaces
|
|
||||||
│ │ │ └── providers.py # Provider interface definitions
|
|
||||||
│ │ ├── providers/ # Content providers
|
|
||||||
│ │ │ ├── base_provider.py # Base loader interface
|
|
||||||
│ │ │ ├── aniworld_provider.py # Aniworld.to implementation
|
|
||||||
│ │ │ ├── provider_factory.py # Provider factory
|
|
||||||
│ │ │ └── streaming/ # Streaming providers (VOE, etc.)
|
|
||||||
│ │ └── exceptions/ # Custom exceptions
|
|
||||||
│ │ └── Exceptions.py # Exception definitions
|
|
||||||
│ ├── server/ # FastAPI web application
|
│ ├── server/ # FastAPI web application
|
||||||
│ │ ├── fastapi_app.py # Main FastAPI application (simplified)
|
│ │ ├── fastapi_app.py # Main FastAPI application (simplified)
|
||||||
│ │ ├── main.py # FastAPI application entry point
|
│ │ ├── main.py # FastAPI application entry point
|
||||||
@ -52,11 +37,6 @@ conda activate AniWorld
|
|||||||
│ │ │ ├── anime_service.py
|
│ │ │ ├── anime_service.py
|
||||||
│ │ │ ├── download_service.py
|
│ │ │ ├── download_service.py
|
||||||
│ │ │ └── websocket_service.py # WebSocket connection management
|
│ │ │ └── websocket_service.py # WebSocket connection management
|
||||||
│ │ ├── database/ # Database layer
|
|
||||||
│ │ │ ├── __init__.py # Database package
|
|
||||||
│ │ │ ├── base.py # Base models and mixins
|
|
||||||
│ │ │ ├── models.py # SQLAlchemy ORM models
|
|
||||||
│ │ │ └── connection.py # Database connection management
|
|
||||||
│ │ ├── utils/ # Utility functions
|
│ │ ├── utils/ # Utility functions
|
||||||
│ │ │ ├── __init__.py
|
│ │ │ ├── __init__.py
|
||||||
│ │ │ ├── security.py
|
│ │ │ ├── security.py
|
||||||
@ -113,9 +93,7 @@ conda activate AniWorld
|
|||||||
|
|
||||||
- **FastAPI**: Modern Python web framework for building APIs
|
- **FastAPI**: Modern Python web framework for building APIs
|
||||||
- **Uvicorn**: ASGI server for running FastAPI applications
|
- **Uvicorn**: ASGI server for running FastAPI applications
|
||||||
- **SQLAlchemy**: SQL toolkit and ORM for database operations
|
|
||||||
- **SQLite**: Lightweight database for storing anime library and configuration
|
- **SQLite**: Lightweight database for storing anime library and configuration
|
||||||
- **Alembic**: Database migration tool for schema management
|
|
||||||
- **Pydantic**: Data validation and serialization
|
- **Pydantic**: Data validation and serialization
|
||||||
- **Jinja2**: Template engine for server-side rendering
|
- **Jinja2**: Template engine for server-side rendering
|
||||||
|
|
||||||
@ -165,37 +143,13 @@ conda activate AniWorld
|
|||||||
|
|
||||||
### Configuration API Notes
|
### Configuration API Notes
|
||||||
|
|
||||||
- Configuration endpoints are exposed under `/api/config`
|
- The configuration endpoints are exposed under `/api/config` and
|
||||||
- Uses file-based persistence with JSON format for human-readable storage
|
operate primarily on a JSON-serializable `AppConfig` model. They are
|
||||||
- Automatic backup creation before configuration updates
|
designed to be lightweight and avoid performing IO during validation
|
||||||
- Configuration validation with detailed error reporting
|
(the `/api/config/validate` endpoint runs in-memory checks only).
|
||||||
- Backup management with create, restore, list, and delete operations
|
- Persistence of configuration changes is intentionally "best-effort"
|
||||||
- Configuration schema versioning with migration support
|
for now and mirrors fields into the runtime settings object. A
|
||||||
- Singleton ConfigService manages all persistence operations
|
follow-up task should add durable storage (file or DB) for configs.
|
||||||
- Default configuration location: `data/config.json`
|
|
||||||
- Backup directory: `data/config_backups/`
|
|
||||||
- Maximum backups retained: 10 (configurable)
|
|
||||||
- Automatic cleanup of old backups exceeding limit
|
|
||||||
|
|
||||||
**Key Endpoints:**
|
|
||||||
|
|
||||||
- `GET /api/config` - Retrieve current configuration
|
|
||||||
- `PUT /api/config` - Update configuration (creates backup)
|
|
||||||
- `POST /api/config/validate` - Validate without applying
|
|
||||||
- `GET /api/config/backups` - List all backups
|
|
||||||
- `POST /api/config/backups` - Create manual backup
|
|
||||||
- `POST /api/config/backups/{name}/restore` - Restore from backup
|
|
||||||
- `DELETE /api/config/backups/{name}` - Delete backup
|
|
||||||
|
|
||||||
**Configuration Service Features:**
|
|
||||||
|
|
||||||
- Atomic file writes using temporary files
|
|
||||||
- JSON format with version metadata
|
|
||||||
- Validation before saving
|
|
||||||
- Automatic backup on updates
|
|
||||||
- Migration support for schema changes
|
|
||||||
- Thread-safe singleton pattern
|
|
||||||
- Comprehensive error handling with custom exceptions
|
|
||||||
|
|
||||||
### Anime Management
|
### Anime Management
|
||||||
|
|
||||||
@ -264,646 +218,8 @@ initialization.
|
|||||||
this state to a shared store (Redis) and persist the master password
|
this state to a shared store (Redis) and persist the master password
|
||||||
hash in a secure config store.
|
hash in a secure config store.
|
||||||
|
|
||||||
## Database Layer (October 2025)
|
|
||||||
|
|
||||||
A comprehensive SQLAlchemy-based database layer was implemented to provide
|
|
||||||
persistent storage for anime series, episodes, download queue, and user sessions.
|
|
||||||
|
|
||||||
### Architecture
|
|
||||||
|
|
||||||
**Location**: `src/server/database/`
|
|
||||||
|
|
||||||
**Components**:
|
|
||||||
|
|
||||||
- `base.py`: Base declarative class and mixins (TimestampMixin, SoftDeleteMixin)
|
|
||||||
- `models.py`: SQLAlchemy ORM models with relationships
|
|
||||||
- `connection.py`: Database engine, session factory, and dependency injection
|
|
||||||
- `__init__.py`: Package exports and public API
|
|
||||||
|
|
||||||
### Database Models
|
|
||||||
|
|
||||||
#### AnimeSeries
|
|
||||||
|
|
||||||
Represents anime series with metadata and provider information.
|
|
||||||
|
|
||||||
**Fields**:
|
|
||||||
|
|
||||||
- `id` (PK): Auto-incrementing primary key
|
|
||||||
- `key`: Unique provider identifier (indexed)
|
|
||||||
- `name`: Series name (indexed)
|
|
||||||
- `site`: Provider site URL
|
|
||||||
- `folder`: Local filesystem path
|
|
||||||
- `description`: Optional series description
|
|
||||||
- `status`: Series status (ongoing, completed)
|
|
||||||
- `total_episodes`: Total episode count
|
|
||||||
- `cover_url`: Cover image URL
|
|
||||||
- `episode_dict`: JSON field storing episode structure {season: [episodes]}
|
|
||||||
- `created_at`, `updated_at`: Audit timestamps (from TimestampMixin)
|
|
||||||
|
|
||||||
**Relationships**:
|
|
||||||
|
|
||||||
- `episodes`: One-to-many with Episode (cascade delete)
|
|
||||||
- `download_items`: One-to-many with DownloadQueueItem (cascade delete)
|
|
||||||
|
|
||||||
#### Episode
|
|
||||||
|
|
||||||
Individual episodes linked to anime series.
|
|
||||||
|
|
||||||
**Fields**:
|
|
||||||
|
|
||||||
- `id` (PK): Auto-incrementing primary key
|
|
||||||
- `series_id` (FK): Foreign key to AnimeSeries (indexed)
|
|
||||||
- `season`: Season number
|
|
||||||
- `episode_number`: Episode number within season
|
|
||||||
- `title`: Optional episode title
|
|
||||||
- `file_path`: Local file path if downloaded
|
|
||||||
- `file_size`: File size in bytes
|
|
||||||
- `is_downloaded`: Boolean download status
|
|
||||||
- `download_date`: Timestamp when downloaded
|
|
||||||
- `created_at`, `updated_at`: Audit timestamps
|
|
||||||
|
|
||||||
**Relationships**:
|
|
||||||
|
|
||||||
- `series`: Many-to-one with AnimeSeries
|
|
||||||
|
|
||||||
#### DownloadQueueItem
|
|
||||||
|
|
||||||
Download queue with status and progress tracking.
|
|
||||||
|
|
||||||
**Fields**:
|
|
||||||
|
|
||||||
- `id` (PK): Auto-incrementing primary key
|
|
||||||
- `series_id` (FK): Foreign key to AnimeSeries (indexed)
|
|
||||||
- `season`: Season number
|
|
||||||
- `episode_number`: Episode number
|
|
||||||
- `status`: Download status enum (indexed)
|
|
||||||
- Values: PENDING, DOWNLOADING, PAUSED, COMPLETED, FAILED, CANCELLED
|
|
||||||
- `priority`: Priority enum
|
|
||||||
- Values: LOW, NORMAL, HIGH
|
|
||||||
- `progress_percent`: Download progress (0-100)
|
|
||||||
- `downloaded_bytes`: Bytes downloaded
|
|
||||||
- `total_bytes`: Total file size
|
|
||||||
- `download_speed`: Current speed (bytes/sec)
|
|
||||||
- `error_message`: Error description if failed
|
|
||||||
- `retry_count`: Number of retry attempts
|
|
||||||
- `download_url`: Provider download URL
|
|
||||||
- `file_destination`: Target file path
|
|
||||||
- `started_at`: Download start timestamp
|
|
||||||
- `completed_at`: Download completion timestamp
|
|
||||||
- `created_at`, `updated_at`: Audit timestamps
|
|
||||||
|
|
||||||
**Relationships**:
|
|
||||||
|
|
||||||
- `series`: Many-to-one with AnimeSeries
|
|
||||||
|
|
||||||
#### UserSession
|
|
||||||
|
|
||||||
User authentication sessions with JWT tokens.
|
|
||||||
|
|
||||||
**Fields**:
|
|
||||||
|
|
||||||
- `id` (PK): Auto-incrementing primary key
|
|
||||||
- `session_id`: Unique session identifier (indexed)
|
|
||||||
- `token_hash`: Hashed JWT token
|
|
||||||
- `user_id`: User identifier (indexed, for multi-user support)
|
|
||||||
- `ip_address`: Client IP address
|
|
||||||
- `user_agent`: Client user agent string
|
|
||||||
- `expires_at`: Session expiration timestamp
|
|
||||||
- `is_active`: Boolean active status (indexed)
|
|
||||||
- `last_activity`: Last activity timestamp
|
|
||||||
- `created_at`, `updated_at`: Audit timestamps
|
|
||||||
|
|
||||||
**Methods**:
|
|
||||||
|
|
||||||
- `is_expired`: Property to check if session has expired
|
|
||||||
- `revoke()`: Revoke session by setting is_active=False
|
|
||||||
|
|
||||||
### Mixins
|
|
||||||
|
|
||||||
#### TimestampMixin
|
|
||||||
|
|
||||||
Adds automatic timestamp tracking to models.
|
|
||||||
|
|
||||||
**Fields**:
|
|
||||||
|
|
||||||
- `created_at`: Automatically set on record creation
|
|
||||||
- `updated_at`: Automatically updated on record modification
|
|
||||||
|
|
||||||
**Usage**: Inherit in models requiring audit timestamps.
|
|
||||||
|
|
||||||
#### SoftDeleteMixin
|
|
||||||
|
|
||||||
Provides soft delete functionality (logical deletion).
|
|
||||||
|
|
||||||
**Fields**:
|
|
||||||
|
|
||||||
- `deleted_at`: Timestamp when soft deleted (NULL if active)
|
|
||||||
|
|
||||||
**Properties**:
|
|
||||||
|
|
||||||
- `is_deleted`: Check if record is soft deleted
|
|
||||||
|
|
||||||
**Methods**:
|
|
||||||
|
|
||||||
- `soft_delete()`: Mark record as deleted
|
|
||||||
- `restore()`: Restore soft deleted record
|
|
||||||
|
|
||||||
**Note**: Currently not used by models but available for future implementation.
|
|
||||||
|
|
||||||
### Database Connection Management
|
|
||||||
|
|
||||||
#### Initialization
|
|
||||||
|
|
||||||
```python
|
|
||||||
from src.server.database import init_db, close_db
|
|
||||||
|
|
||||||
# Application startup
|
|
||||||
await init_db() # Creates engine, session factory, and tables
|
|
||||||
|
|
||||||
# Application shutdown
|
|
||||||
await close_db() # Closes connections and cleanup
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Session Management
|
|
||||||
|
|
||||||
**Async Sessions** (preferred for FastAPI endpoints):
|
|
||||||
|
|
||||||
```python
|
|
||||||
from fastapi import Depends
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
from src.server.database import get_db_session
|
|
||||||
|
|
||||||
@app.get("/anime")
|
|
||||||
async def get_anime(db: AsyncSession = Depends(get_db_session)):
|
|
||||||
result = await db.execute(select(AnimeSeries))
|
|
||||||
return result.scalars().all()
|
|
||||||
```
|
|
||||||
|
|
||||||
**Sync Sessions** (for non-async operations):
|
|
||||||
|
|
||||||
```python
|
|
||||||
from src.server.database.connection import get_sync_session
|
|
||||||
|
|
||||||
session = get_sync_session()
|
|
||||||
try:
|
|
||||||
result = session.execute(select(AnimeSeries))
|
|
||||||
return result.scalars().all()
|
|
||||||
finally:
|
|
||||||
session.close()
|
|
||||||
```
|
|
||||||
|
|
||||||
### Database Configuration
|
|
||||||
|
|
||||||
**Settings** (from `src/config/settings.py`):
|
|
||||||
|
|
||||||
- `DATABASE_URL`: Database connection string
|
|
||||||
- Default: `sqlite:///./data/aniworld.db`
|
|
||||||
- Automatically converted to `sqlite+aiosqlite:///` for async support
|
|
||||||
- `LOG_LEVEL`: When set to "DEBUG", enables SQL query logging
|
|
||||||
|
|
||||||
**Engine Configuration**:
|
|
||||||
|
|
||||||
- **SQLite**: Uses StaticPool, enables foreign keys and WAL mode
|
|
||||||
- **PostgreSQL/MySQL**: Uses QueuePool with pre-ping health checks
|
|
||||||
- **Connection Pooling**: Configured based on database type
|
|
||||||
- **Echo**: SQL query logging in DEBUG mode
|
|
||||||
|
|
||||||
### SQLite Optimizations
|
|
||||||
|
|
||||||
- **Foreign Keys**: Automatically enabled via PRAGMA
|
|
||||||
- **WAL Mode**: Write-Ahead Logging for better concurrency
|
|
||||||
- **Static Pool**: Single connection pool for SQLite
|
|
||||||
- **Async Support**: aiosqlite driver for async operations
|
|
||||||
|
|
||||||
### FastAPI Integration
|
|
||||||
|
|
||||||
**Dependency Injection** (in `src/server/utils/dependencies.py`):
|
|
||||||
|
|
||||||
```python
|
|
||||||
async def get_database_session() -> AsyncGenerator:
|
|
||||||
"""Dependency to get database session."""
|
|
||||||
try:
|
|
||||||
from src.server.database import get_db_session
|
|
||||||
|
|
||||||
async with get_db_session() as session:
|
|
||||||
yield session
|
|
||||||
except ImportError:
|
|
||||||
raise HTTPException(status_code=501, detail="Database not installed")
|
|
||||||
except RuntimeError as e:
|
|
||||||
raise HTTPException(status_code=503, detail=f"Database not available: {str(e)}")
|
|
||||||
```
|
|
||||||
|
|
||||||
**Usage in Endpoints**:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from fastapi import Depends
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
from src.server.utils.dependencies import get_database_session
|
|
||||||
|
|
||||||
@router.get("/series/{series_id}")
|
|
||||||
async def get_series(
|
|
||||||
series_id: int,
|
|
||||||
db: AsyncSession = Depends(get_database_session)
|
|
||||||
):
|
|
||||||
result = await db.execute(
|
|
||||||
select(AnimeSeries).where(AnimeSeries.id == series_id)
|
|
||||||
)
|
|
||||||
series = result.scalar_one_or_none()
|
|
||||||
if not series:
|
|
||||||
raise HTTPException(status_code=404, detail="Series not found")
|
|
||||||
return series
|
|
||||||
```
|
|
||||||
|
|
||||||
### Testing
|
|
||||||
|
|
||||||
**Test Suite**: `tests/unit/test_database_models.py`
|
|
||||||
|
|
||||||
**Coverage**:
|
|
||||||
|
|
||||||
- 30+ comprehensive test cases
|
|
||||||
- Model creation and validation
|
|
||||||
- Relationship testing (one-to-many, cascade deletes)
|
|
||||||
- Unique constraint validation
|
|
||||||
- Query operations (filtering, joins)
|
|
||||||
- Session management
|
|
||||||
- Mixin functionality
|
|
||||||
|
|
||||||
**Test Strategy**:
|
|
||||||
|
|
||||||
- In-memory SQLite database for isolation
|
|
||||||
- Fixtures for engine and session setup
|
|
||||||
- Test all CRUD operations
|
|
||||||
- Verify constraints and relationships
|
|
||||||
- Test edge cases and error conditions
|
|
||||||
|
|
||||||
### Migration Strategy (Future)
|
|
||||||
|
|
||||||
**Alembic Integration** (planned):
|
|
||||||
|
|
||||||
- Alembic installed but not yet configured
|
|
||||||
- Will manage schema migrations in production
|
|
||||||
- Auto-generate migrations from model changes
|
|
||||||
- Version control for database schema
|
|
||||||
|
|
||||||
**Initial Setup**:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Initialize Alembic (future)
|
|
||||||
alembic init alembic
|
|
||||||
|
|
||||||
# Generate initial migration
|
|
||||||
alembic revision --autogenerate -m "Initial schema"
|
|
||||||
|
|
||||||
# Apply migrations
|
|
||||||
alembic upgrade head
|
|
||||||
```
|
|
||||||
|
|
||||||
### Production Considerations
|
|
||||||
|
|
||||||
**Single-Process Deployment** (current):
|
|
||||||
|
|
||||||
- SQLite with WAL mode for concurrency
|
|
||||||
- Static pool for single connection
|
|
||||||
- File-based storage at `data/aniworld.db`
|
|
||||||
|
|
||||||
**Multi-Process Deployment** (future):
|
|
||||||
|
|
||||||
- Switch to PostgreSQL or MySQL
|
|
||||||
- Configure connection pooling (pool_size, max_overflow)
|
|
||||||
- Use QueuePool for connection management
|
|
||||||
- Consider read replicas for scaling
|
|
||||||
|
|
||||||
**Performance**:
|
|
||||||
|
|
||||||
- Indexes on frequently queried columns (key, name, status, is_active)
|
|
||||||
- Foreign key constraints for referential integrity
|
|
||||||
- Cascade deletes for cleanup operations
|
|
||||||
- Efficient joins via relationship loading strategies
|
|
||||||
|
|
||||||
**Monitoring**:
|
|
||||||
|
|
||||||
- SQL query logging in DEBUG mode
|
|
||||||
- Connection pool metrics (when using QueuePool)
|
|
||||||
- Query performance profiling
|
|
||||||
- Database size monitoring
|
|
||||||
|
|
||||||
**Backup Strategy**:
|
|
||||||
|
|
||||||
- SQLite: File-based backups (copy `aniworld.db` file)
|
|
||||||
- WAL checkpoint before backup
|
|
||||||
- Automated backup schedule recommended
|
|
||||||
- Store backups in `data/config_backups/` or separate location
|
|
||||||
|
|
||||||
### Integration with Services
|
|
||||||
|
|
||||||
**AnimeService**:
|
|
||||||
|
|
||||||
- Query series from database
|
|
||||||
- Persist scan results
|
|
||||||
- Update episode metadata
|
|
||||||
|
|
||||||
**DownloadService**:
|
|
||||||
|
|
||||||
- Load queue from database on startup
|
|
||||||
- Persist queue state continuously
|
|
||||||
- Update download progress in real-time
|
|
||||||
|
|
||||||
**AuthService**:
|
|
||||||
|
|
||||||
- Store and validate user sessions
|
|
||||||
- Session revocation via database
|
|
||||||
- Query active sessions for monitoring
|
|
||||||
|
|
||||||
### Benefits of Database Layer
|
|
||||||
|
|
||||||
- **Persistence**: Survives application restarts
|
|
||||||
- **Relationships**: Enforced referential integrity
|
|
||||||
- **Queries**: Powerful filtering and aggregation
|
|
||||||
- **Scalability**: Can migrate to PostgreSQL/MySQL
|
|
||||||
- **ACID**: Atomic transactions for consistency
|
|
||||||
- **Migration**: Schema versioning with Alembic
|
|
||||||
- **Testing**: Easy to test with in-memory database
|
|
||||||
|
|
||||||
### Database Service Layer (October 2025)
|
|
||||||
|
|
||||||
Implemented comprehensive service layer for database CRUD operations.
|
|
||||||
|
|
||||||
**File**: `src/server/database/service.py`
|
|
||||||
|
|
||||||
**Services**:
|
|
||||||
|
|
||||||
- `AnimeSeriesService`: CRUD operations for anime series
|
|
||||||
- `EpisodeService`: Episode management and download tracking
|
|
||||||
- `DownloadQueueService`: Queue management with priority and status
|
|
||||||
- `UserSessionService`: Session management and authentication
|
|
||||||
|
|
||||||
**Key Features**:
|
|
||||||
|
|
||||||
- Repository pattern for clean separation of concerns
|
|
||||||
- Type-safe operations with comprehensive type hints
|
|
||||||
- Async support for all database operations
|
|
||||||
- Transaction management via FastAPI dependency injection
|
|
||||||
- Comprehensive error handling and logging
|
|
||||||
- Search and filtering capabilities
|
|
||||||
- Pagination support for large datasets
|
|
||||||
- Batch operations for performance
|
|
||||||
|
|
||||||
**AnimeSeriesService Operations**:
|
|
||||||
|
|
||||||
- Create series with metadata and provider information
|
|
||||||
- Retrieve by ID, key, or search query
|
|
||||||
- Update series attributes
|
|
||||||
- Delete series with cascade to episodes and queue items
|
|
||||||
- List all series with pagination and eager loading options
|
|
||||||
|
|
||||||
**EpisodeService Operations**:
|
|
||||||
|
|
||||||
- Create episodes for series
|
|
||||||
- Retrieve episodes by series, season, or specific episode
|
|
||||||
- Mark episodes as downloaded with file metadata
|
|
||||||
- Delete episodes
|
|
||||||
|
|
||||||
**DownloadQueueService Operations**:
|
|
||||||
|
|
||||||
- Add items to queue with priority levels (LOW, NORMAL, HIGH)
|
|
||||||
- Retrieve pending, active, or all queue items
|
|
||||||
- Update download status (PENDING, DOWNLOADING, COMPLETED, FAILED, etc.)
|
|
||||||
- Update download progress (percentage, bytes, speed)
|
|
||||||
- Clear completed downloads
|
|
||||||
- Retry failed downloads with max retry limits
|
|
||||||
- Automatic timestamp management (started_at, completed_at)
|
|
||||||
|
|
||||||
**UserSessionService Operations**:
|
|
||||||
|
|
||||||
- Create authentication sessions with JWT tokens
|
|
||||||
- Retrieve sessions by session ID
|
|
||||||
- Get active sessions with expiry checking
|
|
||||||
- Update last activity timestamp
|
|
||||||
- Revoke sessions for logout
|
|
||||||
- Cleanup expired sessions automatically
|
|
||||||
|
|
||||||
**Testing**:
|
|
||||||
|
|
||||||
- Comprehensive test suite with 22 test cases
|
|
||||||
- In-memory SQLite for isolated testing
|
|
||||||
- All CRUD operations tested
|
|
||||||
- Edge cases and error conditions covered
|
|
||||||
- 100% test pass rate
|
|
||||||
|
|
||||||
**Integration**:
|
|
||||||
|
|
||||||
- Exported via database package `__init__.py`
|
|
||||||
- Used by API endpoints via dependency injection
|
|
||||||
- Compatible with existing database models
|
|
||||||
- Follows project coding standards (PEP 8, type hints, docstrings)
|
|
||||||
|
|
||||||
**Database Migrations** (`src/server/database/migrations.py`):
|
|
||||||
|
|
||||||
- Simple schema initialization via SQLAlchemy create_all
|
|
||||||
- Schema version checking utility
|
|
||||||
- Documentation for Alembic integration
|
|
||||||
- Production-ready migration strategy outlined
|
|
||||||
|
|
||||||
## Core Application Logic
|
|
||||||
|
|
||||||
### SeriesApp - Enhanced Core Engine
|
|
||||||
|
|
||||||
The `SeriesApp` class (`src/core/SeriesApp.py`) is the main application engine for anime series management. Enhanced with async support and web integration capabilities.
|
|
||||||
|
|
||||||
#### Key Features
|
|
||||||
|
|
||||||
- **Async Operations**: Support for async download and scan operations
|
|
||||||
- **Progress Callbacks**: Real-time progress reporting via callbacks
|
|
||||||
- **Cancellation Support**: Ability to cancel long-running operations
|
|
||||||
- **Error Handling**: Comprehensive error handling with callback notifications
|
|
||||||
- **Operation Status**: Track current operation status and history
|
|
||||||
|
|
||||||
#### Core Classes
|
|
||||||
|
|
||||||
- `SeriesApp`: Main application class
|
|
||||||
- `OperationStatus`: Enum for operation states (IDLE, RUNNING, COMPLETED, CANCELLED, FAILED)
|
|
||||||
- `ProgressInfo`: Dataclass for progress information
|
|
||||||
- `OperationResult`: Dataclass for operation results
|
|
||||||
|
|
||||||
#### Key Methods
|
|
||||||
|
|
||||||
- `search(words)`: Search for anime series
|
|
||||||
- `download()`: Download episodes with progress tracking
|
|
||||||
- `ReScan()`: Scan directory for missing episodes
|
|
||||||
- `async_download()`: Async version of download
|
|
||||||
- `async_rescan()`: Async version of rescan
|
|
||||||
- `cancel_operation()`: Cancel current operation
|
|
||||||
- `get_operation_status()`: Get current status
|
|
||||||
- `get_series_list()`: Get series with missing episodes
|
|
||||||
|
|
||||||
#### Integration Points
|
|
||||||
|
|
||||||
The SeriesApp integrates with:
|
|
||||||
|
|
||||||
- Provider system for content downloading
|
|
||||||
- Serie scanner for directory analysis
|
|
||||||
- Series list management for tracking missing episodes
|
|
||||||
- Web layer via async operations and callbacks
|
|
||||||
|
|
||||||
## Progress Callback System
|
|
||||||
|
|
||||||
### Overview
|
|
||||||
|
|
||||||
A comprehensive callback system for real-time progress reporting, error handling, and operation completion notifications across core operations (scanning, downloading, searching).
|
|
||||||
|
|
||||||
### Architecture
|
|
||||||
|
|
||||||
- **Interface-based Design**: Abstract base classes define callback contracts
|
|
||||||
- **Context Objects**: Rich context information for each callback type
|
|
||||||
- **Callback Manager**: Centralized management of multiple callbacks
|
|
||||||
- **Thread-safe**: Exception handling prevents callback errors from breaking operations
|
|
||||||
|
|
||||||
### Components
|
|
||||||
|
|
||||||
#### Callback Interfaces (`src/core/interfaces/callbacks.py`)
|
|
||||||
|
|
||||||
- `ProgressCallback`: Reports operation progress updates
|
|
||||||
- `ErrorCallback`: Handles error notifications
|
|
||||||
- `CompletionCallback`: Notifies operation completion
|
|
||||||
|
|
||||||
#### Context Classes
|
|
||||||
|
|
||||||
- `ProgressContext`: Current progress, percentage, phase, and metadata
|
|
||||||
- `ErrorContext`: Error details, recoverability, retry information
|
|
||||||
- `CompletionContext`: Success status, results, and statistics
|
|
||||||
|
|
||||||
#### Enums
|
|
||||||
|
|
||||||
- `OperationType`: SCAN, DOWNLOAD, SEARCH, INITIALIZATION
|
|
||||||
- `ProgressPhase`: STARTING, IN_PROGRESS, COMPLETING, COMPLETED, FAILED, CANCELLED
|
|
||||||
|
|
||||||
#### Callback Manager
|
|
||||||
|
|
||||||
- Register/unregister multiple callbacks per type
|
|
||||||
- Notify all registered callbacks with context
|
|
||||||
- Exception handling for callback errors
|
|
||||||
- Support for clearing all callbacks
|
|
||||||
|
|
||||||
### Integration
|
|
||||||
|
|
||||||
#### SerieScanner
|
|
||||||
|
|
||||||
- Reports scanning progress (folder by folder)
|
|
||||||
- Notifies errors for failed folder scans
|
|
||||||
- Reports completion with statistics
|
|
||||||
|
|
||||||
#### SeriesApp
|
|
||||||
|
|
||||||
- Download progress reporting with percentage
|
|
||||||
- Scan progress through SerieScanner integration
|
|
||||||
- Error notifications for all operations
|
|
||||||
- Completion notifications with results
|
|
||||||
|
|
||||||
### Usage Example
|
|
||||||
|
|
||||||
```python
|
|
||||||
from src.core.interfaces.callbacks import (
|
|
||||||
CallbackManager,
|
|
||||||
ProgressCallback,
|
|
||||||
ProgressContext
|
|
||||||
)
|
|
||||||
|
|
||||||
class MyProgressCallback(ProgressCallback):
|
|
||||||
def on_progress(self, context: ProgressContext):
|
|
||||||
print(f"{context.message}: {context.percentage:.1f}%")
|
|
||||||
|
|
||||||
# Register callback
|
|
||||||
manager = CallbackManager()
|
|
||||||
manager.register_progress_callback(MyProgressCallback())
|
|
||||||
|
|
||||||
# Use with SeriesApp
|
|
||||||
app = SeriesApp(directory, callback_manager=manager)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Recent Infrastructure Changes
|
## Recent Infrastructure Changes
|
||||||
|
|
||||||
### Progress Callback System (October 2025)
|
|
||||||
|
|
||||||
Implemented a comprehensive progress callback system for real-time operation tracking.
|
|
||||||
|
|
||||||
#### Changes Made
|
|
||||||
|
|
||||||
1. **Callback Interfaces**:
|
|
||||||
|
|
||||||
- Created abstract base classes for progress, error, and completion callbacks
|
|
||||||
- Defined rich context objects with operation metadata
|
|
||||||
- Implemented thread-safe callback manager
|
|
||||||
|
|
||||||
2. **SerieScanner Integration**:
|
|
||||||
|
|
||||||
- Added progress reporting for directory scanning
|
|
||||||
- Implemented per-folder progress updates
|
|
||||||
- Error callbacks for scan failures
|
|
||||||
- Completion notifications with statistics
|
|
||||||
|
|
||||||
3. **SeriesApp Integration**:
|
|
||||||
|
|
||||||
- Integrated callback manager into download operations
|
|
||||||
- Progress updates during episode downloads
|
|
||||||
- Error handling with callback notifications
|
|
||||||
- Completion callbacks for all operations
|
|
||||||
- Backward compatibility with legacy callbacks
|
|
||||||
|
|
||||||
4. **Testing**:
|
|
||||||
- 22 comprehensive unit tests
|
|
||||||
- Coverage for all callback types
|
|
||||||
- Exception handling verification
|
|
||||||
- Multiple callback registration tests
|
|
||||||
|
|
||||||
### Core Logic Enhancement (October 2025)
|
|
||||||
|
|
||||||
Enhanced `SeriesApp` with async callback support, progress reporting, and cancellation.
|
|
||||||
|
|
||||||
#### Changes Made
|
|
||||||
|
|
||||||
1. **Async Support**:
|
|
||||||
|
|
||||||
- Added `async_download()` and `async_rescan()` methods
|
|
||||||
- Integrated with asyncio event loop for non-blocking operations
|
|
||||||
- Support for concurrent operations in web environment
|
|
||||||
|
|
||||||
2. **Progress Reporting**:
|
|
||||||
|
|
||||||
- Legacy `ProgressInfo` dataclass for structured progress data
|
|
||||||
- New comprehensive callback system with context objects
|
|
||||||
- Percentage calculation and status tracking
|
|
||||||
|
|
||||||
3. **Cancellation System**:
|
|
||||||
|
|
||||||
- Internal cancellation flag management
|
|
||||||
- Graceful operation cancellation
|
|
||||||
- Cancellation check during long-running operations
|
|
||||||
|
|
||||||
4. **Error Handling**:
|
|
||||||
|
|
||||||
- `OperationResult` dataclass for operation outcomes
|
|
||||||
- Error callback system for notifications
|
|
||||||
- Specific exception types (IOError, OSError, RuntimeError)
|
|
||||||
- Proper exception propagation and logging
|
|
||||||
|
|
||||||
5. **Status Management**:
|
|
||||||
- `OperationStatus` enum for state tracking
|
|
||||||
- Current operation identifier
|
|
||||||
- Status getter methods for monitoring
|
|
||||||
|
|
||||||
#### Test Coverage
|
|
||||||
|
|
||||||
Comprehensive test suite (`tests/unit/test_series_app.py`) with 22 tests covering:
|
|
||||||
|
|
||||||
- Initialization and configuration
|
|
||||||
- Search functionality
|
|
||||||
- Download operations with callbacks
|
|
||||||
- Directory scanning with progress
|
|
||||||
- Async operations
|
|
||||||
- Cancellation handling
|
|
||||||
- Error scenarios
|
|
||||||
- Data model validation
|
|
||||||
|
|
||||||
### Template Integration (October 2025)
|
### Template Integration (October 2025)
|
||||||
|
|
||||||
Completed integration of HTML templates with FastAPI Jinja2 system.
|
Completed integration of HTML templates with FastAPI Jinja2 system.
|
||||||
@ -974,108 +290,6 @@ All templates include:
|
|||||||
- Theme switching support
|
- Theme switching support
|
||||||
- Responsive viewport configuration
|
- Responsive viewport configuration
|
||||||
|
|
||||||
### CSS Integration (October 2025)
|
|
||||||
|
|
||||||
Integrated existing CSS styling with FastAPI's static file serving system.
|
|
||||||
|
|
||||||
#### Implementation Details
|
|
||||||
|
|
||||||
1. **Static File Configuration**:
|
|
||||||
|
|
||||||
- Static files mounted at `/static` in `fastapi_app.py`
|
|
||||||
- Directory: `src/server/web/static/`
|
|
||||||
- Files served using FastAPI's `StaticFiles` middleware
|
|
||||||
- All paths use absolute references (`/static/...`)
|
|
||||||
|
|
||||||
2. **CSS Architecture**:
|
|
||||||
|
|
||||||
- `styles.css` (1,840 lines) - Main stylesheet with Fluent UI design system
|
|
||||||
- `ux_features.css` (203 lines) - Enhanced UX features and accessibility
|
|
||||||
|
|
||||||
3. **Design System** (`styles.css`):
|
|
||||||
|
|
||||||
- **Fluent UI Variables**: CSS custom properties for consistent theming
|
|
||||||
- **Light/Dark Themes**: Dynamic theme switching via `[data-theme="dark"]`
|
|
||||||
- **Typography**: Segoe UI font stack with responsive sizing
|
|
||||||
- **Spacing System**: Consistent spacing scale (xs through xxl)
|
|
||||||
- **Color Palette**: Comprehensive color system for both themes
|
|
||||||
- **Border Radius**: Standardized corner radii (sm, md, lg, xl)
|
|
||||||
- **Shadows**: Elevation system with card and elevated variants
|
|
||||||
- **Transitions**: Smooth animations with consistent timing
|
|
||||||
|
|
||||||
4. **UX Features** (`ux_features.css`):
|
|
||||||
- Drag-and-drop indicators
|
|
||||||
- Bulk selection styling
|
|
||||||
- Keyboard focus indicators
|
|
||||||
- Touch gesture feedback
|
|
||||||
- Mobile responsive utilities
|
|
||||||
- High contrast mode support (`@media (prefers-contrast: high)`)
|
|
||||||
- Screen reader utilities (`.sr-only`)
|
|
||||||
- Window control components
|
|
||||||
|
|
||||||
#### CSS Variables
|
|
||||||
|
|
||||||
**Color System**:
|
|
||||||
|
|
||||||
```css
|
|
||||||
/* Light Theme */
|
|
||||||
--color-bg-primary: #ffffff
|
|
||||||
--color-accent: #0078d4
|
|
||||||
--color-text-primary: #323130
|
|
||||||
|
|
||||||
/* Dark Theme */
|
|
||||||
--color-bg-primary-dark: #202020
|
|
||||||
--color-accent-dark: #60cdff
|
|
||||||
--color-text-primary-dark: #ffffff
|
|
||||||
```
|
|
||||||
|
|
||||||
**Spacing & Typography**:
|
|
||||||
|
|
||||||
```css
|
|
||||||
--spacing-sm: 8px
|
|
||||||
--spacing-md: 12px
|
|
||||||
--spacing-lg: 16px
|
|
||||||
--font-size-body: 14px
|
|
||||||
--font-size-title: 20px
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Template CSS References
|
|
||||||
|
|
||||||
All HTML templates correctly reference CSS files:
|
|
||||||
|
|
||||||
- Index page: Includes both `styles.css` and `ux_features.css`
|
|
||||||
- Other pages: Include `styles.css`
|
|
||||||
- All use absolute paths: `/static/css/styles.css`
|
|
||||||
|
|
||||||
#### Responsive Design
|
|
||||||
|
|
||||||
- Mobile-first approach with breakpoints
|
|
||||||
- Media queries for tablet and desktop layouts
|
|
||||||
- Touch-friendly interface elements
|
|
||||||
- Adaptive typography and spacing
|
|
||||||
|
|
||||||
#### Accessibility Features
|
|
||||||
|
|
||||||
- WCAG-compliant color contrast
|
|
||||||
- High contrast mode support
|
|
||||||
- Screen reader utilities
|
|
||||||
- Keyboard navigation styling
|
|
||||||
- Focus indicators
|
|
||||||
- Reduced motion support
|
|
||||||
|
|
||||||
#### Testing
|
|
||||||
|
|
||||||
Comprehensive test suite in `tests/unit/test_static_files.py`:
|
|
||||||
|
|
||||||
- CSS file accessibility tests
|
|
||||||
- Theme support verification
|
|
||||||
- Responsive design validation
|
|
||||||
- Accessibility feature checks
|
|
||||||
- Content integrity validation
|
|
||||||
- Path correctness verification
|
|
||||||
|
|
||||||
All 17 CSS integration tests passing.
|
|
||||||
|
|
||||||
### Route Controller Refactoring (October 2025)
|
### Route Controller Refactoring (October 2025)
|
||||||
|
|
||||||
Restructured the FastAPI application to use a controller-based architecture for better code organization and maintainability.
|
Restructured the FastAPI application to use a controller-based architecture for better code organization and maintainability.
|
||||||
@ -1844,94 +1058,6 @@ Comprehensive integration tests verify WebSocket broadcasting:
|
|||||||
- Connection count and room membership tracking
|
- Connection count and room membership tracking
|
||||||
- Error tracking for failed broadcasts
|
- Error tracking for failed broadcasts
|
||||||
|
|
||||||
### Frontend Authentication Integration (October 2025)
|
|
||||||
|
|
||||||
Completed JWT-based authentication integration between frontend and backend.
|
|
||||||
|
|
||||||
#### Authentication Token Storage
|
|
||||||
|
|
||||||
**Files Modified:**
|
|
||||||
|
|
||||||
- `src/server/web/templates/login.html` - Store JWT token after successful login
|
|
||||||
- `src/server/web/templates/setup.html` - Redirect to login after setup completion
|
|
||||||
- `src/server/web/static/js/app.js` - Include Bearer token in all authenticated requests
|
|
||||||
- `src/server/web/static/js/queue.js` - Include Bearer token in queue API calls
|
|
||||||
|
|
||||||
**Implementation:**
|
|
||||||
|
|
||||||
- JWT tokens stored in `localStorage` after successful login
|
|
||||||
- Token expiry stored in `localStorage` for client-side validation
|
|
||||||
- `Authorization: Bearer <token>` header included in all authenticated requests
|
|
||||||
- Automatic redirect to `/login` on 401 Unauthorized responses
|
|
||||||
- Token cleared from `localStorage` on logout
|
|
||||||
|
|
||||||
**Key Functions Updated:**
|
|
||||||
|
|
||||||
- `makeAuthenticatedRequest()` in both `app.js` and `queue.js`
|
|
||||||
- `checkAuthentication()` to verify token and redirect if missing/invalid
|
|
||||||
- `logout()` to clear token and redirect to login
|
|
||||||
|
|
||||||
### Frontend API Endpoint Updates (October 2025)
|
|
||||||
|
|
||||||
Updated frontend JavaScript to match new backend API structure.
|
|
||||||
|
|
||||||
**Queue Management API Changes:**
|
|
||||||
|
|
||||||
- `/api/queue/clear` → `/api/queue/completed` for clearing completed downloads
|
|
||||||
- `/api/queue/remove` → `/api/queue/{item_id}` (DELETE) for single item removal
|
|
||||||
- `/api/queue/retry` payload changed to `{item_ids: []}` array format
|
|
||||||
- `/api/download/pause` → `/api/queue/pause`
|
|
||||||
- `/api/download/resume` → `/api/queue/resume`
|
|
||||||
- `/api/download/cancel` → `/api/queue/stop`
|
|
||||||
|
|
||||||
**Response Format Changes:**
|
|
||||||
|
|
||||||
- Login returns `{access_token, token_type, expires_at}` instead of `{status: 'success'}`
|
|
||||||
- Setup returns `{status: 'ok'}` instead of `{status: 'success', redirect_url}`
|
|
||||||
- Logout returns `{status: 'ok'}` instead of `{status: 'success'}`
|
|
||||||
- Queue operations return structured responses with counts (e.g., `{cleared_count, retried_count}`)
|
|
||||||
|
|
||||||
### Frontend WebSocket Integration (October 2025)
|
|
||||||
|
|
||||||
WebSocket integration previously completed and verified functional.
|
|
||||||
|
|
||||||
#### Native WebSocket Implementation
|
|
||||||
|
|
||||||
**Files:**
|
|
||||||
|
|
||||||
- `src/server/web/static/js/websocket_client.js` - Native WebSocket wrapper
|
|
||||||
- Templates already updated to use `websocket_client.js` instead of Socket.IO
|
|
||||||
|
|
||||||
**Event Compatibility:**
|
|
||||||
|
|
||||||
- Dual event handlers in place for backward compatibility
|
|
||||||
- Old events: `scan_completed`, `scan_error`, `download_completed`, `download_error`
|
|
||||||
- New events: `scan_complete`, `scan_failed`, `download_complete`, `download_failed`
|
|
||||||
- Both event types supported simultaneously
|
|
||||||
|
|
||||||
**Room Subscriptions:**
|
|
||||||
|
|
||||||
- `downloads` - Download completion, failures, queue status
|
|
||||||
- `download_progress` - Real-time download progress updates
|
|
||||||
- `scan_progress` - Library scan progress updates
|
|
||||||
|
|
||||||
### Frontend Integration Testing (October 2025)
|
|
||||||
|
|
||||||
Created smoke tests to verify frontend-backend integration.
|
|
||||||
|
|
||||||
**Test File:** `tests/integration/test_frontend_integration_smoke.py`
|
|
||||||
|
|
||||||
**Tests:**
|
|
||||||
|
|
||||||
- JWT token format verification (access_token, token_type, expires_at)
|
|
||||||
- Bearer token authentication on protected endpoints
|
|
||||||
- 401 responses for requests without valid tokens
|
|
||||||
|
|
||||||
**Test Results:**
|
|
||||||
|
|
||||||
- Basic authentication flow: ✅ PASSING
|
|
||||||
- Token validation: Functional with rate limiting considerations
|
|
||||||
|
|
||||||
### Frontend Integration (October 2025)
|
### Frontend Integration (October 2025)
|
||||||
|
|
||||||
Completed integration of existing frontend JavaScript with the new FastAPI backend and native WebSocket implementation.
|
Completed integration of existing frontend JavaScript with the new FastAPI backend and native WebSocket implementation.
|
||||||
|
|||||||
136
instructions.md
136
instructions.md
@ -15,17 +15,6 @@ The goal is to create a FastAPI-based web application that provides a modern int
|
|||||||
- **Type Hints**: Use comprehensive type annotations
|
- **Type Hints**: Use comprehensive type annotations
|
||||||
- **Error Handling**: Proper exception handling and logging
|
- **Error Handling**: Proper exception handling and logging
|
||||||
|
|
||||||
## Additional Implementation Guidelines
|
|
||||||
|
|
||||||
### Code Style and Standards
|
|
||||||
|
|
||||||
- **Type Hints**: Use comprehensive type annotations throughout all modules
|
|
||||||
- **Docstrings**: Follow PEP 257 for function and class documentation
|
|
||||||
- **Error Handling**: Implement custom exception classes with meaningful messages
|
|
||||||
- **Logging**: Use structured logging with appropriate log levels
|
|
||||||
- **Security**: Validate all inputs and sanitize outputs
|
|
||||||
- **Performance**: Use async/await patterns for I/O operations
|
|
||||||
|
|
||||||
## Implementation Order
|
## Implementation Order
|
||||||
|
|
||||||
The tasks should be completed in the following order to ensure proper dependencies and logical progression:
|
The tasks should be completed in the following order to ensure proper dependencies and logical progression:
|
||||||
@ -43,38 +32,80 @@ The tasks should be completed in the following order to ensure proper dependenci
|
|||||||
11. **Deployment and Configuration** - Production setup
|
11. **Deployment and Configuration** - Production setup
|
||||||
12. **Documentation and Error Handling** - Final documentation and error handling
|
12. **Documentation and Error Handling** - Final documentation and error handling
|
||||||
|
|
||||||
## Final Implementation Notes
|
# make the following steps for each task or subtask. make sure you do not miss one
|
||||||
|
|
||||||
1. **Incremental Development**: Implement features incrementally, testing each component thoroughly before moving to the next
|
1. Task the next task
|
||||||
2. **Code Review**: Review all generated code for adherence to project standards
|
2. Process the task
|
||||||
3. **Documentation**: Document all public APIs and complex logic
|
3. Make Tests.
|
||||||
4. **Testing**: Maintain test coverage above 80% for all new code
|
4. Remove task from instructions.md.
|
||||||
5. **Performance**: Profile and optimize critical paths, especially download and streaming operations
|
5. Update infrastructure.md, but only add text that belongs to a infrastructure doc. make sure to summarize text or delete text that do not belog to infrastructure.md. Keep it clear and short.
|
||||||
6. **Security**: Regular security audits and dependency updates
|
6. Commit in git
|
||||||
7. **Monitoring**: Implement comprehensive monitoring and alerting
|
|
||||||
8. **Maintenance**: Plan for regular maintenance and updates
|
|
||||||
|
|
||||||
## Task Completion Checklist
|
|
||||||
|
|
||||||
For each task completed:
|
|
||||||
|
|
||||||
- [ ] Implementation follows coding standards
|
|
||||||
- [ ] Unit tests written and passing
|
|
||||||
- [ ] Integration tests passing
|
|
||||||
- [ ] Documentation updated
|
|
||||||
- [ ] Error handling implemented
|
|
||||||
- [ ] Logging added
|
|
||||||
- [ ] Security considerations addressed
|
|
||||||
- [ ] Performance validated
|
|
||||||
- [ ] Code reviewed
|
|
||||||
- [ ] Task marked as complete in instructions.md
|
|
||||||
- [ ] Infrastructure.md updated
|
|
||||||
- [ ] Changes committed to git
|
|
||||||
|
|
||||||
This comprehensive guide ensures a robust, maintainable, and scalable anime download management system with modern web capabilities.
|
|
||||||
|
|
||||||
## Core Tasks
|
## Core Tasks
|
||||||
|
|
||||||
|
### 7. Frontend Integration
|
||||||
|
|
||||||
|
#### [] Integrate existing CSS styling
|
||||||
|
|
||||||
|
- []Review and integrate existing CSS files in `src/server/web/static/css/`
|
||||||
|
- []Ensure styling works with FastAPI static file serving
|
||||||
|
- []Maintain existing responsive design and theme support
|
||||||
|
- []Update any hardcoded paths if necessary
|
||||||
|
|
||||||
|
#### [] Update frontend-backend integration
|
||||||
|
|
||||||
|
- []Ensure existing JavaScript calls match new API endpoints
|
||||||
|
- []Update authentication flow to work with new auth system
|
||||||
|
- []Verify WebSocket events match new service implementations
|
||||||
|
- []Test all existing UI functionality with new backend
|
||||||
|
|
||||||
|
### 8. Core Logic Integration
|
||||||
|
|
||||||
|
#### [] Enhance SeriesApp for web integration
|
||||||
|
|
||||||
|
- []Update `src/core/SeriesApp.py`
|
||||||
|
- []Add async callback support
|
||||||
|
- []Implement progress reporting
|
||||||
|
- []Include better error handling
|
||||||
|
- []Add cancellation support
|
||||||
|
|
||||||
|
#### [] Create progress callback system
|
||||||
|
|
||||||
|
- []Add progress callback interface
|
||||||
|
- []Implement scan progress reporting
|
||||||
|
- []Add download progress tracking
|
||||||
|
- []Include error/completion callbacks
|
||||||
|
|
||||||
|
#### [] Add configuration persistence
|
||||||
|
|
||||||
|
- []Implement configuration file management
|
||||||
|
- []Add settings validation
|
||||||
|
- []Include backup/restore functionality
|
||||||
|
- []Add migration support for config updates
|
||||||
|
|
||||||
|
### 9. Database Layer
|
||||||
|
|
||||||
|
#### [] Implement database models
|
||||||
|
|
||||||
|
- []Create `src/server/database/models.py`
|
||||||
|
- []Add SQLAlchemy models for anime series
|
||||||
|
- []Implement download queue persistence
|
||||||
|
- []Include user session storage
|
||||||
|
|
||||||
|
#### [] Create database service
|
||||||
|
|
||||||
|
- []Create `src/server/database/service.py`
|
||||||
|
- []Add CRUD operations for anime data
|
||||||
|
- []Implement queue persistence
|
||||||
|
- []Include database migration support
|
||||||
|
|
||||||
|
#### [] Add database initialization
|
||||||
|
|
||||||
|
- []Create `src/server/database/init.py`
|
||||||
|
- []Implement database setup
|
||||||
|
- []Add initial data migration
|
||||||
|
- []Include schema validation
|
||||||
|
|
||||||
### 10. Testing
|
### 10. Testing
|
||||||
|
|
||||||
#### [] Create unit tests for services
|
#### [] Create unit tests for services
|
||||||
@ -195,6 +226,17 @@ When working with these files:
|
|||||||
|
|
||||||
Each task should be implemented with proper error handling, logging, and type hints according to the project's coding standards.
|
Each task should be implemented with proper error handling, logging, and type hints according to the project's coding standards.
|
||||||
|
|
||||||
|
## Additional Implementation Guidelines
|
||||||
|
|
||||||
|
### Code Style and Standards
|
||||||
|
|
||||||
|
- **Type Hints**: Use comprehensive type annotations throughout all modules
|
||||||
|
- **Docstrings**: Follow PEP 257 for function and class documentation
|
||||||
|
- **Error Handling**: Implement custom exception classes with meaningful messages
|
||||||
|
- **Logging**: Use structured logging with appropriate log levels
|
||||||
|
- **Security**: Validate all inputs and sanitize outputs
|
||||||
|
- **Performance**: Use async/await patterns for I/O operations
|
||||||
|
|
||||||
### Monitoring and Health Checks
|
### Monitoring and Health Checks
|
||||||
|
|
||||||
#### [] Implement health check endpoints
|
#### [] Implement health check endpoints
|
||||||
@ -379,6 +421,22 @@ Each task should be implemented with proper error handling, logging, and type hi
|
|||||||
|
|
||||||
### Deployment Strategies
|
### Deployment Strategies
|
||||||
|
|
||||||
|
#### [] Container orchestration
|
||||||
|
|
||||||
|
- []Create `kubernetes/` directory
|
||||||
|
- []Add Kubernetes deployment manifests
|
||||||
|
- []Implement service discovery
|
||||||
|
- []Include load balancing configuration
|
||||||
|
- []Add auto-scaling policies
|
||||||
|
|
||||||
|
#### [] CI/CD pipeline
|
||||||
|
|
||||||
|
- []Create `.github/workflows/`
|
||||||
|
- []Add automated testing pipeline
|
||||||
|
- []Implement deployment automation
|
||||||
|
- []Include security scanning
|
||||||
|
- []Add performance benchmarking
|
||||||
|
|
||||||
#### [] Environment management
|
#### [] Environment management
|
||||||
|
|
||||||
- []Create environment-specific configurations
|
- []Create environment-specific configurations
|
||||||
|
|||||||
@ -12,6 +12,3 @@ structlog==24.1.0
|
|||||||
pytest==7.4.3
|
pytest==7.4.3
|
||||||
pytest-asyncio==0.21.1
|
pytest-asyncio==0.21.1
|
||||||
httpx==0.25.2
|
httpx==0.25.2
|
||||||
sqlalchemy>=2.0.35
|
|
||||||
alembic==1.13.0
|
|
||||||
aiosqlite>=0.19.0
|
|
||||||
@ -1,257 +1,59 @@
|
|||||||
"""
|
|
||||||
SerieScanner - Scans directories for anime series and missing episodes.
|
|
||||||
|
|
||||||
This module provides functionality to scan anime directories, identify
|
|
||||||
missing episodes, and report progress through callback interfaces.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
import logging
|
||||||
|
from .entities.series import Serie
|
||||||
import traceback
|
import traceback
|
||||||
import uuid
|
from ..infrastructure.logging.GlobalLogger import error_logger, noKeyFound_logger
|
||||||
from typing import Callable, Optional
|
from .exceptions.Exceptions import NoKeyFoundException, MatchNotFoundError
|
||||||
|
from .providers.base_provider import Loader
|
||||||
from src.core.entities.series import Serie
|
|
||||||
from src.core.exceptions.Exceptions import MatchNotFoundError, NoKeyFoundException
|
|
||||||
from src.core.interfaces.callbacks import (
|
|
||||||
CallbackManager,
|
|
||||||
CompletionContext,
|
|
||||||
ErrorContext,
|
|
||||||
OperationType,
|
|
||||||
ProgressContext,
|
|
||||||
ProgressPhase,
|
|
||||||
)
|
|
||||||
from src.core.providers.base_provider import Loader
|
|
||||||
from src.infrastructure.logging.GlobalLogger import error_logger, noKeyFound_logger
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class SerieScanner:
|
class SerieScanner:
|
||||||
"""
|
def __init__(self, basePath: str, loader: Loader):
|
||||||
Scans directories for anime series and identifies missing episodes.
|
|
||||||
|
|
||||||
Supports progress callbacks for real-time scanning updates.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
basePath: str,
|
|
||||||
loader: Loader,
|
|
||||||
callback_manager: Optional[CallbackManager] = None
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Initialize the SerieScanner.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
basePath: Base directory containing anime series
|
|
||||||
loader: Loader instance for fetching series information
|
|
||||||
callback_manager: Optional callback manager for progress updates
|
|
||||||
"""
|
|
||||||
self.directory = basePath
|
self.directory = basePath
|
||||||
self.folderDict: dict[str, Serie] = {}
|
self.folderDict: dict[str, Serie] = {} # Proper initialization
|
||||||
self.loader = loader
|
self.loader = loader
|
||||||
self._callback_manager = callback_manager or CallbackManager()
|
logging.info(f"Initialized Loader with base path: {self.directory}")
|
||||||
self._current_operation_id: Optional[str] = None
|
|
||||||
|
|
||||||
logger.info("Initialized SerieScanner with base path: %s", basePath)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def callback_manager(self) -> CallbackManager:
|
|
||||||
"""Get the callback manager instance."""
|
|
||||||
return self._callback_manager
|
|
||||||
|
|
||||||
def Reinit(self):
|
def Reinit(self):
|
||||||
"""Reinitialize the folder dictionary."""
|
self.folderDict: dict[str, Serie] = {} # Proper initialization
|
||||||
self.folderDict: dict[str, Serie] = {}
|
|
||||||
|
|
||||||
def is_null_or_whitespace(self, s):
|
def is_null_or_whitespace(self, s):
|
||||||
"""Check if a string is None or whitespace."""
|
|
||||||
return s is None or s.strip() == ""
|
return s is None or s.strip() == ""
|
||||||
|
|
||||||
def GetTotalToScan(self):
|
def GetTotalToScan(self):
|
||||||
"""Get the total number of folders to scan."""
|
|
||||||
result = self.__find_mp4_files()
|
result = self.__find_mp4_files()
|
||||||
return sum(1 for _ in result)
|
return sum(1 for _ in result)
|
||||||
|
|
||||||
def Scan(self, callback: Optional[Callable[[str, int], None]] = None):
|
def Scan(self, callback):
|
||||||
"""
|
logging.info("Starting process to load missing episodes")
|
||||||
Scan directories for anime series and missing episodes.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
callback: Optional legacy callback function (folder, count)
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
Exception: If scan fails critically
|
|
||||||
"""
|
|
||||||
# Generate unique operation ID
|
|
||||||
self._current_operation_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
logger.info("Starting scan for missing episodes")
|
|
||||||
|
|
||||||
# Notify scan starting
|
|
||||||
self._callback_manager.notify_progress(
|
|
||||||
ProgressContext(
|
|
||||||
operation_type=OperationType.SCAN,
|
|
||||||
operation_id=self._current_operation_id,
|
|
||||||
phase=ProgressPhase.STARTING,
|
|
||||||
current=0,
|
|
||||||
total=0,
|
|
||||||
percentage=0.0,
|
|
||||||
message="Initializing scan"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Get total items to process
|
|
||||||
total_to_scan = self.GetTotalToScan()
|
|
||||||
logger.info("Total folders to scan: %d", total_to_scan)
|
|
||||||
|
|
||||||
result = self.__find_mp4_files()
|
result = self.__find_mp4_files()
|
||||||
counter = 0
|
counter = 0
|
||||||
|
|
||||||
for folder, mp4_files in result:
|
for folder, mp4_files in result:
|
||||||
try:
|
try:
|
||||||
counter += 1
|
counter += 1
|
||||||
|
|
||||||
# Calculate progress
|
|
||||||
percentage = (
|
|
||||||
(counter / total_to_scan * 100)
|
|
||||||
if total_to_scan > 0 else 0
|
|
||||||
)
|
|
||||||
|
|
||||||
# Notify progress
|
|
||||||
self._callback_manager.notify_progress(
|
|
||||||
ProgressContext(
|
|
||||||
operation_type=OperationType.SCAN,
|
|
||||||
operation_id=self._current_operation_id,
|
|
||||||
phase=ProgressPhase.IN_PROGRESS,
|
|
||||||
current=counter,
|
|
||||||
total=total_to_scan,
|
|
||||||
percentage=percentage,
|
|
||||||
message=f"Scanning: {folder}",
|
|
||||||
details=f"Found {len(mp4_files)} episodes"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Call legacy callback if provided
|
|
||||||
if callback:
|
|
||||||
callback(folder, counter)
|
callback(folder, counter)
|
||||||
|
|
||||||
serie = self.__ReadDataFromFile(folder)
|
serie = self.__ReadDataFromFile(folder)
|
||||||
if (
|
if (serie != None and not self.is_null_or_whitespace(serie.key)):
|
||||||
serie is not None
|
missings, site = self.__GetMissingEpisodesAndSeason(serie.key, mp4_files)
|
||||||
and not self.is_null_or_whitespace(serie.key)
|
|
||||||
):
|
|
||||||
missings, site = self.__GetMissingEpisodesAndSeason(
|
|
||||||
serie.key, mp4_files
|
|
||||||
)
|
|
||||||
serie.episodeDict = missings
|
serie.episodeDict = missings
|
||||||
serie.folder = folder
|
serie.folder = folder
|
||||||
data_path = os.path.join(
|
serie.save_to_file(os.path.join(os.path.join(self.directory, folder), 'data'))
|
||||||
self.directory, folder, 'data'
|
if (serie.key in self.folderDict):
|
||||||
)
|
logging.ERROR(f"dublication found: {serie.key}");
|
||||||
serie.save_to_file(data_path)
|
pass
|
||||||
|
|
||||||
if serie.key in self.folderDict:
|
|
||||||
logger.error(
|
|
||||||
"Duplication found: %s", serie.key
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.folderDict[serie.key] = serie
|
self.folderDict[serie.key] = serie
|
||||||
noKeyFound_logger.info(
|
noKeyFound_logger.info(f"Saved Serie: '{str(serie)}'")
|
||||||
"Saved Serie: '%s'", str(serie)
|
|
||||||
)
|
|
||||||
|
|
||||||
except NoKeyFoundException as nkfe:
|
except NoKeyFoundException as nkfe:
|
||||||
# Log error and notify via callback
|
NoKeyFoundException.error(f"Error processing folder '{folder}': {nkfe}")
|
||||||
error_msg = f"Error processing folder '{folder}': {nkfe}"
|
|
||||||
NoKeyFoundException.error(error_msg)
|
|
||||||
|
|
||||||
self._callback_manager.notify_error(
|
|
||||||
ErrorContext(
|
|
||||||
operation_type=OperationType.SCAN,
|
|
||||||
operation_id=self._current_operation_id,
|
|
||||||
error=nkfe,
|
|
||||||
message=error_msg,
|
|
||||||
recoverable=True,
|
|
||||||
metadata={"folder": folder}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Log error and notify via callback
|
error_logger.error(f"Folder: '{folder}' - Unexpected error processing folder '{folder}': {e} \n {traceback.format_exc()}")
|
||||||
error_msg = (
|
|
||||||
f"Folder: '{folder}' - "
|
|
||||||
f"Unexpected error: {e}"
|
|
||||||
)
|
|
||||||
error_logger.error(
|
|
||||||
"%s\n%s",
|
|
||||||
error_msg,
|
|
||||||
traceback.format_exc()
|
|
||||||
)
|
|
||||||
|
|
||||||
self._callback_manager.notify_error(
|
|
||||||
ErrorContext(
|
|
||||||
operation_type=OperationType.SCAN,
|
|
||||||
operation_id=self._current_operation_id,
|
|
||||||
error=e,
|
|
||||||
message=error_msg,
|
|
||||||
recoverable=True,
|
|
||||||
metadata={"folder": folder}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Notify scan completion
|
|
||||||
self._callback_manager.notify_completion(
|
|
||||||
CompletionContext(
|
|
||||||
operation_type=OperationType.SCAN,
|
|
||||||
operation_id=self._current_operation_id,
|
|
||||||
success=True,
|
|
||||||
message=f"Scan completed. Processed {counter} folders.",
|
|
||||||
statistics={
|
|
||||||
"total_folders": counter,
|
|
||||||
"series_found": len(self.folderDict)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"Scan completed. Processed %d folders, found %d series",
|
|
||||||
counter,
|
|
||||||
len(self.folderDict)
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
# Critical error - notify and re-raise
|
|
||||||
error_msg = f"Critical scan error: {e}"
|
|
||||||
logger.error("%s\n%s", error_msg, traceback.format_exc())
|
|
||||||
|
|
||||||
self._callback_manager.notify_error(
|
|
||||||
ErrorContext(
|
|
||||||
operation_type=OperationType.SCAN,
|
|
||||||
operation_id=self._current_operation_id,
|
|
||||||
error=e,
|
|
||||||
message=error_msg,
|
|
||||||
recoverable=False
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
self._callback_manager.notify_completion(
|
|
||||||
CompletionContext(
|
|
||||||
operation_type=OperationType.SCAN,
|
|
||||||
operation_id=self._current_operation_id,
|
|
||||||
success=False,
|
|
||||||
message=error_msg
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
raise
|
|
||||||
|
|
||||||
def __find_mp4_files(self):
|
def __find_mp4_files(self):
|
||||||
"""Find all .mp4 files in the directory structure."""
|
logging.info("Scanning for .mp4 files")
|
||||||
logger.info("Scanning for .mp4 files")
|
|
||||||
for anime_name in os.listdir(self.directory):
|
for anime_name in os.listdir(self.directory):
|
||||||
anime_path = os.path.join(self.directory, anime_name)
|
anime_path = os.path.join(self.directory, anime_name)
|
||||||
if os.path.isdir(anime_path):
|
if os.path.isdir(anime_path):
|
||||||
@ -265,68 +67,43 @@ class SerieScanner:
|
|||||||
yield anime_name, mp4_files if has_files else []
|
yield anime_name, mp4_files if has_files else []
|
||||||
|
|
||||||
def __remove_year(self, input_string: str):
|
def __remove_year(self, input_string: str):
|
||||||
"""Remove year information from input string."""
|
|
||||||
cleaned_string = re.sub(r'\(\d{4}\)', '', input_string).strip()
|
cleaned_string = re.sub(r'\(\d{4}\)', '', input_string).strip()
|
||||||
logger.debug(
|
logging.debug(f"Removed year from '{input_string}' -> '{cleaned_string}'")
|
||||||
"Removed year from '%s' -> '%s'",
|
|
||||||
input_string,
|
|
||||||
cleaned_string
|
|
||||||
)
|
|
||||||
return cleaned_string
|
return cleaned_string
|
||||||
|
|
||||||
def __ReadDataFromFile(self, folder_name: str):
|
def __ReadDataFromFile(self, folder_name: str):
|
||||||
"""Read serie data from file or key file."""
|
|
||||||
folder_path = os.path.join(self.directory, folder_name)
|
folder_path = os.path.join(self.directory, folder_name)
|
||||||
key = None
|
key = None
|
||||||
key_file = os.path.join(folder_path, 'key')
|
key_file = os.path.join(folder_path, 'key')
|
||||||
serie_file = os.path.join(folder_path, 'data')
|
serie_file = os.path.join(folder_path, 'data')
|
||||||
|
|
||||||
if os.path.exists(key_file):
|
if os.path.exists(key_file):
|
||||||
with open(key_file, 'r', encoding='utf-8') as file:
|
with open(key_file, 'r') as file:
|
||||||
key = file.read().strip()
|
key = file.read().strip()
|
||||||
logger.info(
|
logging.info(f"Key found for folder '{folder_name}': {key}")
|
||||||
"Key found for folder '%s': %s",
|
|
||||||
folder_name,
|
|
||||||
key
|
|
||||||
)
|
|
||||||
return Serie(key, "", "aniworld.to", folder_name, dict())
|
return Serie(key, "", "aniworld.to", folder_name, dict())
|
||||||
|
|
||||||
if os.path.exists(serie_file):
|
if os.path.exists(serie_file):
|
||||||
with open(serie_file, "rb") as file:
|
with open(serie_file, "rb") as file:
|
||||||
logger.info(
|
logging.info(f"load serie_file from '{folder_name}': {serie_file}")
|
||||||
"load serie_file from '%s': %s",
|
|
||||||
folder_name,
|
|
||||||
serie_file
|
|
||||||
)
|
|
||||||
return Serie.load_from_file(serie_file)
|
return Serie.load_from_file(serie_file)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def __GetEpisodeAndSeason(self, filename: str):
|
def __GetEpisodeAndSeason(self, filename: str):
|
||||||
"""Extract season and episode numbers from filename."""
|
|
||||||
pattern = r'S(\d+)E(\d+)'
|
pattern = r'S(\d+)E(\d+)'
|
||||||
match = re.search(pattern, filename)
|
match = re.search(pattern, filename)
|
||||||
if match:
|
if match:
|
||||||
season = match.group(1)
|
season = match.group(1)
|
||||||
episode = match.group(2)
|
episode = match.group(2)
|
||||||
logger.debug(
|
logging.debug(f"Extracted season {season}, episode {episode} from '{filename}'")
|
||||||
"Extracted season %s, episode %s from '%s'",
|
|
||||||
season,
|
|
||||||
episode,
|
|
||||||
filename
|
|
||||||
)
|
|
||||||
return int(season), int(episode)
|
return int(season), int(episode)
|
||||||
else:
|
else:
|
||||||
logger.error(
|
logging.error(f"Failed to find season/episode pattern in '{filename}'")
|
||||||
"Failed to find season/episode pattern in '%s'",
|
raise MatchNotFoundError("Season and episode pattern not found in the filename.")
|
||||||
filename
|
|
||||||
)
|
|
||||||
raise MatchNotFoundError(
|
|
||||||
"Season and episode pattern not found in the filename."
|
|
||||||
)
|
|
||||||
|
|
||||||
def __GetEpisodesAndSeasons(self, mp4_files: list):
|
def __GetEpisodesAndSeasons(self, mp4_files: []):
|
||||||
"""Get episodes grouped by season from mp4 files."""
|
|
||||||
episodes_dict = {}
|
episodes_dict = {}
|
||||||
|
|
||||||
for file in mp4_files:
|
for file in mp4_files:
|
||||||
@ -338,19 +115,13 @@ class SerieScanner:
|
|||||||
episodes_dict[season] = [episode]
|
episodes_dict[season] = [episode]
|
||||||
return episodes_dict
|
return episodes_dict
|
||||||
|
|
||||||
def __GetMissingEpisodesAndSeason(self, key: str, mp4_files: list):
|
def __GetMissingEpisodesAndSeason(self, key: str, mp4_files: []):
|
||||||
"""Get missing episodes for a serie."""
|
expected_dict = self.loader.get_season_episode_count(key) # key season , value count of episodes
|
||||||
# key season , value count of episodes
|
|
||||||
expected_dict = self.loader.get_season_episode_count(key)
|
|
||||||
filedict = self.__GetEpisodesAndSeasons(mp4_files)
|
filedict = self.__GetEpisodesAndSeasons(mp4_files)
|
||||||
episodes_dict = {}
|
episodes_dict = {}
|
||||||
for season, expected_count in expected_dict.items():
|
for season, expected_count in expected_dict.items():
|
||||||
existing_episodes = filedict.get(season, [])
|
existing_episodes = filedict.get(season, [])
|
||||||
missing_episodes = [
|
missing_episodes = [ep for ep in range(1, expected_count + 1) if ep not in existing_episodes and self.loader.IsLanguage(season, ep, key)]
|
||||||
ep for ep in range(1, expected_count + 1)
|
|
||||||
if ep not in existing_episodes
|
|
||||||
and self.loader.IsLanguage(season, ep, key)
|
|
||||||
]
|
|
||||||
|
|
||||||
if missing_episodes:
|
if missing_episodes:
|
||||||
episodes_dict[season] = missing_episodes
|
episodes_dict[season] = missing_episodes
|
||||||
|
|||||||
@ -1,589 +1,38 @@
|
|||||||
"""
|
|
||||||
SeriesApp - Core application logic for anime series management.
|
|
||||||
|
|
||||||
This module provides the main application interface for searching,
|
|
||||||
downloading, and managing anime series with support for async callbacks,
|
|
||||||
progress reporting, error handling, and operation cancellation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
import uuid
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
|
||||||
|
|
||||||
from src.core.entities.SerieList import SerieList
|
from src.core.entities.SerieList import SerieList
|
||||||
from src.core.interfaces.callbacks import (
|
|
||||||
CallbackManager,
|
|
||||||
CompletionContext,
|
|
||||||
ErrorContext,
|
|
||||||
OperationType,
|
|
||||||
ProgressContext,
|
|
||||||
ProgressPhase,
|
|
||||||
)
|
|
||||||
from src.core.providers.provider_factory import Loaders
|
from src.core.providers.provider_factory import Loaders
|
||||||
from src.core.SerieScanner import SerieScanner
|
from src.core.SerieScanner import SerieScanner
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class OperationStatus(Enum):
|
|
||||||
"""Status of an operation."""
|
|
||||||
IDLE = "idle"
|
|
||||||
RUNNING = "running"
|
|
||||||
COMPLETED = "completed"
|
|
||||||
CANCELLED = "cancelled"
|
|
||||||
FAILED = "failed"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ProgressInfo:
|
|
||||||
"""Progress information for long-running operations."""
|
|
||||||
current: int
|
|
||||||
total: int
|
|
||||||
message: str
|
|
||||||
percentage: float
|
|
||||||
status: OperationStatus
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class OperationResult:
|
|
||||||
"""Result of an operation."""
|
|
||||||
success: bool
|
|
||||||
message: str
|
|
||||||
data: Optional[Any] = None
|
|
||||||
error: Optional[Exception] = None
|
|
||||||
|
|
||||||
|
|
||||||
class SeriesApp:
|
class SeriesApp:
|
||||||
"""
|
|
||||||
Main application class for anime series management.
|
|
||||||
|
|
||||||
Provides functionality for:
|
|
||||||
- Searching anime series
|
|
||||||
- Downloading episodes
|
|
||||||
- Scanning directories for missing episodes
|
|
||||||
- Managing series lists
|
|
||||||
|
|
||||||
Supports async callbacks for progress reporting and cancellation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
_initialization_count = 0
|
_initialization_count = 0
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, directory_to_search: str):
|
||||||
self,
|
SeriesApp._initialization_count += 1 # Only show initialization message for the first instance
|
||||||
directory_to_search: str,
|
|
||||||
progress_callback: Optional[Callable[[ProgressInfo], None]] = None,
|
|
||||||
error_callback: Optional[Callable[[Exception], None]] = None,
|
|
||||||
callback_manager: Optional[CallbackManager] = None
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Initialize SeriesApp.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
directory_to_search: Base directory for anime series
|
|
||||||
progress_callback: Optional legacy callback for progress updates
|
|
||||||
error_callback: Optional legacy callback for error notifications
|
|
||||||
callback_manager: Optional callback manager for new callback system
|
|
||||||
"""
|
|
||||||
SeriesApp._initialization_count += 1
|
|
||||||
|
|
||||||
# Only show initialization message for the first instance
|
|
||||||
if SeriesApp._initialization_count <= 1:
|
if SeriesApp._initialization_count <= 1:
|
||||||
logger.info("Initializing SeriesApp...")
|
print("Please wait while initializing...")
|
||||||
|
|
||||||
|
self.progress = None
|
||||||
self.directory_to_search = directory_to_search
|
self.directory_to_search = directory_to_search
|
||||||
self.progress_callback = progress_callback
|
|
||||||
self.error_callback = error_callback
|
|
||||||
|
|
||||||
# Initialize new callback system
|
|
||||||
self._callback_manager = callback_manager or CallbackManager()
|
|
||||||
|
|
||||||
# Cancellation support
|
|
||||||
self._cancel_flag = False
|
|
||||||
self._current_operation: Optional[str] = None
|
|
||||||
self._current_operation_id: Optional[str] = None
|
|
||||||
self._operation_status = OperationStatus.IDLE
|
|
||||||
|
|
||||||
# Initialize components
|
|
||||||
try:
|
|
||||||
self.Loaders = Loaders()
|
self.Loaders = Loaders()
|
||||||
self.loader = self.Loaders.GetLoader(key="aniworld.to")
|
self.loader = self.Loaders.GetLoader(key="aniworld.to")
|
||||||
self.SerieScanner = SerieScanner(
|
self.SerieScanner = SerieScanner(directory_to_search, self.loader)
|
||||||
directory_to_search,
|
|
||||||
self.loader,
|
|
||||||
self._callback_manager
|
|
||||||
)
|
|
||||||
self.List = SerieList(self.directory_to_search)
|
self.List = SerieList(self.directory_to_search)
|
||||||
self.__InitList__()
|
self.__InitList__()
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"SeriesApp initialized for directory: %s",
|
|
||||||
directory_to_search
|
|
||||||
)
|
|
||||||
except (IOError, OSError, RuntimeError) as e:
|
|
||||||
logger.error("Failed to initialize SeriesApp: %s", e)
|
|
||||||
self._handle_error(e)
|
|
||||||
raise
|
|
||||||
|
|
||||||
@property
|
|
||||||
def callback_manager(self) -> CallbackManager:
|
|
||||||
"""Get the callback manager instance."""
|
|
||||||
return self._callback_manager
|
|
||||||
|
|
||||||
def __InitList__(self):
|
def __InitList__(self):
|
||||||
"""Initialize the series list with missing episodes."""
|
|
||||||
try:
|
|
||||||
self.series_list = self.List.GetMissingEpisode()
|
self.series_list = self.List.GetMissingEpisode()
|
||||||
logger.debug(
|
|
||||||
"Loaded %d series with missing episodes",
|
|
||||||
len(self.series_list)
|
|
||||||
)
|
|
||||||
except (IOError, OSError, RuntimeError) as e:
|
|
||||||
logger.error("Failed to initialize series list: %s", e)
|
|
||||||
self._handle_error(e)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def search(self, words: str) -> List[Dict[str, Any]]:
|
def search(self, words: str) -> list:
|
||||||
"""
|
return self.loader.Search(words)
|
||||||
Search for anime series.
|
|
||||||
|
|
||||||
Args:
|
def download(self, serieFolder: str, season: int, episode: int, key: str, callback) -> bool:
|
||||||
words: Search query
|
self.loader.Download(self.directory_to_search, serieFolder, season, episode, key, "German Dub", callback)
|
||||||
|
|
||||||
Returns:
|
def ReScan(self, callback):
|
||||||
List of search results
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If search fails
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
logger.info("Searching for: %s", words)
|
|
||||||
results = self.loader.Search(words)
|
|
||||||
logger.info("Found %d results", len(results))
|
|
||||||
return results
|
|
||||||
except (IOError, OSError, RuntimeError) as e:
|
|
||||||
logger.error("Search failed for '%s': %s", words, e)
|
|
||||||
self._handle_error(e)
|
|
||||||
raise
|
|
||||||
|
|
||||||
def download(
|
|
||||||
self,
|
|
||||||
serieFolder: str,
|
|
||||||
season: int,
|
|
||||||
episode: int,
|
|
||||||
key: str,
|
|
||||||
callback: Optional[Callable[[float], None]] = None,
|
|
||||||
language: str = "German Dub"
|
|
||||||
) -> OperationResult:
|
|
||||||
"""
|
|
||||||
Download an episode.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
serieFolder: Serie folder name
|
|
||||||
season: Season number
|
|
||||||
episode: Episode number
|
|
||||||
key: Serie key
|
|
||||||
callback: Optional legacy progress callback
|
|
||||||
language: Language preference
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
OperationResult with download status
|
|
||||||
"""
|
|
||||||
self._current_operation = f"download_S{season:02d}E{episode:02d}"
|
|
||||||
self._current_operation_id = str(uuid.uuid4())
|
|
||||||
self._operation_status = OperationStatus.RUNNING
|
|
||||||
self._cancel_flag = False
|
|
||||||
|
|
||||||
try:
|
|
||||||
logger.info(
|
|
||||||
"Starting download: %s S%02dE%02d",
|
|
||||||
serieFolder, season, episode
|
|
||||||
)
|
|
||||||
|
|
||||||
# Notify download starting
|
|
||||||
start_msg = (
|
|
||||||
f"Starting download: {serieFolder} "
|
|
||||||
f"S{season:02d}E{episode:02d}"
|
|
||||||
)
|
|
||||||
self._callback_manager.notify_progress(
|
|
||||||
ProgressContext(
|
|
||||||
operation_type=OperationType.DOWNLOAD,
|
|
||||||
operation_id=self._current_operation_id,
|
|
||||||
phase=ProgressPhase.STARTING,
|
|
||||||
current=0,
|
|
||||||
total=100,
|
|
||||||
percentage=0.0,
|
|
||||||
message=start_msg,
|
|
||||||
metadata={
|
|
||||||
"series": serieFolder,
|
|
||||||
"season": season,
|
|
||||||
"episode": episode,
|
|
||||||
"key": key,
|
|
||||||
"language": language
|
|
||||||
}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check for cancellation before starting
|
|
||||||
if self._is_cancelled():
|
|
||||||
self._callback_manager.notify_completion(
|
|
||||||
CompletionContext(
|
|
||||||
operation_type=OperationType.DOWNLOAD,
|
|
||||||
operation_id=self._current_operation_id,
|
|
||||||
success=False,
|
|
||||||
message="Download cancelled before starting"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return OperationResult(
|
|
||||||
success=False,
|
|
||||||
message="Download cancelled before starting"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Wrap callback to check for cancellation and report progress
|
|
||||||
def wrapped_callback(progress: float):
|
|
||||||
if self._is_cancelled():
|
|
||||||
raise InterruptedError("Download cancelled by user")
|
|
||||||
|
|
||||||
# Notify progress via new callback system
|
|
||||||
self._callback_manager.notify_progress(
|
|
||||||
ProgressContext(
|
|
||||||
operation_type=OperationType.DOWNLOAD,
|
|
||||||
operation_id=self._current_operation_id,
|
|
||||||
phase=ProgressPhase.IN_PROGRESS,
|
|
||||||
current=int(progress),
|
|
||||||
total=100,
|
|
||||||
percentage=progress,
|
|
||||||
message=f"Downloading: {progress:.1f}%",
|
|
||||||
metadata={
|
|
||||||
"series": serieFolder,
|
|
||||||
"season": season,
|
|
||||||
"episode": episode
|
|
||||||
}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Call legacy callback if provided
|
|
||||||
if callback:
|
|
||||||
callback(progress)
|
|
||||||
|
|
||||||
# Call legacy progress_callback if provided
|
|
||||||
if self.progress_callback:
|
|
||||||
self.progress_callback(ProgressInfo(
|
|
||||||
current=int(progress),
|
|
||||||
total=100,
|
|
||||||
message=f"Downloading S{season:02d}E{episode:02d}",
|
|
||||||
percentage=progress,
|
|
||||||
status=OperationStatus.RUNNING
|
|
||||||
))
|
|
||||||
|
|
||||||
# Perform download
|
|
||||||
self.loader.Download(
|
|
||||||
self.directory_to_search,
|
|
||||||
serieFolder,
|
|
||||||
season,
|
|
||||||
episode,
|
|
||||||
key,
|
|
||||||
language,
|
|
||||||
wrapped_callback
|
|
||||||
)
|
|
||||||
|
|
||||||
self._operation_status = OperationStatus.COMPLETED
|
|
||||||
logger.info(
|
|
||||||
"Download completed: %s S%02dE%02d",
|
|
||||||
serieFolder, season, episode
|
|
||||||
)
|
|
||||||
|
|
||||||
# Notify completion
|
|
||||||
msg = f"Successfully downloaded S{season:02d}E{episode:02d}"
|
|
||||||
self._callback_manager.notify_completion(
|
|
||||||
CompletionContext(
|
|
||||||
operation_type=OperationType.DOWNLOAD,
|
|
||||||
operation_id=self._current_operation_id,
|
|
||||||
success=True,
|
|
||||||
message=msg,
|
|
||||||
statistics={
|
|
||||||
"series": serieFolder,
|
|
||||||
"season": season,
|
|
||||||
"episode": episode
|
|
||||||
}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return OperationResult(
|
|
||||||
success=True,
|
|
||||||
message=msg
|
|
||||||
)
|
|
||||||
|
|
||||||
except InterruptedError as e:
|
|
||||||
self._operation_status = OperationStatus.CANCELLED
|
|
||||||
logger.warning("Download cancelled: %s", e)
|
|
||||||
|
|
||||||
# Notify cancellation
|
|
||||||
self._callback_manager.notify_completion(
|
|
||||||
CompletionContext(
|
|
||||||
operation_type=OperationType.DOWNLOAD,
|
|
||||||
operation_id=self._current_operation_id,
|
|
||||||
success=False,
|
|
||||||
message="Download cancelled"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return OperationResult(
|
|
||||||
success=False,
|
|
||||||
message="Download cancelled",
|
|
||||||
error=e
|
|
||||||
)
|
|
||||||
except (IOError, OSError, RuntimeError) as e:
|
|
||||||
self._operation_status = OperationStatus.FAILED
|
|
||||||
logger.error("Download failed: %s", e)
|
|
||||||
|
|
||||||
# Notify error
|
|
||||||
error_msg = f"Download failed: {str(e)}"
|
|
||||||
self._callback_manager.notify_error(
|
|
||||||
ErrorContext(
|
|
||||||
operation_type=OperationType.DOWNLOAD,
|
|
||||||
operation_id=self._current_operation_id,
|
|
||||||
error=e,
|
|
||||||
message=error_msg,
|
|
||||||
recoverable=False,
|
|
||||||
metadata={
|
|
||||||
"series": serieFolder,
|
|
||||||
"season": season,
|
|
||||||
"episode": episode
|
|
||||||
}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Notify completion with failure
|
|
||||||
self._callback_manager.notify_completion(
|
|
||||||
CompletionContext(
|
|
||||||
operation_type=OperationType.DOWNLOAD,
|
|
||||||
operation_id=self._current_operation_id,
|
|
||||||
success=False,
|
|
||||||
message=error_msg
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
self._handle_error(e)
|
|
||||||
return OperationResult(
|
|
||||||
success=False,
|
|
||||||
message=error_msg,
|
|
||||||
error=e
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
self._current_operation = None
|
|
||||||
self._current_operation_id = None
|
|
||||||
|
|
||||||
def ReScan(
|
|
||||||
self,
|
|
||||||
callback: Optional[Callable[[str, int], None]] = None
|
|
||||||
) -> OperationResult:
|
|
||||||
"""
|
|
||||||
Rescan directory for missing episodes.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
callback: Optional progress callback (folder, current_count)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
OperationResult with scan status
|
|
||||||
"""
|
|
||||||
self._current_operation = "rescan"
|
|
||||||
self._operation_status = OperationStatus.RUNNING
|
|
||||||
self._cancel_flag = False
|
|
||||||
|
|
||||||
try:
|
|
||||||
logger.info("Starting directory rescan")
|
|
||||||
|
|
||||||
# Get total items to scan
|
|
||||||
total_to_scan = self.SerieScanner.GetTotalToScan()
|
|
||||||
logger.info("Total folders to scan: %d", total_to_scan)
|
|
||||||
|
|
||||||
# Reinitialize scanner
|
|
||||||
self.SerieScanner.Reinit()
|
self.SerieScanner.Reinit()
|
||||||
|
self.SerieScanner.Scan(callback)
|
||||||
|
|
||||||
# Wrap callback for progress reporting and cancellation
|
|
||||||
def wrapped_callback(folder: str, current: int):
|
|
||||||
if self._is_cancelled():
|
|
||||||
raise InterruptedError("Scan cancelled by user")
|
|
||||||
|
|
||||||
# Calculate progress
|
|
||||||
if total_to_scan > 0:
|
|
||||||
percentage = (current / total_to_scan * 100)
|
|
||||||
else:
|
|
||||||
percentage = 0
|
|
||||||
|
|
||||||
# Report progress
|
|
||||||
if self.progress_callback:
|
|
||||||
progress_info = ProgressInfo(
|
|
||||||
current=current,
|
|
||||||
total=total_to_scan,
|
|
||||||
message=f"Scanning: {folder}",
|
|
||||||
percentage=percentage,
|
|
||||||
status=OperationStatus.RUNNING
|
|
||||||
)
|
|
||||||
self.progress_callback(progress_info)
|
|
||||||
|
|
||||||
# Call original callback if provided
|
|
||||||
if callback:
|
|
||||||
callback(folder, current)
|
|
||||||
|
|
||||||
# Perform scan
|
|
||||||
self.SerieScanner.Scan(wrapped_callback)
|
|
||||||
|
|
||||||
# Reinitialize list
|
|
||||||
self.List = SerieList(self.directory_to_search)
|
self.List = SerieList(self.directory_to_search)
|
||||||
self.__InitList__()
|
self.__InitList__()
|
||||||
|
|
||||||
self._operation_status = OperationStatus.COMPLETED
|
|
||||||
logger.info("Directory rescan completed successfully")
|
|
||||||
|
|
||||||
msg = (
|
|
||||||
f"Scan completed. Found {len(self.series_list)} "
|
|
||||||
f"series."
|
|
||||||
)
|
|
||||||
return OperationResult(
|
|
||||||
success=True,
|
|
||||||
message=msg,
|
|
||||||
data={"series_count": len(self.series_list)}
|
|
||||||
)
|
|
||||||
|
|
||||||
except InterruptedError as e:
|
|
||||||
self._operation_status = OperationStatus.CANCELLED
|
|
||||||
logger.warning("Scan cancelled: %s", e)
|
|
||||||
return OperationResult(
|
|
||||||
success=False,
|
|
||||||
message="Scan cancelled",
|
|
||||||
error=e
|
|
||||||
)
|
|
||||||
except (IOError, OSError, RuntimeError) as e:
|
|
||||||
self._operation_status = OperationStatus.FAILED
|
|
||||||
logger.error("Scan failed: %s", e)
|
|
||||||
self._handle_error(e)
|
|
||||||
return OperationResult(
|
|
||||||
success=False,
|
|
||||||
message=f"Scan failed: {str(e)}",
|
|
||||||
error=e
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
self._current_operation = None
|
|
||||||
|
|
||||||
async def async_download(
|
|
||||||
self,
|
|
||||||
serieFolder: str,
|
|
||||||
season: int,
|
|
||||||
episode: int,
|
|
||||||
key: str,
|
|
||||||
callback: Optional[Callable[[float], None]] = None,
|
|
||||||
language: str = "German Dub"
|
|
||||||
) -> OperationResult:
|
|
||||||
"""
|
|
||||||
Async version of download method.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
serieFolder: Serie folder name
|
|
||||||
season: Season number
|
|
||||||
episode: Episode number
|
|
||||||
key: Serie key
|
|
||||||
callback: Optional progress callback
|
|
||||||
language: Language preference
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
OperationResult with download status
|
|
||||||
"""
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
return await loop.run_in_executor(
|
|
||||||
None,
|
|
||||||
self.download,
|
|
||||||
serieFolder,
|
|
||||||
season,
|
|
||||||
episode,
|
|
||||||
key,
|
|
||||||
callback,
|
|
||||||
language
|
|
||||||
)
|
|
||||||
|
|
||||||
async def async_rescan(
|
|
||||||
self,
|
|
||||||
callback: Optional[Callable[[str, int], None]] = None
|
|
||||||
) -> OperationResult:
|
|
||||||
"""
|
|
||||||
Async version of ReScan method.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
callback: Optional progress callback
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
OperationResult with scan status
|
|
||||||
"""
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
return await loop.run_in_executor(
|
|
||||||
None,
|
|
||||||
self.ReScan,
|
|
||||||
callback
|
|
||||||
)
|
|
||||||
|
|
||||||
def cancel_operation(self) -> bool:
|
|
||||||
"""
|
|
||||||
Cancel the current operation.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if operation cancelled, False if no operation running
|
|
||||||
"""
|
|
||||||
if (self._current_operation and
|
|
||||||
self._operation_status == OperationStatus.RUNNING):
|
|
||||||
logger.info(
|
|
||||||
"Cancelling operation: %s",
|
|
||||||
self._current_operation
|
|
||||||
)
|
|
||||||
self._cancel_flag = True
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _is_cancelled(self) -> bool:
|
|
||||||
"""Check if the current operation has been cancelled."""
|
|
||||||
return self._cancel_flag
|
|
||||||
|
|
||||||
def _handle_error(self, error: Exception):
|
|
||||||
"""
|
|
||||||
Handle errors and notify via callback.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
error: Exception that occurred
|
|
||||||
"""
|
|
||||||
if self.error_callback:
|
|
||||||
try:
|
|
||||||
self.error_callback(error)
|
|
||||||
except (RuntimeError, ValueError) as callback_error:
|
|
||||||
logger.error(
|
|
||||||
"Error in error callback: %s",
|
|
||||||
callback_error
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_series_list(self) -> List[Any]:
|
|
||||||
"""
|
|
||||||
Get the current series list.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of series with missing episodes
|
|
||||||
"""
|
|
||||||
return self.series_list
|
|
||||||
|
|
||||||
def get_operation_status(self) -> OperationStatus:
|
|
||||||
"""
|
|
||||||
Get the current operation status.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Current operation status
|
|
||||||
"""
|
|
||||||
return self._operation_status
|
|
||||||
|
|
||||||
def get_current_operation(self) -> Optional[str]:
|
|
||||||
"""
|
|
||||||
Get the current operation name.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Name of current operation or None
|
|
||||||
"""
|
|
||||||
return self._current_operation
|
|
||||||
|
|||||||
@ -1,347 +0,0 @@
|
|||||||
"""
|
|
||||||
Progress callback interfaces for core operations.
|
|
||||||
|
|
||||||
This module defines clean interfaces for progress reporting, error handling,
|
|
||||||
and completion notifications across all core operations (scanning,
|
|
||||||
downloading).
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Any, Dict, Optional
|
|
||||||
|
|
||||||
|
|
||||||
class OperationType(str, Enum):
|
|
||||||
"""Types of operations that can report progress."""
|
|
||||||
|
|
||||||
SCAN = "scan"
|
|
||||||
DOWNLOAD = "download"
|
|
||||||
SEARCH = "search"
|
|
||||||
INITIALIZATION = "initialization"
|
|
||||||
|
|
||||||
|
|
||||||
class ProgressPhase(str, Enum):
|
|
||||||
"""Phases of an operation's lifecycle."""
|
|
||||||
|
|
||||||
STARTING = "starting"
|
|
||||||
IN_PROGRESS = "in_progress"
|
|
||||||
COMPLETING = "completing"
|
|
||||||
COMPLETED = "completed"
|
|
||||||
FAILED = "failed"
|
|
||||||
CANCELLED = "cancelled"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ProgressContext:
|
|
||||||
"""
|
|
||||||
Complete context information for a progress update.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
operation_type: Type of operation being performed
|
|
||||||
operation_id: Unique identifier for this operation
|
|
||||||
phase: Current phase of the operation
|
|
||||||
current: Current progress value (e.g., files processed)
|
|
||||||
total: Total progress value (e.g., total files)
|
|
||||||
percentage: Completion percentage (0.0 to 100.0)
|
|
||||||
message: Human-readable progress message
|
|
||||||
details: Additional context-specific details
|
|
||||||
metadata: Extra metadata for specialized use cases
|
|
||||||
"""
|
|
||||||
|
|
||||||
operation_type: OperationType
|
|
||||||
operation_id: str
|
|
||||||
phase: ProgressPhase
|
|
||||||
current: int
|
|
||||||
total: int
|
|
||||||
percentage: float
|
|
||||||
message: str
|
|
||||||
details: Optional[str] = None
|
|
||||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
|
||||||
"""Convert to dictionary for serialization."""
|
|
||||||
return {
|
|
||||||
"operation_type": self.operation_type.value,
|
|
||||||
"operation_id": self.operation_id,
|
|
||||||
"phase": self.phase.value,
|
|
||||||
"current": self.current,
|
|
||||||
"total": self.total,
|
|
||||||
"percentage": round(self.percentage, 2),
|
|
||||||
"message": self.message,
|
|
||||||
"details": self.details,
|
|
||||||
"metadata": self.metadata,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ErrorContext:
|
|
||||||
"""
|
|
||||||
Context information for error callbacks.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
operation_type: Type of operation that failed
|
|
||||||
operation_id: Unique identifier for the operation
|
|
||||||
error: The exception that occurred
|
|
||||||
message: Human-readable error message
|
|
||||||
recoverable: Whether the error is recoverable
|
|
||||||
retry_count: Number of retry attempts made
|
|
||||||
metadata: Additional error context
|
|
||||||
"""
|
|
||||||
|
|
||||||
operation_type: OperationType
|
|
||||||
operation_id: str
|
|
||||||
error: Exception
|
|
||||||
message: str
|
|
||||||
recoverable: bool = False
|
|
||||||
retry_count: int = 0
|
|
||||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
|
||||||
"""Convert to dictionary for serialization."""
|
|
||||||
return {
|
|
||||||
"operation_type": self.operation_type.value,
|
|
||||||
"operation_id": self.operation_id,
|
|
||||||
"error_type": type(self.error).__name__,
|
|
||||||
"error_message": str(self.error),
|
|
||||||
"message": self.message,
|
|
||||||
"recoverable": self.recoverable,
|
|
||||||
"retry_count": self.retry_count,
|
|
||||||
"metadata": self.metadata,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class CompletionContext:
|
|
||||||
"""
|
|
||||||
Context information for completion callbacks.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
operation_type: Type of operation that completed
|
|
||||||
operation_id: Unique identifier for the operation
|
|
||||||
success: Whether the operation completed successfully
|
|
||||||
message: Human-readable completion message
|
|
||||||
result_data: Result data from the operation
|
|
||||||
statistics: Operation statistics (duration, items processed, etc.)
|
|
||||||
metadata: Additional completion context
|
|
||||||
"""
|
|
||||||
|
|
||||||
operation_type: OperationType
|
|
||||||
operation_id: str
|
|
||||||
success: bool
|
|
||||||
message: str
|
|
||||||
result_data: Optional[Any] = None
|
|
||||||
statistics: Dict[str, Any] = field(default_factory=dict)
|
|
||||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
|
||||||
"""Convert to dictionary for serialization."""
|
|
||||||
return {
|
|
||||||
"operation_type": self.operation_type.value,
|
|
||||||
"operation_id": self.operation_id,
|
|
||||||
"success": self.success,
|
|
||||||
"message": self.message,
|
|
||||||
"statistics": self.statistics,
|
|
||||||
"metadata": self.metadata,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ProgressCallback(ABC):
|
|
||||||
"""
|
|
||||||
Abstract base class for progress callbacks.
|
|
||||||
|
|
||||||
Implement this interface to receive progress updates from core operations.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def on_progress(self, context: ProgressContext) -> None:
|
|
||||||
"""
|
|
||||||
Called when progress is made in an operation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
context: Complete progress context information
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ErrorCallback(ABC):
|
|
||||||
"""
|
|
||||||
Abstract base class for error callbacks.
|
|
||||||
|
|
||||||
Implement this interface to receive error notifications from core
|
|
||||||
operations.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def on_error(self, context: ErrorContext) -> None:
|
|
||||||
"""
|
|
||||||
Called when an error occurs during an operation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
context: Complete error context information
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class CompletionCallback(ABC):
|
|
||||||
"""
|
|
||||||
Abstract base class for completion callbacks.
|
|
||||||
|
|
||||||
Implement this interface to receive completion notifications from
|
|
||||||
core operations.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def on_completion(self, context: CompletionContext) -> None:
|
|
||||||
"""
|
|
||||||
Called when an operation completes (successfully or not).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
context: Complete completion context information
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class CallbackManager:
|
|
||||||
"""
|
|
||||||
Manages multiple callbacks for an operation.
|
|
||||||
|
|
||||||
This class allows registering multiple progress, error, and completion
|
|
||||||
callbacks and dispatching events to all registered callbacks.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
"""Initialize the callback manager."""
|
|
||||||
self._progress_callbacks: list[ProgressCallback] = []
|
|
||||||
self._error_callbacks: list[ErrorCallback] = []
|
|
||||||
self._completion_callbacks: list[CompletionCallback] = []
|
|
||||||
|
|
||||||
def register_progress_callback(self, callback: ProgressCallback) -> None:
|
|
||||||
"""
|
|
||||||
Register a progress callback.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
callback: Progress callback to register
|
|
||||||
"""
|
|
||||||
if callback not in self._progress_callbacks:
|
|
||||||
self._progress_callbacks.append(callback)
|
|
||||||
|
|
||||||
def register_error_callback(self, callback: ErrorCallback) -> None:
|
|
||||||
"""
|
|
||||||
Register an error callback.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
callback: Error callback to register
|
|
||||||
"""
|
|
||||||
if callback not in self._error_callbacks:
|
|
||||||
self._error_callbacks.append(callback)
|
|
||||||
|
|
||||||
def register_completion_callback(
|
|
||||||
self,
|
|
||||||
callback: CompletionCallback
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Register a completion callback.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
callback: Completion callback to register
|
|
||||||
"""
|
|
||||||
if callback not in self._completion_callbacks:
|
|
||||||
self._completion_callbacks.append(callback)
|
|
||||||
|
|
||||||
def unregister_progress_callback(self, callback: ProgressCallback) -> None:
|
|
||||||
"""
|
|
||||||
Unregister a progress callback.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
callback: Progress callback to unregister
|
|
||||||
"""
|
|
||||||
if callback in self._progress_callbacks:
|
|
||||||
self._progress_callbacks.remove(callback)
|
|
||||||
|
|
||||||
def unregister_error_callback(self, callback: ErrorCallback) -> None:
|
|
||||||
"""
|
|
||||||
Unregister an error callback.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
callback: Error callback to unregister
|
|
||||||
"""
|
|
||||||
if callback in self._error_callbacks:
|
|
||||||
self._error_callbacks.remove(callback)
|
|
||||||
|
|
||||||
def unregister_completion_callback(
|
|
||||||
self,
|
|
||||||
callback: CompletionCallback
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Unregister a completion callback.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
callback: Completion callback to unregister
|
|
||||||
"""
|
|
||||||
if callback in self._completion_callbacks:
|
|
||||||
self._completion_callbacks.remove(callback)
|
|
||||||
|
|
||||||
def notify_progress(self, context: ProgressContext) -> None:
|
|
||||||
"""
|
|
||||||
Notify all registered progress callbacks.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
context: Progress context to send
|
|
||||||
"""
|
|
||||||
for callback in self._progress_callbacks:
|
|
||||||
try:
|
|
||||||
callback.on_progress(context)
|
|
||||||
except Exception as e:
|
|
||||||
# Log but don't let callback errors break the operation
|
|
||||||
logging.error(
|
|
||||||
"Error in progress callback %s: %s",
|
|
||||||
callback,
|
|
||||||
e,
|
|
||||||
exc_info=True
|
|
||||||
)
|
|
||||||
|
|
||||||
def notify_error(self, context: ErrorContext) -> None:
|
|
||||||
"""
|
|
||||||
Notify all registered error callbacks.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
context: Error context to send
|
|
||||||
"""
|
|
||||||
for callback in self._error_callbacks:
|
|
||||||
try:
|
|
||||||
callback.on_error(context)
|
|
||||||
except Exception as e:
|
|
||||||
# Log but don't let callback errors break the operation
|
|
||||||
logging.error(
|
|
||||||
"Error in error callback %s: %s",
|
|
||||||
callback,
|
|
||||||
e,
|
|
||||||
exc_info=True
|
|
||||||
)
|
|
||||||
|
|
||||||
def notify_completion(self, context: CompletionContext) -> None:
|
|
||||||
"""
|
|
||||||
Notify all registered completion callbacks.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
context: Completion context to send
|
|
||||||
"""
|
|
||||||
for callback in self._completion_callbacks:
|
|
||||||
try:
|
|
||||||
callback.on_completion(context)
|
|
||||||
except Exception as e:
|
|
||||||
# Log but don't let callback errors break the operation
|
|
||||||
logging.error(
|
|
||||||
"Error in completion callback %s: %s",
|
|
||||||
callback,
|
|
||||||
e,
|
|
||||||
exc_info=True
|
|
||||||
)
|
|
||||||
|
|
||||||
def clear_all_callbacks(self) -> None:
|
|
||||||
"""Clear all registered callbacks."""
|
|
||||||
self._progress_callbacks.clear()
|
|
||||||
self._error_callbacks.clear()
|
|
||||||
self._completion_callbacks.clear()
|
|
||||||
@ -1,14 +1,9 @@
|
|||||||
from typing import Dict, List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
|
||||||
|
from src.config.settings import settings
|
||||||
from src.server.models.config import AppConfig, ConfigUpdate, ValidationResult
|
from src.server.models.config import AppConfig, ConfigUpdate, ValidationResult
|
||||||
from src.server.services.config_service import (
|
|
||||||
ConfigBackupError,
|
|
||||||
ConfigServiceError,
|
|
||||||
ConfigValidationError,
|
|
||||||
get_config_service,
|
|
||||||
)
|
|
||||||
from src.server.utils.dependencies import require_auth
|
from src.server.utils.dependencies import require_auth
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/config", tags=["config"])
|
router = APIRouter(prefix="/api/config", tags=["config"])
|
||||||
@ -16,144 +11,58 @@ router = APIRouter(prefix="/api/config", tags=["config"])
|
|||||||
|
|
||||||
@router.get("", response_model=AppConfig)
|
@router.get("", response_model=AppConfig)
|
||||||
def get_config(auth: Optional[dict] = Depends(require_auth)) -> AppConfig:
|
def get_config(auth: Optional[dict] = Depends(require_auth)) -> AppConfig:
|
||||||
"""Return current application configuration."""
|
"""Return current application configuration (read-only)."""
|
||||||
|
# Construct AppConfig from pydantic-settings where possible
|
||||||
|
cfg_data = {
|
||||||
|
"name": getattr(settings, "app_name", "Aniworld"),
|
||||||
|
"data_dir": getattr(settings, "data_dir", "data"),
|
||||||
|
"scheduler": getattr(settings, "scheduler", {}),
|
||||||
|
"logging": getattr(settings, "logging", {}),
|
||||||
|
"backup": getattr(settings, "backup", {}),
|
||||||
|
"other": getattr(settings, "other", {}),
|
||||||
|
}
|
||||||
try:
|
try:
|
||||||
config_service = get_config_service()
|
return AppConfig(**cfg_data)
|
||||||
return config_service.load_config()
|
except Exception as e:
|
||||||
except ConfigServiceError as e:
|
raise HTTPException(status_code=500, detail=f"Failed to read config: {e}")
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=f"Failed to load config: {e}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
|
|
||||||
@router.put("", response_model=AppConfig)
|
@router.put("", response_model=AppConfig)
|
||||||
def update_config(
|
def update_config(update: ConfigUpdate, auth: dict = Depends(require_auth)) -> AppConfig:
|
||||||
update: ConfigUpdate, auth: dict = Depends(require_auth)
|
"""Apply an update to the configuration and return the new config.
|
||||||
) -> AppConfig:
|
|
||||||
"""Apply an update to the configuration and persist it.
|
|
||||||
|
|
||||||
Creates automatic backup before applying changes.
|
Note: persistence strategy for settings is out-of-scope for this task.
|
||||||
|
This endpoint updates the in-memory Settings where possible and returns
|
||||||
|
the merged result as an AppConfig.
|
||||||
"""
|
"""
|
||||||
|
# Build current AppConfig from settings then apply update
|
||||||
|
current = get_config(auth)
|
||||||
|
new_cfg = update.apply_to(current)
|
||||||
|
|
||||||
|
# Mirror some fields back into pydantic-settings 'settings' where safe.
|
||||||
|
# Avoid writing secrets or unsupported fields.
|
||||||
try:
|
try:
|
||||||
config_service = get_config_service()
|
if new_cfg.data_dir:
|
||||||
return config_service.update_config(update)
|
setattr(settings, "data_dir", new_cfg.data_dir)
|
||||||
except ConfigValidationError as e:
|
# scheduler/logging/backup/other kept in memory only for now
|
||||||
raise HTTPException(
|
setattr(settings, "scheduler", new_cfg.scheduler.model_dump())
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
setattr(settings, "logging", new_cfg.logging.model_dump())
|
||||||
detail=f"Invalid configuration: {e}"
|
setattr(settings, "backup", new_cfg.backup.model_dump())
|
||||||
) from e
|
setattr(settings, "other", new_cfg.other)
|
||||||
except ConfigServiceError as e:
|
except Exception:
|
||||||
raise HTTPException(
|
# Best-effort; do not fail the request if persistence is not available
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
pass
|
||||||
detail=f"Failed to update config: {e}"
|
|
||||||
) from e
|
return new_cfg
|
||||||
|
|
||||||
|
|
||||||
@router.post("/validate", response_model=ValidationResult)
|
@router.post("/validate", response_model=ValidationResult)
|
||||||
def validate_config(
|
def validate_config(cfg: AppConfig, auth: dict = Depends(require_auth)) -> ValidationResult:
|
||||||
cfg: AppConfig, auth: dict = Depends(require_auth) # noqa: ARG001
|
|
||||||
) -> ValidationResult:
|
|
||||||
"""Validate a provided AppConfig without applying it.
|
"""Validate a provided AppConfig without applying it.
|
||||||
|
|
||||||
Returns ValidationResult with any validation errors.
|
Returns ValidationResult with any validation errors.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
config_service = get_config_service()
|
return cfg.validate()
|
||||||
return config_service.validate_config(cfg)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=str(e)
|
|
||||||
) from e
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/backups", response_model=List[Dict[str, object]])
|
|
||||||
def list_backups(
|
|
||||||
auth: dict = Depends(require_auth)
|
|
||||||
) -> List[Dict[str, object]]:
|
|
||||||
"""List all available configuration backups.
|
|
||||||
|
|
||||||
Returns list of backup metadata including name, size, and created time.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
config_service = get_config_service()
|
|
||||||
return config_service.list_backups()
|
|
||||||
except ConfigServiceError as e:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=f"Failed to list backups: {e}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/backups", response_model=Dict[str, str])
|
|
||||||
def create_backup(
|
|
||||||
name: Optional[str] = None, auth: dict = Depends(require_auth)
|
|
||||||
) -> Dict[str, str]:
|
|
||||||
"""Create a backup of the current configuration.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: Optional custom backup name (timestamp used if not provided)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary with backup name and message
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
config_service = get_config_service()
|
|
||||||
backup_path = config_service.create_backup(name)
|
|
||||||
return {
|
|
||||||
"name": backup_path.name,
|
|
||||||
"message": "Backup created successfully"
|
|
||||||
}
|
|
||||||
except ConfigBackupError as e:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=f"Failed to create backup: {e}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/backups/{backup_name}/restore", response_model=AppConfig)
|
|
||||||
def restore_backup(
|
|
||||||
backup_name: str, auth: dict = Depends(require_auth)
|
|
||||||
) -> AppConfig:
|
|
||||||
"""Restore configuration from a backup.
|
|
||||||
|
|
||||||
Creates backup of current config before restoring.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
backup_name: Name of backup file to restore
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Restored configuration
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
config_service = get_config_service()
|
|
||||||
return config_service.restore_backup(backup_name)
|
|
||||||
except ConfigBackupError as e:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail=f"Failed to restore backup: {e}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/backups/{backup_name}")
|
|
||||||
def delete_backup(
|
|
||||||
backup_name: str, auth: dict = Depends(require_auth)
|
|
||||||
) -> Dict[str, str]:
|
|
||||||
"""Delete a configuration backup.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
backup_name: Name of backup file to delete
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Success message
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
config_service = get_config_service()
|
|
||||||
config_service.delete_backup(backup_name)
|
|
||||||
return {"message": f"Backup '{backup_name}' deleted successfully"}
|
|
||||||
except ConfigBackupError as e:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail=f"Failed to delete backup: {e}"
|
|
||||||
) from e
|
|
||||||
|
|||||||
@ -1,436 +0,0 @@
|
|||||||
# Database Layer
|
|
||||||
|
|
||||||
SQLAlchemy-based database layer for the Aniworld web application.
|
|
||||||
|
|
||||||
## Overview
|
|
||||||
|
|
||||||
This package provides persistent storage for anime series, episodes, download queue, and user sessions using SQLAlchemy ORM with comprehensive service layer for CRUD operations.
|
|
||||||
|
|
||||||
## Quick Start
|
|
||||||
|
|
||||||
### Installation
|
|
||||||
|
|
||||||
Install required dependencies:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pip install sqlalchemy alembic aiosqlite
|
|
||||||
```
|
|
||||||
|
|
||||||
Or use the project requirements:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pip install -r requirements.txt
|
|
||||||
```
|
|
||||||
|
|
||||||
### Initialization
|
|
||||||
|
|
||||||
Initialize the database on application startup:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from src.server.database import init_db, close_db
|
|
||||||
|
|
||||||
# Startup
|
|
||||||
await init_db()
|
|
||||||
|
|
||||||
# Shutdown
|
|
||||||
await close_db()
|
|
||||||
```
|
|
||||||
|
|
||||||
### Usage in FastAPI
|
|
||||||
|
|
||||||
Use the database session dependency in your endpoints:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from fastapi import Depends
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
from src.server.database import get_db_session, AnimeSeries
|
|
||||||
from sqlalchemy import select
|
|
||||||
|
|
||||||
@app.get("/anime")
|
|
||||||
async def get_anime(db: AsyncSession = Depends(get_db_session)):
|
|
||||||
result = await db.execute(select(AnimeSeries))
|
|
||||||
return result.scalars().all()
|
|
||||||
```
|
|
||||||
|
|
||||||
## Models
|
|
||||||
|
|
||||||
### AnimeSeries
|
|
||||||
|
|
||||||
Represents an anime series with metadata and relationships.
|
|
||||||
|
|
||||||
```python
|
|
||||||
series = AnimeSeries(
|
|
||||||
key="attack-on-titan",
|
|
||||||
name="Attack on Titan",
|
|
||||||
site="https://aniworld.to",
|
|
||||||
folder="/anime/attack-on-titan",
|
|
||||||
description="Epic anime about titans",
|
|
||||||
status="completed",
|
|
||||||
total_episodes=75
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
### Episode
|
|
||||||
|
|
||||||
Individual episodes linked to series.
|
|
||||||
|
|
||||||
```python
|
|
||||||
episode = Episode(
|
|
||||||
series_id=series.id,
|
|
||||||
season=1,
|
|
||||||
episode_number=5,
|
|
||||||
title="The Fifth Episode",
|
|
||||||
is_downloaded=True
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
### DownloadQueueItem
|
|
||||||
|
|
||||||
Download queue with progress tracking.
|
|
||||||
|
|
||||||
```python
|
|
||||||
from src.server.database.models import DownloadStatus, DownloadPriority
|
|
||||||
|
|
||||||
item = DownloadQueueItem(
|
|
||||||
series_id=series.id,
|
|
||||||
season=1,
|
|
||||||
episode_number=3,
|
|
||||||
status=DownloadStatus.DOWNLOADING,
|
|
||||||
priority=DownloadPriority.HIGH,
|
|
||||||
progress_percent=45.5
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
### UserSession
|
|
||||||
|
|
||||||
User authentication sessions.
|
|
||||||
|
|
||||||
```python
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
|
|
||||||
session = UserSession(
|
|
||||||
session_id="unique-session-id",
|
|
||||||
token_hash="hashed-jwt-token",
|
|
||||||
expires_at=datetime.utcnow() + timedelta(hours=24),
|
|
||||||
is_active=True
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Mixins
|
|
||||||
|
|
||||||
### TimestampMixin
|
|
||||||
|
|
||||||
Adds automatic timestamp tracking:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from src.server.database.base import Base, TimestampMixin
|
|
||||||
|
|
||||||
class MyModel(Base, TimestampMixin):
|
|
||||||
__tablename__ = "my_table"
|
|
||||||
# created_at and updated_at automatically added
|
|
||||||
```
|
|
||||||
|
|
||||||
### SoftDeleteMixin
|
|
||||||
|
|
||||||
Provides soft delete functionality:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from src.server.database.base import Base, SoftDeleteMixin
|
|
||||||
|
|
||||||
class MyModel(Base, SoftDeleteMixin):
|
|
||||||
__tablename__ = "my_table"
|
|
||||||
|
|
||||||
# Usage
|
|
||||||
instance.soft_delete() # Mark as deleted
|
|
||||||
instance.is_deleted # Check if deleted
|
|
||||||
instance.restore() # Restore deleted record
|
|
||||||
```
|
|
||||||
|
|
||||||
## Configuration
|
|
||||||
|
|
||||||
Configure database via environment variables:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
DATABASE_URL=sqlite:///./data/aniworld.db
|
|
||||||
LOG_LEVEL=DEBUG # Enables SQL query logging
|
|
||||||
```
|
|
||||||
|
|
||||||
Or in code:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from src.config.settings import settings
|
|
||||||
|
|
||||||
settings.database_url = "sqlite:///./data/aniworld.db"
|
|
||||||
```
|
|
||||||
|
|
||||||
## Migrations (Future)
|
|
||||||
|
|
||||||
Alembic is installed for database migrations:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Initialize Alembic
|
|
||||||
alembic init alembic
|
|
||||||
|
|
||||||
# Generate migration
|
|
||||||
alembic revision --autogenerate -m "Description"
|
|
||||||
|
|
||||||
# Apply migrations
|
|
||||||
alembic upgrade head
|
|
||||||
|
|
||||||
# Rollback
|
|
||||||
alembic downgrade -1
|
|
||||||
```
|
|
||||||
|
|
||||||
## Testing
|
|
||||||
|
|
||||||
Run database tests:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pytest tests/unit/test_database_models.py -v
|
|
||||||
```
|
|
||||||
|
|
||||||
The test suite uses an in-memory SQLite database for isolation and speed.
|
|
||||||
|
|
||||||
## Architecture
|
|
||||||
|
|
||||||
- **base.py**: Base declarative class and mixins
|
|
||||||
- **models.py**: SQLAlchemy ORM models (4 models)
|
|
||||||
- **connection.py**: Engine, session factory, dependency injection
|
|
||||||
- **migrations.py**: Alembic migration placeholder
|
|
||||||
- ****init**.py**: Package exports
|
|
||||||
- **service.py**: Service layer with CRUD operations
|
|
||||||
|
|
||||||
## Service Layer
|
|
||||||
|
|
||||||
The service layer provides high-level CRUD operations for all models:
|
|
||||||
|
|
||||||
### AnimeSeriesService
|
|
||||||
|
|
||||||
```python
|
|
||||||
from src.server.database import AnimeSeriesService
|
|
||||||
|
|
||||||
# Create series
|
|
||||||
series = await AnimeSeriesService.create(
|
|
||||||
db,
|
|
||||||
key="my-anime",
|
|
||||||
name="My Anime",
|
|
||||||
site="https://example.com",
|
|
||||||
folder="/path/to/anime"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get by ID or key
|
|
||||||
series = await AnimeSeriesService.get_by_id(db, series_id)
|
|
||||||
series = await AnimeSeriesService.get_by_key(db, "my-anime")
|
|
||||||
|
|
||||||
# Get all with pagination
|
|
||||||
all_series = await AnimeSeriesService.get_all(db, limit=50, offset=0)
|
|
||||||
|
|
||||||
# Update
|
|
||||||
updated = await AnimeSeriesService.update(db, series_id, name="Updated Name")
|
|
||||||
|
|
||||||
# Delete (cascades to episodes and downloads)
|
|
||||||
deleted = await AnimeSeriesService.delete(db, series_id)
|
|
||||||
|
|
||||||
# Search
|
|
||||||
results = await AnimeSeriesService.search(db, "naruto", limit=10)
|
|
||||||
```
|
|
||||||
|
|
||||||
### EpisodeService
|
|
||||||
|
|
||||||
```python
|
|
||||||
from src.server.database import EpisodeService
|
|
||||||
|
|
||||||
# Create episode
|
|
||||||
episode = await EpisodeService.create(
|
|
||||||
db,
|
|
||||||
series_id=1,
|
|
||||||
season=1,
|
|
||||||
episode_number=5,
|
|
||||||
title="Episode 5"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get episodes for series
|
|
||||||
episodes = await EpisodeService.get_by_series(db, series_id, season=1)
|
|
||||||
|
|
||||||
# Get specific episode
|
|
||||||
episode = await EpisodeService.get_by_episode(db, series_id, season=1, episode_number=5)
|
|
||||||
|
|
||||||
# Mark as downloaded
|
|
||||||
updated = await EpisodeService.mark_downloaded(
|
|
||||||
db,
|
|
||||||
episode_id,
|
|
||||||
file_path="/path/to/file.mp4",
|
|
||||||
file_size=1024000
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
### DownloadQueueService
|
|
||||||
|
|
||||||
```python
|
|
||||||
from src.server.database import DownloadQueueService
|
|
||||||
from src.server.database.models import DownloadPriority, DownloadStatus
|
|
||||||
|
|
||||||
# Add to queue
|
|
||||||
item = await DownloadQueueService.create(
|
|
||||||
db,
|
|
||||||
series_id=1,
|
|
||||||
season=1,
|
|
||||||
episode_number=5,
|
|
||||||
priority=DownloadPriority.HIGH
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get pending downloads (ordered by priority)
|
|
||||||
pending = await DownloadQueueService.get_pending(db, limit=10)
|
|
||||||
|
|
||||||
# Get active downloads
|
|
||||||
active = await DownloadQueueService.get_active(db)
|
|
||||||
|
|
||||||
# Update status
|
|
||||||
updated = await DownloadQueueService.update_status(
|
|
||||||
db,
|
|
||||||
item_id,
|
|
||||||
DownloadStatus.DOWNLOADING
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update progress
|
|
||||||
updated = await DownloadQueueService.update_progress(
|
|
||||||
db,
|
|
||||||
item_id,
|
|
||||||
progress_percent=50.0,
|
|
||||||
downloaded_bytes=500000,
|
|
||||||
total_bytes=1000000,
|
|
||||||
download_speed=50000.0
|
|
||||||
)
|
|
||||||
|
|
||||||
# Clear completed
|
|
||||||
count = await DownloadQueueService.clear_completed(db)
|
|
||||||
|
|
||||||
# Retry failed downloads
|
|
||||||
retried = await DownloadQueueService.retry_failed(db, max_retries=3)
|
|
||||||
```
|
|
||||||
|
|
||||||
### UserSessionService
|
|
||||||
|
|
||||||
```python
|
|
||||||
from src.server.database import UserSessionService
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
|
|
||||||
# Create session
|
|
||||||
expires_at = datetime.utcnow() + timedelta(hours=24)
|
|
||||||
session = await UserSessionService.create(
|
|
||||||
db,
|
|
||||||
session_id="unique-session-id",
|
|
||||||
token_hash="hashed-jwt-token",
|
|
||||||
expires_at=expires_at,
|
|
||||||
user_id="user123",
|
|
||||||
ip_address="127.0.0.1"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get session
|
|
||||||
session = await UserSessionService.get_by_session_id(db, "session-id")
|
|
||||||
|
|
||||||
# Get active sessions
|
|
||||||
active = await UserSessionService.get_active_sessions(db, user_id="user123")
|
|
||||||
|
|
||||||
# Update activity
|
|
||||||
updated = await UserSessionService.update_activity(db, "session-id")
|
|
||||||
|
|
||||||
# Revoke session
|
|
||||||
revoked = await UserSessionService.revoke(db, "session-id")
|
|
||||||
|
|
||||||
# Cleanup expired sessions
|
|
||||||
count = await UserSessionService.cleanup_expired(db)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Database Schema
|
|
||||||
|
|
||||||
```
|
|
||||||
anime_series (id, key, name, site, folder, ...)
|
|
||||||
├── episodes (id, series_id, season, episode_number, ...)
|
|
||||||
└── download_queue (id, series_id, season, episode_number, status, ...)
|
|
||||||
|
|
||||||
user_sessions (id, session_id, token_hash, expires_at, ...)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Production Considerations
|
|
||||||
|
|
||||||
### SQLite (Current)
|
|
||||||
|
|
||||||
- Single file: `data/aniworld.db`
|
|
||||||
- WAL mode for concurrency
|
|
||||||
- Foreign keys enabled
|
|
||||||
- Static connection pool
|
|
||||||
|
|
||||||
### PostgreSQL/MySQL (Future)
|
|
||||||
|
|
||||||
For multi-process deployments:
|
|
||||||
|
|
||||||
```python
|
|
||||||
DATABASE_URL=postgresql+asyncpg://user:pass@host/db
|
|
||||||
# or
|
|
||||||
DATABASE_URL=mysql+aiomysql://user:pass@host/db
|
|
||||||
```
|
|
||||||
|
|
||||||
Configure connection pooling:
|
|
||||||
|
|
||||||
```python
|
|
||||||
engine = create_async_engine(
|
|
||||||
url,
|
|
||||||
pool_size=10,
|
|
||||||
max_overflow=20,
|
|
||||||
pool_pre_ping=True
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Performance Tips
|
|
||||||
|
|
||||||
1. **Indexes**: Models have indexes on frequently queried columns
|
|
||||||
2. **Relationships**: Use `selectinload()` or `joinedload()` for eager loading
|
|
||||||
3. **Batching**: Use bulk operations for multiple inserts/updates
|
|
||||||
4. **Query Optimization**: Profile slow queries in DEBUG mode
|
|
||||||
|
|
||||||
Example with eager loading:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from sqlalchemy.orm import selectinload
|
|
||||||
|
|
||||||
result = await db.execute(
|
|
||||||
select(AnimeSeries)
|
|
||||||
.options(selectinload(AnimeSeries.episodes))
|
|
||||||
.where(AnimeSeries.key == "attack-on-titan")
|
|
||||||
)
|
|
||||||
series = result.scalar_one()
|
|
||||||
# episodes already loaded, no additional queries
|
|
||||||
```
|
|
||||||
|
|
||||||
## Troubleshooting
|
|
||||||
|
|
||||||
### Database not initialized
|
|
||||||
|
|
||||||
```
|
|
||||||
RuntimeError: Database not initialized. Call init_db() first.
|
|
||||||
```
|
|
||||||
|
|
||||||
Solution: Call `await init_db()` during application startup.
|
|
||||||
|
|
||||||
### Table does not exist
|
|
||||||
|
|
||||||
```
|
|
||||||
sqlalchemy.exc.OperationalError: no such table: anime_series
|
|
||||||
```
|
|
||||||
|
|
||||||
Solution: `Base.metadata.create_all()` is called automatically by `init_db()`.
|
|
||||||
|
|
||||||
### Foreign key constraint failed
|
|
||||||
|
|
||||||
```
|
|
||||||
sqlalchemy.exc.IntegrityError: FOREIGN KEY constraint failed
|
|
||||||
```
|
|
||||||
|
|
||||||
Solution: Ensure referenced records exist before creating relationships.
|
|
||||||
|
|
||||||
## Further Reading
|
|
||||||
|
|
||||||
- [SQLAlchemy 2.0 Documentation](https://docs.sqlalchemy.org/en/20/)
|
|
||||||
- [Alembic Tutorial](https://alembic.sqlalchemy.org/en/latest/tutorial.html)
|
|
||||||
- [FastAPI with Databases](https://fastapi.tiangolo.com/tutorial/sql-databases/)
|
|
||||||
@ -1,80 +0,0 @@
|
|||||||
"""Database package for the Aniworld web application.
|
|
||||||
|
|
||||||
This package provides SQLAlchemy models, database connection management,
|
|
||||||
and session handling for persistent storage.
|
|
||||||
|
|
||||||
Modules:
|
|
||||||
- models: SQLAlchemy ORM models for anime series, episodes, download queue, and sessions
|
|
||||||
- connection: Database engine and session factory configuration
|
|
||||||
- base: Base class for all SQLAlchemy models
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
from src.server.database import get_db_session, init_db
|
|
||||||
|
|
||||||
# Initialize database on application startup
|
|
||||||
init_db()
|
|
||||||
|
|
||||||
# Use in FastAPI endpoints
|
|
||||||
@app.get("/anime")
|
|
||||||
async def get_anime(db: AsyncSession = Depends(get_db_session)):
|
|
||||||
result = await db.execute(select(AnimeSeries))
|
|
||||||
return result.scalars().all()
|
|
||||||
"""
|
|
||||||
|
|
||||||
from src.server.database.base import Base
|
|
||||||
from src.server.database.connection import close_db, get_db_session, init_db
|
|
||||||
from src.server.database.init import (
|
|
||||||
CURRENT_SCHEMA_VERSION,
|
|
||||||
EXPECTED_TABLES,
|
|
||||||
check_database_health,
|
|
||||||
create_database_backup,
|
|
||||||
create_database_schema,
|
|
||||||
get_database_info,
|
|
||||||
get_migration_guide,
|
|
||||||
get_schema_version,
|
|
||||||
initialize_database,
|
|
||||||
seed_initial_data,
|
|
||||||
validate_database_schema,
|
|
||||||
)
|
|
||||||
from src.server.database.models import (
|
|
||||||
AnimeSeries,
|
|
||||||
DownloadQueueItem,
|
|
||||||
Episode,
|
|
||||||
UserSession,
|
|
||||||
)
|
|
||||||
from src.server.database.service import (
|
|
||||||
AnimeSeriesService,
|
|
||||||
DownloadQueueService,
|
|
||||||
EpisodeService,
|
|
||||||
UserSessionService,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
# Base and connection
|
|
||||||
"Base",
|
|
||||||
"get_db_session",
|
|
||||||
"init_db",
|
|
||||||
"close_db",
|
|
||||||
# Initialization functions
|
|
||||||
"initialize_database",
|
|
||||||
"create_database_schema",
|
|
||||||
"validate_database_schema",
|
|
||||||
"get_schema_version",
|
|
||||||
"seed_initial_data",
|
|
||||||
"check_database_health",
|
|
||||||
"create_database_backup",
|
|
||||||
"get_database_info",
|
|
||||||
"get_migration_guide",
|
|
||||||
"CURRENT_SCHEMA_VERSION",
|
|
||||||
"EXPECTED_TABLES",
|
|
||||||
# Models
|
|
||||||
"AnimeSeries",
|
|
||||||
"Episode",
|
|
||||||
"DownloadQueueItem",
|
|
||||||
"UserSession",
|
|
||||||
# Services
|
|
||||||
"AnimeSeriesService",
|
|
||||||
"EpisodeService",
|
|
||||||
"DownloadQueueService",
|
|
||||||
"UserSessionService",
|
|
||||||
]
|
|
||||||
@ -1,74 +0,0 @@
|
|||||||
"""Base SQLAlchemy declarative base for all database models.
|
|
||||||
|
|
||||||
This module provides the base class that all ORM models inherit from,
|
|
||||||
along with common functionality and mixins.
|
|
||||||
"""
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from sqlalchemy import DateTime, func
|
|
||||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
|
||||||
|
|
||||||
|
|
||||||
class Base(DeclarativeBase):
|
|
||||||
"""Base class for all SQLAlchemy ORM models.
|
|
||||||
|
|
||||||
Provides common functionality and type annotations for all models.
|
|
||||||
All models should inherit from this class.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class TimestampMixin:
|
|
||||||
"""Mixin to add created_at and updated_at timestamp columns.
|
|
||||||
|
|
||||||
Automatically tracks when records are created and updated.
|
|
||||||
Use this mixin for models that need audit timestamps.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
created_at: Timestamp when record was created
|
|
||||||
updated_at: Timestamp when record was last updated
|
|
||||||
"""
|
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True),
|
|
||||||
server_default=func.now(),
|
|
||||||
nullable=False,
|
|
||||||
doc="Timestamp when record was created"
|
|
||||||
)
|
|
||||||
updated_at: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True),
|
|
||||||
server_default=func.now(),
|
|
||||||
onupdate=func.now(),
|
|
||||||
nullable=False,
|
|
||||||
doc="Timestamp when record was last updated"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SoftDeleteMixin:
|
|
||||||
"""Mixin to add soft delete functionality.
|
|
||||||
|
|
||||||
Instead of deleting records, marks them as deleted with a timestamp.
|
|
||||||
Useful for maintaining audit trails and allowing recovery.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
deleted_at: Timestamp when record was soft deleted, None if active
|
|
||||||
"""
|
|
||||||
deleted_at: Mapped[datetime | None] = mapped_column(
|
|
||||||
DateTime(timezone=True),
|
|
||||||
nullable=True,
|
|
||||||
default=None,
|
|
||||||
doc="Timestamp when record was soft deleted"
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_deleted(self) -> bool:
|
|
||||||
"""Check if record is soft deleted."""
|
|
||||||
return self.deleted_at is not None
|
|
||||||
|
|
||||||
def soft_delete(self) -> None:
|
|
||||||
"""Mark record as deleted without removing from database."""
|
|
||||||
self.deleted_at = datetime.utcnow()
|
|
||||||
|
|
||||||
def restore(self) -> None:
|
|
||||||
"""Restore a soft deleted record."""
|
|
||||||
self.deleted_at = None
|
|
||||||
@ -1,258 +0,0 @@
|
|||||||
"""Database connection and session management for SQLAlchemy.
|
|
||||||
|
|
||||||
This module provides database engine creation, session factory configuration,
|
|
||||||
and dependency injection helpers for FastAPI endpoints.
|
|
||||||
|
|
||||||
Functions:
|
|
||||||
- init_db: Initialize database engine and create tables
|
|
||||||
- close_db: Close database connections and cleanup
|
|
||||||
- get_db_session: FastAPI dependency for database sessions
|
|
||||||
- get_engine: Get database engine instance
|
|
||||||
"""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
from typing import AsyncGenerator, Optional
|
|
||||||
|
|
||||||
from sqlalchemy import create_engine, event, pool
|
|
||||||
from sqlalchemy.ext.asyncio import (
|
|
||||||
AsyncEngine,
|
|
||||||
AsyncSession,
|
|
||||||
async_sessionmaker,
|
|
||||||
create_async_engine,
|
|
||||||
)
|
|
||||||
from sqlalchemy.orm import Session, sessionmaker
|
|
||||||
|
|
||||||
from src.config.settings import settings
|
|
||||||
from src.server.database.base import Base
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Global engine and session factory instances
|
|
||||||
_engine: Optional[AsyncEngine] = None
|
|
||||||
_sync_engine: Optional[create_engine] = None
|
|
||||||
_session_factory: Optional[async_sessionmaker[AsyncSession]] = None
|
|
||||||
_sync_session_factory: Optional[sessionmaker[Session]] = None
|
|
||||||
|
|
||||||
|
|
||||||
def _get_database_url() -> str:
|
|
||||||
"""Get database URL from settings.
|
|
||||||
|
|
||||||
Converts SQLite URLs to async format if needed.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Database URL string suitable for async engine
|
|
||||||
"""
|
|
||||||
url = settings.database_url
|
|
||||||
|
|
||||||
# Convert sqlite:/// to sqlite+aiosqlite:/// for async support
|
|
||||||
if url.startswith("sqlite:///"):
|
|
||||||
url = url.replace("sqlite:///", "sqlite+aiosqlite:///")
|
|
||||||
|
|
||||||
return url
|
|
||||||
|
|
||||||
|
|
||||||
def _configure_sqlite_engine(engine: AsyncEngine) -> None:
|
|
||||||
"""Configure SQLite-specific engine settings.
|
|
||||||
|
|
||||||
Enables foreign key support and optimizes connection pooling.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
engine: SQLAlchemy async engine instance
|
|
||||||
"""
|
|
||||||
@event.listens_for(engine.sync_engine, "connect")
|
|
||||||
def set_sqlite_pragma(dbapi_conn, connection_record):
|
|
||||||
"""Enable foreign keys and set pragmas for SQLite."""
|
|
||||||
cursor = dbapi_conn.cursor()
|
|
||||||
cursor.execute("PRAGMA foreign_keys=ON")
|
|
||||||
cursor.execute("PRAGMA journal_mode=WAL")
|
|
||||||
cursor.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def init_db() -> None:
|
|
||||||
"""Initialize database engine and create tables.
|
|
||||||
|
|
||||||
Creates async and sync engines, session factories, and database tables.
|
|
||||||
Should be called during application startup.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
Exception: If database initialization fails
|
|
||||||
"""
|
|
||||||
global _engine, _sync_engine, _session_factory, _sync_session_factory
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Get database URL
|
|
||||||
db_url = _get_database_url()
|
|
||||||
logger.info(f"Initializing database: {db_url}")
|
|
||||||
|
|
||||||
# Create async engine
|
|
||||||
_engine = create_async_engine(
|
|
||||||
db_url,
|
|
||||||
echo=settings.log_level == "DEBUG",
|
|
||||||
poolclass=pool.StaticPool if "sqlite" in db_url else pool.QueuePool,
|
|
||||||
pool_pre_ping=True,
|
|
||||||
future=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Configure SQLite if needed
|
|
||||||
if "sqlite" in db_url:
|
|
||||||
_configure_sqlite_engine(_engine)
|
|
||||||
|
|
||||||
# Create async session factory
|
|
||||||
_session_factory = async_sessionmaker(
|
|
||||||
bind=_engine,
|
|
||||||
class_=AsyncSession,
|
|
||||||
expire_on_commit=False,
|
|
||||||
autoflush=False,
|
|
||||||
autocommit=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create sync engine for initial setup
|
|
||||||
sync_url = settings.database_url
|
|
||||||
_sync_engine = create_engine(
|
|
||||||
sync_url,
|
|
||||||
echo=settings.log_level == "DEBUG",
|
|
||||||
poolclass=pool.StaticPool if "sqlite" in sync_url else pool.QueuePool,
|
|
||||||
pool_pre_ping=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create sync session factory
|
|
||||||
_sync_session_factory = sessionmaker(
|
|
||||||
bind=_sync_engine,
|
|
||||||
expire_on_commit=False,
|
|
||||||
autoflush=False,
|
|
||||||
autocommit=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create all tables
|
|
||||||
logger.info("Creating database tables...")
|
|
||||||
Base.metadata.create_all(bind=_sync_engine)
|
|
||||||
logger.info("Database initialization complete")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to initialize database: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
async def close_db() -> None:
|
|
||||||
"""Close database connections and cleanup resources.
|
|
||||||
|
|
||||||
Should be called during application shutdown.
|
|
||||||
"""
|
|
||||||
global _engine, _sync_engine, _session_factory, _sync_session_factory
|
|
||||||
|
|
||||||
try:
|
|
||||||
if _engine:
|
|
||||||
logger.info("Closing async database engine...")
|
|
||||||
await _engine.dispose()
|
|
||||||
_engine = None
|
|
||||||
_session_factory = None
|
|
||||||
|
|
||||||
if _sync_engine:
|
|
||||||
logger.info("Closing sync database engine...")
|
|
||||||
_sync_engine.dispose()
|
|
||||||
_sync_engine = None
|
|
||||||
_sync_session_factory = None
|
|
||||||
|
|
||||||
logger.info("Database connections closed")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error closing database: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
def get_engine() -> AsyncEngine:
|
|
||||||
"""Get the database engine instance.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
AsyncEngine instance
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If database is not initialized
|
|
||||||
"""
|
|
||||||
if _engine is None:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Database not initialized. Call init_db() first."
|
|
||||||
)
|
|
||||||
return _engine
|
|
||||||
|
|
||||||
|
|
||||||
def get_sync_engine():
|
|
||||||
"""Get the sync database engine instance.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Engine instance
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If database is not initialized
|
|
||||||
"""
|
|
||||||
if _sync_engine is None:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Database not initialized. Call init_db() first."
|
|
||||||
)
|
|
||||||
return _sync_engine
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
|
|
||||||
"""FastAPI dependency to get database session.
|
|
||||||
|
|
||||||
Provides an async database session with automatic commit/rollback.
|
|
||||||
Use this as a dependency in FastAPI endpoints.
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
AsyncSession: Database session for async operations
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If database is not initialized
|
|
||||||
|
|
||||||
Example:
|
|
||||||
@app.get("/anime")
|
|
||||||
async def get_anime(
|
|
||||||
db: AsyncSession = Depends(get_db_session)
|
|
||||||
):
|
|
||||||
result = await db.execute(select(AnimeSeries))
|
|
||||||
return result.scalars().all()
|
|
||||||
"""
|
|
||||||
if _session_factory is None:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Database not initialized. Call init_db() first."
|
|
||||||
)
|
|
||||||
|
|
||||||
session = _session_factory()
|
|
||||||
try:
|
|
||||||
yield session
|
|
||||||
await session.commit()
|
|
||||||
except Exception:
|
|
||||||
await session.rollback()
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
await session.close()
|
|
||||||
|
|
||||||
|
|
||||||
def get_sync_session() -> Session:
|
|
||||||
"""Get a sync database session.
|
|
||||||
|
|
||||||
Use this for synchronous operations outside FastAPI endpoints.
|
|
||||||
Remember to close the session when done.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Session: Database session for sync operations
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If database is not initialized
|
|
||||||
|
|
||||||
Example:
|
|
||||||
session = get_sync_session()
|
|
||||||
try:
|
|
||||||
result = session.execute(select(AnimeSeries))
|
|
||||||
return result.scalars().all()
|
|
||||||
finally:
|
|
||||||
session.close()
|
|
||||||
"""
|
|
||||||
if _sync_session_factory is None:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Database not initialized. Call init_db() first."
|
|
||||||
)
|
|
||||||
|
|
||||||
return _sync_session_factory()
|
|
||||||
@ -1,479 +0,0 @@
|
|||||||
"""Example integration of database service with existing services.
|
|
||||||
|
|
||||||
This file demonstrates how to integrate the database service layer with
|
|
||||||
existing application services like AnimeService and DownloadService.
|
|
||||||
|
|
||||||
These examples show patterns for:
|
|
||||||
- Persisting scan results to database
|
|
||||||
- Loading queue from database on startup
|
|
||||||
- Syncing download progress to database
|
|
||||||
- Maintaining consistency between in-memory state and database
|
|
||||||
"""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from src.core.entities.series import Serie
|
|
||||||
from src.server.database.models import DownloadPriority, DownloadStatus
|
|
||||||
from src.server.database.service import (
|
|
||||||
AnimeSeriesService,
|
|
||||||
DownloadQueueService,
|
|
||||||
EpisodeService,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Example 1: Persist Scan Results
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
async def persist_scan_results(
|
|
||||||
db: AsyncSession,
|
|
||||||
series_list: List[Serie],
|
|
||||||
) -> None:
|
|
||||||
"""Persist scan results to database.
|
|
||||||
|
|
||||||
Updates or creates anime series and their episodes based on
|
|
||||||
scan results from SerieScanner.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
series_list: List of Serie objects from scan
|
|
||||||
"""
|
|
||||||
logger.info(f"Persisting {len(series_list)} series to database")
|
|
||||||
|
|
||||||
for serie in series_list:
|
|
||||||
# Check if series exists
|
|
||||||
existing = await AnimeSeriesService.get_by_key(db, serie.key)
|
|
||||||
|
|
||||||
if existing:
|
|
||||||
# Update existing series
|
|
||||||
await AnimeSeriesService.update(
|
|
||||||
db,
|
|
||||||
existing.id,
|
|
||||||
name=serie.name,
|
|
||||||
site=serie.site,
|
|
||||||
folder=serie.folder,
|
|
||||||
episode_dict=serie.episode_dict,
|
|
||||||
)
|
|
||||||
series_id = existing.id
|
|
||||||
else:
|
|
||||||
# Create new series
|
|
||||||
new_series = await AnimeSeriesService.create(
|
|
||||||
db,
|
|
||||||
key=serie.key,
|
|
||||||
name=serie.name,
|
|
||||||
site=serie.site,
|
|
||||||
folder=serie.folder,
|
|
||||||
episode_dict=serie.episode_dict,
|
|
||||||
)
|
|
||||||
series_id = new_series.id
|
|
||||||
|
|
||||||
# Update episodes for this series
|
|
||||||
await _update_episodes(db, series_id, serie)
|
|
||||||
|
|
||||||
await db.commit()
|
|
||||||
logger.info("Scan results persisted successfully")
|
|
||||||
|
|
||||||
|
|
||||||
async def _update_episodes(
|
|
||||||
db: AsyncSession,
|
|
||||||
series_id: int,
|
|
||||||
serie: Serie,
|
|
||||||
) -> None:
|
|
||||||
"""Update episodes for a series.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
series_id: Series ID in database
|
|
||||||
serie: Serie object with episode information
|
|
||||||
"""
|
|
||||||
# Get existing episodes
|
|
||||||
existing_episodes = await EpisodeService.get_by_series(db, series_id)
|
|
||||||
existing_map = {
|
|
||||||
(ep.season, ep.episode_number): ep
|
|
||||||
for ep in existing_episodes
|
|
||||||
}
|
|
||||||
|
|
||||||
# Iterate through episode_dict to create/update episodes
|
|
||||||
for season, episodes in serie.episode_dict.items():
|
|
||||||
for ep_num in episodes:
|
|
||||||
key = (int(season), int(ep_num))
|
|
||||||
|
|
||||||
if key in existing_map:
|
|
||||||
# Episode exists, check if downloaded
|
|
||||||
episode = existing_map[key]
|
|
||||||
# Update if needed (e.g., file path changed)
|
|
||||||
if not episode.is_downloaded:
|
|
||||||
# Check if file exists locally
|
|
||||||
# This would be done by checking serie.local_episodes
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
# Create new episode
|
|
||||||
await EpisodeService.create(
|
|
||||||
db,
|
|
||||||
series_id=series_id,
|
|
||||||
season=int(season),
|
|
||||||
episode_number=int(ep_num),
|
|
||||||
is_downloaded=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Example 2: Load Queue from Database
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
async def load_queue_from_database(
|
|
||||||
db: AsyncSession,
|
|
||||||
) -> List[dict]:
|
|
||||||
"""Load download queue from database.
|
|
||||||
|
|
||||||
Retrieves pending and active download items from database and
|
|
||||||
converts them to format suitable for DownloadService.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of download items as dictionaries
|
|
||||||
"""
|
|
||||||
logger.info("Loading download queue from database")
|
|
||||||
|
|
||||||
# Get pending and active items
|
|
||||||
pending = await DownloadQueueService.get_pending(db)
|
|
||||||
active = await DownloadQueueService.get_active(db)
|
|
||||||
|
|
||||||
all_items = pending + active
|
|
||||||
|
|
||||||
# Convert to dictionary format for DownloadService
|
|
||||||
queue_items = []
|
|
||||||
for item in all_items:
|
|
||||||
queue_items.append({
|
|
||||||
"id": item.id,
|
|
||||||
"series_id": item.series_id,
|
|
||||||
"season": item.season,
|
|
||||||
"episode_number": item.episode_number,
|
|
||||||
"status": item.status.value,
|
|
||||||
"priority": item.priority.value,
|
|
||||||
"progress_percent": item.progress_percent,
|
|
||||||
"downloaded_bytes": item.downloaded_bytes,
|
|
||||||
"total_bytes": item.total_bytes,
|
|
||||||
"download_speed": item.download_speed,
|
|
||||||
"error_message": item.error_message,
|
|
||||||
"retry_count": item.retry_count,
|
|
||||||
})
|
|
||||||
|
|
||||||
logger.info(f"Loaded {len(queue_items)} items from database")
|
|
||||||
return queue_items
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Example 3: Sync Download Progress to Database
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
async def sync_download_progress(
|
|
||||||
db: AsyncSession,
|
|
||||||
item_id: int,
|
|
||||||
progress_percent: float,
|
|
||||||
downloaded_bytes: int,
|
|
||||||
total_bytes: Optional[int] = None,
|
|
||||||
download_speed: Optional[float] = None,
|
|
||||||
) -> None:
|
|
||||||
"""Sync download progress to database.
|
|
||||||
|
|
||||||
Updates download queue item progress in database. This would be called
|
|
||||||
from the download progress callback.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
item_id: Download queue item ID
|
|
||||||
progress_percent: Progress percentage (0-100)
|
|
||||||
downloaded_bytes: Bytes downloaded
|
|
||||||
total_bytes: Optional total file size
|
|
||||||
download_speed: Optional current speed (bytes/sec)
|
|
||||||
"""
|
|
||||||
await DownloadQueueService.update_progress(
|
|
||||||
db,
|
|
||||||
item_id,
|
|
||||||
progress_percent,
|
|
||||||
downloaded_bytes,
|
|
||||||
total_bytes,
|
|
||||||
download_speed,
|
|
||||||
)
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
|
|
||||||
async def mark_download_complete(
|
|
||||||
db: AsyncSession,
|
|
||||||
item_id: int,
|
|
||||||
file_path: str,
|
|
||||||
file_size: int,
|
|
||||||
) -> None:
|
|
||||||
"""Mark download as complete in database.
|
|
||||||
|
|
||||||
Updates download queue item status and marks episode as downloaded.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
item_id: Download queue item ID
|
|
||||||
file_path: Path to downloaded file
|
|
||||||
file_size: File size in bytes
|
|
||||||
"""
|
|
||||||
# Get download item
|
|
||||||
item = await DownloadQueueService.get_by_id(db, item_id)
|
|
||||||
if not item:
|
|
||||||
logger.error(f"Download item {item_id} not found")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Update download status
|
|
||||||
await DownloadQueueService.update_status(
|
|
||||||
db,
|
|
||||||
item_id,
|
|
||||||
DownloadStatus.COMPLETED,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Find or create episode and mark as downloaded
|
|
||||||
episode = await EpisodeService.get_by_episode(
|
|
||||||
db,
|
|
||||||
item.series_id,
|
|
||||||
item.season,
|
|
||||||
item.episode_number,
|
|
||||||
)
|
|
||||||
|
|
||||||
if episode:
|
|
||||||
await EpisodeService.mark_downloaded(
|
|
||||||
db,
|
|
||||||
episode.id,
|
|
||||||
file_path,
|
|
||||||
file_size,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Create episode
|
|
||||||
episode = await EpisodeService.create(
|
|
||||||
db,
|
|
||||||
series_id=item.series_id,
|
|
||||||
season=item.season,
|
|
||||||
episode_number=item.episode_number,
|
|
||||||
file_path=file_path,
|
|
||||||
file_size=file_size,
|
|
||||||
is_downloaded=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
await db.commit()
|
|
||||||
logger.info(
|
|
||||||
f"Marked download complete: S{item.season:02d}E{item.episode_number:02d}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def mark_download_failed(
|
|
||||||
db: AsyncSession,
|
|
||||||
item_id: int,
|
|
||||||
error_message: str,
|
|
||||||
) -> None:
|
|
||||||
"""Mark download as failed in database.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
item_id: Download queue item ID
|
|
||||||
error_message: Error description
|
|
||||||
"""
|
|
||||||
await DownloadQueueService.update_status(
|
|
||||||
db,
|
|
||||||
item_id,
|
|
||||||
DownloadStatus.FAILED,
|
|
||||||
error_message=error_message,
|
|
||||||
)
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Example 4: Add Episodes to Download Queue
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
async def add_episodes_to_queue(
|
|
||||||
db: AsyncSession,
|
|
||||||
series_key: str,
|
|
||||||
episodes: List[tuple[int, int]], # List of (season, episode) tuples
|
|
||||||
priority: DownloadPriority = DownloadPriority.NORMAL,
|
|
||||||
) -> int:
|
|
||||||
"""Add multiple episodes to download queue.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
series_key: Series provider key
|
|
||||||
episodes: List of (season, episode_number) tuples
|
|
||||||
priority: Download priority
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Number of episodes added to queue
|
|
||||||
"""
|
|
||||||
# Get series
|
|
||||||
series = await AnimeSeriesService.get_by_key(db, series_key)
|
|
||||||
if not series:
|
|
||||||
logger.error(f"Series not found: {series_key}")
|
|
||||||
return 0
|
|
||||||
|
|
||||||
added_count = 0
|
|
||||||
for season, episode_number in episodes:
|
|
||||||
# Check if already in queue
|
|
||||||
existing_items = await DownloadQueueService.get_all(db)
|
|
||||||
already_queued = any(
|
|
||||||
item.series_id == series.id
|
|
||||||
and item.season == season
|
|
||||||
and item.episode_number == episode_number
|
|
||||||
and item.status in (DownloadStatus.PENDING, DownloadStatus.DOWNLOADING)
|
|
||||||
for item in existing_items
|
|
||||||
)
|
|
||||||
|
|
||||||
if not already_queued:
|
|
||||||
await DownloadQueueService.create(
|
|
||||||
db,
|
|
||||||
series_id=series.id,
|
|
||||||
season=season,
|
|
||||||
episode_number=episode_number,
|
|
||||||
priority=priority,
|
|
||||||
)
|
|
||||||
added_count += 1
|
|
||||||
|
|
||||||
await db.commit()
|
|
||||||
logger.info(f"Added {added_count} episodes to download queue")
|
|
||||||
return added_count
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Example 5: Integration with AnimeService
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
class EnhancedAnimeService:
|
|
||||||
"""Enhanced AnimeService with database persistence.
|
|
||||||
|
|
||||||
This is an example of how to wrap the existing AnimeService with
|
|
||||||
database persistence capabilities.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, db_session_factory):
|
|
||||||
"""Initialize enhanced anime service.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db_session_factory: Async session factory for database access
|
|
||||||
"""
|
|
||||||
self.db_session_factory = db_session_factory
|
|
||||||
|
|
||||||
async def rescan_with_persistence(self, directory: str) -> dict:
|
|
||||||
"""Rescan directory and persist results.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
directory: Directory to scan
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Scan results dictionary
|
|
||||||
"""
|
|
||||||
# Import here to avoid circular dependencies
|
|
||||||
from src.core.SeriesApp import SeriesApp
|
|
||||||
|
|
||||||
# Perform scan
|
|
||||||
app = SeriesApp(directory)
|
|
||||||
series_list = app.ReScan()
|
|
||||||
|
|
||||||
# Persist to database
|
|
||||||
async with self.db_session_factory() as db:
|
|
||||||
await persist_scan_results(db, series_list)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"total_series": len(series_list),
|
|
||||||
"message": "Scan completed and persisted to database",
|
|
||||||
}
|
|
||||||
|
|
||||||
async def get_series_with_missing_episodes(self) -> List[dict]:
|
|
||||||
"""Get series with missing episodes from database.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of series with missing episodes
|
|
||||||
"""
|
|
||||||
async with self.db_session_factory() as db:
|
|
||||||
# Get all series
|
|
||||||
all_series = await AnimeSeriesService.get_all(
|
|
||||||
db,
|
|
||||||
with_episodes=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Filter series with missing episodes
|
|
||||||
series_with_missing = []
|
|
||||||
for series in all_series:
|
|
||||||
if series.episode_dict:
|
|
||||||
total_episodes = sum(
|
|
||||||
len(eps) for eps in series.episode_dict.values()
|
|
||||||
)
|
|
||||||
downloaded_episodes = sum(
|
|
||||||
1 for ep in series.episodes if ep.is_downloaded
|
|
||||||
)
|
|
||||||
|
|
||||||
if downloaded_episodes < total_episodes:
|
|
||||||
series_with_missing.append({
|
|
||||||
"id": series.id,
|
|
||||||
"key": series.key,
|
|
||||||
"name": series.name,
|
|
||||||
"total_episodes": total_episodes,
|
|
||||||
"downloaded_episodes": downloaded_episodes,
|
|
||||||
"missing_episodes": total_episodes - downloaded_episodes,
|
|
||||||
})
|
|
||||||
|
|
||||||
return series_with_missing
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Usage Example
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
async def example_usage():
|
|
||||||
"""Example usage of database service integration."""
|
|
||||||
from src.server.database import get_db_session
|
|
||||||
|
|
||||||
# Get database session
|
|
||||||
async with get_db_session() as db:
|
|
||||||
# Example 1: Add episodes to queue
|
|
||||||
added = await add_episodes_to_queue(
|
|
||||||
db,
|
|
||||||
series_key="attack-on-titan",
|
|
||||||
episodes=[(1, 1), (1, 2), (1, 3)],
|
|
||||||
priority=DownloadPriority.HIGH,
|
|
||||||
)
|
|
||||||
print(f"Added {added} episodes to queue")
|
|
||||||
|
|
||||||
# Example 2: Load queue
|
|
||||||
queue_items = await load_queue_from_database(db)
|
|
||||||
print(f"Queue has {len(queue_items)} items")
|
|
||||||
|
|
||||||
# Example 3: Update progress
|
|
||||||
if queue_items:
|
|
||||||
await sync_download_progress(
|
|
||||||
db,
|
|
||||||
item_id=queue_items[0]["id"],
|
|
||||||
progress_percent=50.0,
|
|
||||||
downloaded_bytes=500000,
|
|
||||||
total_bytes=1000000,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Example 4: Mark complete
|
|
||||||
if queue_items:
|
|
||||||
await mark_download_complete(
|
|
||||||
db,
|
|
||||||
item_id=queue_items[0]["id"],
|
|
||||||
file_path="/path/to/file.mp4",
|
|
||||||
file_size=1000000,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import asyncio
|
|
||||||
asyncio.run(example_usage())
|
|
||||||
@ -1,662 +0,0 @@
|
|||||||
"""Database initialization and setup module.
|
|
||||||
|
|
||||||
This module provides comprehensive database initialization functionality:
|
|
||||||
- Schema creation and validation
|
|
||||||
- Initial data migration
|
|
||||||
- Database health checks
|
|
||||||
- Schema versioning support
|
|
||||||
- Migration utilities
|
|
||||||
|
|
||||||
For production deployments, consider using Alembic for managed migrations.
|
|
||||||
"""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
from sqlalchemy import inspect, text
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncEngine
|
|
||||||
|
|
||||||
from src.config.settings import settings
|
|
||||||
from src.server.database.base import Base
|
|
||||||
from src.server.database.connection import get_engine
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Schema Version Constants
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
CURRENT_SCHEMA_VERSION = "1.0.0"
|
|
||||||
SCHEMA_VERSION_TABLE = "schema_version"
|
|
||||||
|
|
||||||
# Expected tables in the current schema
|
|
||||||
EXPECTED_TABLES = {
|
|
||||||
"anime_series",
|
|
||||||
"episodes",
|
|
||||||
"download_queue",
|
|
||||||
"user_sessions",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Expected indexes for performance
|
|
||||||
EXPECTED_INDEXES = {
|
|
||||||
"anime_series": ["ix_anime_series_key", "ix_anime_series_name"],
|
|
||||||
"episodes": ["ix_episodes_series_id"],
|
|
||||||
"download_queue": [
|
|
||||||
"ix_download_queue_series_id",
|
|
||||||
"ix_download_queue_status",
|
|
||||||
],
|
|
||||||
"user_sessions": [
|
|
||||||
"ix_user_sessions_session_id",
|
|
||||||
"ix_user_sessions_user_id",
|
|
||||||
"ix_user_sessions_is_active",
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Database Initialization
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
async def initialize_database(
|
|
||||||
engine: Optional[AsyncEngine] = None,
|
|
||||||
create_schema: bool = True,
|
|
||||||
validate_schema: bool = True,
|
|
||||||
seed_data: bool = False,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""Initialize database with schema creation and validation.
|
|
||||||
|
|
||||||
This is the main entry point for database initialization. It performs:
|
|
||||||
1. Schema creation (if requested)
|
|
||||||
2. Schema validation (if requested)
|
|
||||||
3. Initial data seeding (if requested)
|
|
||||||
4. Health check
|
|
||||||
|
|
||||||
Args:
|
|
||||||
engine: Optional database engine (uses default if not provided)
|
|
||||||
create_schema: Whether to create database schema
|
|
||||||
validate_schema: Whether to validate schema after creation
|
|
||||||
seed_data: Whether to seed initial data
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary with initialization results containing:
|
|
||||||
- success: Whether initialization succeeded
|
|
||||||
- schema_version: Current schema version
|
|
||||||
- tables_created: List of tables created
|
|
||||||
- validation_result: Schema validation result
|
|
||||||
- health_check: Database health status
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If database initialization fails
|
|
||||||
|
|
||||||
Example:
|
|
||||||
result = await initialize_database(
|
|
||||||
create_schema=True,
|
|
||||||
validate_schema=True,
|
|
||||||
seed_data=True
|
|
||||||
)
|
|
||||||
if result["success"]:
|
|
||||||
logger.info(f"Database initialized: {result['schema_version']}")
|
|
||||||
"""
|
|
||||||
if engine is None:
|
|
||||||
engine = get_engine()
|
|
||||||
|
|
||||||
logger.info("Starting database initialization...")
|
|
||||||
result = {
|
|
||||||
"success": False,
|
|
||||||
"schema_version": None,
|
|
||||||
"tables_created": [],
|
|
||||||
"validation_result": None,
|
|
||||||
"health_check": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Create schema if requested
|
|
||||||
if create_schema:
|
|
||||||
tables = await create_database_schema(engine)
|
|
||||||
result["tables_created"] = tables
|
|
||||||
logger.info(f"Created {len(tables)} tables")
|
|
||||||
|
|
||||||
# Validate schema if requested
|
|
||||||
if validate_schema:
|
|
||||||
validation = await validate_database_schema(engine)
|
|
||||||
result["validation_result"] = validation
|
|
||||||
|
|
||||||
if not validation["valid"]:
|
|
||||||
logger.warning(
|
|
||||||
f"Schema validation issues: {validation['issues']}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Seed initial data if requested
|
|
||||||
if seed_data:
|
|
||||||
await seed_initial_data(engine)
|
|
||||||
logger.info("Initial data seeding complete")
|
|
||||||
|
|
||||||
# Get schema version
|
|
||||||
version = await get_schema_version(engine)
|
|
||||||
result["schema_version"] = version
|
|
||||||
|
|
||||||
# Health check
|
|
||||||
health = await check_database_health(engine)
|
|
||||||
result["health_check"] = health
|
|
||||||
|
|
||||||
result["success"] = True
|
|
||||||
logger.info("Database initialization complete")
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Database initialization failed: {e}", exc_info=True)
|
|
||||||
raise RuntimeError(f"Failed to initialize database: {e}") from e
|
|
||||||
|
|
||||||
|
|
||||||
async def create_database_schema(
|
|
||||||
engine: Optional[AsyncEngine] = None
|
|
||||||
) -> List[str]:
|
|
||||||
"""Create database schema with all tables and indexes.
|
|
||||||
|
|
||||||
Creates all tables defined in Base.metadata if they don't exist.
|
|
||||||
This is idempotent - safe to call multiple times.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
engine: Optional database engine (uses default if not provided)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of table names created
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If schema creation fails
|
|
||||||
"""
|
|
||||||
if engine is None:
|
|
||||||
engine = get_engine()
|
|
||||||
|
|
||||||
logger.info("Creating database schema...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Create all tables
|
|
||||||
async with engine.begin() as conn:
|
|
||||||
# Get existing tables before creation
|
|
||||||
existing_tables = await conn.run_sync(
|
|
||||||
lambda sync_conn: inspect(sync_conn).get_table_names()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create all tables defined in Base
|
|
||||||
await conn.run_sync(Base.metadata.create_all)
|
|
||||||
|
|
||||||
# Get tables after creation
|
|
||||||
new_tables = await conn.run_sync(
|
|
||||||
lambda sync_conn: inspect(sync_conn).get_table_names()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Determine which tables were created
|
|
||||||
created_tables = [t for t in new_tables if t not in existing_tables]
|
|
||||||
|
|
||||||
if created_tables:
|
|
||||||
logger.info(f"Created tables: {', '.join(created_tables)}")
|
|
||||||
else:
|
|
||||||
logger.info("All tables already exist")
|
|
||||||
|
|
||||||
return new_tables
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to create schema: {e}", exc_info=True)
|
|
||||||
raise RuntimeError(f"Schema creation failed: {e}") from e
|
|
||||||
|
|
||||||
|
|
||||||
async def validate_database_schema(
|
|
||||||
engine: Optional[AsyncEngine] = None
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""Validate database schema integrity.
|
|
||||||
|
|
||||||
Checks that all expected tables, columns, and indexes exist.
|
|
||||||
Reports any missing or unexpected schema elements.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
engine: Optional database engine (uses default if not provided)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary with validation results containing:
|
|
||||||
- valid: Whether schema is valid
|
|
||||||
- missing_tables: List of missing tables
|
|
||||||
- extra_tables: List of unexpected tables
|
|
||||||
- missing_indexes: Dict of missing indexes by table
|
|
||||||
- issues: List of validation issues
|
|
||||||
"""
|
|
||||||
if engine is None:
|
|
||||||
engine = get_engine()
|
|
||||||
|
|
||||||
logger.info("Validating database schema...")
|
|
||||||
|
|
||||||
result = {
|
|
||||||
"valid": True,
|
|
||||||
"missing_tables": [],
|
|
||||||
"extra_tables": [],
|
|
||||||
"missing_indexes": {},
|
|
||||||
"issues": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with engine.connect() as conn:
|
|
||||||
# Get existing tables
|
|
||||||
existing_tables = await conn.run_sync(
|
|
||||||
lambda sync_conn: set(inspect(sync_conn).get_table_names())
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check for missing tables
|
|
||||||
missing = EXPECTED_TABLES - existing_tables
|
|
||||||
if missing:
|
|
||||||
result["missing_tables"] = list(missing)
|
|
||||||
result["valid"] = False
|
|
||||||
result["issues"].append(
|
|
||||||
f"Missing tables: {', '.join(missing)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check for extra tables (excluding SQLite internal tables)
|
|
||||||
extra = existing_tables - EXPECTED_TABLES
|
|
||||||
extra = {t for t in extra if not t.startswith("sqlite_")}
|
|
||||||
if extra:
|
|
||||||
result["extra_tables"] = list(extra)
|
|
||||||
result["issues"].append(
|
|
||||||
f"Unexpected tables: {', '.join(extra)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check indexes for each table
|
|
||||||
for table_name in EXPECTED_TABLES & existing_tables:
|
|
||||||
existing_indexes = await conn.run_sync(
|
|
||||||
lambda sync_conn: [
|
|
||||||
idx["name"]
|
|
||||||
for idx in inspect(sync_conn).get_indexes(table_name)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
expected_indexes = EXPECTED_INDEXES.get(table_name, [])
|
|
||||||
missing_indexes = [
|
|
||||||
idx for idx in expected_indexes
|
|
||||||
if idx not in existing_indexes
|
|
||||||
]
|
|
||||||
|
|
||||||
if missing_indexes:
|
|
||||||
result["missing_indexes"][table_name] = missing_indexes
|
|
||||||
result["valid"] = False
|
|
||||||
result["issues"].append(
|
|
||||||
f"Missing indexes on {table_name}: "
|
|
||||||
f"{', '.join(missing_indexes)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if result["valid"]:
|
|
||||||
logger.info("Schema validation passed")
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
f"Schema validation issues found: {len(result['issues'])}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Schema validation failed: {e}", exc_info=True)
|
|
||||||
return {
|
|
||||||
"valid": False,
|
|
||||||
"missing_tables": [],
|
|
||||||
"extra_tables": [],
|
|
||||||
"missing_indexes": {},
|
|
||||||
"issues": [f"Validation error: {str(e)}"],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Schema Version Management
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
async def get_schema_version(engine: Optional[AsyncEngine] = None) -> str:
|
|
||||||
"""Get current database schema version.
|
|
||||||
|
|
||||||
Returns version string based on existing tables and structure.
|
|
||||||
For production, consider using Alembic versioning.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
engine: Optional database engine (uses default if not provided)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Schema version string (e.g., "1.0.0", "empty", "unknown")
|
|
||||||
"""
|
|
||||||
if engine is None:
|
|
||||||
engine = get_engine()
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with engine.connect() as conn:
|
|
||||||
# Get existing tables
|
|
||||||
tables = await conn.run_sync(
|
|
||||||
lambda sync_conn: set(inspect(sync_conn).get_table_names())
|
|
||||||
)
|
|
||||||
|
|
||||||
# Filter out SQLite internal tables
|
|
||||||
tables = {t for t in tables if not t.startswith("sqlite_")}
|
|
||||||
|
|
||||||
if not tables:
|
|
||||||
return "empty"
|
|
||||||
elif tables == EXPECTED_TABLES:
|
|
||||||
return CURRENT_SCHEMA_VERSION
|
|
||||||
else:
|
|
||||||
return "unknown"
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to get schema version: {e}")
|
|
||||||
return "error"
|
|
||||||
|
|
||||||
|
|
||||||
async def create_schema_version_table(
|
|
||||||
engine: Optional[AsyncEngine] = None
|
|
||||||
) -> None:
|
|
||||||
"""Create schema version tracking table.
|
|
||||||
|
|
||||||
Future enhancement for tracking schema migrations with Alembic.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
engine: Optional database engine (uses default if not provided)
|
|
||||||
"""
|
|
||||||
if engine is None:
|
|
||||||
engine = get_engine()
|
|
||||||
|
|
||||||
async with engine.begin() as conn:
|
|
||||||
await conn.execute(
|
|
||||||
text(
|
|
||||||
f"""
|
|
||||||
CREATE TABLE IF NOT EXISTS {SCHEMA_VERSION_TABLE} (
|
|
||||||
version VARCHAR(20) PRIMARY KEY,
|
|
||||||
applied_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
description TEXT
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Initial Data Seeding
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
async def seed_initial_data(engine: Optional[AsyncEngine] = None) -> None:
|
|
||||||
"""Seed database with initial data.
|
|
||||||
|
|
||||||
Creates default configuration and sample data if database is empty.
|
|
||||||
Safe to call multiple times - only seeds if tables are empty.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
engine: Optional database engine (uses default if not provided)
|
|
||||||
"""
|
|
||||||
if engine is None:
|
|
||||||
engine = get_engine()
|
|
||||||
|
|
||||||
logger.info("Seeding initial data...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Use engine directly for seeding to avoid dependency on session factory
|
|
||||||
async with engine.connect() as conn:
|
|
||||||
# Check if data already exists
|
|
||||||
result = await conn.execute(
|
|
||||||
text("SELECT COUNT(*) FROM anime_series")
|
|
||||||
)
|
|
||||||
count = result.scalar()
|
|
||||||
|
|
||||||
if count > 0:
|
|
||||||
logger.info("Database already contains data, skipping seed")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Seed sample data if needed
|
|
||||||
# Note: In production, you may want to skip this
|
|
||||||
logger.info("Database is empty, but no sample data to seed")
|
|
||||||
logger.info("Data will be populated via normal application usage")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to seed initial data: {e}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Database Health Check
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
async def check_database_health(
|
|
||||||
engine: Optional[AsyncEngine] = None
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""Check database health and connectivity.
|
|
||||||
|
|
||||||
Performs basic health checks including:
|
|
||||||
- Database connectivity
|
|
||||||
- Table accessibility
|
|
||||||
- Basic query execution
|
|
||||||
|
|
||||||
Args:
|
|
||||||
engine: Optional database engine (uses default if not provided)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary with health check results containing:
|
|
||||||
- healthy: Overall health status
|
|
||||||
- accessible: Whether database is accessible
|
|
||||||
- tables: Number of tables
|
|
||||||
- connectivity_ms: Connection time in milliseconds
|
|
||||||
- issues: List of any health issues
|
|
||||||
"""
|
|
||||||
if engine is None:
|
|
||||||
engine = get_engine()
|
|
||||||
|
|
||||||
result = {
|
|
||||||
"healthy": True,
|
|
||||||
"accessible": False,
|
|
||||||
"tables": 0,
|
|
||||||
"connectivity_ms": 0,
|
|
||||||
"issues": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Measure connectivity time
|
|
||||||
import time
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
async with engine.connect() as conn:
|
|
||||||
# Test basic query
|
|
||||||
await conn.execute(text("SELECT 1"))
|
|
||||||
|
|
||||||
# Get table count
|
|
||||||
tables = await conn.run_sync(
|
|
||||||
lambda sync_conn: inspect(sync_conn).get_table_names()
|
|
||||||
)
|
|
||||||
result["tables"] = len(tables)
|
|
||||||
|
|
||||||
end_time = time.time()
|
|
||||||
# Ensure at least 1ms for timing (avoid 0 for very fast operations)
|
|
||||||
result["connectivity_ms"] = max(1, int((end_time - start_time) * 1000))
|
|
||||||
result["accessible"] = True
|
|
||||||
|
|
||||||
# Check for expected tables
|
|
||||||
if result["tables"] < len(EXPECTED_TABLES):
|
|
||||||
result["healthy"] = False
|
|
||||||
result["issues"].append(
|
|
||||||
f"Expected {len(EXPECTED_TABLES)} tables, "
|
|
||||||
f"found {result['tables']}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if result["healthy"]:
|
|
||||||
logger.info(
|
|
||||||
f"Database health check passed "
|
|
||||||
f"(connectivity: {result['connectivity_ms']}ms)"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.warning(f"Database health issues: {result['issues']}")
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Database health check failed: {e}")
|
|
||||||
return {
|
|
||||||
"healthy": False,
|
|
||||||
"accessible": False,
|
|
||||||
"tables": 0,
|
|
||||||
"connectivity_ms": 0,
|
|
||||||
"issues": [str(e)],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Database Backup and Restore
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
async def create_database_backup(
|
|
||||||
backup_path: Optional[Path] = None
|
|
||||||
) -> Path:
|
|
||||||
"""Create database backup.
|
|
||||||
|
|
||||||
For SQLite databases, creates a copy of the database file.
|
|
||||||
For other databases, this should be extended to use appropriate tools.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
backup_path: Optional path for backup file
|
|
||||||
(defaults to data/backups/aniworld_YYYYMMDD_HHMMSS.db)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Path to created backup file
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If backup creation fails
|
|
||||||
"""
|
|
||||||
import shutil
|
|
||||||
|
|
||||||
# Get database path from settings
|
|
||||||
db_url = settings.database_url
|
|
||||||
|
|
||||||
if not db_url.startswith("sqlite"):
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Backup currently only supported for SQLite databases"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract database file path
|
|
||||||
db_path = Path(db_url.replace("sqlite:///", ""))
|
|
||||||
|
|
||||||
if not db_path.exists():
|
|
||||||
raise RuntimeError(f"Database file not found: {db_path}")
|
|
||||||
|
|
||||||
# Create backup path
|
|
||||||
if backup_path is None:
|
|
||||||
backup_dir = Path("data/backups")
|
|
||||||
backup_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
||||||
backup_path = backup_dir / f"aniworld_{timestamp}.db"
|
|
||||||
|
|
||||||
try:
|
|
||||||
logger.info(f"Creating database backup: {backup_path}")
|
|
||||||
shutil.copy2(db_path, backup_path)
|
|
||||||
logger.info(f"Backup created successfully: {backup_path}")
|
|
||||||
return backup_path
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to create backup: {e}", exc_info=True)
|
|
||||||
raise RuntimeError(f"Backup creation failed: {e}") from e
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Utility Functions
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
def get_database_info() -> Dict[str, Any]:
|
|
||||||
"""Get database configuration information.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary with database configuration details
|
|
||||||
"""
|
|
||||||
return {
|
|
||||||
"database_url": settings.database_url,
|
|
||||||
"database_type": (
|
|
||||||
"sqlite" if "sqlite" in settings.database_url
|
|
||||||
else "postgresql" if "postgresql" in settings.database_url
|
|
||||||
else "mysql" if "mysql" in settings.database_url
|
|
||||||
else "unknown"
|
|
||||||
),
|
|
||||||
"schema_version": CURRENT_SCHEMA_VERSION,
|
|
||||||
"expected_tables": list(EXPECTED_TABLES),
|
|
||||||
"log_level": settings.log_level,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_migration_guide() -> str:
|
|
||||||
"""Get migration guide for production deployments.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Migration guide text
|
|
||||||
"""
|
|
||||||
return """
|
|
||||||
Database Migration Guide
|
|
||||||
========================
|
|
||||||
|
|
||||||
Current Setup: SQLAlchemy create_all()
|
|
||||||
- Automatically creates tables on startup
|
|
||||||
- Suitable for development and single-instance deployments
|
|
||||||
- Schema changes require manual handling
|
|
||||||
|
|
||||||
For Production with Alembic:
|
|
||||||
============================
|
|
||||||
|
|
||||||
1. Initialize Alembic (already installed):
|
|
||||||
alembic init alembic
|
|
||||||
|
|
||||||
2. Configure alembic/env.py:
|
|
||||||
from src.server.database.base import Base
|
|
||||||
target_metadata = Base.metadata
|
|
||||||
|
|
||||||
3. Configure alembic.ini:
|
|
||||||
sqlalchemy.url = <your-database-url>
|
|
||||||
|
|
||||||
4. Generate initial migration:
|
|
||||||
alembic revision --autogenerate -m "Initial schema v1.0.0"
|
|
||||||
|
|
||||||
5. Review migration in alembic/versions/
|
|
||||||
|
|
||||||
6. Apply migration:
|
|
||||||
alembic upgrade head
|
|
||||||
|
|
||||||
7. For future schema changes:
|
|
||||||
- Modify models in src/server/database/models.py
|
|
||||||
- Generate migration: alembic revision --autogenerate -m "Description"
|
|
||||||
- Review generated migration
|
|
||||||
- Test in staging environment
|
|
||||||
- Apply: alembic upgrade head
|
|
||||||
- For rollback: alembic downgrade -1
|
|
||||||
|
|
||||||
Best Practices:
|
|
||||||
==============
|
|
||||||
- Always backup database before migrations
|
|
||||||
- Test migrations in staging first
|
|
||||||
- Review auto-generated migrations carefully
|
|
||||||
- Keep migrations in version control
|
|
||||||
- Document breaking changes
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Public API
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"initialize_database",
|
|
||||||
"create_database_schema",
|
|
||||||
"validate_database_schema",
|
|
||||||
"get_schema_version",
|
|
||||||
"create_schema_version_table",
|
|
||||||
"seed_initial_data",
|
|
||||||
"check_database_health",
|
|
||||||
"create_database_backup",
|
|
||||||
"get_database_info",
|
|
||||||
"get_migration_guide",
|
|
||||||
"CURRENT_SCHEMA_VERSION",
|
|
||||||
"EXPECTED_TABLES",
|
|
||||||
]
|
|
||||||
@ -1,167 +0,0 @@
|
|||||||
"""Database migration utilities.
|
|
||||||
|
|
||||||
This module provides utilities for database migrations and schema versioning.
|
|
||||||
Alembic integration can be added when needed for production environments.
|
|
||||||
|
|
||||||
For now, we use SQLAlchemy's create_all for automatic schema creation.
|
|
||||||
"""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from sqlalchemy import text
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncEngine
|
|
||||||
|
|
||||||
from src.server.database.base import Base
|
|
||||||
from src.server.database.connection import get_engine, get_sync_engine
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
async def initialize_schema(engine: Optional[AsyncEngine] = None) -> None:
|
|
||||||
"""Initialize database schema.
|
|
||||||
|
|
||||||
Creates all tables defined in Base metadata if they don't exist.
|
|
||||||
This is a simple migration strategy suitable for single-instance deployments.
|
|
||||||
|
|
||||||
For production with multiple instances, consider using Alembic:
|
|
||||||
- alembic init alembic
|
|
||||||
- alembic revision --autogenerate -m "Initial schema"
|
|
||||||
- alembic upgrade head
|
|
||||||
|
|
||||||
Args:
|
|
||||||
engine: Optional database engine (uses default if not provided)
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If database is not initialized
|
|
||||||
"""
|
|
||||||
if engine is None:
|
|
||||||
engine = get_engine()
|
|
||||||
|
|
||||||
logger.info("Initializing database schema...")
|
|
||||||
|
|
||||||
# Create all tables
|
|
||||||
async with engine.begin() as conn:
|
|
||||||
await conn.run_sync(Base.metadata.create_all)
|
|
||||||
|
|
||||||
logger.info("Database schema initialized successfully")
|
|
||||||
|
|
||||||
|
|
||||||
async def check_schema_version(engine: Optional[AsyncEngine] = None) -> str:
|
|
||||||
"""Check current database schema version.
|
|
||||||
|
|
||||||
Returns a simple version identifier based on existing tables.
|
|
||||||
For production, consider using Alembic for proper versioning.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
engine: Optional database engine (uses default if not provided)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Schema version string
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If database is not initialized
|
|
||||||
"""
|
|
||||||
if engine is None:
|
|
||||||
engine = get_engine()
|
|
||||||
|
|
||||||
async with engine.connect() as conn:
|
|
||||||
# Check which tables exist
|
|
||||||
result = await conn.execute(
|
|
||||||
text(
|
|
||||||
"SELECT name FROM sqlite_master "
|
|
||||||
"WHERE type='table' AND name NOT LIKE 'sqlite_%'"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
tables = [row[0] for row in result]
|
|
||||||
|
|
||||||
if not tables:
|
|
||||||
return "empty"
|
|
||||||
elif len(tables) == 4 and all(
|
|
||||||
t in tables for t in [
|
|
||||||
"anime_series",
|
|
||||||
"episodes",
|
|
||||||
"download_queue",
|
|
||||||
"user_sessions",
|
|
||||||
]
|
|
||||||
):
|
|
||||||
return "v1.0"
|
|
||||||
else:
|
|
||||||
return "custom"
|
|
||||||
|
|
||||||
|
|
||||||
def get_migration_info() -> str:
|
|
||||||
"""Get information about database migration setup.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Migration setup information
|
|
||||||
"""
|
|
||||||
return """
|
|
||||||
Database Migration Information
|
|
||||||
==============================
|
|
||||||
|
|
||||||
Current Strategy: SQLAlchemy create_all()
|
|
||||||
- Automatically creates tables on startup
|
|
||||||
- Suitable for development and single-instance deployments
|
|
||||||
- Schema changes require manual handling
|
|
||||||
|
|
||||||
For Production Migrations (Alembic):
|
|
||||||
====================================
|
|
||||||
|
|
||||||
1. Initialize Alembic:
|
|
||||||
alembic init alembic
|
|
||||||
|
|
||||||
2. Configure alembic/env.py:
|
|
||||||
- Import Base from src.server.database.base
|
|
||||||
- Set target_metadata = Base.metadata
|
|
||||||
|
|
||||||
3. Configure alembic.ini:
|
|
||||||
- Set sqlalchemy.url to your database URL
|
|
||||||
|
|
||||||
4. Generate initial migration:
|
|
||||||
alembic revision --autogenerate -m "Initial schema"
|
|
||||||
|
|
||||||
5. Apply migrations:
|
|
||||||
alembic upgrade head
|
|
||||||
|
|
||||||
6. For future changes:
|
|
||||||
- Modify models in src/server/database/models.py
|
|
||||||
- Generate migration: alembic revision --autogenerate -m "Description"
|
|
||||||
- Review generated migration in alembic/versions/
|
|
||||||
- Apply: alembic upgrade head
|
|
||||||
|
|
||||||
Benefits of Alembic:
|
|
||||||
- Version control for database schema
|
|
||||||
- Automatic migration generation from model changes
|
|
||||||
- Rollback support with downgrade scripts
|
|
||||||
- Multi-instance deployment support
|
|
||||||
- Safe schema changes in production
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Future Alembic Integration
|
|
||||||
# =============================================================================
|
|
||||||
#
|
|
||||||
# When ready to use Alembic, follow these steps:
|
|
||||||
#
|
|
||||||
# 1. Install Alembic (already in requirements.txt):
|
|
||||||
# pip install alembic
|
|
||||||
#
|
|
||||||
# 2. Initialize Alembic from project root:
|
|
||||||
# alembic init alembic
|
|
||||||
#
|
|
||||||
# 3. Update alembic/env.py to use our Base:
|
|
||||||
# from src.server.database.base import Base
|
|
||||||
# target_metadata = Base.metadata
|
|
||||||
#
|
|
||||||
# 4. Configure alembic.ini with DATABASE_URL from settings
|
|
||||||
#
|
|
||||||
# 5. Generate initial migration:
|
|
||||||
# alembic revision --autogenerate -m "Initial schema"
|
|
||||||
#
|
|
||||||
# 6. Review generated migration and apply:
|
|
||||||
# alembic upgrade head
|
|
||||||
#
|
|
||||||
# =============================================================================
|
|
||||||
@ -1,429 +0,0 @@
|
|||||||
"""SQLAlchemy ORM models for the Aniworld web application.
|
|
||||||
|
|
||||||
This module defines database models for anime series, episodes, download queue,
|
|
||||||
and user sessions. Models use SQLAlchemy 2.0 style with type annotations.
|
|
||||||
|
|
||||||
Models:
|
|
||||||
- AnimeSeries: Represents an anime series with metadata
|
|
||||||
- Episode: Individual episodes linked to series
|
|
||||||
- DownloadQueueItem: Download queue with status and progress tracking
|
|
||||||
- UserSession: User authentication sessions with JWT tokens
|
|
||||||
"""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from datetime import datetime
|
|
||||||
from enum import Enum
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from sqlalchemy import (
|
|
||||||
JSON,
|
|
||||||
Boolean,
|
|
||||||
DateTime,
|
|
||||||
Float,
|
|
||||||
ForeignKey,
|
|
||||||
Integer,
|
|
||||||
String,
|
|
||||||
Text,
|
|
||||||
func,
|
|
||||||
)
|
|
||||||
from sqlalchemy import Enum as SQLEnum
|
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
|
||||||
|
|
||||||
from src.server.database.base import Base, TimestampMixin
|
|
||||||
|
|
||||||
|
|
||||||
class AnimeSeries(Base, TimestampMixin):
|
|
||||||
"""SQLAlchemy model for anime series.
|
|
||||||
|
|
||||||
Represents an anime series with metadata, provider information,
|
|
||||||
and links to episodes. Corresponds to the core Serie class.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
id: Primary key
|
|
||||||
key: Unique identifier used by provider
|
|
||||||
name: Series name
|
|
||||||
site: Provider site URL
|
|
||||||
folder: Local filesystem path
|
|
||||||
description: Optional series description
|
|
||||||
status: Current status (ongoing, completed, etc.)
|
|
||||||
total_episodes: Total number of episodes
|
|
||||||
cover_url: URL to series cover image
|
|
||||||
episodes: Relationship to Episode models
|
|
||||||
download_items: Relationship to DownloadQueueItem models
|
|
||||||
created_at: Creation timestamp (from TimestampMixin)
|
|
||||||
updated_at: Last update timestamp (from TimestampMixin)
|
|
||||||
"""
|
|
||||||
__tablename__ = "anime_series"
|
|
||||||
|
|
||||||
# Primary key
|
|
||||||
id: Mapped[int] = mapped_column(
|
|
||||||
Integer, primary_key=True, autoincrement=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Core identification
|
|
||||||
key: Mapped[str] = mapped_column(
|
|
||||||
String(255), unique=True, nullable=False, index=True,
|
|
||||||
doc="Unique provider key"
|
|
||||||
)
|
|
||||||
name: Mapped[str] = mapped_column(
|
|
||||||
String(500), nullable=False, index=True,
|
|
||||||
doc="Series name"
|
|
||||||
)
|
|
||||||
site: Mapped[str] = mapped_column(
|
|
||||||
String(500), nullable=False,
|
|
||||||
doc="Provider site URL"
|
|
||||||
)
|
|
||||||
folder: Mapped[str] = mapped_column(
|
|
||||||
String(1000), nullable=False,
|
|
||||||
doc="Local filesystem path"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Metadata
|
|
||||||
description: Mapped[Optional[str]] = mapped_column(
|
|
||||||
Text, nullable=True,
|
|
||||||
doc="Series description"
|
|
||||||
)
|
|
||||||
status: Mapped[Optional[str]] = mapped_column(
|
|
||||||
String(50), nullable=True,
|
|
||||||
doc="Series status (ongoing, completed, etc.)"
|
|
||||||
)
|
|
||||||
total_episodes: Mapped[Optional[int]] = mapped_column(
|
|
||||||
Integer, nullable=True,
|
|
||||||
doc="Total number of episodes"
|
|
||||||
)
|
|
||||||
cover_url: Mapped[Optional[str]] = mapped_column(
|
|
||||||
String(1000), nullable=True,
|
|
||||||
doc="URL to cover image"
|
|
||||||
)
|
|
||||||
|
|
||||||
# JSON field for episode dictionary (season -> [episodes])
|
|
||||||
episode_dict: Mapped[Optional[dict]] = mapped_column(
|
|
||||||
JSON, nullable=True,
|
|
||||||
doc="Episode dictionary {season: [episodes]}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Relationships
|
|
||||||
episodes: Mapped[List["Episode"]] = relationship(
|
|
||||||
"Episode",
|
|
||||||
back_populates="series",
|
|
||||||
cascade="all, delete-orphan"
|
|
||||||
)
|
|
||||||
download_items: Mapped[List["DownloadQueueItem"]] = relationship(
|
|
||||||
"DownloadQueueItem",
|
|
||||||
back_populates="series",
|
|
||||||
cascade="all, delete-orphan"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return f"<AnimeSeries(id={self.id}, key='{self.key}', name='{self.name}')>"
|
|
||||||
|
|
||||||
|
|
||||||
class Episode(Base, TimestampMixin):
|
|
||||||
"""SQLAlchemy model for anime episodes.
|
|
||||||
|
|
||||||
Represents individual episodes linked to an anime series.
|
|
||||||
Tracks download status and file location.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
id: Primary key
|
|
||||||
series_id: Foreign key to AnimeSeries
|
|
||||||
season: Season number
|
|
||||||
episode_number: Episode number within season
|
|
||||||
title: Episode title
|
|
||||||
file_path: Local file path if downloaded
|
|
||||||
file_size: File size in bytes
|
|
||||||
is_downloaded: Whether episode is downloaded
|
|
||||||
download_date: When episode was downloaded
|
|
||||||
series: Relationship to AnimeSeries
|
|
||||||
created_at: Creation timestamp (from TimestampMixin)
|
|
||||||
updated_at: Last update timestamp (from TimestampMixin)
|
|
||||||
"""
|
|
||||||
__tablename__ = "episodes"
|
|
||||||
|
|
||||||
# Primary key
|
|
||||||
id: Mapped[int] = mapped_column(
|
|
||||||
Integer, primary_key=True, autoincrement=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Foreign key to series
|
|
||||||
series_id: Mapped[int] = mapped_column(
|
|
||||||
ForeignKey("anime_series.id", ondelete="CASCADE"),
|
|
||||||
nullable=False,
|
|
||||||
index=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Episode identification
|
|
||||||
season: Mapped[int] = mapped_column(
|
|
||||||
Integer, nullable=False,
|
|
||||||
doc="Season number"
|
|
||||||
)
|
|
||||||
episode_number: Mapped[int] = mapped_column(
|
|
||||||
Integer, nullable=False,
|
|
||||||
doc="Episode number within season"
|
|
||||||
)
|
|
||||||
title: Mapped[Optional[str]] = mapped_column(
|
|
||||||
String(500), nullable=True,
|
|
||||||
doc="Episode title"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Download information
|
|
||||||
file_path: Mapped[Optional[str]] = mapped_column(
|
|
||||||
String(1000), nullable=True,
|
|
||||||
doc="Local file path"
|
|
||||||
)
|
|
||||||
file_size: Mapped[Optional[int]] = mapped_column(
|
|
||||||
Integer, nullable=True,
|
|
||||||
doc="File size in bytes"
|
|
||||||
)
|
|
||||||
is_downloaded: Mapped[bool] = mapped_column(
|
|
||||||
Boolean, default=False, nullable=False,
|
|
||||||
doc="Whether episode is downloaded"
|
|
||||||
)
|
|
||||||
download_date: Mapped[Optional[datetime]] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=True,
|
|
||||||
doc="When episode was downloaded"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Relationship
|
|
||||||
series: Mapped["AnimeSeries"] = relationship(
|
|
||||||
"AnimeSeries",
|
|
||||||
back_populates="episodes"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return (
|
|
||||||
f"<Episode(id={self.id}, series_id={self.series_id}, "
|
|
||||||
f"S{self.season:02d}E{self.episode_number:02d})>"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DownloadStatus(str, Enum):
|
|
||||||
"""Status enum for download queue items."""
|
|
||||||
PENDING = "pending"
|
|
||||||
DOWNLOADING = "downloading"
|
|
||||||
PAUSED = "paused"
|
|
||||||
COMPLETED = "completed"
|
|
||||||
FAILED = "failed"
|
|
||||||
CANCELLED = "cancelled"
|
|
||||||
|
|
||||||
|
|
||||||
class DownloadPriority(str, Enum):
|
|
||||||
"""Priority enum for download queue items."""
|
|
||||||
LOW = "low"
|
|
||||||
NORMAL = "normal"
|
|
||||||
HIGH = "high"
|
|
||||||
|
|
||||||
|
|
||||||
class DownloadQueueItem(Base, TimestampMixin):
|
|
||||||
"""SQLAlchemy model for download queue items.
|
|
||||||
|
|
||||||
Tracks download queue with status, progress, and error information.
|
|
||||||
Provides persistence for the DownloadService queue state.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
id: Primary key
|
|
||||||
series_id: Foreign key to AnimeSeries
|
|
||||||
season: Season number
|
|
||||||
episode_number: Episode number
|
|
||||||
status: Current download status
|
|
||||||
priority: Download priority
|
|
||||||
progress_percent: Download progress (0-100)
|
|
||||||
downloaded_bytes: Bytes downloaded
|
|
||||||
total_bytes: Total file size
|
|
||||||
download_speed: Current speed in bytes/sec
|
|
||||||
error_message: Error description if failed
|
|
||||||
retry_count: Number of retry attempts
|
|
||||||
download_url: Provider download URL
|
|
||||||
file_destination: Target file path
|
|
||||||
started_at: When download started
|
|
||||||
completed_at: When download completed
|
|
||||||
series: Relationship to AnimeSeries
|
|
||||||
created_at: Creation timestamp (from TimestampMixin)
|
|
||||||
updated_at: Last update timestamp (from TimestampMixin)
|
|
||||||
"""
|
|
||||||
__tablename__ = "download_queue"
|
|
||||||
|
|
||||||
# Primary key
|
|
||||||
id: Mapped[int] = mapped_column(
|
|
||||||
Integer, primary_key=True, autoincrement=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Foreign key to series
|
|
||||||
series_id: Mapped[int] = mapped_column(
|
|
||||||
ForeignKey("anime_series.id", ondelete="CASCADE"),
|
|
||||||
nullable=False,
|
|
||||||
index=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Episode identification
|
|
||||||
season: Mapped[int] = mapped_column(
|
|
||||||
Integer, nullable=False,
|
|
||||||
doc="Season number"
|
|
||||||
)
|
|
||||||
episode_number: Mapped[int] = mapped_column(
|
|
||||||
Integer, nullable=False,
|
|
||||||
doc="Episode number"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Queue management
|
|
||||||
status: Mapped[str] = mapped_column(
|
|
||||||
SQLEnum(DownloadStatus),
|
|
||||||
default=DownloadStatus.PENDING,
|
|
||||||
nullable=False,
|
|
||||||
index=True,
|
|
||||||
doc="Current download status"
|
|
||||||
)
|
|
||||||
priority: Mapped[str] = mapped_column(
|
|
||||||
SQLEnum(DownloadPriority),
|
|
||||||
default=DownloadPriority.NORMAL,
|
|
||||||
nullable=False,
|
|
||||||
doc="Download priority"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Progress tracking
|
|
||||||
progress_percent: Mapped[float] = mapped_column(
|
|
||||||
Float, default=0.0, nullable=False,
|
|
||||||
doc="Progress percentage (0-100)"
|
|
||||||
)
|
|
||||||
downloaded_bytes: Mapped[int] = mapped_column(
|
|
||||||
Integer, default=0, nullable=False,
|
|
||||||
doc="Bytes downloaded"
|
|
||||||
)
|
|
||||||
total_bytes: Mapped[Optional[int]] = mapped_column(
|
|
||||||
Integer, nullable=True,
|
|
||||||
doc="Total file size"
|
|
||||||
)
|
|
||||||
download_speed: Mapped[Optional[float]] = mapped_column(
|
|
||||||
Float, nullable=True,
|
|
||||||
doc="Current download speed (bytes/sec)"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Error handling
|
|
||||||
error_message: Mapped[Optional[str]] = mapped_column(
|
|
||||||
Text, nullable=True,
|
|
||||||
doc="Error description"
|
|
||||||
)
|
|
||||||
retry_count: Mapped[int] = mapped_column(
|
|
||||||
Integer, default=0, nullable=False,
|
|
||||||
doc="Number of retry attempts"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Download details
|
|
||||||
download_url: Mapped[Optional[str]] = mapped_column(
|
|
||||||
String(1000), nullable=True,
|
|
||||||
doc="Provider download URL"
|
|
||||||
)
|
|
||||||
file_destination: Mapped[Optional[str]] = mapped_column(
|
|
||||||
String(1000), nullable=True,
|
|
||||||
doc="Target file path"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Timestamps
|
|
||||||
started_at: Mapped[Optional[datetime]] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=True,
|
|
||||||
doc="When download started"
|
|
||||||
)
|
|
||||||
completed_at: Mapped[Optional[datetime]] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=True,
|
|
||||||
doc="When download completed"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Relationship
|
|
||||||
series: Mapped["AnimeSeries"] = relationship(
|
|
||||||
"AnimeSeries",
|
|
||||||
back_populates="download_items"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return (
|
|
||||||
f"<DownloadQueueItem(id={self.id}, "
|
|
||||||
f"series_id={self.series_id}, "
|
|
||||||
f"S{self.season:02d}E{self.episode_number:02d}, "
|
|
||||||
f"status={self.status})>"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class UserSession(Base, TimestampMixin):
|
|
||||||
"""SQLAlchemy model for user sessions.
|
|
||||||
|
|
||||||
Tracks authenticated user sessions with JWT tokens.
|
|
||||||
Supports session management, revocation, and expiry.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
id: Primary key
|
|
||||||
session_id: Unique session identifier
|
|
||||||
token_hash: Hashed JWT token for validation
|
|
||||||
user_id: User identifier (for multi-user support)
|
|
||||||
ip_address: Client IP address
|
|
||||||
user_agent: Client user agent string
|
|
||||||
expires_at: Session expiration timestamp
|
|
||||||
is_active: Whether session is active
|
|
||||||
last_activity: Last activity timestamp
|
|
||||||
created_at: Creation timestamp (from TimestampMixin)
|
|
||||||
updated_at: Last update timestamp (from TimestampMixin)
|
|
||||||
"""
|
|
||||||
__tablename__ = "user_sessions"
|
|
||||||
|
|
||||||
# Primary key
|
|
||||||
id: Mapped[int] = mapped_column(
|
|
||||||
Integer, primary_key=True, autoincrement=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Session identification
|
|
||||||
session_id: Mapped[str] = mapped_column(
|
|
||||||
String(255), unique=True, nullable=False, index=True,
|
|
||||||
doc="Unique session identifier"
|
|
||||||
)
|
|
||||||
token_hash: Mapped[str] = mapped_column(
|
|
||||||
String(255), nullable=False,
|
|
||||||
doc="Hashed JWT token"
|
|
||||||
)
|
|
||||||
|
|
||||||
# User information
|
|
||||||
user_id: Mapped[Optional[str]] = mapped_column(
|
|
||||||
String(255), nullable=True, index=True,
|
|
||||||
doc="User identifier (for multi-user)"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Client information
|
|
||||||
ip_address: Mapped[Optional[str]] = mapped_column(
|
|
||||||
String(45), nullable=True,
|
|
||||||
doc="Client IP address"
|
|
||||||
)
|
|
||||||
user_agent: Mapped[Optional[str]] = mapped_column(
|
|
||||||
String(500), nullable=True,
|
|
||||||
doc="Client user agent"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Session management
|
|
||||||
expires_at: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=False,
|
|
||||||
doc="Session expiration"
|
|
||||||
)
|
|
||||||
is_active: Mapped[bool] = mapped_column(
|
|
||||||
Boolean, default=True, nullable=False, index=True,
|
|
||||||
doc="Whether session is active"
|
|
||||||
)
|
|
||||||
last_activity: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True),
|
|
||||||
server_default=func.now(),
|
|
||||||
onupdate=func.now(),
|
|
||||||
nullable=False,
|
|
||||||
doc="Last activity timestamp"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return (
|
|
||||||
f"<UserSession(id={self.id}, "
|
|
||||||
f"session_id='{self.session_id}', "
|
|
||||||
f"is_active={self.is_active})>"
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_expired(self) -> bool:
|
|
||||||
"""Check if session has expired."""
|
|
||||||
return datetime.utcnow() > self.expires_at
|
|
||||||
|
|
||||||
def revoke(self) -> None:
|
|
||||||
"""Revoke this session."""
|
|
||||||
self.is_active = False
|
|
||||||
@ -1,879 +0,0 @@
|
|||||||
"""Database service layer for CRUD operations.
|
|
||||||
|
|
||||||
This module provides a comprehensive service layer for database operations,
|
|
||||||
implementing the Repository pattern for clean separation of concerns.
|
|
||||||
|
|
||||||
Services:
|
|
||||||
- AnimeSeriesService: CRUD operations for anime series
|
|
||||||
- EpisodeService: CRUD operations for episodes
|
|
||||||
- DownloadQueueService: CRUD operations for download queue
|
|
||||||
- UserSessionService: CRUD operations for user sessions
|
|
||||||
|
|
||||||
All services support both async and sync operations for flexibility.
|
|
||||||
"""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from typing import Dict, List, Optional
|
|
||||||
|
|
||||||
from sqlalchemy import delete, select, update
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
from sqlalchemy.orm import Session, selectinload
|
|
||||||
|
|
||||||
from src.server.database.models import (
|
|
||||||
AnimeSeries,
|
|
||||||
DownloadPriority,
|
|
||||||
DownloadQueueItem,
|
|
||||||
DownloadStatus,
|
|
||||||
Episode,
|
|
||||||
UserSession,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Anime Series Service
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
class AnimeSeriesService:
|
|
||||||
"""Service for anime series CRUD operations.
|
|
||||||
|
|
||||||
Provides methods for creating, reading, updating, and deleting anime series
|
|
||||||
with support for both async and sync database sessions.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def create(
|
|
||||||
db: AsyncSession,
|
|
||||||
key: str,
|
|
||||||
name: str,
|
|
||||||
site: str,
|
|
||||||
folder: str,
|
|
||||||
description: Optional[str] = None,
|
|
||||||
status: Optional[str] = None,
|
|
||||||
total_episodes: Optional[int] = None,
|
|
||||||
cover_url: Optional[str] = None,
|
|
||||||
episode_dict: Optional[Dict] = None,
|
|
||||||
) -> AnimeSeries:
|
|
||||||
"""Create a new anime series.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
key: Unique provider key
|
|
||||||
name: Series name
|
|
||||||
site: Provider site URL
|
|
||||||
folder: Local filesystem path
|
|
||||||
description: Optional series description
|
|
||||||
status: Optional series status
|
|
||||||
total_episodes: Optional total episode count
|
|
||||||
cover_url: Optional cover image URL
|
|
||||||
episode_dict: Optional episode dictionary
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Created AnimeSeries instance
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
IntegrityError: If series with key already exists
|
|
||||||
"""
|
|
||||||
series = AnimeSeries(
|
|
||||||
key=key,
|
|
||||||
name=name,
|
|
||||||
site=site,
|
|
||||||
folder=folder,
|
|
||||||
description=description,
|
|
||||||
status=status,
|
|
||||||
total_episodes=total_episodes,
|
|
||||||
cover_url=cover_url,
|
|
||||||
episode_dict=episode_dict,
|
|
||||||
)
|
|
||||||
db.add(series)
|
|
||||||
await db.flush()
|
|
||||||
await db.refresh(series)
|
|
||||||
logger.info(f"Created anime series: {series.name} (key={series.key})")
|
|
||||||
return series
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def get_by_id(db: AsyncSession, series_id: int) -> Optional[AnimeSeries]:
|
|
||||||
"""Get anime series by ID.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
series_id: Series primary key
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
AnimeSeries instance or None if not found
|
|
||||||
"""
|
|
||||||
result = await db.execute(
|
|
||||||
select(AnimeSeries).where(AnimeSeries.id == series_id)
|
|
||||||
)
|
|
||||||
return result.scalar_one_or_none()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def get_by_key(db: AsyncSession, key: str) -> Optional[AnimeSeries]:
|
|
||||||
"""Get anime series by provider key.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
key: Unique provider key
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
AnimeSeries instance or None if not found
|
|
||||||
"""
|
|
||||||
result = await db.execute(
|
|
||||||
select(AnimeSeries).where(AnimeSeries.key == key)
|
|
||||||
)
|
|
||||||
return result.scalar_one_or_none()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def get_all(
|
|
||||||
db: AsyncSession,
|
|
||||||
limit: Optional[int] = None,
|
|
||||||
offset: int = 0,
|
|
||||||
with_episodes: bool = False,
|
|
||||||
) -> List[AnimeSeries]:
|
|
||||||
"""Get all anime series.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
limit: Optional limit for results
|
|
||||||
offset: Offset for pagination
|
|
||||||
with_episodes: Whether to eagerly load episodes
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of AnimeSeries instances
|
|
||||||
"""
|
|
||||||
query = select(AnimeSeries)
|
|
||||||
|
|
||||||
if with_episodes:
|
|
||||||
query = query.options(selectinload(AnimeSeries.episodes))
|
|
||||||
|
|
||||||
query = query.offset(offset)
|
|
||||||
if limit:
|
|
||||||
query = query.limit(limit)
|
|
||||||
|
|
||||||
result = await db.execute(query)
|
|
||||||
return list(result.scalars().all())
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def update(
|
|
||||||
db: AsyncSession,
|
|
||||||
series_id: int,
|
|
||||||
**kwargs,
|
|
||||||
) -> Optional[AnimeSeries]:
|
|
||||||
"""Update anime series.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
series_id: Series primary key
|
|
||||||
**kwargs: Fields to update
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Updated AnimeSeries instance or None if not found
|
|
||||||
"""
|
|
||||||
series = await AnimeSeriesService.get_by_id(db, series_id)
|
|
||||||
if not series:
|
|
||||||
return None
|
|
||||||
|
|
||||||
for key, value in kwargs.items():
|
|
||||||
if hasattr(series, key):
|
|
||||||
setattr(series, key, value)
|
|
||||||
|
|
||||||
await db.flush()
|
|
||||||
await db.refresh(series)
|
|
||||||
logger.info(f"Updated anime series: {series.name} (id={series_id})")
|
|
||||||
return series
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def delete(db: AsyncSession, series_id: int) -> bool:
|
|
||||||
"""Delete anime series.
|
|
||||||
|
|
||||||
Cascades to delete all episodes and download items.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
series_id: Series primary key
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if deleted, False if not found
|
|
||||||
"""
|
|
||||||
result = await db.execute(
|
|
||||||
delete(AnimeSeries).where(AnimeSeries.id == series_id)
|
|
||||||
)
|
|
||||||
deleted = result.rowcount > 0
|
|
||||||
if deleted:
|
|
||||||
logger.info(f"Deleted anime series with id={series_id}")
|
|
||||||
return deleted
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def search(
|
|
||||||
db: AsyncSession,
|
|
||||||
query: str,
|
|
||||||
limit: int = 50,
|
|
||||||
) -> List[AnimeSeries]:
|
|
||||||
"""Search anime series by name.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
query: Search query
|
|
||||||
limit: Maximum results
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of matching AnimeSeries instances
|
|
||||||
"""
|
|
||||||
result = await db.execute(
|
|
||||||
select(AnimeSeries)
|
|
||||||
.where(AnimeSeries.name.ilike(f"%{query}%"))
|
|
||||||
.limit(limit)
|
|
||||||
)
|
|
||||||
return list(result.scalars().all())
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Episode Service
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
class EpisodeService:
|
|
||||||
"""Service for episode CRUD operations.
|
|
||||||
|
|
||||||
Provides methods for managing episodes within anime series.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def create(
|
|
||||||
db: AsyncSession,
|
|
||||||
series_id: int,
|
|
||||||
season: int,
|
|
||||||
episode_number: int,
|
|
||||||
title: Optional[str] = None,
|
|
||||||
file_path: Optional[str] = None,
|
|
||||||
file_size: Optional[int] = None,
|
|
||||||
is_downloaded: bool = False,
|
|
||||||
) -> Episode:
|
|
||||||
"""Create a new episode.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
series_id: Foreign key to AnimeSeries
|
|
||||||
season: Season number
|
|
||||||
episode_number: Episode number within season
|
|
||||||
title: Optional episode title
|
|
||||||
file_path: Optional local file path
|
|
||||||
file_size: Optional file size in bytes
|
|
||||||
is_downloaded: Whether episode is downloaded
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Created Episode instance
|
|
||||||
"""
|
|
||||||
episode = Episode(
|
|
||||||
series_id=series_id,
|
|
||||||
season=season,
|
|
||||||
episode_number=episode_number,
|
|
||||||
title=title,
|
|
||||||
file_path=file_path,
|
|
||||||
file_size=file_size,
|
|
||||||
is_downloaded=is_downloaded,
|
|
||||||
download_date=datetime.utcnow() if is_downloaded else None,
|
|
||||||
)
|
|
||||||
db.add(episode)
|
|
||||||
await db.flush()
|
|
||||||
await db.refresh(episode)
|
|
||||||
logger.debug(
|
|
||||||
f"Created episode: S{season:02d}E{episode_number:02d} "
|
|
||||||
f"for series_id={series_id}"
|
|
||||||
)
|
|
||||||
return episode
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def get_by_id(db: AsyncSession, episode_id: int) -> Optional[Episode]:
|
|
||||||
"""Get episode by ID.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
episode_id: Episode primary key
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Episode instance or None if not found
|
|
||||||
"""
|
|
||||||
result = await db.execute(
|
|
||||||
select(Episode).where(Episode.id == episode_id)
|
|
||||||
)
|
|
||||||
return result.scalar_one_or_none()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def get_by_series(
|
|
||||||
db: AsyncSession,
|
|
||||||
series_id: int,
|
|
||||||
season: Optional[int] = None,
|
|
||||||
) -> List[Episode]:
|
|
||||||
"""Get episodes for a series.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
series_id: Foreign key to AnimeSeries
|
|
||||||
season: Optional season filter
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of Episode instances
|
|
||||||
"""
|
|
||||||
query = select(Episode).where(Episode.series_id == series_id)
|
|
||||||
|
|
||||||
if season is not None:
|
|
||||||
query = query.where(Episode.season == season)
|
|
||||||
|
|
||||||
query = query.order_by(Episode.season, Episode.episode_number)
|
|
||||||
result = await db.execute(query)
|
|
||||||
return list(result.scalars().all())
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def get_by_episode(
|
|
||||||
db: AsyncSession,
|
|
||||||
series_id: int,
|
|
||||||
season: int,
|
|
||||||
episode_number: int,
|
|
||||||
) -> Optional[Episode]:
|
|
||||||
"""Get specific episode.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
series_id: Foreign key to AnimeSeries
|
|
||||||
season: Season number
|
|
||||||
episode_number: Episode number
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Episode instance or None if not found
|
|
||||||
"""
|
|
||||||
result = await db.execute(
|
|
||||||
select(Episode).where(
|
|
||||||
Episode.series_id == series_id,
|
|
||||||
Episode.season == season,
|
|
||||||
Episode.episode_number == episode_number,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return result.scalar_one_or_none()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def mark_downloaded(
|
|
||||||
db: AsyncSession,
|
|
||||||
episode_id: int,
|
|
||||||
file_path: str,
|
|
||||||
file_size: int,
|
|
||||||
) -> Optional[Episode]:
|
|
||||||
"""Mark episode as downloaded.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
episode_id: Episode primary key
|
|
||||||
file_path: Local file path
|
|
||||||
file_size: File size in bytes
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Updated Episode instance or None if not found
|
|
||||||
"""
|
|
||||||
episode = await EpisodeService.get_by_id(db, episode_id)
|
|
||||||
if not episode:
|
|
||||||
return None
|
|
||||||
|
|
||||||
episode.is_downloaded = True
|
|
||||||
episode.file_path = file_path
|
|
||||||
episode.file_size = file_size
|
|
||||||
episode.download_date = datetime.utcnow()
|
|
||||||
|
|
||||||
await db.flush()
|
|
||||||
await db.refresh(episode)
|
|
||||||
logger.info(
|
|
||||||
f"Marked episode as downloaded: "
|
|
||||||
f"S{episode.season:02d}E{episode.episode_number:02d}"
|
|
||||||
)
|
|
||||||
return episode
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def delete(db: AsyncSession, episode_id: int) -> bool:
|
|
||||||
"""Delete episode.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
episode_id: Episode primary key
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if deleted, False if not found
|
|
||||||
"""
|
|
||||||
result = await db.execute(
|
|
||||||
delete(Episode).where(Episode.id == episode_id)
|
|
||||||
)
|
|
||||||
return result.rowcount > 0
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Download Queue Service
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
class DownloadQueueService:
|
|
||||||
"""Service for download queue CRUD operations.
|
|
||||||
|
|
||||||
Provides methods for managing the download queue with status tracking,
|
|
||||||
priority management, and progress updates.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def create(
|
|
||||||
db: AsyncSession,
|
|
||||||
series_id: int,
|
|
||||||
season: int,
|
|
||||||
episode_number: int,
|
|
||||||
priority: DownloadPriority = DownloadPriority.NORMAL,
|
|
||||||
download_url: Optional[str] = None,
|
|
||||||
file_destination: Optional[str] = None,
|
|
||||||
) -> DownloadQueueItem:
|
|
||||||
"""Add item to download queue.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
series_id: Foreign key to AnimeSeries
|
|
||||||
season: Season number
|
|
||||||
episode_number: Episode number
|
|
||||||
priority: Download priority
|
|
||||||
download_url: Optional provider download URL
|
|
||||||
file_destination: Optional target file path
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Created DownloadQueueItem instance
|
|
||||||
"""
|
|
||||||
item = DownloadQueueItem(
|
|
||||||
series_id=series_id,
|
|
||||||
season=season,
|
|
||||||
episode_number=episode_number,
|
|
||||||
status=DownloadStatus.PENDING,
|
|
||||||
priority=priority,
|
|
||||||
download_url=download_url,
|
|
||||||
file_destination=file_destination,
|
|
||||||
)
|
|
||||||
db.add(item)
|
|
||||||
await db.flush()
|
|
||||||
await db.refresh(item)
|
|
||||||
logger.info(
|
|
||||||
f"Added to download queue: S{season:02d}E{episode_number:02d} "
|
|
||||||
f"for series_id={series_id} with priority={priority}"
|
|
||||||
)
|
|
||||||
return item
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def get_by_id(
|
|
||||||
db: AsyncSession,
|
|
||||||
item_id: int,
|
|
||||||
) -> Optional[DownloadQueueItem]:
|
|
||||||
"""Get download queue item by ID.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
item_id: Item primary key
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
DownloadQueueItem instance or None if not found
|
|
||||||
"""
|
|
||||||
result = await db.execute(
|
|
||||||
select(DownloadQueueItem).where(DownloadQueueItem.id == item_id)
|
|
||||||
)
|
|
||||||
return result.scalar_one_or_none()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def get_by_status(
|
|
||||||
db: AsyncSession,
|
|
||||||
status: DownloadStatus,
|
|
||||||
limit: Optional[int] = None,
|
|
||||||
) -> List[DownloadQueueItem]:
|
|
||||||
"""Get download queue items by status.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
status: Download status filter
|
|
||||||
limit: Optional limit for results
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of DownloadQueueItem instances
|
|
||||||
"""
|
|
||||||
query = select(DownloadQueueItem).where(
|
|
||||||
DownloadQueueItem.status == status
|
|
||||||
)
|
|
||||||
|
|
||||||
# Order by priority (HIGH first) then creation time
|
|
||||||
query = query.order_by(
|
|
||||||
DownloadQueueItem.priority.desc(),
|
|
||||||
DownloadQueueItem.created_at.asc(),
|
|
||||||
)
|
|
||||||
|
|
||||||
if limit:
|
|
||||||
query = query.limit(limit)
|
|
||||||
|
|
||||||
result = await db.execute(query)
|
|
||||||
return list(result.scalars().all())
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def get_pending(
|
|
||||||
db: AsyncSession,
|
|
||||||
limit: Optional[int] = None,
|
|
||||||
) -> List[DownloadQueueItem]:
|
|
||||||
"""Get pending download queue items.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
limit: Optional limit for results
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of pending DownloadQueueItem instances ordered by priority
|
|
||||||
"""
|
|
||||||
return await DownloadQueueService.get_by_status(
|
|
||||||
db, DownloadStatus.PENDING, limit
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def get_active(db: AsyncSession) -> List[DownloadQueueItem]:
|
|
||||||
"""Get active download queue items.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of downloading DownloadQueueItem instances
|
|
||||||
"""
|
|
||||||
return await DownloadQueueService.get_by_status(
|
|
||||||
db, DownloadStatus.DOWNLOADING
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def get_all(
|
|
||||||
db: AsyncSession,
|
|
||||||
with_series: bool = False,
|
|
||||||
) -> List[DownloadQueueItem]:
|
|
||||||
"""Get all download queue items.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
with_series: Whether to eagerly load series data
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of all DownloadQueueItem instances
|
|
||||||
"""
|
|
||||||
query = select(DownloadQueueItem)
|
|
||||||
|
|
||||||
if with_series:
|
|
||||||
query = query.options(selectinload(DownloadQueueItem.series))
|
|
||||||
|
|
||||||
query = query.order_by(
|
|
||||||
DownloadQueueItem.priority.desc(),
|
|
||||||
DownloadQueueItem.created_at.asc(),
|
|
||||||
)
|
|
||||||
|
|
||||||
result = await db.execute(query)
|
|
||||||
return list(result.scalars().all())
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def update_status(
|
|
||||||
db: AsyncSession,
|
|
||||||
item_id: int,
|
|
||||||
status: DownloadStatus,
|
|
||||||
error_message: Optional[str] = None,
|
|
||||||
) -> Optional[DownloadQueueItem]:
|
|
||||||
"""Update download queue item status.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
item_id: Item primary key
|
|
||||||
status: New download status
|
|
||||||
error_message: Optional error message for failed status
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Updated DownloadQueueItem instance or None if not found
|
|
||||||
"""
|
|
||||||
item = await DownloadQueueService.get_by_id(db, item_id)
|
|
||||||
if not item:
|
|
||||||
return None
|
|
||||||
|
|
||||||
item.status = status
|
|
||||||
|
|
||||||
# Update timestamps based on status
|
|
||||||
if status == DownloadStatus.DOWNLOADING and not item.started_at:
|
|
||||||
item.started_at = datetime.utcnow()
|
|
||||||
elif status in (DownloadStatus.COMPLETED, DownloadStatus.FAILED):
|
|
||||||
item.completed_at = datetime.utcnow()
|
|
||||||
|
|
||||||
# Set error message for failed downloads
|
|
||||||
if status == DownloadStatus.FAILED and error_message:
|
|
||||||
item.error_message = error_message
|
|
||||||
item.retry_count += 1
|
|
||||||
|
|
||||||
await db.flush()
|
|
||||||
await db.refresh(item)
|
|
||||||
logger.debug(f"Updated download queue item {item_id} status to {status}")
|
|
||||||
return item
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def update_progress(
|
|
||||||
db: AsyncSession,
|
|
||||||
item_id: int,
|
|
||||||
progress_percent: float,
|
|
||||||
downloaded_bytes: int,
|
|
||||||
total_bytes: Optional[int] = None,
|
|
||||||
download_speed: Optional[float] = None,
|
|
||||||
) -> Optional[DownloadQueueItem]:
|
|
||||||
"""Update download progress.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
item_id: Item primary key
|
|
||||||
progress_percent: Progress percentage (0-100)
|
|
||||||
downloaded_bytes: Bytes downloaded
|
|
||||||
total_bytes: Optional total file size
|
|
||||||
download_speed: Optional current speed (bytes/sec)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Updated DownloadQueueItem instance or None if not found
|
|
||||||
"""
|
|
||||||
item = await DownloadQueueService.get_by_id(db, item_id)
|
|
||||||
if not item:
|
|
||||||
return None
|
|
||||||
|
|
||||||
item.progress_percent = progress_percent
|
|
||||||
item.downloaded_bytes = downloaded_bytes
|
|
||||||
|
|
||||||
if total_bytes is not None:
|
|
||||||
item.total_bytes = total_bytes
|
|
||||||
|
|
||||||
if download_speed is not None:
|
|
||||||
item.download_speed = download_speed
|
|
||||||
|
|
||||||
await db.flush()
|
|
||||||
await db.refresh(item)
|
|
||||||
return item
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def delete(db: AsyncSession, item_id: int) -> bool:
|
|
||||||
"""Delete download queue item.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
item_id: Item primary key
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if deleted, False if not found
|
|
||||||
"""
|
|
||||||
result = await db.execute(
|
|
||||||
delete(DownloadQueueItem).where(DownloadQueueItem.id == item_id)
|
|
||||||
)
|
|
||||||
deleted = result.rowcount > 0
|
|
||||||
if deleted:
|
|
||||||
logger.info(f"Deleted download queue item with id={item_id}")
|
|
||||||
return deleted
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def clear_completed(db: AsyncSession) -> int:
|
|
||||||
"""Clear completed downloads from queue.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Number of items cleared
|
|
||||||
"""
|
|
||||||
result = await db.execute(
|
|
||||||
delete(DownloadQueueItem).where(
|
|
||||||
DownloadQueueItem.status == DownloadStatus.COMPLETED
|
|
||||||
)
|
|
||||||
)
|
|
||||||
count = result.rowcount
|
|
||||||
logger.info(f"Cleared {count} completed downloads from queue")
|
|
||||||
return count
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def retry_failed(
|
|
||||||
db: AsyncSession,
|
|
||||||
max_retries: int = 3,
|
|
||||||
) -> List[DownloadQueueItem]:
|
|
||||||
"""Retry failed downloads that haven't exceeded max retries.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
max_retries: Maximum number of retry attempts
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of items marked for retry
|
|
||||||
"""
|
|
||||||
result = await db.execute(
|
|
||||||
select(DownloadQueueItem).where(
|
|
||||||
DownloadQueueItem.status == DownloadStatus.FAILED,
|
|
||||||
DownloadQueueItem.retry_count < max_retries,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
items = list(result.scalars().all())
|
|
||||||
|
|
||||||
for item in items:
|
|
||||||
item.status = DownloadStatus.PENDING
|
|
||||||
item.error_message = None
|
|
||||||
item.progress_percent = 0.0
|
|
||||||
item.downloaded_bytes = 0
|
|
||||||
item.started_at = None
|
|
||||||
item.completed_at = None
|
|
||||||
|
|
||||||
await db.flush()
|
|
||||||
logger.info(f"Marked {len(items)} failed downloads for retry")
|
|
||||||
return items
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# User Session Service
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
class UserSessionService:
|
|
||||||
"""Service for user session CRUD operations.
|
|
||||||
|
|
||||||
Provides methods for managing user authentication sessions with JWT tokens.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def create(
|
|
||||||
db: AsyncSession,
|
|
||||||
session_id: str,
|
|
||||||
token_hash: str,
|
|
||||||
expires_at: datetime,
|
|
||||||
user_id: Optional[str] = None,
|
|
||||||
ip_address: Optional[str] = None,
|
|
||||||
user_agent: Optional[str] = None,
|
|
||||||
) -> UserSession:
|
|
||||||
"""Create a new user session.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
session_id: Unique session identifier
|
|
||||||
token_hash: Hashed JWT token
|
|
||||||
expires_at: Session expiration timestamp
|
|
||||||
user_id: Optional user identifier
|
|
||||||
ip_address: Optional client IP address
|
|
||||||
user_agent: Optional client user agent
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Created UserSession instance
|
|
||||||
"""
|
|
||||||
session = UserSession(
|
|
||||||
session_id=session_id,
|
|
||||||
token_hash=token_hash,
|
|
||||||
expires_at=expires_at,
|
|
||||||
user_id=user_id,
|
|
||||||
ip_address=ip_address,
|
|
||||||
user_agent=user_agent,
|
|
||||||
)
|
|
||||||
db.add(session)
|
|
||||||
await db.flush()
|
|
||||||
await db.refresh(session)
|
|
||||||
logger.info(f"Created user session: {session_id}")
|
|
||||||
return session
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def get_by_session_id(
|
|
||||||
db: AsyncSession,
|
|
||||||
session_id: str,
|
|
||||||
) -> Optional[UserSession]:
|
|
||||||
"""Get session by session ID.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
session_id: Unique session identifier
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
UserSession instance or None if not found
|
|
||||||
"""
|
|
||||||
result = await db.execute(
|
|
||||||
select(UserSession).where(UserSession.session_id == session_id)
|
|
||||||
)
|
|
||||||
return result.scalar_one_or_none()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def get_active_sessions(
|
|
||||||
db: AsyncSession,
|
|
||||||
user_id: Optional[str] = None,
|
|
||||||
) -> List[UserSession]:
|
|
||||||
"""Get active sessions.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
user_id: Optional user ID filter
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of active UserSession instances
|
|
||||||
"""
|
|
||||||
query = select(UserSession).where(
|
|
||||||
UserSession.is_active == True,
|
|
||||||
UserSession.expires_at > datetime.utcnow(),
|
|
||||||
)
|
|
||||||
|
|
||||||
if user_id:
|
|
||||||
query = query.where(UserSession.user_id == user_id)
|
|
||||||
|
|
||||||
result = await db.execute(query)
|
|
||||||
return list(result.scalars().all())
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def update_activity(
|
|
||||||
db: AsyncSession,
|
|
||||||
session_id: str,
|
|
||||||
) -> Optional[UserSession]:
|
|
||||||
"""Update session last activity timestamp.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
session_id: Unique session identifier
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Updated UserSession instance or None if not found
|
|
||||||
"""
|
|
||||||
session = await UserSessionService.get_by_session_id(db, session_id)
|
|
||||||
if not session:
|
|
||||||
return None
|
|
||||||
|
|
||||||
session.last_activity = datetime.utcnow()
|
|
||||||
await db.flush()
|
|
||||||
await db.refresh(session)
|
|
||||||
return session
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def revoke(db: AsyncSession, session_id: str) -> bool:
|
|
||||||
"""Revoke a session.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
session_id: Unique session identifier
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if revoked, False if not found
|
|
||||||
"""
|
|
||||||
session = await UserSessionService.get_by_session_id(db, session_id)
|
|
||||||
if not session:
|
|
||||||
return False
|
|
||||||
|
|
||||||
session.revoke()
|
|
||||||
await db.flush()
|
|
||||||
logger.info(f"Revoked user session: {session_id}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def cleanup_expired(db: AsyncSession) -> int:
|
|
||||||
"""Clean up expired sessions.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: Database session
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Number of sessions deleted
|
|
||||||
"""
|
|
||||||
result = await db.execute(
|
|
||||||
delete(UserSession).where(
|
|
||||||
UserSession.expires_at < datetime.utcnow()
|
|
||||||
)
|
|
||||||
)
|
|
||||||
count = result.rowcount
|
|
||||||
logger.info(f"Cleaned up {count} expired sessions")
|
|
||||||
return count
|
|
||||||
@ -6,6 +6,67 @@ from typing import List, Optional
|
|||||||
from pydantic import BaseModel, Field, HttpUrl
|
from pydantic import BaseModel, Field, HttpUrl
|
||||||
|
|
||||||
|
|
||||||
|
class EpisodeInfo(BaseModel):
|
||||||
|
"""Information about a single episode."""
|
||||||
|
|
||||||
|
episode_number: int = Field(..., ge=1, description="Episode index (1-based)")
|
||||||
|
title: Optional[str] = Field(None, description="Optional episode title")
|
||||||
|
aired_at: Optional[datetime] = Field(None, description="Air date/time if known")
|
||||||
|
duration_seconds: Optional[int] = Field(None, ge=0, description="Duration in seconds")
|
||||||
|
available: bool = Field(True, description="Whether the episode is available for download")
|
||||||
|
sources: List[HttpUrl] = Field(default_factory=list, description="List of known streaming/download source URLs")
|
||||||
|
|
||||||
|
|
||||||
|
class MissingEpisodeInfo(BaseModel):
|
||||||
|
"""Represents a gap in the episode list for a series."""
|
||||||
|
|
||||||
|
from_episode: int = Field(..., ge=1, description="Starting missing episode number")
|
||||||
|
to_episode: int = Field(..., ge=1, description="Ending missing episode number (inclusive)")
|
||||||
|
reason: Optional[str] = Field(None, description="Optional explanation why episodes are missing")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def count(self) -> int:
|
||||||
|
"""Number of missing episodes in the range."""
|
||||||
|
return max(0, self.to_episode - self.from_episode + 1)
|
||||||
|
|
||||||
|
|
||||||
|
class AnimeSeriesResponse(BaseModel):
|
||||||
|
"""Response model for a series with metadata and episodes."""
|
||||||
|
|
||||||
|
id: str = Field(..., description="Unique series identifier")
|
||||||
|
title: str = Field(..., description="Series title")
|
||||||
|
alt_titles: List[str] = Field(default_factory=list, description="Alternative titles")
|
||||||
|
description: Optional[str] = Field(None, description="Short series description")
|
||||||
|
total_episodes: Optional[int] = Field(None, ge=0, description="Declared total episode count if known")
|
||||||
|
episodes: List[EpisodeInfo] = Field(default_factory=list, description="Known episodes information")
|
||||||
|
missing_episodes: List[MissingEpisodeInfo] = Field(default_factory=list, description="Detected missing episode ranges")
|
||||||
|
thumbnail: Optional[HttpUrl] = Field(None, description="Optional thumbnail image URL")
|
||||||
|
|
||||||
|
|
||||||
|
class SearchRequest(BaseModel):
|
||||||
|
"""Request payload for searching series."""
|
||||||
|
|
||||||
|
query: str = Field(..., min_length=1)
|
||||||
|
limit: int = Field(10, ge=1, le=100)
|
||||||
|
include_adult: bool = Field(False)
|
||||||
|
|
||||||
|
|
||||||
|
class SearchResult(BaseModel):
|
||||||
|
"""Search result item for a series discovery endpoint."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
title: str
|
||||||
|
snippet: Optional[str] = None
|
||||||
|
thumbnail: Optional[HttpUrl] = None
|
||||||
|
score: Optional[float] = None
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, HttpUrl
|
||||||
|
|
||||||
|
|
||||||
class EpisodeInfo(BaseModel):
|
class EpisodeInfo(BaseModel):
|
||||||
"""Information about a single episode."""
|
"""Information about a single episode."""
|
||||||
|
|
||||||
|
|||||||
@ -1,366 +0,0 @@
|
|||||||
"""Configuration persistence service for managing application settings.
|
|
||||||
|
|
||||||
This service handles:
|
|
||||||
- Loading and saving configuration to JSON files
|
|
||||||
- Configuration validation
|
|
||||||
- Backup and restore functionality
|
|
||||||
- Configuration migration for version updates
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import shutil
|
|
||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, List, Optional
|
|
||||||
|
|
||||||
from src.server.models.config import AppConfig, ConfigUpdate, ValidationResult
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigServiceError(Exception):
|
|
||||||
"""Base exception for configuration service errors."""
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigNotFoundError(ConfigServiceError):
|
|
||||||
"""Raised when configuration file is not found."""
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigValidationError(ConfigServiceError):
|
|
||||||
"""Raised when configuration validation fails."""
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigBackupError(ConfigServiceError):
|
|
||||||
"""Raised when backup operations fail."""
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigService:
|
|
||||||
"""Service for managing application configuration persistence.
|
|
||||||
|
|
||||||
Handles loading, saving, validation, backup, and migration of
|
|
||||||
configuration files. Uses JSON format for human-readable and
|
|
||||||
version-control friendly storage.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Current configuration schema version
|
|
||||||
CONFIG_VERSION = "1.0.0"
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config_path: Path = Path("data/config.json"),
|
|
||||||
backup_dir: Path = Path("data/config_backups"),
|
|
||||||
max_backups: int = 10
|
|
||||||
):
|
|
||||||
"""Initialize configuration service.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config_path: Path to main configuration file
|
|
||||||
backup_dir: Directory for storing configuration backups
|
|
||||||
max_backups: Maximum number of backups to keep
|
|
||||||
"""
|
|
||||||
self.config_path = config_path
|
|
||||||
self.backup_dir = backup_dir
|
|
||||||
self.max_backups = max_backups
|
|
||||||
|
|
||||||
# Ensure directories exist
|
|
||||||
self.config_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
self.backup_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
def load_config(self) -> AppConfig:
|
|
||||||
"""Load configuration from file.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
AppConfig: Loaded configuration
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ConfigNotFoundError: If config file doesn't exist
|
|
||||||
ConfigValidationError: If config validation fails
|
|
||||||
"""
|
|
||||||
if not self.config_path.exists():
|
|
||||||
# Create default configuration
|
|
||||||
default_config = self._create_default_config()
|
|
||||||
self.save_config(default_config)
|
|
||||||
return default_config
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(self.config_path, "r", encoding="utf-8") as f:
|
|
||||||
data = json.load(f)
|
|
||||||
|
|
||||||
# Check if migration is needed
|
|
||||||
file_version = data.get("version", "1.0.0")
|
|
||||||
if file_version != self.CONFIG_VERSION:
|
|
||||||
data = self._migrate_config(data, file_version)
|
|
||||||
|
|
||||||
# Remove version key before constructing AppConfig
|
|
||||||
data.pop("version", None)
|
|
||||||
|
|
||||||
config = AppConfig(**data)
|
|
||||||
|
|
||||||
# Validate configuration
|
|
||||||
validation = config.validate()
|
|
||||||
if not validation.valid:
|
|
||||||
errors = ', '.join(validation.errors or [])
|
|
||||||
raise ConfigValidationError(
|
|
||||||
f"Invalid configuration: {errors}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return config
|
|
||||||
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
raise ConfigValidationError(
|
|
||||||
f"Invalid JSON in config file: {e}"
|
|
||||||
) from e
|
|
||||||
except Exception as e:
|
|
||||||
if isinstance(e, ConfigServiceError):
|
|
||||||
raise
|
|
||||||
raise ConfigValidationError(
|
|
||||||
f"Failed to load config: {e}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
def save_config(
|
|
||||||
self, config: AppConfig, create_backup: bool = True
|
|
||||||
) -> None:
|
|
||||||
"""Save configuration to file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config: Configuration to save
|
|
||||||
create_backup: Whether to create backup before saving
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ConfigValidationError: If config validation fails
|
|
||||||
"""
|
|
||||||
# Validate before saving
|
|
||||||
validation = config.validate()
|
|
||||||
if not validation.valid:
|
|
||||||
errors = ', '.join(validation.errors or [])
|
|
||||||
raise ConfigValidationError(
|
|
||||||
f"Cannot save invalid configuration: {errors}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create backup if requested and file exists
|
|
||||||
if create_backup and self.config_path.exists():
|
|
||||||
try:
|
|
||||||
self.create_backup()
|
|
||||||
except ConfigBackupError as e:
|
|
||||||
# Log but don't fail save operation
|
|
||||||
print(f"Warning: Failed to create backup: {e}")
|
|
||||||
|
|
||||||
# Save configuration with version
|
|
||||||
data = config.model_dump()
|
|
||||||
data["version"] = self.CONFIG_VERSION
|
|
||||||
|
|
||||||
# Write to temporary file first for atomic operation
|
|
||||||
temp_path = self.config_path.with_suffix(".tmp")
|
|
||||||
try:
|
|
||||||
with open(temp_path, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
|
||||||
|
|
||||||
# Atomic replace
|
|
||||||
temp_path.replace(self.config_path)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
# Clean up temp file on error
|
|
||||||
if temp_path.exists():
|
|
||||||
temp_path.unlink()
|
|
||||||
raise ConfigServiceError(f"Failed to save config: {e}") from e
|
|
||||||
|
|
||||||
def update_config(self, update: ConfigUpdate) -> AppConfig:
|
|
||||||
"""Update configuration with partial changes.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
update: Configuration update to apply
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
AppConfig: Updated configuration
|
|
||||||
"""
|
|
||||||
current = self.load_config()
|
|
||||||
updated = update.apply_to(current)
|
|
||||||
self.save_config(updated)
|
|
||||||
return updated
|
|
||||||
|
|
||||||
def validate_config(self, config: AppConfig) -> ValidationResult:
|
|
||||||
"""Validate configuration without saving.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config: Configuration to validate
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ValidationResult: Validation result with errors if any
|
|
||||||
"""
|
|
||||||
return config.validate()
|
|
||||||
|
|
||||||
def create_backup(self, name: Optional[str] = None) -> Path:
|
|
||||||
"""Create backup of current configuration.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: Optional custom backup name (timestamp used if not provided)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Path: Path to created backup file
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ConfigBackupError: If backup creation fails
|
|
||||||
"""
|
|
||||||
if not self.config_path.exists():
|
|
||||||
raise ConfigBackupError("Cannot backup non-existent config file")
|
|
||||||
|
|
||||||
# Generate backup filename
|
|
||||||
if name is None:
|
|
||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
||||||
name = f"config_backup_{timestamp}.json"
|
|
||||||
elif not name.endswith(".json"):
|
|
||||||
name = f"{name}.json"
|
|
||||||
|
|
||||||
backup_path = self.backup_dir / name
|
|
||||||
|
|
||||||
try:
|
|
||||||
shutil.copy2(self.config_path, backup_path)
|
|
||||||
|
|
||||||
# Clean up old backups
|
|
||||||
self._cleanup_old_backups()
|
|
||||||
|
|
||||||
return backup_path
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise ConfigBackupError(f"Failed to create backup: {e}") from e
|
|
||||||
|
|
||||||
def restore_backup(self, backup_name: str) -> AppConfig:
|
|
||||||
"""Restore configuration from backup.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
backup_name: Name of backup file to restore
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
AppConfig: Restored configuration
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ConfigBackupError: If restore fails
|
|
||||||
"""
|
|
||||||
backup_path = self.backup_dir / backup_name
|
|
||||||
|
|
||||||
if not backup_path.exists():
|
|
||||||
raise ConfigBackupError(f"Backup not found: {backup_name}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Create backup of current config before restoring
|
|
||||||
if self.config_path.exists():
|
|
||||||
self.create_backup("pre_restore")
|
|
||||||
|
|
||||||
# Restore backup
|
|
||||||
shutil.copy2(backup_path, self.config_path)
|
|
||||||
|
|
||||||
# Load and validate restored config
|
|
||||||
return self.load_config()
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise ConfigBackupError(
|
|
||||||
f"Failed to restore backup: {e}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
def list_backups(self) -> List[Dict[str, object]]:
|
|
||||||
"""List available configuration backups.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of backup metadata dictionaries with name, size, and
|
|
||||||
created timestamp
|
|
||||||
"""
|
|
||||||
backups: List[Dict[str, object]] = []
|
|
||||||
|
|
||||||
if not self.backup_dir.exists():
|
|
||||||
return backups
|
|
||||||
|
|
||||||
for backup_file in sorted(
|
|
||||||
self.backup_dir.glob("*.json"),
|
|
||||||
key=lambda p: p.stat().st_mtime,
|
|
||||||
reverse=True
|
|
||||||
):
|
|
||||||
stat = backup_file.stat()
|
|
||||||
created_timestamp = datetime.fromtimestamp(stat.st_mtime)
|
|
||||||
backups.append({
|
|
||||||
"name": backup_file.name,
|
|
||||||
"size_bytes": stat.st_size,
|
|
||||||
"created_at": created_timestamp.isoformat(),
|
|
||||||
})
|
|
||||||
|
|
||||||
return backups
|
|
||||||
|
|
||||||
def delete_backup(self, backup_name: str) -> None:
|
|
||||||
"""Delete a configuration backup.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
backup_name: Name of backup file to delete
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ConfigBackupError: If deletion fails
|
|
||||||
"""
|
|
||||||
backup_path = self.backup_dir / backup_name
|
|
||||||
|
|
||||||
if not backup_path.exists():
|
|
||||||
raise ConfigBackupError(f"Backup not found: {backup_name}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
backup_path.unlink()
|
|
||||||
except OSError as e:
|
|
||||||
raise ConfigBackupError(f"Failed to delete backup: {e}") from e
|
|
||||||
|
|
||||||
def _create_default_config(self) -> AppConfig:
|
|
||||||
"""Create default configuration.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
AppConfig: Default configuration
|
|
||||||
"""
|
|
||||||
return AppConfig()
|
|
||||||
|
|
||||||
def _cleanup_old_backups(self) -> None:
|
|
||||||
"""Remove old backups exceeding max_backups limit."""
|
|
||||||
if not self.backup_dir.exists():
|
|
||||||
return
|
|
||||||
|
|
||||||
# Get all backups sorted by modification time (oldest first)
|
|
||||||
backups = sorted(
|
|
||||||
self.backup_dir.glob("*.json"),
|
|
||||||
key=lambda p: p.stat().st_mtime
|
|
||||||
)
|
|
||||||
|
|
||||||
# Remove oldest backups if limit exceeded
|
|
||||||
while len(backups) > self.max_backups:
|
|
||||||
oldest = backups.pop(0)
|
|
||||||
try:
|
|
||||||
oldest.unlink()
|
|
||||||
except (OSError, IOError):
|
|
||||||
# Ignore errors during cleanup
|
|
||||||
continue
|
|
||||||
|
|
||||||
def _migrate_config(
|
|
||||||
self, data: Dict, from_version: str # noqa: ARG002
|
|
||||||
) -> Dict:
|
|
||||||
"""Migrate configuration from old version to current.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data: Configuration data to migrate
|
|
||||||
from_version: Version to migrate from (reserved for future use)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict: Migrated configuration data
|
|
||||||
"""
|
|
||||||
# Currently only one version exists
|
|
||||||
# Future migrations would go here
|
|
||||||
# Example:
|
|
||||||
# if from_version == "1.0.0" and self.CONFIG_VERSION == "2.0.0":
|
|
||||||
# data = self._migrate_1_0_to_2_0(data)
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
# Singleton instance
|
|
||||||
_config_service: Optional[ConfigService] = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_config_service() -> ConfigService:
|
|
||||||
"""Get singleton ConfigService instance.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ConfigService: Singleton instance
|
|
||||||
"""
|
|
||||||
global _config_service
|
|
||||||
if _config_service is None:
|
|
||||||
_config_service = ConfigService()
|
|
||||||
return _config_service
|
|
||||||
@ -68,33 +68,18 @@ def reset_series_app() -> None:
|
|||||||
_series_app = None
|
_series_app = None
|
||||||
|
|
||||||
|
|
||||||
async def get_database_session() -> AsyncGenerator:
|
async def get_database_session() -> AsyncGenerator[Optional[object], None]:
|
||||||
"""
|
"""
|
||||||
Dependency to get database session.
|
Dependency to get database session.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
AsyncSession: Database session for async operations
|
AsyncSession: Database session for async operations
|
||||||
|
|
||||||
Example:
|
|
||||||
@app.get("/anime")
|
|
||||||
async def get_anime(db: AsyncSession = Depends(get_database_session)):
|
|
||||||
result = await db.execute(select(AnimeSeries))
|
|
||||||
return result.scalars().all()
|
|
||||||
"""
|
"""
|
||||||
try:
|
# TODO: Implement database session management
|
||||||
from src.server.database import get_db_session
|
# This is a placeholder for future database implementation
|
||||||
|
|
||||||
async with get_db_session() as session:
|
|
||||||
yield session
|
|
||||||
except ImportError:
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||||
detail="Database functionality not installed"
|
detail="Database functionality not yet implemented"
|
||||||
)
|
|
||||||
except RuntimeError as e:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
||||||
detail=f"Database not available: {str(e)}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -40,19 +40,10 @@ class AniWorldApp {
|
|||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// First check if we have a token
|
const response = await fetch('/api/auth/status');
|
||||||
const token = localStorage.getItem('access_token');
|
|
||||||
|
|
||||||
// Build request with token if available
|
|
||||||
const headers = {};
|
|
||||||
if (token) {
|
|
||||||
headers['Authorization'] = `Bearer ${token}`;
|
|
||||||
}
|
|
||||||
|
|
||||||
const response = await fetch('/api/auth/status', { headers });
|
|
||||||
const data = await response.json();
|
const data = await response.json();
|
||||||
|
|
||||||
if (!data.configured) {
|
if (!data.has_master_password) {
|
||||||
// No master password set, redirect to setup
|
// No master password set, redirect to setup
|
||||||
window.location.href = '/setup';
|
window.location.href = '/setup';
|
||||||
return;
|
return;
|
||||||
@ -60,58 +51,37 @@ class AniWorldApp {
|
|||||||
|
|
||||||
if (!data.authenticated) {
|
if (!data.authenticated) {
|
||||||
// Not authenticated, redirect to login
|
// Not authenticated, redirect to login
|
||||||
localStorage.removeItem('access_token');
|
|
||||||
localStorage.removeItem('token_expires_at');
|
|
||||||
window.location.href = '/login';
|
window.location.href = '/login';
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// User is authenticated, show logout button
|
// User is authenticated, show logout button if master password is set
|
||||||
const logoutBtn = document.getElementById('logout-btn');
|
if (data.has_master_password) {
|
||||||
if (logoutBtn) {
|
document.getElementById('logout-btn').style.display = 'block';
|
||||||
logoutBtn.style.display = 'block';
|
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Authentication check failed:', error);
|
console.error('Authentication check failed:', error);
|
||||||
// On error, clear token and redirect to login
|
// On error, assume we need to login
|
||||||
localStorage.removeItem('access_token');
|
|
||||||
localStorage.removeItem('token_expires_at');
|
|
||||||
window.location.href = '/login';
|
window.location.href = '/login';
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async logout() {
|
async logout() {
|
||||||
try {
|
try {
|
||||||
const response = await this.makeAuthenticatedRequest('/api/auth/logout', { method: 'POST' });
|
const response = await fetch('/api/auth/logout', { method: 'POST' });
|
||||||
|
|
||||||
// Clear tokens from localStorage
|
|
||||||
localStorage.removeItem('access_token');
|
|
||||||
localStorage.removeItem('token_expires_at');
|
|
||||||
|
|
||||||
if (response && response.ok) {
|
|
||||||
const data = await response.json();
|
const data = await response.json();
|
||||||
if (data.status === 'ok') {
|
|
||||||
this.showToast('Logged out successfully', 'success');
|
|
||||||
} else {
|
|
||||||
this.showToast('Logged out', 'success');
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Even if the API fails, we cleared the token locally
|
|
||||||
this.showToast('Logged out', 'success');
|
|
||||||
}
|
|
||||||
|
|
||||||
|
if (data.status === 'success') {
|
||||||
|
this.showToast('Logged out successfully', 'success');
|
||||||
setTimeout(() => {
|
setTimeout(() => {
|
||||||
window.location.href = '/login';
|
window.location.href = '/login';
|
||||||
}, 1000);
|
}, 1000);
|
||||||
|
} else {
|
||||||
|
this.showToast('Logout failed', 'error');
|
||||||
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Logout error:', error);
|
console.error('Logout error:', error);
|
||||||
// Clear token even on error
|
this.showToast('Logout failed', 'error');
|
||||||
localStorage.removeItem('access_token');
|
|
||||||
localStorage.removeItem('token_expires_at');
|
|
||||||
this.showToast('Logged out', 'success');
|
|
||||||
setTimeout(() => {
|
|
||||||
window.location.href = '/login';
|
|
||||||
}, 1000);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -564,31 +534,15 @@ class AniWorldApp {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async makeAuthenticatedRequest(url, options = {}) {
|
async makeAuthenticatedRequest(url, options = {}) {
|
||||||
// Get JWT token from localStorage
|
// Ensure credentials are included for session-based authentication
|
||||||
const token = localStorage.getItem('access_token');
|
|
||||||
|
|
||||||
// Check if token exists
|
|
||||||
if (!token) {
|
|
||||||
window.location.href = '/login';
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Include Authorization header with Bearer token
|
|
||||||
const requestOptions = {
|
const requestOptions = {
|
||||||
credentials: 'same-origin',
|
credentials: 'same-origin',
|
||||||
...options,
|
...options
|
||||||
headers: {
|
|
||||||
'Authorization': `Bearer ${token}`,
|
|
||||||
...options.headers
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const response = await fetch(url, requestOptions);
|
const response = await fetch(url, requestOptions);
|
||||||
|
|
||||||
if (response.status === 401) {
|
if (response.status === 401) {
|
||||||
// Token is invalid or expired, clear it and redirect to login
|
|
||||||
localStorage.removeItem('access_token');
|
|
||||||
localStorage.removeItem('token_expires_at');
|
|
||||||
window.location.href = '/login';
|
window.location.href = '/login';
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
@ -1889,16 +1843,20 @@ class AniWorldApp {
|
|||||||
if (!this.isDownloading || this.isPaused) return;
|
if (!this.isDownloading || this.isPaused) return;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const response = await this.makeAuthenticatedRequest('/api/queue/pause', { method: 'POST' });
|
const response = await this.makeAuthenticatedRequest('/api/download/pause', { method: 'POST' });
|
||||||
if (!response) return;
|
if (!response) return;
|
||||||
const data = await response.json();
|
const data = await response.json();
|
||||||
|
|
||||||
|
if (data.status === 'success') {
|
||||||
document.getElementById('pause-download').classList.add('hidden');
|
document.getElementById('pause-download').classList.add('hidden');
|
||||||
document.getElementById('resume-download').classList.remove('hidden');
|
document.getElementById('resume-download').classList.remove('hidden');
|
||||||
this.showToast('Queue paused', 'warning');
|
this.showToast('Download paused', 'warning');
|
||||||
|
} else {
|
||||||
|
this.showToast(`Pause failed: ${data.message}`, 'error');
|
||||||
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Pause error:', error);
|
console.error('Pause error:', error);
|
||||||
this.showToast('Failed to pause queue', 'error');
|
this.showToast('Failed to pause download', 'error');
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1906,32 +1864,40 @@ class AniWorldApp {
|
|||||||
if (!this.isDownloading || !this.isPaused) return;
|
if (!this.isDownloading || !this.isPaused) return;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const response = await this.makeAuthenticatedRequest('/api/queue/resume', { method: 'POST' });
|
const response = await this.makeAuthenticatedRequest('/api/download/resume', { method: 'POST' });
|
||||||
if (!response) return;
|
if (!response) return;
|
||||||
const data = await response.json();
|
const data = await response.json();
|
||||||
|
|
||||||
|
if (data.status === 'success') {
|
||||||
document.getElementById('pause-download').classList.remove('hidden');
|
document.getElementById('pause-download').classList.remove('hidden');
|
||||||
document.getElementById('resume-download').classList.add('hidden');
|
document.getElementById('resume-download').classList.add('hidden');
|
||||||
this.showToast('Queue resumed', 'success');
|
this.showToast('Download resumed', 'success');
|
||||||
|
} else {
|
||||||
|
this.showToast(`Resume failed: ${data.message}`, 'error');
|
||||||
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Resume error:', error);
|
console.error('Resume error:', error);
|
||||||
this.showToast('Failed to resume queue', 'error');
|
this.showToast('Failed to resume download', 'error');
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async cancelDownload() {
|
async cancelDownload() {
|
||||||
if (!this.isDownloading) return;
|
if (!this.isDownloading) return;
|
||||||
|
|
||||||
if (confirm('Are you sure you want to stop the download queue?')) {
|
if (confirm('Are you sure you want to cancel the download?')) {
|
||||||
try {
|
try {
|
||||||
const response = await this.makeAuthenticatedRequest('/api/queue/stop', { method: 'POST' });
|
const response = await this.makeAuthenticatedRequest('/api/download/cancel', { method: 'POST' });
|
||||||
if (!response) return;
|
if (!response) return;
|
||||||
const data = await response.json();
|
const data = await response.json();
|
||||||
|
|
||||||
this.showToast('Queue stopped', 'warning');
|
if (data.status === 'success') {
|
||||||
|
this.showToast('Download cancelled', 'warning');
|
||||||
|
} else {
|
||||||
|
this.showToast(`Cancel failed: ${data.message}`, 'error');
|
||||||
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Stop error:', error);
|
console.error('Cancel error:', error);
|
||||||
this.showToast('Failed to stop queue', 'error');
|
this.showToast('Failed to cancel download', 'error');
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -482,20 +482,20 @@ class QueueManager {
|
|||||||
if (!confirmed) return;
|
if (!confirmed) return;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
if (type === 'completed') {
|
const response = await this.makeAuthenticatedRequest('/api/queue/clear', {
|
||||||
// Use the new DELETE /api/queue/completed endpoint
|
method: 'POST',
|
||||||
const response = await this.makeAuthenticatedRequest('/api/queue/completed', {
|
headers: { 'Content-Type': 'application/json' },
|
||||||
method: 'DELETE'
|
body: JSON.stringify({ type })
|
||||||
});
|
});
|
||||||
|
|
||||||
if (!response) return;
|
if (!response) return;
|
||||||
const data = await response.json();
|
const data = await response.json();
|
||||||
|
|
||||||
this.showToast(`Cleared ${data.cleared_count} completed downloads`, 'success');
|
if (data.status === 'success') {
|
||||||
|
this.showToast(data.message, 'success');
|
||||||
this.loadQueueData();
|
this.loadQueueData();
|
||||||
} else {
|
} else {
|
||||||
// For pending and failed, use the old logic (TODO: implement backend endpoints)
|
this.showToast(data.message, 'error');
|
||||||
this.showToast(`Clear ${type} not yet implemented`, 'warning');
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
@ -509,14 +509,18 @@ class QueueManager {
|
|||||||
const response = await this.makeAuthenticatedRequest('/api/queue/retry', {
|
const response = await this.makeAuthenticatedRequest('/api/queue/retry', {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: { 'Content-Type': 'application/json' },
|
headers: { 'Content-Type': 'application/json' },
|
||||||
body: JSON.stringify({ item_ids: [downloadId] }) // New API expects item_ids array
|
body: JSON.stringify({ id: downloadId })
|
||||||
});
|
});
|
||||||
|
|
||||||
if (!response) return;
|
if (!response) return;
|
||||||
const data = await response.json();
|
const data = await response.json();
|
||||||
|
|
||||||
this.showToast(`Retried ${data.retried_count} download(s)`, 'success');
|
if (data.status === 'success') {
|
||||||
|
this.showToast('Download added back to queue', 'success');
|
||||||
this.loadQueueData();
|
this.loadQueueData();
|
||||||
|
} else {
|
||||||
|
this.showToast(data.message, 'error');
|
||||||
|
}
|
||||||
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Error retrying download:', error);
|
console.error('Error retrying download:', error);
|
||||||
@ -541,13 +545,16 @@ class QueueManager {
|
|||||||
|
|
||||||
async removeFromQueue(downloadId) {
|
async removeFromQueue(downloadId) {
|
||||||
try {
|
try {
|
||||||
const response = await this.makeAuthenticatedRequest(`/api/queue/${downloadId}`, {
|
const response = await this.makeAuthenticatedRequest('/api/queue/remove', {
|
||||||
method: 'DELETE'
|
method: 'POST',
|
||||||
|
headers: { 'Content-Type': 'application/json' },
|
||||||
|
body: JSON.stringify({ id: downloadId })
|
||||||
});
|
});
|
||||||
|
|
||||||
if (!response) return;
|
if (!response) return;
|
||||||
|
const data = await response.json();
|
||||||
|
|
||||||
if (response.status === 204) {
|
if (data.status === 'success') {
|
||||||
this.showToast('Download removed from queue', 'success');
|
this.showToast('Download removed from queue', 'success');
|
||||||
this.loadQueueData();
|
this.loadQueueData();
|
||||||
} else {
|
} else {
|
||||||
@ -637,31 +644,15 @@ class QueueManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async makeAuthenticatedRequest(url, options = {}) {
|
async makeAuthenticatedRequest(url, options = {}) {
|
||||||
// Get JWT token from localStorage
|
// Ensure credentials are included for session-based authentication
|
||||||
const token = localStorage.getItem('access_token');
|
|
||||||
|
|
||||||
// Check if token exists
|
|
||||||
if (!token) {
|
|
||||||
window.location.href = '/login';
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Include Authorization header with Bearer token
|
|
||||||
const requestOptions = {
|
const requestOptions = {
|
||||||
credentials: 'same-origin',
|
credentials: 'same-origin',
|
||||||
...options,
|
...options
|
||||||
headers: {
|
|
||||||
'Authorization': `Bearer ${token}`,
|
|
||||||
...options.headers
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const response = await fetch(url, requestOptions);
|
const response = await fetch(url, requestOptions);
|
||||||
|
|
||||||
if (response.status === 401) {
|
if (response.status === 401) {
|
||||||
// Token is invalid or expired, clear it and redirect to login
|
|
||||||
localStorage.removeItem('access_token');
|
|
||||||
localStorage.removeItem('token_expires_at');
|
|
||||||
window.location.href = '/login';
|
window.location.href = '/login';
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -323,19 +323,13 @@
|
|||||||
|
|
||||||
const data = await response.json();
|
const data = await response.json();
|
||||||
|
|
||||||
if (response.ok && data.access_token) {
|
if (data.status === 'success') {
|
||||||
// Store JWT token in localStorage
|
showMessage(data.message, 'success');
|
||||||
localStorage.setItem('access_token', data.access_token);
|
|
||||||
if (data.expires_at) {
|
|
||||||
localStorage.setItem('token_expires_at', data.expires_at);
|
|
||||||
}
|
|
||||||
showMessage('Login successful', 'success');
|
|
||||||
setTimeout(() => {
|
setTimeout(() => {
|
||||||
window.location.href = '/';
|
window.location.href = '/';
|
||||||
}, 1000);
|
}, 1000);
|
||||||
} else {
|
} else {
|
||||||
const errorMessage = data.detail || data.message || 'Invalid credentials';
|
showMessage(data.message, 'error');
|
||||||
showMessage(errorMessage, 'error');
|
|
||||||
passwordInput.value = '';
|
passwordInput.value = '';
|
||||||
passwordInput.focus();
|
passwordInput.focus();
|
||||||
}
|
}
|
||||||
|
|||||||
@ -503,20 +503,22 @@
|
|||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
},
|
},
|
||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
master_password: password
|
password,
|
||||||
|
directory
|
||||||
})
|
})
|
||||||
});
|
});
|
||||||
|
|
||||||
const data = await response.json();
|
const data = await response.json();
|
||||||
|
|
||||||
if (response.ok && data.status === 'ok') {
|
if (data.status === 'success') {
|
||||||
showMessage('Setup completed successfully! Redirecting to login...', 'success');
|
showMessage('Setup completed successfully! Redirecting...', 'success');
|
||||||
setTimeout(() => {
|
setTimeout(() => {
|
||||||
window.location.href = '/login';
|
// Use redirect_url from API response, fallback to /login
|
||||||
|
const redirectUrl = data.redirect_url || '/login';
|
||||||
|
window.location.href = redirectUrl;
|
||||||
}, 2000);
|
}, 2000);
|
||||||
} else {
|
} else {
|
||||||
const errorMessage = data.detail || data.message || 'Setup failed';
|
showMessage(data.message, 'error');
|
||||||
showMessage(errorMessage, 'error');
|
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
showMessage('Setup failed. Please try again.', 'error');
|
showMessage('Setup failed. Please try again.', 'error');
|
||||||
|
|||||||
@ -1,52 +1,12 @@
|
|||||||
"""Integration tests for configuration API endpoints."""
|
|
||||||
|
|
||||||
import tempfile
|
|
||||||
from pathlib import Path
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
from src.server.fastapi_app import app
|
from src.server.fastapi_app import app
|
||||||
from src.server.models.config import AppConfig
|
from src.server.models.config import AppConfig, SchedulerConfig
|
||||||
from src.server.services.config_service import ConfigService
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
def test_get_config_public():
|
||||||
def temp_config_dir():
|
|
||||||
"""Create temporary directory for test config files."""
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
|
||||||
yield Path(tmpdir)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def config_service(temp_config_dir):
|
|
||||||
"""Create ConfigService instance with temporary paths."""
|
|
||||||
config_path = temp_config_dir / "config.json"
|
|
||||||
backup_dir = temp_config_dir / "backups"
|
|
||||||
return ConfigService(
|
|
||||||
config_path=config_path, backup_dir=backup_dir, max_backups=3
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_config_service(config_service):
|
|
||||||
"""Mock get_config_service to return test instance."""
|
|
||||||
with patch(
|
|
||||||
"src.server.api.config.get_config_service",
|
|
||||||
return_value=config_service
|
|
||||||
):
|
|
||||||
yield config_service
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def client():
|
|
||||||
"""Create test client."""
|
|
||||||
return TestClient(app)
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_config_public(client, mock_config_service):
|
|
||||||
"""Test getting configuration."""
|
|
||||||
resp = client.get("/api/config")
|
resp = client.get("/api/config")
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
@ -54,8 +14,7 @@ def test_get_config_public(client, mock_config_service):
|
|||||||
assert "data_dir" in data
|
assert "data_dir" in data
|
||||||
|
|
||||||
|
|
||||||
def test_validate_config(client, mock_config_service):
|
def test_validate_config():
|
||||||
"""Test configuration validation."""
|
|
||||||
cfg = {
|
cfg = {
|
||||||
"name": "Aniworld",
|
"name": "Aniworld",
|
||||||
"data_dir": "data",
|
"data_dir": "data",
|
||||||
@ -70,95 +29,8 @@ def test_validate_config(client, mock_config_service):
|
|||||||
assert body.get("valid") is True
|
assert body.get("valid") is True
|
||||||
|
|
||||||
|
|
||||||
def test_validate_invalid_config(client, mock_config_service):
|
def test_update_config_unauthorized():
|
||||||
"""Test validation of invalid configuration."""
|
# update requires auth; without auth should be 401
|
||||||
cfg = {
|
|
||||||
"name": "Aniworld",
|
|
||||||
"backup": {"enabled": True, "path": None}, # Invalid
|
|
||||||
}
|
|
||||||
resp = client.post("/api/config/validate", json=cfg)
|
|
||||||
assert resp.status_code == 200
|
|
||||||
body = resp.json()
|
|
||||||
assert body.get("valid") is False
|
|
||||||
assert len(body.get("errors", [])) > 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_update_config_unauthorized(client):
|
|
||||||
"""Test that update requires authentication."""
|
|
||||||
update = {"scheduler": {"enabled": False}}
|
update = {"scheduler": {"enabled": False}}
|
||||||
resp = client.put("/api/config", json=update)
|
resp = client.put("/api/config", json=update)
|
||||||
assert resp.status_code in (401, 422)
|
assert resp.status_code in (401, 422)
|
||||||
|
|
||||||
|
|
||||||
def test_list_backups(client, mock_config_service):
|
|
||||||
"""Test listing configuration backups."""
|
|
||||||
# Create a sample config first
|
|
||||||
sample_config = AppConfig(name="TestApp", data_dir="test_data")
|
|
||||||
mock_config_service.save_config(sample_config, create_backup=False)
|
|
||||||
mock_config_service.create_backup(name="test_backup")
|
|
||||||
|
|
||||||
resp = client.get("/api/config/backups")
|
|
||||||
assert resp.status_code == 200
|
|
||||||
backups = resp.json()
|
|
||||||
assert isinstance(backups, list)
|
|
||||||
if len(backups) > 0:
|
|
||||||
assert "name" in backups[0]
|
|
||||||
assert "size_bytes" in backups[0]
|
|
||||||
assert "created_at" in backups[0]
|
|
||||||
|
|
||||||
|
|
||||||
def test_create_backup(client, mock_config_service):
|
|
||||||
"""Test creating a configuration backup."""
|
|
||||||
# Create a sample config first
|
|
||||||
sample_config = AppConfig(name="TestApp", data_dir="test_data")
|
|
||||||
mock_config_service.save_config(sample_config, create_backup=False)
|
|
||||||
|
|
||||||
resp = client.post("/api/config/backups")
|
|
||||||
assert resp.status_code == 200
|
|
||||||
data = resp.json()
|
|
||||||
assert "name" in data
|
|
||||||
assert "message" in data
|
|
||||||
|
|
||||||
|
|
||||||
def test_restore_backup(client, mock_config_service):
|
|
||||||
"""Test restoring configuration from backup."""
|
|
||||||
# Create initial config and backup
|
|
||||||
sample_config = AppConfig(name="TestApp", data_dir="test_data")
|
|
||||||
mock_config_service.save_config(sample_config, create_backup=False)
|
|
||||||
mock_config_service.create_backup(name="restore_test")
|
|
||||||
|
|
||||||
# Modify config
|
|
||||||
sample_config.name = "Modified"
|
|
||||||
mock_config_service.save_config(sample_config, create_backup=False)
|
|
||||||
|
|
||||||
# Restore from backup
|
|
||||||
resp = client.post("/api/config/backups/restore_test.json/restore")
|
|
||||||
assert resp.status_code == 200
|
|
||||||
data = resp.json()
|
|
||||||
assert data["name"] == "TestApp" # Original name restored
|
|
||||||
|
|
||||||
|
|
||||||
def test_delete_backup(client, mock_config_service):
|
|
||||||
"""Test deleting a configuration backup."""
|
|
||||||
# Create a sample config and backup
|
|
||||||
sample_config = AppConfig(name="TestApp", data_dir="test_data")
|
|
||||||
mock_config_service.save_config(sample_config, create_backup=False)
|
|
||||||
mock_config_service.create_backup(name="delete_test")
|
|
||||||
|
|
||||||
resp = client.delete("/api/config/backups/delete_test.json")
|
|
||||||
assert resp.status_code == 200
|
|
||||||
data = resp.json()
|
|
||||||
assert "deleted successfully" in data["message"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_config_persistence(client, mock_config_service):
|
|
||||||
"""Test end-to-end configuration persistence."""
|
|
||||||
# Get initial config
|
|
||||||
resp = client.get("/api/config")
|
|
||||||
assert resp.status_code == 200
|
|
||||||
initial = resp.json()
|
|
||||||
|
|
||||||
# Validate it can be loaded again
|
|
||||||
resp2 = client.get("/api/config")
|
|
||||||
assert resp2.status_code == 200
|
|
||||||
assert resp2.json() == initial
|
|
||||||
|
|||||||
@ -1,238 +0,0 @@
|
|||||||
"""
|
|
||||||
Tests for frontend authentication integration.
|
|
||||||
|
|
||||||
These smoke tests verify that the key authentication and API endpoints
|
|
||||||
work correctly with JWT tokens as expected by the frontend.
|
|
||||||
"""
|
|
||||||
import pytest
|
|
||||||
from httpx import ASGITransport, AsyncClient
|
|
||||||
|
|
||||||
from src.server.fastapi_app import app
|
|
||||||
from src.server.services.auth_service import auth_service
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def reset_auth():
|
|
||||||
"""Reset authentication state before each test."""
|
|
||||||
# Reset auth service state
|
|
||||||
original_hash = auth_service._hash
|
|
||||||
auth_service._hash = None
|
|
||||||
auth_service._failed.clear()
|
|
||||||
yield
|
|
||||||
# Restore
|
|
||||||
auth_service._hash = original_hash
|
|
||||||
auth_service._failed.clear()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def client():
|
|
||||||
"""Create an async test client."""
|
|
||||||
transport = ASGITransport(app=app)
|
|
||||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
|
||||||
yield ac
|
|
||||||
|
|
||||||
|
|
||||||
class TestFrontendAuthIntegration:
|
|
||||||
"""Test authentication integration matching frontend expectations."""
|
|
||||||
|
|
||||||
async def test_setup_returns_ok_status(self, client):
|
|
||||||
"""Test setup endpoint returns expected format for frontend."""
|
|
||||||
response = await client.post(
|
|
||||||
"/api/auth/setup",
|
|
||||||
json={"master_password": "StrongP@ss123"}
|
|
||||||
)
|
|
||||||
assert response.status_code == 201
|
|
||||||
data = response.json()
|
|
||||||
# Frontend expects 'status': 'ok'
|
|
||||||
assert data["status"] == "ok"
|
|
||||||
|
|
||||||
async def test_login_returns_access_token(self, client):
|
|
||||||
"""Test login flow and verify JWT token is returned."""
|
|
||||||
# Setup master password first
|
|
||||||
client.post("/api/auth/setup", json={"master_password": "StrongP@ss123"})
|
|
||||||
|
|
||||||
# Login with correct password
|
|
||||||
response = client.post(
|
|
||||||
"/api/auth/login",
|
|
||||||
json={"password": "StrongP@ss123"}
|
|
||||||
)
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
|
|
||||||
# Verify token is returned
|
|
||||||
assert "access_token" in data
|
|
||||||
assert data["token_type"] == "bearer"
|
|
||||||
assert "expires_at" in data
|
|
||||||
|
|
||||||
# Verify token can be used for authenticated requests
|
|
||||||
token = data["access_token"]
|
|
||||||
headers = {"Authorization": f"Bearer {token}"}
|
|
||||||
response = client.get("/api/auth/status", headers=headers)
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert data["authenticated"] is True
|
|
||||||
|
|
||||||
def test_login_with_wrong_password(self, client):
|
|
||||||
"""Test login with incorrect password."""
|
|
||||||
# Setup master password first
|
|
||||||
client.post("/api/auth/setup", json={"master_password": "StrongP@ss123"})
|
|
||||||
|
|
||||||
# Login with wrong password
|
|
||||||
response = client.post(
|
|
||||||
"/api/auth/login",
|
|
||||||
json={"password": "WrongPassword"}
|
|
||||||
)
|
|
||||||
assert response.status_code == 401
|
|
||||||
data = response.json()
|
|
||||||
assert "detail" in data
|
|
||||||
|
|
||||||
def test_logout_clears_session(self, client):
|
|
||||||
"""Test logout functionality."""
|
|
||||||
# Setup and login
|
|
||||||
client.post("/api/auth/setup", json={"master_password": "StrongP@ss123"})
|
|
||||||
login_response = client.post(
|
|
||||||
"/api/auth/login",
|
|
||||||
json={"password": "StrongP@ss123"}
|
|
||||||
)
|
|
||||||
token = login_response.json()["access_token"]
|
|
||||||
headers = {"Authorization": f"Bearer {token}"}
|
|
||||||
|
|
||||||
# Logout
|
|
||||||
response = client.post("/api/auth/logout", headers=headers)
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json()["status"] == "ok"
|
|
||||||
|
|
||||||
def test_authenticated_request_without_token_returns_401(self, client):
|
|
||||||
"""Test that authenticated endpoints reject requests without tokens."""
|
|
||||||
# Setup master password
|
|
||||||
client.post("/api/auth/setup", json={"master_password": "StrongP@ss123"})
|
|
||||||
|
|
||||||
# Try to access authenticated endpoint without token
|
|
||||||
response = client.get("/api/v1/anime")
|
|
||||||
assert response.status_code == 401
|
|
||||||
|
|
||||||
def test_authenticated_request_with_invalid_token_returns_401(self, client):
|
|
||||||
"""Test that authenticated endpoints reject invalid tokens."""
|
|
||||||
# Setup master password
|
|
||||||
client.post("/api/auth/setup", json={"master_password": "StrongP@ss123"})
|
|
||||||
|
|
||||||
# Try to access authenticated endpoint with invalid token
|
|
||||||
headers = {"Authorization": "Bearer invalid_token_here"}
|
|
||||||
response = client.get("/api/v1/anime", headers=headers)
|
|
||||||
assert response.status_code == 401
|
|
||||||
|
|
||||||
def test_remember_me_extends_token_expiry(self, client):
|
|
||||||
"""Test that remember_me flag affects token expiry."""
|
|
||||||
# Setup master password
|
|
||||||
client.post("/api/auth/setup", json={"master_password": "StrongP@ss123"})
|
|
||||||
|
|
||||||
# Login without remember me
|
|
||||||
response1 = client.post(
|
|
||||||
"/api/auth/login",
|
|
||||||
json={"password": "StrongP@ss123", "remember": False}
|
|
||||||
)
|
|
||||||
data1 = response1.json()
|
|
||||||
|
|
||||||
# Login with remember me
|
|
||||||
response2 = client.post(
|
|
||||||
"/api/auth/login",
|
|
||||||
json={"password": "StrongP@ss123", "remember": True}
|
|
||||||
)
|
|
||||||
data2 = response2.json()
|
|
||||||
|
|
||||||
# Both should return tokens with expiry
|
|
||||||
assert "expires_at" in data1
|
|
||||||
assert "expires_at" in data2
|
|
||||||
|
|
||||||
def test_setup_fails_if_already_configured(self, client):
|
|
||||||
"""Test that setup fails if master password is already set."""
|
|
||||||
# Setup once
|
|
||||||
client.post("/api/auth/setup", json={"master_password": "StrongP@ss123"})
|
|
||||||
|
|
||||||
# Try to setup again
|
|
||||||
response = client.post(
|
|
||||||
"/api/auth/setup",
|
|
||||||
json={"master_password": "AnotherPassword123!"}
|
|
||||||
)
|
|
||||||
assert response.status_code == 400
|
|
||||||
assert "already configured" in response.json()["detail"].lower()
|
|
||||||
|
|
||||||
def test_weak_password_validation_in_setup(self, client):
|
|
||||||
"""Test that setup rejects weak passwords."""
|
|
||||||
# Try with short password
|
|
||||||
response = client.post(
|
|
||||||
"/api/auth/setup",
|
|
||||||
json={"master_password": "short"}
|
|
||||||
)
|
|
||||||
assert response.status_code == 400
|
|
||||||
|
|
||||||
# Try with all lowercase
|
|
||||||
response = client.post(
|
|
||||||
"/api/auth/setup",
|
|
||||||
json={"master_password": "alllowercase"}
|
|
||||||
)
|
|
||||||
assert response.status_code == 400
|
|
||||||
|
|
||||||
# Try without special characters
|
|
||||||
response = client.post(
|
|
||||||
"/api/auth/setup",
|
|
||||||
json={"master_password": "NoSpecialChars123"}
|
|
||||||
)
|
|
||||||
assert response.status_code == 400
|
|
||||||
|
|
||||||
|
|
||||||
class TestTokenAuthenticationFlow:
|
|
||||||
"""Test JWT token-based authentication workflow."""
|
|
||||||
|
|
||||||
def test_full_authentication_workflow(self, client):
|
|
||||||
"""Test complete authentication workflow with token management."""
|
|
||||||
# 1. Check initial status
|
|
||||||
response = client.get("/api/auth/status")
|
|
||||||
assert not response.json()["configured"]
|
|
||||||
|
|
||||||
# 2. Setup master password
|
|
||||||
client.post("/api/auth/setup", json={"master_password": "StrongP@ss123"})
|
|
||||||
|
|
||||||
# 3. Login and get token
|
|
||||||
response = client.post(
|
|
||||||
"/api/auth/login",
|
|
||||||
json={"password": "StrongP@ss123"}
|
|
||||||
)
|
|
||||||
token = response.json()["access_token"]
|
|
||||||
headers = {"Authorization": f"Bearer {token}"}
|
|
||||||
|
|
||||||
# 4. Access authenticated endpoint
|
|
||||||
response = client.get("/api/auth/status", headers=headers)
|
|
||||||
assert response.json()["authenticated"] is True
|
|
||||||
|
|
||||||
# 5. Logout
|
|
||||||
response = client.post("/api/auth/logout", headers=headers)
|
|
||||||
assert response.json()["status"] == "ok"
|
|
||||||
|
|
||||||
def test_token_included_in_all_authenticated_requests(self, client):
|
|
||||||
"""Test that token must be included in authenticated API requests."""
|
|
||||||
# Setup and login
|
|
||||||
client.post("/api/auth/setup", json={"master_password": "StrongP@ss123"})
|
|
||||||
response = client.post(
|
|
||||||
"/api/auth/login",
|
|
||||||
json={"password": "StrongP@ss123"}
|
|
||||||
)
|
|
||||||
token = response.json()["access_token"]
|
|
||||||
headers = {"Authorization": f"Bearer {token}"}
|
|
||||||
|
|
||||||
# Test various authenticated endpoints
|
|
||||||
endpoints = [
|
|
||||||
"/api/v1/anime",
|
|
||||||
"/api/queue/status",
|
|
||||||
"/api/config",
|
|
||||||
]
|
|
||||||
|
|
||||||
for endpoint in endpoints:
|
|
||||||
# Without token - should fail
|
|
||||||
response = client.get(endpoint)
|
|
||||||
assert response.status_code == 401, f"Endpoint {endpoint} should require auth"
|
|
||||||
|
|
||||||
# With token - should work or return expected response
|
|
||||||
response = client.get(endpoint, headers=headers)
|
|
||||||
# Some endpoints may return 503 if services not configured, that's ok
|
|
||||||
assert response.status_code in [200, 503], f"Endpoint {endpoint} failed with token"
|
|
||||||
@ -1,97 +0,0 @@
|
|||||||
"""
|
|
||||||
Smoke tests for frontend-backend integration.
|
|
||||||
|
|
||||||
These tests verify that key authentication and API changes work correctly
|
|
||||||
with the frontend's expectations for JWT tokens.
|
|
||||||
"""
|
|
||||||
import pytest
|
|
||||||
from httpx import ASGITransport, AsyncClient
|
|
||||||
|
|
||||||
from src.server.fastapi_app import app
|
|
||||||
from src.server.services.auth_service import auth_service
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def reset_auth():
|
|
||||||
"""Reset authentication state."""
|
|
||||||
auth_service._hash = None
|
|
||||||
auth_service._failed.clear()
|
|
||||||
yield
|
|
||||||
auth_service._hash = None
|
|
||||||
auth_service._failed.clear()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def client():
|
|
||||||
"""Create async test client."""
|
|
||||||
transport = ASGITransport(app=app)
|
|
||||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
|
||||||
yield ac
|
|
||||||
|
|
||||||
|
|
||||||
class TestFrontendIntegration:
|
|
||||||
"""Test frontend integration with JWT authentication."""
|
|
||||||
|
|
||||||
async def test_login_returns_jwt_token(self, client):
|
|
||||||
"""Test that login returns JWT token in expected format."""
|
|
||||||
# Setup
|
|
||||||
await client.post(
|
|
||||||
"/api/auth/setup",
|
|
||||||
json={"master_password": "StrongP@ss123"}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Login
|
|
||||||
response = await client.post(
|
|
||||||
"/api/auth/login",
|
|
||||||
json={"password": "StrongP@ss123"}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
|
|
||||||
# Frontend expects these fields
|
|
||||||
assert "access_token" in data
|
|
||||||
assert "token_type" in data
|
|
||||||
assert data["token_type"] == "bearer"
|
|
||||||
|
|
||||||
async def test_authenticated_endpoints_require_bearer_token(self, client):
|
|
||||||
"""Test that authenticated endpoints require Bearer token."""
|
|
||||||
# Setup and login
|
|
||||||
await client.post(
|
|
||||||
"/api/auth/setup",
|
|
||||||
json={"master_password": "StrongP@ss123"}
|
|
||||||
)
|
|
||||||
login_resp = await client.post(
|
|
||||||
"/api/auth/login",
|
|
||||||
json={"password": "StrongP@ss123"}
|
|
||||||
)
|
|
||||||
token = login_resp.json()["access_token"]
|
|
||||||
|
|
||||||
# Test without token - should fail
|
|
||||||
response = await client.get("/api/v1/anime")
|
|
||||||
assert response.status_code == 401
|
|
||||||
|
|
||||||
# Test with Bearer token in header - should work or return 503
|
|
||||||
headers = {"Authorization": f"Bearer {token}"}
|
|
||||||
response = await client.get("/api/v1/anime", headers=headers)
|
|
||||||
# May return 503 if anime directory not configured
|
|
||||||
assert response.status_code in [200, 503]
|
|
||||||
|
|
||||||
async def test_queue_endpoints_accessible_with_token(self, client):
|
|
||||||
"""Test queue endpoints work with JWT token."""
|
|
||||||
# Setup and login
|
|
||||||
await client.post(
|
|
||||||
"/api/auth/setup",
|
|
||||||
json={"master_password": "StrongP@ss123"}
|
|
||||||
)
|
|
||||||
login_resp = await client.post(
|
|
||||||
"/api/auth/login",
|
|
||||||
json={"password": "StrongP@ss123"}
|
|
||||||
)
|
|
||||||
token = login_resp.json()["access_token"]
|
|
||||||
headers = {"Authorization": f"Bearer {token}"}
|
|
||||||
|
|
||||||
# Test queue status endpoint
|
|
||||||
response = await client.get("/api/queue/status", headers=headers)
|
|
||||||
# Should work or return 503 if service not configured
|
|
||||||
assert response.status_code in [200, 503]
|
|
||||||
@ -1,420 +0,0 @@
|
|||||||
"""
|
|
||||||
Unit tests for the progress callback system.
|
|
||||||
|
|
||||||
Tests the callback interfaces, context classes, and callback manager
|
|
||||||
functionality.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
from src.core.interfaces.callbacks import (
|
|
||||||
CallbackManager,
|
|
||||||
CompletionCallback,
|
|
||||||
CompletionContext,
|
|
||||||
ErrorCallback,
|
|
||||||
ErrorContext,
|
|
||||||
OperationType,
|
|
||||||
ProgressCallback,
|
|
||||||
ProgressContext,
|
|
||||||
ProgressPhase,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestProgressContext(unittest.TestCase):
|
|
||||||
"""Test ProgressContext dataclass."""
|
|
||||||
|
|
||||||
def test_progress_context_creation(self):
|
|
||||||
"""Test creating a progress context."""
|
|
||||||
context = ProgressContext(
|
|
||||||
operation_type=OperationType.DOWNLOAD,
|
|
||||||
operation_id="test-123",
|
|
||||||
phase=ProgressPhase.IN_PROGRESS,
|
|
||||||
current=50,
|
|
||||||
total=100,
|
|
||||||
percentage=50.0,
|
|
||||||
message="Downloading...",
|
|
||||||
details="Episode 5",
|
|
||||||
metadata={"series": "Test"}
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEqual(context.operation_type, OperationType.DOWNLOAD)
|
|
||||||
self.assertEqual(context.operation_id, "test-123")
|
|
||||||
self.assertEqual(context.phase, ProgressPhase.IN_PROGRESS)
|
|
||||||
self.assertEqual(context.current, 50)
|
|
||||||
self.assertEqual(context.total, 100)
|
|
||||||
self.assertEqual(context.percentage, 50.0)
|
|
||||||
self.assertEqual(context.message, "Downloading...")
|
|
||||||
self.assertEqual(context.details, "Episode 5")
|
|
||||||
self.assertEqual(context.metadata, {"series": "Test"})
|
|
||||||
|
|
||||||
def test_progress_context_to_dict(self):
|
|
||||||
"""Test converting progress context to dictionary."""
|
|
||||||
context = ProgressContext(
|
|
||||||
operation_type=OperationType.SCAN,
|
|
||||||
operation_id="scan-456",
|
|
||||||
phase=ProgressPhase.COMPLETED,
|
|
||||||
current=100,
|
|
||||||
total=100,
|
|
||||||
percentage=100.0,
|
|
||||||
message="Scan complete"
|
|
||||||
)
|
|
||||||
|
|
||||||
result = context.to_dict()
|
|
||||||
|
|
||||||
self.assertEqual(result["operation_type"], "scan")
|
|
||||||
self.assertEqual(result["operation_id"], "scan-456")
|
|
||||||
self.assertEqual(result["phase"], "completed")
|
|
||||||
self.assertEqual(result["current"], 100)
|
|
||||||
self.assertEqual(result["total"], 100)
|
|
||||||
self.assertEqual(result["percentage"], 100.0)
|
|
||||||
self.assertEqual(result["message"], "Scan complete")
|
|
||||||
self.assertIsNone(result["details"])
|
|
||||||
self.assertEqual(result["metadata"], {})
|
|
||||||
|
|
||||||
def test_progress_context_default_metadata(self):
|
|
||||||
"""Test that metadata defaults to empty dict."""
|
|
||||||
context = ProgressContext(
|
|
||||||
operation_type=OperationType.DOWNLOAD,
|
|
||||||
operation_id="test",
|
|
||||||
phase=ProgressPhase.STARTING,
|
|
||||||
current=0,
|
|
||||||
total=100,
|
|
||||||
percentage=0.0,
|
|
||||||
message="Starting"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertIsNotNone(context.metadata)
|
|
||||||
self.assertEqual(context.metadata, {})
|
|
||||||
|
|
||||||
|
|
||||||
class TestErrorContext(unittest.TestCase):
|
|
||||||
"""Test ErrorContext dataclass."""
|
|
||||||
|
|
||||||
def test_error_context_creation(self):
|
|
||||||
"""Test creating an error context."""
|
|
||||||
error = ValueError("Test error")
|
|
||||||
context = ErrorContext(
|
|
||||||
operation_type=OperationType.DOWNLOAD,
|
|
||||||
operation_id="test-789",
|
|
||||||
error=error,
|
|
||||||
message="Download failed",
|
|
||||||
recoverable=True,
|
|
||||||
retry_count=2,
|
|
||||||
metadata={"attempt": 3}
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEqual(context.operation_type, OperationType.DOWNLOAD)
|
|
||||||
self.assertEqual(context.operation_id, "test-789")
|
|
||||||
self.assertEqual(context.error, error)
|
|
||||||
self.assertEqual(context.message, "Download failed")
|
|
||||||
self.assertTrue(context.recoverable)
|
|
||||||
self.assertEqual(context.retry_count, 2)
|
|
||||||
self.assertEqual(context.metadata, {"attempt": 3})
|
|
||||||
|
|
||||||
def test_error_context_to_dict(self):
|
|
||||||
"""Test converting error context to dictionary."""
|
|
||||||
error = RuntimeError("Network error")
|
|
||||||
context = ErrorContext(
|
|
||||||
operation_type=OperationType.SCAN,
|
|
||||||
operation_id="scan-error",
|
|
||||||
error=error,
|
|
||||||
message="Scan error occurred",
|
|
||||||
recoverable=False
|
|
||||||
)
|
|
||||||
|
|
||||||
result = context.to_dict()
|
|
||||||
|
|
||||||
self.assertEqual(result["operation_type"], "scan")
|
|
||||||
self.assertEqual(result["operation_id"], "scan-error")
|
|
||||||
self.assertEqual(result["error_type"], "RuntimeError")
|
|
||||||
self.assertEqual(result["error_message"], "Network error")
|
|
||||||
self.assertEqual(result["message"], "Scan error occurred")
|
|
||||||
self.assertFalse(result["recoverable"])
|
|
||||||
self.assertEqual(result["retry_count"], 0)
|
|
||||||
self.assertEqual(result["metadata"], {})
|
|
||||||
|
|
||||||
|
|
||||||
class TestCompletionContext(unittest.TestCase):
|
|
||||||
"""Test CompletionContext dataclass."""
|
|
||||||
|
|
||||||
def test_completion_context_creation(self):
|
|
||||||
"""Test creating a completion context."""
|
|
||||||
context = CompletionContext(
|
|
||||||
operation_type=OperationType.DOWNLOAD,
|
|
||||||
operation_id="download-complete",
|
|
||||||
success=True,
|
|
||||||
message="Download completed successfully",
|
|
||||||
result_data={"file": "episode.mp4"},
|
|
||||||
statistics={"size": 1024, "time": 60},
|
|
||||||
metadata={"quality": "HD"}
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEqual(context.operation_type, OperationType.DOWNLOAD)
|
|
||||||
self.assertEqual(context.operation_id, "download-complete")
|
|
||||||
self.assertTrue(context.success)
|
|
||||||
self.assertEqual(context.message, "Download completed successfully")
|
|
||||||
self.assertEqual(context.result_data, {"file": "episode.mp4"})
|
|
||||||
self.assertEqual(context.statistics, {"size": 1024, "time": 60})
|
|
||||||
self.assertEqual(context.metadata, {"quality": "HD"})
|
|
||||||
|
|
||||||
def test_completion_context_to_dict(self):
|
|
||||||
"""Test converting completion context to dictionary."""
|
|
||||||
context = CompletionContext(
|
|
||||||
operation_type=OperationType.SCAN,
|
|
||||||
operation_id="scan-complete",
|
|
||||||
success=False,
|
|
||||||
message="Scan failed"
|
|
||||||
)
|
|
||||||
|
|
||||||
result = context.to_dict()
|
|
||||||
|
|
||||||
self.assertEqual(result["operation_type"], "scan")
|
|
||||||
self.assertEqual(result["operation_id"], "scan-complete")
|
|
||||||
self.assertFalse(result["success"])
|
|
||||||
self.assertEqual(result["message"], "Scan failed")
|
|
||||||
self.assertEqual(result["statistics"], {})
|
|
||||||
self.assertEqual(result["metadata"], {})
|
|
||||||
|
|
||||||
|
|
||||||
class MockProgressCallback(ProgressCallback):
|
|
||||||
"""Mock implementation of ProgressCallback for testing."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.calls = []
|
|
||||||
|
|
||||||
def on_progress(self, context: ProgressContext) -> None:
|
|
||||||
self.calls.append(context)
|
|
||||||
|
|
||||||
|
|
||||||
class MockErrorCallback(ErrorCallback):
|
|
||||||
"""Mock implementation of ErrorCallback for testing."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.calls = []
|
|
||||||
|
|
||||||
def on_error(self, context: ErrorContext) -> None:
|
|
||||||
self.calls.append(context)
|
|
||||||
|
|
||||||
|
|
||||||
class MockCompletionCallback(CompletionCallback):
|
|
||||||
"""Mock implementation of CompletionCallback for testing."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.calls = []
|
|
||||||
|
|
||||||
def on_completion(self, context: CompletionContext) -> None:
|
|
||||||
self.calls.append(context)
|
|
||||||
|
|
||||||
|
|
||||||
class TestCallbackManager(unittest.TestCase):
|
|
||||||
"""Test CallbackManager functionality."""
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
"""Set up test fixtures."""
|
|
||||||
self.manager = CallbackManager()
|
|
||||||
|
|
||||||
def test_register_progress_callback(self):
|
|
||||||
"""Test registering a progress callback."""
|
|
||||||
callback = MockProgressCallback()
|
|
||||||
self.manager.register_progress_callback(callback)
|
|
||||||
|
|
||||||
# Callback should be registered
|
|
||||||
self.assertIn(callback, self.manager._progress_callbacks)
|
|
||||||
|
|
||||||
def test_register_duplicate_progress_callback(self):
|
|
||||||
"""Test that duplicate callbacks are not added."""
|
|
||||||
callback = MockProgressCallback()
|
|
||||||
self.manager.register_progress_callback(callback)
|
|
||||||
self.manager.register_progress_callback(callback)
|
|
||||||
|
|
||||||
# Should only be registered once
|
|
||||||
self.assertEqual(
|
|
||||||
self.manager._progress_callbacks.count(callback),
|
|
||||||
1
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_register_error_callback(self):
|
|
||||||
"""Test registering an error callback."""
|
|
||||||
callback = MockErrorCallback()
|
|
||||||
self.manager.register_error_callback(callback)
|
|
||||||
|
|
||||||
self.assertIn(callback, self.manager._error_callbacks)
|
|
||||||
|
|
||||||
def test_register_completion_callback(self):
|
|
||||||
"""Test registering a completion callback."""
|
|
||||||
callback = MockCompletionCallback()
|
|
||||||
self.manager.register_completion_callback(callback)
|
|
||||||
|
|
||||||
self.assertIn(callback, self.manager._completion_callbacks)
|
|
||||||
|
|
||||||
def test_unregister_progress_callback(self):
|
|
||||||
"""Test unregistering a progress callback."""
|
|
||||||
callback = MockProgressCallback()
|
|
||||||
self.manager.register_progress_callback(callback)
|
|
||||||
self.manager.unregister_progress_callback(callback)
|
|
||||||
|
|
||||||
self.assertNotIn(callback, self.manager._progress_callbacks)
|
|
||||||
|
|
||||||
def test_unregister_error_callback(self):
|
|
||||||
"""Test unregistering an error callback."""
|
|
||||||
callback = MockErrorCallback()
|
|
||||||
self.manager.register_error_callback(callback)
|
|
||||||
self.manager.unregister_error_callback(callback)
|
|
||||||
|
|
||||||
self.assertNotIn(callback, self.manager._error_callbacks)
|
|
||||||
|
|
||||||
def test_unregister_completion_callback(self):
|
|
||||||
"""Test unregistering a completion callback."""
|
|
||||||
callback = MockCompletionCallback()
|
|
||||||
self.manager.register_completion_callback(callback)
|
|
||||||
self.manager.unregister_completion_callback(callback)
|
|
||||||
|
|
||||||
self.assertNotIn(callback, self.manager._completion_callbacks)
|
|
||||||
|
|
||||||
def test_notify_progress(self):
|
|
||||||
"""Test notifying progress callbacks."""
|
|
||||||
callback1 = MockProgressCallback()
|
|
||||||
callback2 = MockProgressCallback()
|
|
||||||
self.manager.register_progress_callback(callback1)
|
|
||||||
self.manager.register_progress_callback(callback2)
|
|
||||||
|
|
||||||
context = ProgressContext(
|
|
||||||
operation_type=OperationType.DOWNLOAD,
|
|
||||||
operation_id="test",
|
|
||||||
phase=ProgressPhase.IN_PROGRESS,
|
|
||||||
current=50,
|
|
||||||
total=100,
|
|
||||||
percentage=50.0,
|
|
||||||
message="Test progress"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.manager.notify_progress(context)
|
|
||||||
|
|
||||||
# Both callbacks should be called
|
|
||||||
self.assertEqual(len(callback1.calls), 1)
|
|
||||||
self.assertEqual(len(callback2.calls), 1)
|
|
||||||
self.assertEqual(callback1.calls[0], context)
|
|
||||||
self.assertEqual(callback2.calls[0], context)
|
|
||||||
|
|
||||||
def test_notify_error(self):
|
|
||||||
"""Test notifying error callbacks."""
|
|
||||||
callback = MockErrorCallback()
|
|
||||||
self.manager.register_error_callback(callback)
|
|
||||||
|
|
||||||
error = ValueError("Test error")
|
|
||||||
context = ErrorContext(
|
|
||||||
operation_type=OperationType.DOWNLOAD,
|
|
||||||
operation_id="test",
|
|
||||||
error=error,
|
|
||||||
message="Error occurred"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.manager.notify_error(context)
|
|
||||||
|
|
||||||
self.assertEqual(len(callback.calls), 1)
|
|
||||||
self.assertEqual(callback.calls[0], context)
|
|
||||||
|
|
||||||
def test_notify_completion(self):
|
|
||||||
"""Test notifying completion callbacks."""
|
|
||||||
callback = MockCompletionCallback()
|
|
||||||
self.manager.register_completion_callback(callback)
|
|
||||||
|
|
||||||
context = CompletionContext(
|
|
||||||
operation_type=OperationType.SCAN,
|
|
||||||
operation_id="test",
|
|
||||||
success=True,
|
|
||||||
message="Operation completed"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.manager.notify_completion(context)
|
|
||||||
|
|
||||||
self.assertEqual(len(callback.calls), 1)
|
|
||||||
self.assertEqual(callback.calls[0], context)
|
|
||||||
|
|
||||||
def test_callback_exception_handling(self):
|
|
||||||
"""Test that exceptions in callbacks don't break notification."""
|
|
||||||
# Create a callback that raises an exception
|
|
||||||
class FailingCallback(ProgressCallback):
|
|
||||||
def on_progress(self, context: ProgressContext) -> None:
|
|
||||||
raise RuntimeError("Callback failed")
|
|
||||||
|
|
||||||
failing_callback = FailingCallback()
|
|
||||||
working_callback = MockProgressCallback()
|
|
||||||
|
|
||||||
self.manager.register_progress_callback(failing_callback)
|
|
||||||
self.manager.register_progress_callback(working_callback)
|
|
||||||
|
|
||||||
context = ProgressContext(
|
|
||||||
operation_type=OperationType.DOWNLOAD,
|
|
||||||
operation_id="test",
|
|
||||||
phase=ProgressPhase.IN_PROGRESS,
|
|
||||||
current=50,
|
|
||||||
total=100,
|
|
||||||
percentage=50.0,
|
|
||||||
message="Test"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should not raise exception
|
|
||||||
self.manager.notify_progress(context)
|
|
||||||
|
|
||||||
# Working callback should still be called
|
|
||||||
self.assertEqual(len(working_callback.calls), 1)
|
|
||||||
|
|
||||||
def test_clear_all_callbacks(self):
|
|
||||||
"""Test clearing all callbacks."""
|
|
||||||
self.manager.register_progress_callback(MockProgressCallback())
|
|
||||||
self.manager.register_error_callback(MockErrorCallback())
|
|
||||||
self.manager.register_completion_callback(MockCompletionCallback())
|
|
||||||
|
|
||||||
self.manager.clear_all_callbacks()
|
|
||||||
|
|
||||||
self.assertEqual(len(self.manager._progress_callbacks), 0)
|
|
||||||
self.assertEqual(len(self.manager._error_callbacks), 0)
|
|
||||||
self.assertEqual(len(self.manager._completion_callbacks), 0)
|
|
||||||
|
|
||||||
def test_multiple_notifications(self):
|
|
||||||
"""Test multiple progress notifications."""
|
|
||||||
callback = MockProgressCallback()
|
|
||||||
self.manager.register_progress_callback(callback)
|
|
||||||
|
|
||||||
for i in range(5):
|
|
||||||
context = ProgressContext(
|
|
||||||
operation_type=OperationType.DOWNLOAD,
|
|
||||||
operation_id="test",
|
|
||||||
phase=ProgressPhase.IN_PROGRESS,
|
|
||||||
current=i * 20,
|
|
||||||
total=100,
|
|
||||||
percentage=i * 20.0,
|
|
||||||
message=f"Progress {i}"
|
|
||||||
)
|
|
||||||
self.manager.notify_progress(context)
|
|
||||||
|
|
||||||
self.assertEqual(len(callback.calls), 5)
|
|
||||||
|
|
||||||
|
|
||||||
class TestOperationType(unittest.TestCase):
|
|
||||||
"""Test OperationType enum."""
|
|
||||||
|
|
||||||
def test_operation_types(self):
|
|
||||||
"""Test all operation types are defined."""
|
|
||||||
self.assertEqual(OperationType.SCAN, "scan")
|
|
||||||
self.assertEqual(OperationType.DOWNLOAD, "download")
|
|
||||||
self.assertEqual(OperationType.SEARCH, "search")
|
|
||||||
self.assertEqual(OperationType.INITIALIZATION, "initialization")
|
|
||||||
|
|
||||||
|
|
||||||
class TestProgressPhase(unittest.TestCase):
|
|
||||||
"""Test ProgressPhase enum."""
|
|
||||||
|
|
||||||
def test_progress_phases(self):
|
|
||||||
"""Test all progress phases are defined."""
|
|
||||||
self.assertEqual(ProgressPhase.STARTING, "starting")
|
|
||||||
self.assertEqual(ProgressPhase.IN_PROGRESS, "in_progress")
|
|
||||||
self.assertEqual(ProgressPhase.COMPLETING, "completing")
|
|
||||||
self.assertEqual(ProgressPhase.COMPLETED, "completed")
|
|
||||||
self.assertEqual(ProgressPhase.FAILED, "failed")
|
|
||||||
self.assertEqual(ProgressPhase.CANCELLED, "cancelled")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
@ -1,369 +0,0 @@
|
|||||||
"""Unit tests for ConfigService."""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import tempfile
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from src.server.models.config import (
|
|
||||||
AppConfig,
|
|
||||||
BackupConfig,
|
|
||||||
ConfigUpdate,
|
|
||||||
LoggingConfig,
|
|
||||||
SchedulerConfig,
|
|
||||||
)
|
|
||||||
from src.server.services.config_service import (
|
|
||||||
ConfigBackupError,
|
|
||||||
ConfigService,
|
|
||||||
ConfigServiceError,
|
|
||||||
ConfigValidationError,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def temp_dir():
|
|
||||||
"""Create temporary directory for test config files."""
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
|
||||||
yield Path(tmpdir)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def config_service(temp_dir):
|
|
||||||
"""Create ConfigService instance with temporary paths."""
|
|
||||||
config_path = temp_dir / "config.json"
|
|
||||||
backup_dir = temp_dir / "backups"
|
|
||||||
return ConfigService(
|
|
||||||
config_path=config_path, backup_dir=backup_dir, max_backups=3
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sample_config():
|
|
||||||
"""Create sample configuration."""
|
|
||||||
return AppConfig(
|
|
||||||
name="TestApp",
|
|
||||||
data_dir="test_data",
|
|
||||||
scheduler=SchedulerConfig(enabled=True, interval_minutes=30),
|
|
||||||
logging=LoggingConfig(level="DEBUG", file="test.log"),
|
|
||||||
backup=BackupConfig(enabled=False),
|
|
||||||
other={"custom_key": "custom_value"},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestConfigServiceInitialization:
|
|
||||||
"""Test ConfigService initialization and directory creation."""
|
|
||||||
|
|
||||||
def test_initialization_creates_directories(self, temp_dir):
|
|
||||||
"""Test that initialization creates necessary directories."""
|
|
||||||
config_path = temp_dir / "subdir" / "config.json"
|
|
||||||
backup_dir = temp_dir / "subdir" / "backups"
|
|
||||||
|
|
||||||
service = ConfigService(config_path=config_path, backup_dir=backup_dir)
|
|
||||||
|
|
||||||
assert config_path.parent.exists()
|
|
||||||
assert backup_dir.exists()
|
|
||||||
assert service.config_path == config_path
|
|
||||||
assert service.backup_dir == backup_dir
|
|
||||||
|
|
||||||
def test_initialization_with_existing_directories(self, config_service):
|
|
||||||
"""Test initialization with existing directories works."""
|
|
||||||
assert config_service.config_path.parent.exists()
|
|
||||||
assert config_service.backup_dir.exists()
|
|
||||||
|
|
||||||
|
|
||||||
class TestConfigServiceLoadSave:
|
|
||||||
"""Test configuration loading and saving."""
|
|
||||||
|
|
||||||
def test_load_creates_default_config_if_not_exists(self, config_service):
|
|
||||||
"""Test that load creates default config if file doesn't exist."""
|
|
||||||
config = config_service.load_config()
|
|
||||||
|
|
||||||
assert isinstance(config, AppConfig)
|
|
||||||
assert config.name == "Aniworld"
|
|
||||||
assert config_service.config_path.exists()
|
|
||||||
|
|
||||||
def test_save_and_load_config(self, config_service, sample_config):
|
|
||||||
"""Test saving and loading configuration."""
|
|
||||||
config_service.save_config(sample_config, create_backup=False)
|
|
||||||
|
|
||||||
loaded_config = config_service.load_config()
|
|
||||||
|
|
||||||
assert loaded_config.name == sample_config.name
|
|
||||||
assert loaded_config.data_dir == sample_config.data_dir
|
|
||||||
assert loaded_config.scheduler.enabled == sample_config.scheduler.enabled
|
|
||||||
assert loaded_config.logging.level == sample_config.logging.level
|
|
||||||
assert loaded_config.other == sample_config.other
|
|
||||||
|
|
||||||
def test_save_includes_version(self, config_service, sample_config):
|
|
||||||
"""Test that saved config includes version field."""
|
|
||||||
config_service.save_config(sample_config, create_backup=False)
|
|
||||||
|
|
||||||
with open(config_service.config_path, "r", encoding="utf-8") as f:
|
|
||||||
data = json.load(f)
|
|
||||||
|
|
||||||
assert "version" in data
|
|
||||||
assert data["version"] == ConfigService.CONFIG_VERSION
|
|
||||||
|
|
||||||
def test_save_creates_backup_by_default(self, config_service, sample_config):
|
|
||||||
"""Test that save creates backup by default if file exists."""
|
|
||||||
# Save initial config
|
|
||||||
config_service.save_config(sample_config, create_backup=False)
|
|
||||||
|
|
||||||
# Modify and save again (should create backup)
|
|
||||||
sample_config.name = "Modified"
|
|
||||||
config_service.save_config(sample_config, create_backup=True)
|
|
||||||
|
|
||||||
backups = list(config_service.backup_dir.glob("*.json"))
|
|
||||||
assert len(backups) == 1
|
|
||||||
|
|
||||||
def test_save_atomic_operation(self, config_service, sample_config):
|
|
||||||
"""Test that save is atomic (uses temp file)."""
|
|
||||||
# Mock exception during JSON dump by using invalid data
|
|
||||||
# This should not corrupt existing config
|
|
||||||
config_service.save_config(sample_config, create_backup=False)
|
|
||||||
|
|
||||||
# Verify temp file is cleaned up after successful save
|
|
||||||
temp_files = list(config_service.config_path.parent.glob("*.tmp"))
|
|
||||||
assert len(temp_files) == 0
|
|
||||||
|
|
||||||
def test_load_invalid_json_raises_error(self, config_service):
|
|
||||||
"""Test that loading invalid JSON raises ConfigValidationError."""
|
|
||||||
# Write invalid JSON
|
|
||||||
config_service.config_path.write_text("invalid json {")
|
|
||||||
|
|
||||||
with pytest.raises(ConfigValidationError, match="Invalid JSON"):
|
|
||||||
config_service.load_config()
|
|
||||||
|
|
||||||
|
|
||||||
class TestConfigServiceValidation:
|
|
||||||
"""Test configuration validation."""
|
|
||||||
|
|
||||||
def test_validate_valid_config(self, config_service, sample_config):
|
|
||||||
"""Test validation of valid configuration."""
|
|
||||||
result = config_service.validate_config(sample_config)
|
|
||||||
|
|
||||||
assert result.valid is True
|
|
||||||
assert result.errors == []
|
|
||||||
|
|
||||||
def test_validate_invalid_config(self, config_service):
|
|
||||||
"""Test validation of invalid configuration."""
|
|
||||||
# Create config with backups enabled but no path
|
|
||||||
invalid_config = AppConfig(
|
|
||||||
backup=BackupConfig(enabled=True, path=None)
|
|
||||||
)
|
|
||||||
|
|
||||||
result = config_service.validate_config(invalid_config)
|
|
||||||
|
|
||||||
assert result.valid is False
|
|
||||||
assert len(result.errors or []) > 0
|
|
||||||
|
|
||||||
def test_save_invalid_config_raises_error(self, config_service):
|
|
||||||
"""Test that saving invalid config raises error."""
|
|
||||||
invalid_config = AppConfig(
|
|
||||||
backup=BackupConfig(enabled=True, path=None)
|
|
||||||
)
|
|
||||||
|
|
||||||
with pytest.raises(ConfigValidationError, match="Cannot save invalid"):
|
|
||||||
config_service.save_config(invalid_config)
|
|
||||||
|
|
||||||
|
|
||||||
class TestConfigServiceUpdate:
|
|
||||||
"""Test configuration updates."""
|
|
||||||
|
|
||||||
def test_update_config(self, config_service, sample_config):
|
|
||||||
"""Test updating configuration."""
|
|
||||||
config_service.save_config(sample_config, create_backup=False)
|
|
||||||
|
|
||||||
update = ConfigUpdate(
|
|
||||||
scheduler=SchedulerConfig(enabled=False, interval_minutes=60),
|
|
||||||
logging=LoggingConfig(level="INFO"),
|
|
||||||
)
|
|
||||||
|
|
||||||
updated_config = config_service.update_config(update)
|
|
||||||
|
|
||||||
assert updated_config.scheduler.enabled is False
|
|
||||||
assert updated_config.scheduler.interval_minutes == 60
|
|
||||||
assert updated_config.logging.level == "INFO"
|
|
||||||
# Other fields should remain unchanged
|
|
||||||
assert updated_config.name == sample_config.name
|
|
||||||
assert updated_config.data_dir == sample_config.data_dir
|
|
||||||
|
|
||||||
def test_update_persists_changes(self, config_service, sample_config):
|
|
||||||
"""Test that updates are persisted to disk."""
|
|
||||||
config_service.save_config(sample_config, create_backup=False)
|
|
||||||
|
|
||||||
update = ConfigUpdate(logging=LoggingConfig(level="ERROR"))
|
|
||||||
config_service.update_config(update)
|
|
||||||
|
|
||||||
# Load fresh config from disk
|
|
||||||
loaded = config_service.load_config()
|
|
||||||
assert loaded.logging.level == "ERROR"
|
|
||||||
|
|
||||||
|
|
||||||
class TestConfigServiceBackups:
|
|
||||||
"""Test configuration backup functionality."""
|
|
||||||
|
|
||||||
def test_create_backup(self, config_service, sample_config):
|
|
||||||
"""Test creating configuration backup."""
|
|
||||||
config_service.save_config(sample_config, create_backup=False)
|
|
||||||
|
|
||||||
backup_path = config_service.create_backup()
|
|
||||||
|
|
||||||
assert backup_path.exists()
|
|
||||||
assert backup_path.suffix == ".json"
|
|
||||||
assert "config_backup_" in backup_path.name
|
|
||||||
|
|
||||||
def test_create_backup_with_custom_name(
|
|
||||||
self, config_service, sample_config
|
|
||||||
):
|
|
||||||
"""Test creating backup with custom name."""
|
|
||||||
config_service.save_config(sample_config, create_backup=False)
|
|
||||||
|
|
||||||
backup_path = config_service.create_backup(name="my_backup")
|
|
||||||
|
|
||||||
assert backup_path.name == "my_backup.json"
|
|
||||||
|
|
||||||
def test_create_backup_without_config_raises_error(self, config_service):
|
|
||||||
"""Test that creating backup without config file raises error."""
|
|
||||||
with pytest.raises(ConfigBackupError, match="Cannot backup non-existent"):
|
|
||||||
config_service.create_backup()
|
|
||||||
|
|
||||||
def test_list_backups(self, config_service, sample_config):
|
|
||||||
"""Test listing configuration backups."""
|
|
||||||
config_service.save_config(sample_config, create_backup=False)
|
|
||||||
|
|
||||||
# Create multiple backups
|
|
||||||
config_service.create_backup(name="backup1")
|
|
||||||
config_service.create_backup(name="backup2")
|
|
||||||
config_service.create_backup(name="backup3")
|
|
||||||
|
|
||||||
backups = config_service.list_backups()
|
|
||||||
|
|
||||||
assert len(backups) == 3
|
|
||||||
assert all("name" in b for b in backups)
|
|
||||||
assert all("size_bytes" in b for b in backups)
|
|
||||||
assert all("created_at" in b for b in backups)
|
|
||||||
|
|
||||||
# Should be sorted by creation time (newest first)
|
|
||||||
backup_names = [b["name"] for b in backups]
|
|
||||||
assert "backup3.json" in backup_names
|
|
||||||
|
|
||||||
def test_list_backups_empty(self, config_service):
|
|
||||||
"""Test listing backups when none exist."""
|
|
||||||
backups = config_service.list_backups()
|
|
||||||
assert backups == []
|
|
||||||
|
|
||||||
def test_restore_backup(self, config_service, sample_config):
|
|
||||||
"""Test restoring configuration from backup."""
|
|
||||||
# Save initial config and create backup
|
|
||||||
config_service.save_config(sample_config, create_backup=False)
|
|
||||||
config_service.create_backup(name="original")
|
|
||||||
|
|
||||||
# Modify and save config
|
|
||||||
sample_config.name = "Modified"
|
|
||||||
config_service.save_config(sample_config, create_backup=False)
|
|
||||||
|
|
||||||
# Restore from backup
|
|
||||||
restored = config_service.restore_backup("original.json")
|
|
||||||
|
|
||||||
assert restored.name == "TestApp" # Original name
|
|
||||||
|
|
||||||
def test_restore_backup_creates_pre_restore_backup(
|
|
||||||
self, config_service, sample_config
|
|
||||||
):
|
|
||||||
"""Test that restore creates pre-restore backup."""
|
|
||||||
config_service.save_config(sample_config, create_backup=False)
|
|
||||||
config_service.create_backup(name="backup1")
|
|
||||||
|
|
||||||
sample_config.name = "Modified"
|
|
||||||
config_service.save_config(sample_config, create_backup=False)
|
|
||||||
|
|
||||||
config_service.restore_backup("backup1.json")
|
|
||||||
|
|
||||||
backups = config_service.list_backups()
|
|
||||||
backup_names = [b["name"] for b in backups]
|
|
||||||
|
|
||||||
assert any("pre_restore" in name for name in backup_names)
|
|
||||||
|
|
||||||
def test_restore_nonexistent_backup_raises_error(self, config_service):
|
|
||||||
"""Test that restoring non-existent backup raises error."""
|
|
||||||
with pytest.raises(ConfigBackupError, match="Backup not found"):
|
|
||||||
config_service.restore_backup("nonexistent.json")
|
|
||||||
|
|
||||||
def test_delete_backup(self, config_service, sample_config):
|
|
||||||
"""Test deleting configuration backup."""
|
|
||||||
config_service.save_config(sample_config, create_backup=False)
|
|
||||||
config_service.create_backup(name="to_delete")
|
|
||||||
|
|
||||||
config_service.delete_backup("to_delete.json")
|
|
||||||
|
|
||||||
backups = config_service.list_backups()
|
|
||||||
assert len(backups) == 0
|
|
||||||
|
|
||||||
def test_delete_nonexistent_backup_raises_error(self, config_service):
|
|
||||||
"""Test that deleting non-existent backup raises error."""
|
|
||||||
with pytest.raises(ConfigBackupError, match="Backup not found"):
|
|
||||||
config_service.delete_backup("nonexistent.json")
|
|
||||||
|
|
||||||
def test_cleanup_old_backups(self, config_service, sample_config):
|
|
||||||
"""Test that old backups are cleaned up when limit exceeded."""
|
|
||||||
config_service.save_config(sample_config, create_backup=False)
|
|
||||||
|
|
||||||
# Create more backups than max_backups (3)
|
|
||||||
for i in range(5):
|
|
||||||
config_service.create_backup(name=f"backup{i}")
|
|
||||||
|
|
||||||
backups = config_service.list_backups()
|
|
||||||
assert len(backups) == 3 # Should only keep max_backups
|
|
||||||
|
|
||||||
|
|
||||||
class TestConfigServiceMigration:
|
|
||||||
"""Test configuration migration."""
|
|
||||||
|
|
||||||
def test_migration_preserves_data(self, config_service, sample_config):
|
|
||||||
"""Test that migration preserves configuration data."""
|
|
||||||
# Manually save config with old version
|
|
||||||
data = sample_config.model_dump()
|
|
||||||
data["version"] = "0.9.0" # Old version
|
|
||||||
|
|
||||||
with open(config_service.config_path, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(data, f)
|
|
||||||
|
|
||||||
# Load should migrate automatically
|
|
||||||
loaded = config_service.load_config()
|
|
||||||
|
|
||||||
assert loaded.name == sample_config.name
|
|
||||||
assert loaded.data_dir == sample_config.data_dir
|
|
||||||
|
|
||||||
|
|
||||||
class TestConfigServiceSingleton:
|
|
||||||
"""Test singleton instance management."""
|
|
||||||
|
|
||||||
def test_get_config_service_returns_singleton(self):
|
|
||||||
"""Test that get_config_service returns same instance."""
|
|
||||||
from src.server.services.config_service import get_config_service
|
|
||||||
|
|
||||||
service1 = get_config_service()
|
|
||||||
service2 = get_config_service()
|
|
||||||
|
|
||||||
assert service1 is service2
|
|
||||||
|
|
||||||
|
|
||||||
class TestConfigServiceErrorHandling:
|
|
||||||
"""Test error handling in ConfigService."""
|
|
||||||
|
|
||||||
def test_save_config_creates_temp_file(
|
|
||||||
self, config_service, sample_config
|
|
||||||
):
|
|
||||||
"""Test that save operation uses temporary file."""
|
|
||||||
# Save config and verify temp file is cleaned up
|
|
||||||
config_service.save_config(sample_config, create_backup=False)
|
|
||||||
|
|
||||||
# Verify no temp files remain
|
|
||||||
temp_files = list(config_service.config_path.parent.glob("*.tmp"))
|
|
||||||
assert len(temp_files) == 0
|
|
||||||
|
|
||||||
# Verify config was saved successfully
|
|
||||||
loaded = config_service.load_config()
|
|
||||||
assert loaded.name == sample_config.name
|
|
||||||
@ -1,495 +0,0 @@
|
|||||||
"""Unit tests for database initialization module.
|
|
||||||
|
|
||||||
Tests cover:
|
|
||||||
- Database initialization
|
|
||||||
- Schema creation and validation
|
|
||||||
- Schema version management
|
|
||||||
- Initial data seeding
|
|
||||||
- Health checks
|
|
||||||
- Backup functionality
|
|
||||||
"""
|
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from sqlalchemy import text
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
|
|
||||||
from sqlalchemy.pool import StaticPool
|
|
||||||
|
|
||||||
from src.server.database.base import Base
|
|
||||||
from src.server.database.init import (
|
|
||||||
CURRENT_SCHEMA_VERSION,
|
|
||||||
EXPECTED_TABLES,
|
|
||||||
check_database_health,
|
|
||||||
create_database_backup,
|
|
||||||
create_database_schema,
|
|
||||||
get_database_info,
|
|
||||||
get_migration_guide,
|
|
||||||
get_schema_version,
|
|
||||||
initialize_database,
|
|
||||||
seed_initial_data,
|
|
||||||
validate_database_schema,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def test_engine():
|
|
||||||
"""Create in-memory SQLite engine for testing."""
|
|
||||||
engine = create_async_engine(
|
|
||||||
"sqlite+aiosqlite:///:memory:",
|
|
||||||
echo=False,
|
|
||||||
poolclass=StaticPool,
|
|
||||||
)
|
|
||||||
yield engine
|
|
||||||
await engine.dispose()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def test_engine_with_tables(test_engine):
|
|
||||||
"""Create engine with tables already created."""
|
|
||||||
async with test_engine.begin() as conn:
|
|
||||||
await conn.run_sync(Base.metadata.create_all)
|
|
||||||
yield test_engine
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Database Initialization Tests
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_initialize_database_success(test_engine):
|
|
||||||
"""Test successful database initialization."""
|
|
||||||
result = await initialize_database(
|
|
||||||
engine=test_engine,
|
|
||||||
create_schema=True,
|
|
||||||
validate_schema=True,
|
|
||||||
seed_data=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result["success"] is True
|
|
||||||
assert result["schema_version"] == CURRENT_SCHEMA_VERSION
|
|
||||||
assert len(result["tables_created"]) == len(EXPECTED_TABLES)
|
|
||||||
assert result["validation_result"]["valid"] is True
|
|
||||||
assert result["health_check"]["healthy"] is True
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_initialize_database_without_schema_creation(test_engine_with_tables):
|
|
||||||
"""Test initialization without creating schema."""
|
|
||||||
result = await initialize_database(
|
|
||||||
engine=test_engine_with_tables,
|
|
||||||
create_schema=False,
|
|
||||||
validate_schema=True,
|
|
||||||
seed_data=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result["success"] is True
|
|
||||||
assert result["schema_version"] == CURRENT_SCHEMA_VERSION
|
|
||||||
assert result["tables_created"] == []
|
|
||||||
assert result["validation_result"]["valid"] is True
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_initialize_database_with_seeding(test_engine):
|
|
||||||
"""Test initialization with data seeding."""
|
|
||||||
result = await initialize_database(
|
|
||||||
engine=test_engine,
|
|
||||||
create_schema=True,
|
|
||||||
validate_schema=True,
|
|
||||||
seed_data=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result["success"] is True
|
|
||||||
# Seeding should complete without errors
|
|
||||||
# (even if no actual data is seeded for empty database)
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Schema Creation Tests
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_database_schema(test_engine):
|
|
||||||
"""Test creating database schema."""
|
|
||||||
tables = await create_database_schema(test_engine)
|
|
||||||
|
|
||||||
assert len(tables) == len(EXPECTED_TABLES)
|
|
||||||
assert set(tables) == EXPECTED_TABLES
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_database_schema_idempotent(test_engine_with_tables):
|
|
||||||
"""Test that creating schema is idempotent."""
|
|
||||||
# Tables already exist
|
|
||||||
tables = await create_database_schema(test_engine_with_tables)
|
|
||||||
|
|
||||||
# Should return existing tables, not create duplicates
|
|
||||||
assert len(tables) == len(EXPECTED_TABLES)
|
|
||||||
assert set(tables) == EXPECTED_TABLES
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_schema_uses_default_engine_when_none():
|
|
||||||
"""Test schema creation with None engine uses default."""
|
|
||||||
with patch("src.server.database.init.get_engine") as mock_get_engine:
|
|
||||||
# Create a real test engine
|
|
||||||
test_engine = create_async_engine(
|
|
||||||
"sqlite+aiosqlite:///:memory:",
|
|
||||||
echo=False,
|
|
||||||
poolclass=StaticPool,
|
|
||||||
)
|
|
||||||
mock_get_engine.return_value = test_engine
|
|
||||||
|
|
||||||
# This should call get_engine() and work with test engine
|
|
||||||
tables = await create_database_schema(engine=None)
|
|
||||||
assert len(tables) == len(EXPECTED_TABLES)
|
|
||||||
|
|
||||||
await test_engine.dispose()
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Schema Validation Tests
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_validate_database_schema_valid(test_engine_with_tables):
|
|
||||||
"""Test validating a valid schema."""
|
|
||||||
result = await validate_database_schema(test_engine_with_tables)
|
|
||||||
|
|
||||||
assert result["valid"] is True
|
|
||||||
assert len(result["missing_tables"]) == 0
|
|
||||||
assert len(result["issues"]) == 0
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_validate_database_schema_empty(test_engine):
|
|
||||||
"""Test validating an empty database."""
|
|
||||||
result = await validate_database_schema(test_engine)
|
|
||||||
|
|
||||||
assert result["valid"] is False
|
|
||||||
assert len(result["missing_tables"]) == len(EXPECTED_TABLES)
|
|
||||||
assert len(result["issues"]) > 0
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_validate_database_schema_partial(test_engine):
|
|
||||||
"""Test validating partially created schema."""
|
|
||||||
# Create only one table
|
|
||||||
async with test_engine.begin() as conn:
|
|
||||||
await conn.execute(
|
|
||||||
text("""
|
|
||||||
CREATE TABLE anime_series (
|
|
||||||
id INTEGER PRIMARY KEY,
|
|
||||||
key VARCHAR(255) UNIQUE NOT NULL,
|
|
||||||
name VARCHAR(500) NOT NULL
|
|
||||||
)
|
|
||||||
""")
|
|
||||||
)
|
|
||||||
|
|
||||||
result = await validate_database_schema(test_engine)
|
|
||||||
|
|
||||||
assert result["valid"] is False
|
|
||||||
assert len(result["missing_tables"]) == len(EXPECTED_TABLES) - 1
|
|
||||||
assert "anime_series" not in result["missing_tables"]
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Schema Version Tests
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_schema_version_empty(test_engine):
|
|
||||||
"""Test getting schema version from empty database."""
|
|
||||||
version = await get_schema_version(test_engine)
|
|
||||||
assert version == "empty"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_schema_version_current(test_engine_with_tables):
|
|
||||||
"""Test getting schema version from current schema."""
|
|
||||||
version = await get_schema_version(test_engine_with_tables)
|
|
||||||
assert version == CURRENT_SCHEMA_VERSION
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_schema_version_unknown(test_engine):
|
|
||||||
"""Test getting schema version from unknown schema."""
|
|
||||||
# Create some random tables
|
|
||||||
async with test_engine.begin() as conn:
|
|
||||||
await conn.execute(
|
|
||||||
text("CREATE TABLE random_table (id INTEGER PRIMARY KEY)")
|
|
||||||
)
|
|
||||||
|
|
||||||
version = await get_schema_version(test_engine)
|
|
||||||
assert version == "unknown"
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Data Seeding Tests
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_seed_initial_data_empty_database(test_engine_with_tables):
|
|
||||||
"""Test seeding data into empty database."""
|
|
||||||
# Should complete without errors
|
|
||||||
await seed_initial_data(test_engine_with_tables)
|
|
||||||
|
|
||||||
# Verify database is still empty (no sample data)
|
|
||||||
async with test_engine_with_tables.connect() as conn:
|
|
||||||
result = await conn.execute(text("SELECT COUNT(*) FROM anime_series"))
|
|
||||||
count = result.scalar()
|
|
||||||
assert count == 0
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_seed_initial_data_existing_data(test_engine_with_tables):
|
|
||||||
"""Test seeding skips if data already exists."""
|
|
||||||
# Add some data
|
|
||||||
async with test_engine_with_tables.begin() as conn:
|
|
||||||
await conn.execute(
|
|
||||||
text("""
|
|
||||||
INSERT INTO anime_series (key, name, site, folder)
|
|
||||||
VALUES ('test-key', 'Test Anime', 'https://test.com', '/test')
|
|
||||||
""")
|
|
||||||
)
|
|
||||||
|
|
||||||
# Seeding should skip
|
|
||||||
await seed_initial_data(test_engine_with_tables)
|
|
||||||
|
|
||||||
# Verify only one record exists
|
|
||||||
async with test_engine_with_tables.connect() as conn:
|
|
||||||
result = await conn.execute(text("SELECT COUNT(*) FROM anime_series"))
|
|
||||||
count = result.scalar()
|
|
||||||
assert count == 1
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Health Check Tests
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_check_database_health_healthy(test_engine_with_tables):
|
|
||||||
"""Test health check on healthy database."""
|
|
||||||
result = await check_database_health(test_engine_with_tables)
|
|
||||||
|
|
||||||
assert result["healthy"] is True
|
|
||||||
assert result["accessible"] is True
|
|
||||||
assert result["tables"] == len(EXPECTED_TABLES)
|
|
||||||
assert result["connectivity_ms"] > 0
|
|
||||||
assert len(result["issues"]) == 0
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_check_database_health_empty(test_engine):
|
|
||||||
"""Test health check on empty database."""
|
|
||||||
result = await check_database_health(test_engine)
|
|
||||||
|
|
||||||
assert result["healthy"] is False
|
|
||||||
assert result["accessible"] is True
|
|
||||||
assert result["tables"] == 0
|
|
||||||
assert len(result["issues"]) > 0
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_check_database_health_connection_error():
|
|
||||||
"""Test health check with connection error."""
|
|
||||||
mock_engine = AsyncMock(spec=AsyncEngine)
|
|
||||||
mock_engine.connect.side_effect = Exception("Connection failed")
|
|
||||||
|
|
||||||
result = await check_database_health(mock_engine)
|
|
||||||
|
|
||||||
assert result["healthy"] is False
|
|
||||||
assert result["accessible"] is False
|
|
||||||
assert len(result["issues"]) > 0
|
|
||||||
assert "Connection failed" in result["issues"][0]
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Backup Tests
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_database_backup_not_sqlite():
|
|
||||||
"""Test backup fails for non-SQLite databases."""
|
|
||||||
with patch("src.server.database.init.settings") as mock_settings:
|
|
||||||
mock_settings.database_url = "postgresql://localhost/test"
|
|
||||||
|
|
||||||
with pytest.raises(NotImplementedError):
|
|
||||||
await create_database_backup()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_database_backup_file_not_found():
|
|
||||||
"""Test backup fails if database file doesn't exist."""
|
|
||||||
with patch("src.server.database.init.settings") as mock_settings:
|
|
||||||
mock_settings.database_url = "sqlite:///nonexistent.db"
|
|
||||||
|
|
||||||
with pytest.raises(RuntimeError, match="Database file not found"):
|
|
||||||
await create_database_backup()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_database_backup_success(tmp_path):
|
|
||||||
"""Test successful database backup."""
|
|
||||||
# Create a temporary database file
|
|
||||||
db_file = tmp_path / "test.db"
|
|
||||||
db_file.write_text("test data")
|
|
||||||
|
|
||||||
backup_file = tmp_path / "backup.db"
|
|
||||||
|
|
||||||
with patch("src.server.database.init.settings") as mock_settings:
|
|
||||||
mock_settings.database_url = f"sqlite:///{db_file}"
|
|
||||||
|
|
||||||
result = await create_database_backup(backup_path=backup_file)
|
|
||||||
|
|
||||||
assert result == backup_file
|
|
||||||
assert backup_file.exists()
|
|
||||||
assert backup_file.read_text() == "test data"
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Utility Function Tests
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_database_info():
|
|
||||||
"""Test getting database configuration info."""
|
|
||||||
info = get_database_info()
|
|
||||||
|
|
||||||
assert "database_url" in info
|
|
||||||
assert "database_type" in info
|
|
||||||
assert "schema_version" in info
|
|
||||||
assert "expected_tables" in info
|
|
||||||
assert info["schema_version"] == CURRENT_SCHEMA_VERSION
|
|
||||||
assert set(info["expected_tables"]) == EXPECTED_TABLES
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_migration_guide():
|
|
||||||
"""Test getting migration guide."""
|
|
||||||
guide = get_migration_guide()
|
|
||||||
|
|
||||||
assert isinstance(guide, str)
|
|
||||||
assert "Alembic" in guide
|
|
||||||
assert "alembic init" in guide
|
|
||||||
assert "alembic upgrade head" in guide
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Integration Tests
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_full_initialization_workflow(test_engine):
|
|
||||||
"""Test complete initialization workflow."""
|
|
||||||
# 1. Initialize database
|
|
||||||
result = await initialize_database(
|
|
||||||
engine=test_engine,
|
|
||||||
create_schema=True,
|
|
||||||
validate_schema=True,
|
|
||||||
seed_data=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result["success"] is True
|
|
||||||
|
|
||||||
# 2. Verify schema
|
|
||||||
validation = await validate_database_schema(test_engine)
|
|
||||||
assert validation["valid"] is True
|
|
||||||
|
|
||||||
# 3. Check version
|
|
||||||
version = await get_schema_version(test_engine)
|
|
||||||
assert version == CURRENT_SCHEMA_VERSION
|
|
||||||
|
|
||||||
# 4. Health check
|
|
||||||
health = await check_database_health(test_engine)
|
|
||||||
assert health["healthy"] is True
|
|
||||||
assert health["accessible"] is True
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_reinitialize_existing_database(test_engine_with_tables):
|
|
||||||
"""Test reinitializing an existing database."""
|
|
||||||
# Should be idempotent - safe to call multiple times
|
|
||||||
result1 = await initialize_database(
|
|
||||||
engine=test_engine_with_tables,
|
|
||||||
create_schema=True,
|
|
||||||
validate_schema=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
result2 = await initialize_database(
|
|
||||||
engine=test_engine_with_tables,
|
|
||||||
create_schema=True,
|
|
||||||
validate_schema=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result1["success"] is True
|
|
||||||
assert result2["success"] is True
|
|
||||||
assert result1["schema_version"] == result2["schema_version"]
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Error Handling Tests
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_initialize_database_with_creation_error():
|
|
||||||
"""Test initialization handles schema creation errors."""
|
|
||||||
mock_engine = AsyncMock(spec=AsyncEngine)
|
|
||||||
mock_engine.begin.side_effect = Exception("Creation failed")
|
|
||||||
|
|
||||||
with pytest.raises(RuntimeError, match="Failed to initialize database"):
|
|
||||||
await initialize_database(
|
|
||||||
engine=mock_engine,
|
|
||||||
create_schema=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_schema_with_connection_error():
|
|
||||||
"""Test schema creation handles connection errors."""
|
|
||||||
mock_engine = AsyncMock(spec=AsyncEngine)
|
|
||||||
mock_engine.begin.side_effect = Exception("Connection failed")
|
|
||||||
|
|
||||||
with pytest.raises(RuntimeError, match="Schema creation failed"):
|
|
||||||
await create_database_schema(mock_engine)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_validate_schema_with_inspection_error():
|
|
||||||
"""Test validation handles inspection errors gracefully."""
|
|
||||||
mock_engine = AsyncMock(spec=AsyncEngine)
|
|
||||||
mock_engine.connect.side_effect = Exception("Inspection failed")
|
|
||||||
|
|
||||||
result = await validate_database_schema(mock_engine)
|
|
||||||
|
|
||||||
assert result["valid"] is False
|
|
||||||
assert len(result["issues"]) > 0
|
|
||||||
assert "Inspection failed" in result["issues"][0]
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Constants Tests
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
def test_schema_constants():
|
|
||||||
"""Test that schema constants are properly defined."""
|
|
||||||
assert CURRENT_SCHEMA_VERSION == "1.0.0"
|
|
||||||
assert len(EXPECTED_TABLES) == 4
|
|
||||||
assert "anime_series" in EXPECTED_TABLES
|
|
||||||
assert "episodes" in EXPECTED_TABLES
|
|
||||||
assert "download_queue" in EXPECTED_TABLES
|
|
||||||
assert "user_sessions" in EXPECTED_TABLES
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
pytest.main([__file__, "-v"])
|
|
||||||
@ -1,561 +0,0 @@
|
|||||||
"""Unit tests for database models and connection management.
|
|
||||||
|
|
||||||
Tests SQLAlchemy models, relationships, session management, and database
|
|
||||||
operations. Uses an in-memory SQLite database for isolated testing.
|
|
||||||
"""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from sqlalchemy import create_engine, select
|
|
||||||
from sqlalchemy.orm import Session, sessionmaker
|
|
||||||
|
|
||||||
from src.server.database.base import Base, SoftDeleteMixin, TimestampMixin
|
|
||||||
from src.server.database.models import (
|
|
||||||
AnimeSeries,
|
|
||||||
DownloadPriority,
|
|
||||||
DownloadQueueItem,
|
|
||||||
DownloadStatus,
|
|
||||||
Episode,
|
|
||||||
UserSession,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def db_engine():
|
|
||||||
"""Create in-memory SQLite database engine for testing."""
|
|
||||||
engine = create_engine("sqlite:///:memory:", echo=False)
|
|
||||||
Base.metadata.create_all(engine)
|
|
||||||
return engine
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def db_session(db_engine):
|
|
||||||
"""Create database session for testing."""
|
|
||||||
SessionLocal = sessionmaker(bind=db_engine)
|
|
||||||
session = SessionLocal()
|
|
||||||
yield session
|
|
||||||
session.close()
|
|
||||||
|
|
||||||
|
|
||||||
class TestAnimeSeries:
|
|
||||||
"""Test cases for AnimeSeries model."""
|
|
||||||
|
|
||||||
def test_create_anime_series(self, db_session: Session):
|
|
||||||
"""Test creating an anime series."""
|
|
||||||
series = AnimeSeries(
|
|
||||||
key="attack-on-titan",
|
|
||||||
name="Attack on Titan",
|
|
||||||
site="https://aniworld.to",
|
|
||||||
folder="/anime/attack-on-titan",
|
|
||||||
description="Epic anime about titans",
|
|
||||||
status="completed",
|
|
||||||
total_episodes=75,
|
|
||||||
cover_url="https://example.com/cover.jpg",
|
|
||||||
episode_dict={1: [1, 2, 3], 2: [1, 2, 3, 4]},
|
|
||||||
)
|
|
||||||
|
|
||||||
db_session.add(series)
|
|
||||||
db_session.commit()
|
|
||||||
|
|
||||||
# Verify saved
|
|
||||||
assert series.id is not None
|
|
||||||
assert series.key == "attack-on-titan"
|
|
||||||
assert series.name == "Attack on Titan"
|
|
||||||
assert series.created_at is not None
|
|
||||||
assert series.updated_at is not None
|
|
||||||
|
|
||||||
def test_anime_series_unique_key(self, db_session: Session):
|
|
||||||
"""Test that series key must be unique."""
|
|
||||||
series1 = AnimeSeries(
|
|
||||||
key="unique-key",
|
|
||||||
name="Series 1",
|
|
||||||
site="https://example.com",
|
|
||||||
folder="/anime/series1",
|
|
||||||
)
|
|
||||||
series2 = AnimeSeries(
|
|
||||||
key="unique-key",
|
|
||||||
name="Series 2",
|
|
||||||
site="https://example.com",
|
|
||||||
folder="/anime/series2",
|
|
||||||
)
|
|
||||||
|
|
||||||
db_session.add(series1)
|
|
||||||
db_session.commit()
|
|
||||||
|
|
||||||
db_session.add(series2)
|
|
||||||
with pytest.raises(Exception): # IntegrityError
|
|
||||||
db_session.commit()
|
|
||||||
|
|
||||||
def test_anime_series_relationships(self, db_session: Session):
|
|
||||||
"""Test relationships with episodes and download items."""
|
|
||||||
series = AnimeSeries(
|
|
||||||
key="test-series",
|
|
||||||
name="Test Series",
|
|
||||||
site="https://example.com",
|
|
||||||
folder="/anime/test",
|
|
||||||
)
|
|
||||||
db_session.add(series)
|
|
||||||
db_session.commit()
|
|
||||||
|
|
||||||
# Add episodes
|
|
||||||
episode1 = Episode(
|
|
||||||
series_id=series.id,
|
|
||||||
season=1,
|
|
||||||
episode_number=1,
|
|
||||||
title="Episode 1",
|
|
||||||
)
|
|
||||||
episode2 = Episode(
|
|
||||||
series_id=series.id,
|
|
||||||
season=1,
|
|
||||||
episode_number=2,
|
|
||||||
title="Episode 2",
|
|
||||||
)
|
|
||||||
db_session.add_all([episode1, episode2])
|
|
||||||
db_session.commit()
|
|
||||||
|
|
||||||
# Verify relationship
|
|
||||||
assert len(series.episodes) == 2
|
|
||||||
assert series.episodes[0].title == "Episode 1"
|
|
||||||
|
|
||||||
def test_anime_series_cascade_delete(self, db_session: Session):
|
|
||||||
"""Test that deleting series cascades to episodes."""
|
|
||||||
series = AnimeSeries(
|
|
||||||
key="cascade-test",
|
|
||||||
name="Cascade Test",
|
|
||||||
site="https://example.com",
|
|
||||||
folder="/anime/cascade",
|
|
||||||
)
|
|
||||||
db_session.add(series)
|
|
||||||
db_session.commit()
|
|
||||||
|
|
||||||
# Add episodes
|
|
||||||
episode = Episode(
|
|
||||||
series_id=series.id,
|
|
||||||
season=1,
|
|
||||||
episode_number=1,
|
|
||||||
)
|
|
||||||
db_session.add(episode)
|
|
||||||
db_session.commit()
|
|
||||||
|
|
||||||
series_id = series.id
|
|
||||||
|
|
||||||
# Delete series
|
|
||||||
db_session.delete(series)
|
|
||||||
db_session.commit()
|
|
||||||
|
|
||||||
# Verify episodes are deleted
|
|
||||||
result = db_session.execute(
|
|
||||||
select(Episode).where(Episode.series_id == series_id)
|
|
||||||
)
|
|
||||||
assert result.scalar_one_or_none() is None
|
|
||||||
|
|
||||||
|
|
||||||
class TestEpisode:
|
|
||||||
"""Test cases for Episode model."""
|
|
||||||
|
|
||||||
def test_create_episode(self, db_session: Session):
|
|
||||||
"""Test creating an episode."""
|
|
||||||
series = AnimeSeries(
|
|
||||||
key="test-series",
|
|
||||||
name="Test Series",
|
|
||||||
site="https://example.com",
|
|
||||||
folder="/anime/test",
|
|
||||||
)
|
|
||||||
db_session.add(series)
|
|
||||||
db_session.commit()
|
|
||||||
|
|
||||||
episode = Episode(
|
|
||||||
series_id=series.id,
|
|
||||||
season=1,
|
|
||||||
episode_number=5,
|
|
||||||
title="The Fifth Episode",
|
|
||||||
file_path="/anime/test/S01E05.mp4",
|
|
||||||
file_size=524288000, # 500 MB
|
|
||||||
is_downloaded=True,
|
|
||||||
download_date=datetime.utcnow(),
|
|
||||||
)
|
|
||||||
|
|
||||||
db_session.add(episode)
|
|
||||||
db_session.commit()
|
|
||||||
|
|
||||||
# Verify saved
|
|
||||||
assert episode.id is not None
|
|
||||||
assert episode.season == 1
|
|
||||||
assert episode.episode_number == 5
|
|
||||||
assert episode.is_downloaded is True
|
|
||||||
assert episode.created_at is not None
|
|
||||||
|
|
||||||
def test_episode_relationship_to_series(self, db_session: Session):
|
|
||||||
"""Test episode relationship to series."""
|
|
||||||
series = AnimeSeries(
|
|
||||||
key="relationship-test",
|
|
||||||
name="Relationship Test",
|
|
||||||
site="https://example.com",
|
|
||||||
folder="/anime/relationship",
|
|
||||||
)
|
|
||||||
db_session.add(series)
|
|
||||||
db_session.commit()
|
|
||||||
|
|
||||||
episode = Episode(
|
|
||||||
series_id=series.id,
|
|
||||||
season=1,
|
|
||||||
episode_number=1,
|
|
||||||
)
|
|
||||||
db_session.add(episode)
|
|
||||||
db_session.commit()
|
|
||||||
|
|
||||||
# Verify relationship
|
|
||||||
assert episode.series.name == "Relationship Test"
|
|
||||||
assert episode.series.key == "relationship-test"
|
|
||||||
|
|
||||||
|
|
||||||
class TestDownloadQueueItem:
|
|
||||||
"""Test cases for DownloadQueueItem model."""
|
|
||||||
|
|
||||||
def test_create_download_item(self, db_session: Session):
|
|
||||||
"""Test creating a download queue item."""
|
|
||||||
series = AnimeSeries(
|
|
||||||
key="download-test",
|
|
||||||
name="Download Test",
|
|
||||||
site="https://example.com",
|
|
||||||
folder="/anime/download",
|
|
||||||
)
|
|
||||||
db_session.add(series)
|
|
||||||
db_session.commit()
|
|
||||||
|
|
||||||
item = DownloadQueueItem(
|
|
||||||
series_id=series.id,
|
|
||||||
season=1,
|
|
||||||
episode_number=3,
|
|
||||||
status=DownloadStatus.DOWNLOADING,
|
|
||||||
priority=DownloadPriority.HIGH,
|
|
||||||
progress_percent=45.5,
|
|
||||||
downloaded_bytes=250000000,
|
|
||||||
total_bytes=550000000,
|
|
||||||
download_speed=2500000.0,
|
|
||||||
retry_count=0,
|
|
||||||
download_url="https://example.com/download/ep3",
|
|
||||||
file_destination="/anime/download/S01E03.mp4",
|
|
||||||
)
|
|
||||||
|
|
||||||
db_session.add(item)
|
|
||||||
db_session.commit()
|
|
||||||
|
|
||||||
# Verify saved
|
|
||||||
assert item.id is not None
|
|
||||||
assert item.status == DownloadStatus.DOWNLOADING
|
|
||||||
assert item.priority == DownloadPriority.HIGH
|
|
||||||
assert item.progress_percent == 45.5
|
|
||||||
assert item.retry_count == 0
|
|
||||||
|
|
||||||
def test_download_item_status_enum(self, db_session: Session):
|
|
||||||
"""Test download status enum values."""
|
|
||||||
series = AnimeSeries(
|
|
||||||
key="status-test",
|
|
||||||
name="Status Test",
|
|
||||||
site="https://example.com",
|
|
||||||
folder="/anime/status",
|
|
||||||
)
|
|
||||||
db_session.add(series)
|
|
||||||
db_session.commit()
|
|
||||||
|
|
||||||
item = DownloadQueueItem(
|
|
||||||
series_id=series.id,
|
|
||||||
season=1,
|
|
||||||
episode_number=1,
|
|
||||||
status=DownloadStatus.PENDING,
|
|
||||||
)
|
|
||||||
db_session.add(item)
|
|
||||||
db_session.commit()
|
|
||||||
|
|
||||||
# Update status
|
|
||||||
item.status = DownloadStatus.COMPLETED
|
|
||||||
db_session.commit()
|
|
||||||
|
|
||||||
# Verify status change
|
|
||||||
assert item.status == DownloadStatus.COMPLETED
|
|
||||||
|
|
||||||
def test_download_item_error_handling(self, db_session: Session):
|
|
||||||
"""Test download item with error information."""
|
|
||||||
series = AnimeSeries(
|
|
||||||
key="error-test",
|
|
||||||
name="Error Test",
|
|
||||||
site="https://example.com",
|
|
||||||
folder="/anime/error",
|
|
||||||
)
|
|
||||||
db_session.add(series)
|
|
||||||
db_session.commit()
|
|
||||||
|
|
||||||
item = DownloadQueueItem(
|
|
||||||
series_id=series.id,
|
|
||||||
season=1,
|
|
||||||
episode_number=1,
|
|
||||||
status=DownloadStatus.FAILED,
|
|
||||||
error_message="Network timeout after 30 seconds",
|
|
||||||
retry_count=2,
|
|
||||||
)
|
|
||||||
db_session.add(item)
|
|
||||||
db_session.commit()
|
|
||||||
|
|
||||||
# Verify error info
|
|
||||||
assert item.status == DownloadStatus.FAILED
|
|
||||||
assert item.error_message == "Network timeout after 30 seconds"
|
|
||||||
assert item.retry_count == 2
|
|
||||||
|
|
||||||
|
|
||||||
class TestUserSession:
|
|
||||||
"""Test cases for UserSession model."""
|
|
||||||
|
|
||||||
def test_create_user_session(self, db_session: Session):
|
|
||||||
"""Test creating a user session."""
|
|
||||||
expires = datetime.utcnow() + timedelta(hours=24)
|
|
||||||
|
|
||||||
session = UserSession(
|
|
||||||
session_id="test-session-123",
|
|
||||||
token_hash="hashed-token-value",
|
|
||||||
user_id="user-1",
|
|
||||||
ip_address="192.168.1.100",
|
|
||||||
user_agent="Mozilla/5.0",
|
|
||||||
expires_at=expires,
|
|
||||||
is_active=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
db_session.add(session)
|
|
||||||
db_session.commit()
|
|
||||||
|
|
||||||
# Verify saved
|
|
||||||
assert session.id is not None
|
|
||||||
assert session.session_id == "test-session-123"
|
|
||||||
assert session.is_active is True
|
|
||||||
assert session.created_at is not None
|
|
||||||
|
|
||||||
def test_session_unique_session_id(self, db_session: Session):
|
|
||||||
"""Test that session_id must be unique."""
|
|
||||||
expires = datetime.utcnow() + timedelta(hours=24)
|
|
||||||
|
|
||||||
session1 = UserSession(
|
|
||||||
session_id="duplicate-id",
|
|
||||||
token_hash="hash1",
|
|
||||||
expires_at=expires,
|
|
||||||
)
|
|
||||||
session2 = UserSession(
|
|
||||||
session_id="duplicate-id",
|
|
||||||
token_hash="hash2",
|
|
||||||
expires_at=expires,
|
|
||||||
)
|
|
||||||
|
|
||||||
db_session.add(session1)
|
|
||||||
db_session.commit()
|
|
||||||
|
|
||||||
db_session.add(session2)
|
|
||||||
with pytest.raises(Exception): # IntegrityError
|
|
||||||
db_session.commit()
|
|
||||||
|
|
||||||
def test_session_is_expired(self, db_session: Session):
|
|
||||||
"""Test session expiration check."""
|
|
||||||
# Create expired session
|
|
||||||
expired = datetime.utcnow() - timedelta(hours=1)
|
|
||||||
session = UserSession(
|
|
||||||
session_id="expired-session",
|
|
||||||
token_hash="hash",
|
|
||||||
expires_at=expired,
|
|
||||||
)
|
|
||||||
|
|
||||||
db_session.add(session)
|
|
||||||
db_session.commit()
|
|
||||||
|
|
||||||
# Verify is_expired
|
|
||||||
assert session.is_expired is True
|
|
||||||
|
|
||||||
def test_session_revoke(self, db_session: Session):
|
|
||||||
"""Test session revocation."""
|
|
||||||
expires = datetime.utcnow() + timedelta(hours=24)
|
|
||||||
session = UserSession(
|
|
||||||
session_id="revoke-test",
|
|
||||||
token_hash="hash",
|
|
||||||
expires_at=expires,
|
|
||||||
is_active=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
db_session.add(session)
|
|
||||||
db_session.commit()
|
|
||||||
|
|
||||||
# Revoke session
|
|
||||||
session.revoke()
|
|
||||||
db_session.commit()
|
|
||||||
|
|
||||||
# Verify revoked
|
|
||||||
assert session.is_active is False
|
|
||||||
|
|
||||||
|
|
||||||
class TestTimestampMixin:
|
|
||||||
"""Test cases for TimestampMixin."""
|
|
||||||
|
|
||||||
def test_timestamp_auto_creation(self, db_session: Session):
|
|
||||||
"""Test that timestamps are automatically created."""
|
|
||||||
series = AnimeSeries(
|
|
||||||
key="timestamp-test",
|
|
||||||
name="Timestamp Test",
|
|
||||||
site="https://example.com",
|
|
||||||
folder="/anime/timestamp",
|
|
||||||
)
|
|
||||||
|
|
||||||
db_session.add(series)
|
|
||||||
db_session.commit()
|
|
||||||
|
|
||||||
# Verify timestamps exist
|
|
||||||
assert series.created_at is not None
|
|
||||||
assert series.updated_at is not None
|
|
||||||
assert series.created_at == series.updated_at
|
|
||||||
|
|
||||||
def test_timestamp_auto_update(self, db_session: Session):
|
|
||||||
"""Test that updated_at is automatically updated."""
|
|
||||||
series = AnimeSeries(
|
|
||||||
key="update-test",
|
|
||||||
name="Update Test",
|
|
||||||
site="https://example.com",
|
|
||||||
folder="/anime/update",
|
|
||||||
)
|
|
||||||
|
|
||||||
db_session.add(series)
|
|
||||||
db_session.commit()
|
|
||||||
|
|
||||||
original_updated = series.updated_at
|
|
||||||
|
|
||||||
# Update and save
|
|
||||||
series.name = "Updated Name"
|
|
||||||
db_session.commit()
|
|
||||||
|
|
||||||
# Verify updated_at changed
|
|
||||||
# Note: This test may be flaky due to timing
|
|
||||||
assert series.created_at is not None
|
|
||||||
|
|
||||||
|
|
||||||
class TestSoftDeleteMixin:
|
|
||||||
"""Test cases for SoftDeleteMixin."""
|
|
||||||
|
|
||||||
def test_soft_delete_not_applied_to_models(self):
|
|
||||||
"""Test that SoftDeleteMixin is not applied to current models.
|
|
||||||
|
|
||||||
This is a documentation test - models don't currently use
|
|
||||||
SoftDeleteMixin, but it's available for future use.
|
|
||||||
"""
|
|
||||||
# Verify models don't have deleted_at attribute
|
|
||||||
series = AnimeSeries(
|
|
||||||
key="soft-delete-test",
|
|
||||||
name="Soft Delete Test",
|
|
||||||
site="https://example.com",
|
|
||||||
folder="/anime/soft-delete",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Models shouldn't have soft delete attributes
|
|
||||||
assert not hasattr(series, "deleted_at")
|
|
||||||
assert not hasattr(series, "is_deleted")
|
|
||||||
assert not hasattr(series, "soft_delete")
|
|
||||||
|
|
||||||
|
|
||||||
class TestDatabaseQueries:
|
|
||||||
"""Test complex database queries and operations."""
|
|
||||||
|
|
||||||
def test_query_series_with_episodes(self, db_session: Session):
|
|
||||||
"""Test querying series with their episodes."""
|
|
||||||
# Create series with episodes
|
|
||||||
series = AnimeSeries(
|
|
||||||
key="query-test",
|
|
||||||
name="Query Test",
|
|
||||||
site="https://example.com",
|
|
||||||
folder="/anime/query",
|
|
||||||
)
|
|
||||||
db_session.add(series)
|
|
||||||
db_session.commit()
|
|
||||||
|
|
||||||
# Add multiple episodes
|
|
||||||
for i in range(1, 6):
|
|
||||||
episode = Episode(
|
|
||||||
series_id=series.id,
|
|
||||||
season=1,
|
|
||||||
episode_number=i,
|
|
||||||
title=f"Episode {i}",
|
|
||||||
)
|
|
||||||
db_session.add(episode)
|
|
||||||
db_session.commit()
|
|
||||||
|
|
||||||
# Query series with episodes
|
|
||||||
result = db_session.execute(
|
|
||||||
select(AnimeSeries).where(AnimeSeries.key == "query-test")
|
|
||||||
)
|
|
||||||
queried_series = result.scalar_one()
|
|
||||||
|
|
||||||
# Verify episodes loaded
|
|
||||||
assert len(queried_series.episodes) == 5
|
|
||||||
|
|
||||||
def test_query_download_queue_by_status(self, db_session: Session):
|
|
||||||
"""Test querying download queue by status."""
|
|
||||||
series = AnimeSeries(
|
|
||||||
key="queue-query-test",
|
|
||||||
name="Queue Query Test",
|
|
||||||
site="https://example.com",
|
|
||||||
folder="/anime/queue-query",
|
|
||||||
)
|
|
||||||
db_session.add(series)
|
|
||||||
db_session.commit()
|
|
||||||
|
|
||||||
# Create items with different statuses
|
|
||||||
for i, status in enumerate([
|
|
||||||
DownloadStatus.PENDING,
|
|
||||||
DownloadStatus.DOWNLOADING,
|
|
||||||
DownloadStatus.COMPLETED,
|
|
||||||
]):
|
|
||||||
item = DownloadQueueItem(
|
|
||||||
series_id=series.id,
|
|
||||||
season=1,
|
|
||||||
episode_number=i + 1,
|
|
||||||
status=status,
|
|
||||||
)
|
|
||||||
db_session.add(item)
|
|
||||||
db_session.commit()
|
|
||||||
|
|
||||||
# Query pending items
|
|
||||||
result = db_session.execute(
|
|
||||||
select(DownloadQueueItem).where(
|
|
||||||
DownloadQueueItem.status == DownloadStatus.PENDING
|
|
||||||
)
|
|
||||||
)
|
|
||||||
pending = result.scalars().all()
|
|
||||||
|
|
||||||
# Verify query
|
|
||||||
assert len(pending) == 1
|
|
||||||
assert pending[0].episode_number == 1
|
|
||||||
|
|
||||||
def test_query_active_sessions(self, db_session: Session):
|
|
||||||
"""Test querying active user sessions."""
|
|
||||||
expires = datetime.utcnow() + timedelta(hours=24)
|
|
||||||
|
|
||||||
# Create active and inactive sessions
|
|
||||||
active = UserSession(
|
|
||||||
session_id="active-1",
|
|
||||||
token_hash="hash1",
|
|
||||||
expires_at=expires,
|
|
||||||
is_active=True,
|
|
||||||
)
|
|
||||||
inactive = UserSession(
|
|
||||||
session_id="inactive-1",
|
|
||||||
token_hash="hash2",
|
|
||||||
expires_at=expires,
|
|
||||||
is_active=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
db_session.add_all([active, inactive])
|
|
||||||
db_session.commit()
|
|
||||||
|
|
||||||
# Query active sessions
|
|
||||||
result = db_session.execute(
|
|
||||||
select(UserSession).where(UserSession.is_active == True)
|
|
||||||
)
|
|
||||||
active_sessions = result.scalars().all()
|
|
||||||
|
|
||||||
# Verify query
|
|
||||||
assert len(active_sessions) == 1
|
|
||||||
assert active_sessions[0].session_id == "active-1"
|
|
||||||
@ -1,682 +0,0 @@
|
|||||||
"""Unit tests for database service layer.
|
|
||||||
|
|
||||||
Tests CRUD operations for all database services using in-memory SQLite.
|
|
||||||
"""
|
|
||||||
import asyncio
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
|
||||||
from sqlalchemy.orm import sessionmaker
|
|
||||||
|
|
||||||
from src.server.database.base import Base
|
|
||||||
from src.server.database.models import DownloadPriority, DownloadStatus
|
|
||||||
from src.server.database.service import (
|
|
||||||
AnimeSeriesService,
|
|
||||||
DownloadQueueService,
|
|
||||||
EpisodeService,
|
|
||||||
UserSessionService,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def db_engine():
|
|
||||||
"""Create in-memory database engine for testing."""
|
|
||||||
engine = create_async_engine(
|
|
||||||
"sqlite+aiosqlite:///:memory:",
|
|
||||||
echo=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create all tables
|
|
||||||
async with engine.begin() as conn:
|
|
||||||
await conn.run_sync(Base.metadata.create_all)
|
|
||||||
|
|
||||||
yield engine
|
|
||||||
|
|
||||||
# Cleanup
|
|
||||||
await engine.dispose()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def db_session(db_engine):
|
|
||||||
"""Create database session for testing."""
|
|
||||||
async_session = sessionmaker(
|
|
||||||
db_engine,
|
|
||||||
class_=AsyncSession,
|
|
||||||
expire_on_commit=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
async with async_session() as session:
|
|
||||||
yield session
|
|
||||||
await session.rollback()
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# AnimeSeriesService Tests
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_anime_series(db_session):
|
|
||||||
"""Test creating an anime series."""
|
|
||||||
series = await AnimeSeriesService.create(
|
|
||||||
db_session,
|
|
||||||
key="test-anime-1",
|
|
||||||
name="Test Anime",
|
|
||||||
site="https://example.com",
|
|
||||||
folder="/path/to/anime",
|
|
||||||
description="A test anime",
|
|
||||||
status="ongoing",
|
|
||||||
total_episodes=12,
|
|
||||||
cover_url="https://example.com/cover.jpg",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert series.id is not None
|
|
||||||
assert series.key == "test-anime-1"
|
|
||||||
assert series.name == "Test Anime"
|
|
||||||
assert series.description == "A test anime"
|
|
||||||
assert series.total_episodes == 12
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_anime_series_by_id(db_session):
|
|
||||||
"""Test retrieving anime series by ID."""
|
|
||||||
# Create series
|
|
||||||
series = await AnimeSeriesService.create(
|
|
||||||
db_session,
|
|
||||||
key="test-anime-2",
|
|
||||||
name="Test Anime 2",
|
|
||||||
site="https://example.com",
|
|
||||||
folder="/path/to/anime2",
|
|
||||||
)
|
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
# Retrieve series
|
|
||||||
retrieved = await AnimeSeriesService.get_by_id(db_session, series.id)
|
|
||||||
assert retrieved is not None
|
|
||||||
assert retrieved.id == series.id
|
|
||||||
assert retrieved.key == "test-anime-2"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_anime_series_by_key(db_session):
|
|
||||||
"""Test retrieving anime series by provider key."""
|
|
||||||
# Create series
|
|
||||||
await AnimeSeriesService.create(
|
|
||||||
db_session,
|
|
||||||
key="unique-key",
|
|
||||||
name="Test Anime",
|
|
||||||
site="https://example.com",
|
|
||||||
folder="/path/to/anime",
|
|
||||||
)
|
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
# Retrieve by key
|
|
||||||
retrieved = await AnimeSeriesService.get_by_key(db_session, "unique-key")
|
|
||||||
assert retrieved is not None
|
|
||||||
assert retrieved.key == "unique-key"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_all_anime_series(db_session):
|
|
||||||
"""Test retrieving all anime series."""
|
|
||||||
# Create multiple series
|
|
||||||
await AnimeSeriesService.create(
|
|
||||||
db_session,
|
|
||||||
key="anime-1",
|
|
||||||
name="Anime 1",
|
|
||||||
site="https://example.com",
|
|
||||||
folder="/path/1",
|
|
||||||
)
|
|
||||||
await AnimeSeriesService.create(
|
|
||||||
db_session,
|
|
||||||
key="anime-2",
|
|
||||||
name="Anime 2",
|
|
||||||
site="https://example.com",
|
|
||||||
folder="/path/2",
|
|
||||||
)
|
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
# Retrieve all
|
|
||||||
all_series = await AnimeSeriesService.get_all(db_session)
|
|
||||||
assert len(all_series) == 2
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_anime_series(db_session):
|
|
||||||
"""Test updating anime series."""
|
|
||||||
# Create series
|
|
||||||
series = await AnimeSeriesService.create(
|
|
||||||
db_session,
|
|
||||||
key="anime-update",
|
|
||||||
name="Original Name",
|
|
||||||
site="https://example.com",
|
|
||||||
folder="/path/original",
|
|
||||||
)
|
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
# Update series
|
|
||||||
updated = await AnimeSeriesService.update(
|
|
||||||
db_session,
|
|
||||||
series.id,
|
|
||||||
name="Updated Name",
|
|
||||||
total_episodes=24,
|
|
||||||
)
|
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
assert updated is not None
|
|
||||||
assert updated.name == "Updated Name"
|
|
||||||
assert updated.total_episodes == 24
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_delete_anime_series(db_session):
|
|
||||||
"""Test deleting anime series."""
|
|
||||||
# Create series
|
|
||||||
series = await AnimeSeriesService.create(
|
|
||||||
db_session,
|
|
||||||
key="anime-delete",
|
|
||||||
name="To Delete",
|
|
||||||
site="https://example.com",
|
|
||||||
folder="/path/delete",
|
|
||||||
)
|
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
# Delete series
|
|
||||||
deleted = await AnimeSeriesService.delete(db_session, series.id)
|
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
assert deleted is True
|
|
||||||
|
|
||||||
# Verify deletion
|
|
||||||
retrieved = await AnimeSeriesService.get_by_id(db_session, series.id)
|
|
||||||
assert retrieved is None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_search_anime_series(db_session):
|
|
||||||
"""Test searching anime series by name."""
|
|
||||||
# Create series
|
|
||||||
await AnimeSeriesService.create(
|
|
||||||
db_session,
|
|
||||||
key="naruto",
|
|
||||||
name="Naruto Shippuden",
|
|
||||||
site="https://example.com",
|
|
||||||
folder="/path/naruto",
|
|
||||||
)
|
|
||||||
await AnimeSeriesService.create(
|
|
||||||
db_session,
|
|
||||||
key="bleach",
|
|
||||||
name="Bleach",
|
|
||||||
site="https://example.com",
|
|
||||||
folder="/path/bleach",
|
|
||||||
)
|
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
# Search
|
|
||||||
results = await AnimeSeriesService.search(db_session, "naruto")
|
|
||||||
assert len(results) == 1
|
|
||||||
assert results[0].name == "Naruto Shippuden"
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# EpisodeService Tests
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_episode(db_session):
|
|
||||||
"""Test creating an episode."""
|
|
||||||
# Create series first
|
|
||||||
series = await AnimeSeriesService.create(
|
|
||||||
db_session,
|
|
||||||
key="test-series",
|
|
||||||
name="Test Series",
|
|
||||||
site="https://example.com",
|
|
||||||
folder="/path/test",
|
|
||||||
)
|
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
# Create episode
|
|
||||||
episode = await EpisodeService.create(
|
|
||||||
db_session,
|
|
||||||
series_id=series.id,
|
|
||||||
season=1,
|
|
||||||
episode_number=1,
|
|
||||||
title="Episode 1",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert episode.id is not None
|
|
||||||
assert episode.series_id == series.id
|
|
||||||
assert episode.season == 1
|
|
||||||
assert episode.episode_number == 1
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_episodes_by_series(db_session):
|
|
||||||
"""Test retrieving episodes for a series."""
|
|
||||||
# Create series
|
|
||||||
series = await AnimeSeriesService.create(
|
|
||||||
db_session,
|
|
||||||
key="test-series-2",
|
|
||||||
name="Test Series 2",
|
|
||||||
site="https://example.com",
|
|
||||||
folder="/path/test2",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create episodes
|
|
||||||
await EpisodeService.create(
|
|
||||||
db_session,
|
|
||||||
series_id=series.id,
|
|
||||||
season=1,
|
|
||||||
episode_number=1,
|
|
||||||
)
|
|
||||||
await EpisodeService.create(
|
|
||||||
db_session,
|
|
||||||
series_id=series.id,
|
|
||||||
season=1,
|
|
||||||
episode_number=2,
|
|
||||||
)
|
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
# Retrieve episodes
|
|
||||||
episodes = await EpisodeService.get_by_series(db_session, series.id)
|
|
||||||
assert len(episodes) == 2
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_mark_episode_downloaded(db_session):
|
|
||||||
"""Test marking episode as downloaded."""
|
|
||||||
# Create series and episode
|
|
||||||
series = await AnimeSeriesService.create(
|
|
||||||
db_session,
|
|
||||||
key="test-series-3",
|
|
||||||
name="Test Series 3",
|
|
||||||
site="https://example.com",
|
|
||||||
folder="/path/test3",
|
|
||||||
)
|
|
||||||
episode = await EpisodeService.create(
|
|
||||||
db_session,
|
|
||||||
series_id=series.id,
|
|
||||||
season=1,
|
|
||||||
episode_number=1,
|
|
||||||
)
|
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
# Mark as downloaded
|
|
||||||
updated = await EpisodeService.mark_downloaded(
|
|
||||||
db_session,
|
|
||||||
episode.id,
|
|
||||||
file_path="/path/to/file.mp4",
|
|
||||||
file_size=1024000,
|
|
||||||
)
|
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
assert updated is not None
|
|
||||||
assert updated.is_downloaded is True
|
|
||||||
assert updated.file_path == "/path/to/file.mp4"
|
|
||||||
assert updated.download_date is not None
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# DownloadQueueService Tests
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_download_queue_item(db_session):
|
|
||||||
"""Test adding item to download queue."""
|
|
||||||
# Create series
|
|
||||||
series = await AnimeSeriesService.create(
|
|
||||||
db_session,
|
|
||||||
key="test-series-4",
|
|
||||||
name="Test Series 4",
|
|
||||||
site="https://example.com",
|
|
||||||
folder="/path/test4",
|
|
||||||
)
|
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
# Add to queue
|
|
||||||
item = await DownloadQueueService.create(
|
|
||||||
db_session,
|
|
||||||
series_id=series.id,
|
|
||||||
season=1,
|
|
||||||
episode_number=1,
|
|
||||||
priority=DownloadPriority.HIGH,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert item.id is not None
|
|
||||||
assert item.status == DownloadStatus.PENDING
|
|
||||||
assert item.priority == DownloadPriority.HIGH
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_pending_downloads(db_session):
|
|
||||||
"""Test retrieving pending downloads."""
|
|
||||||
# Create series
|
|
||||||
series = await AnimeSeriesService.create(
|
|
||||||
db_session,
|
|
||||||
key="test-series-5",
|
|
||||||
name="Test Series 5",
|
|
||||||
site="https://example.com",
|
|
||||||
folder="/path/test5",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add pending items
|
|
||||||
await DownloadQueueService.create(
|
|
||||||
db_session,
|
|
||||||
series_id=series.id,
|
|
||||||
season=1,
|
|
||||||
episode_number=1,
|
|
||||||
)
|
|
||||||
await DownloadQueueService.create(
|
|
||||||
db_session,
|
|
||||||
series_id=series.id,
|
|
||||||
season=1,
|
|
||||||
episode_number=2,
|
|
||||||
)
|
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
# Retrieve pending
|
|
||||||
pending = await DownloadQueueService.get_pending(db_session)
|
|
||||||
assert len(pending) == 2
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_download_status(db_session):
|
|
||||||
"""Test updating download status."""
|
|
||||||
# Create series and queue item
|
|
||||||
series = await AnimeSeriesService.create(
|
|
||||||
db_session,
|
|
||||||
key="test-series-6",
|
|
||||||
name="Test Series 6",
|
|
||||||
site="https://example.com",
|
|
||||||
folder="/path/test6",
|
|
||||||
)
|
|
||||||
item = await DownloadQueueService.create(
|
|
||||||
db_session,
|
|
||||||
series_id=series.id,
|
|
||||||
season=1,
|
|
||||||
episode_number=1,
|
|
||||||
)
|
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
# Update status
|
|
||||||
updated = await DownloadQueueService.update_status(
|
|
||||||
db_session,
|
|
||||||
item.id,
|
|
||||||
DownloadStatus.DOWNLOADING,
|
|
||||||
)
|
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
assert updated is not None
|
|
||||||
assert updated.status == DownloadStatus.DOWNLOADING
|
|
||||||
assert updated.started_at is not None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_download_progress(db_session):
|
|
||||||
"""Test updating download progress."""
|
|
||||||
# Create series and queue item
|
|
||||||
series = await AnimeSeriesService.create(
|
|
||||||
db_session,
|
|
||||||
key="test-series-7",
|
|
||||||
name="Test Series 7",
|
|
||||||
site="https://example.com",
|
|
||||||
folder="/path/test7",
|
|
||||||
)
|
|
||||||
item = await DownloadQueueService.create(
|
|
||||||
db_session,
|
|
||||||
series_id=series.id,
|
|
||||||
season=1,
|
|
||||||
episode_number=1,
|
|
||||||
)
|
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
# Update progress
|
|
||||||
updated = await DownloadQueueService.update_progress(
|
|
||||||
db_session,
|
|
||||||
item.id,
|
|
||||||
progress_percent=50.0,
|
|
||||||
downloaded_bytes=500000,
|
|
||||||
total_bytes=1000000,
|
|
||||||
download_speed=50000.0,
|
|
||||||
)
|
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
assert updated is not None
|
|
||||||
assert updated.progress_percent == 50.0
|
|
||||||
assert updated.downloaded_bytes == 500000
|
|
||||||
assert updated.total_bytes == 1000000
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_clear_completed_downloads(db_session):
|
|
||||||
"""Test clearing completed downloads."""
|
|
||||||
# Create series and completed items
|
|
||||||
series = await AnimeSeriesService.create(
|
|
||||||
db_session,
|
|
||||||
key="test-series-8",
|
|
||||||
name="Test Series 8",
|
|
||||||
site="https://example.com",
|
|
||||||
folder="/path/test8",
|
|
||||||
)
|
|
||||||
item1 = await DownloadQueueService.create(
|
|
||||||
db_session,
|
|
||||||
series_id=series.id,
|
|
||||||
season=1,
|
|
||||||
episode_number=1,
|
|
||||||
)
|
|
||||||
item2 = await DownloadQueueService.create(
|
|
||||||
db_session,
|
|
||||||
series_id=series.id,
|
|
||||||
season=1,
|
|
||||||
episode_number=2,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mark items as completed
|
|
||||||
await DownloadQueueService.update_status(
|
|
||||||
db_session,
|
|
||||||
item1.id,
|
|
||||||
DownloadStatus.COMPLETED,
|
|
||||||
)
|
|
||||||
await DownloadQueueService.update_status(
|
|
||||||
db_session,
|
|
||||||
item2.id,
|
|
||||||
DownloadStatus.COMPLETED,
|
|
||||||
)
|
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
# Clear completed
|
|
||||||
count = await DownloadQueueService.clear_completed(db_session)
|
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
assert count == 2
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_retry_failed_downloads(db_session):
|
|
||||||
"""Test retrying failed downloads."""
|
|
||||||
# Create series and failed item
|
|
||||||
series = await AnimeSeriesService.create(
|
|
||||||
db_session,
|
|
||||||
key="test-series-9",
|
|
||||||
name="Test Series 9",
|
|
||||||
site="https://example.com",
|
|
||||||
folder="/path/test9",
|
|
||||||
)
|
|
||||||
item = await DownloadQueueService.create(
|
|
||||||
db_session,
|
|
||||||
series_id=series.id,
|
|
||||||
season=1,
|
|
||||||
episode_number=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mark as failed
|
|
||||||
await DownloadQueueService.update_status(
|
|
||||||
db_session,
|
|
||||||
item.id,
|
|
||||||
DownloadStatus.FAILED,
|
|
||||||
error_message="Network error",
|
|
||||||
)
|
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
# Retry
|
|
||||||
retried = await DownloadQueueService.retry_failed(db_session)
|
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
assert len(retried) == 1
|
|
||||||
assert retried[0].status == DownloadStatus.PENDING
|
|
||||||
assert retried[0].error_message is None
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# UserSessionService Tests
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_user_session(db_session):
|
|
||||||
"""Test creating a user session."""
|
|
||||||
expires_at = datetime.utcnow() + timedelta(hours=24)
|
|
||||||
session = await UserSessionService.create(
|
|
||||||
db_session,
|
|
||||||
session_id="test-session-1",
|
|
||||||
token_hash="hashed-token",
|
|
||||||
expires_at=expires_at,
|
|
||||||
user_id="user123",
|
|
||||||
ip_address="127.0.0.1",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert session.id is not None
|
|
||||||
assert session.session_id == "test-session-1"
|
|
||||||
assert session.is_active is True
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_session_by_id(db_session):
|
|
||||||
"""Test retrieving session by ID."""
|
|
||||||
expires_at = datetime.utcnow() + timedelta(hours=24)
|
|
||||||
session = await UserSessionService.create(
|
|
||||||
db_session,
|
|
||||||
session_id="test-session-2",
|
|
||||||
token_hash="hashed-token",
|
|
||||||
expires_at=expires_at,
|
|
||||||
)
|
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
# Retrieve
|
|
||||||
retrieved = await UserSessionService.get_by_session_id(
|
|
||||||
db_session,
|
|
||||||
"test-session-2",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert retrieved is not None
|
|
||||||
assert retrieved.session_id == "test-session-2"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_active_sessions(db_session):
|
|
||||||
"""Test retrieving active sessions."""
|
|
||||||
expires_at = datetime.utcnow() + timedelta(hours=24)
|
|
||||||
|
|
||||||
# Create active session
|
|
||||||
await UserSessionService.create(
|
|
||||||
db_session,
|
|
||||||
session_id="active-session",
|
|
||||||
token_hash="hashed-token",
|
|
||||||
expires_at=expires_at,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create expired session
|
|
||||||
await UserSessionService.create(
|
|
||||||
db_session,
|
|
||||||
session_id="expired-session",
|
|
||||||
token_hash="hashed-token",
|
|
||||||
expires_at=datetime.utcnow() - timedelta(hours=1),
|
|
||||||
)
|
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
# Retrieve active sessions
|
|
||||||
active = await UserSessionService.get_active_sessions(db_session)
|
|
||||||
assert len(active) == 1
|
|
||||||
assert active[0].session_id == "active-session"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_revoke_session(db_session):
|
|
||||||
"""Test revoking a session."""
|
|
||||||
expires_at = datetime.utcnow() + timedelta(hours=24)
|
|
||||||
session = await UserSessionService.create(
|
|
||||||
db_session,
|
|
||||||
session_id="test-session-3",
|
|
||||||
token_hash="hashed-token",
|
|
||||||
expires_at=expires_at,
|
|
||||||
)
|
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
# Revoke
|
|
||||||
revoked = await UserSessionService.revoke(db_session, "test-session-3")
|
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
assert revoked is True
|
|
||||||
|
|
||||||
# Verify
|
|
||||||
retrieved = await UserSessionService.get_by_session_id(
|
|
||||||
db_session,
|
|
||||||
"test-session-3",
|
|
||||||
)
|
|
||||||
assert retrieved.is_active is False
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_cleanup_expired_sessions(db_session):
|
|
||||||
"""Test cleaning up expired sessions."""
|
|
||||||
# Create expired sessions
|
|
||||||
await UserSessionService.create(
|
|
||||||
db_session,
|
|
||||||
session_id="expired-1",
|
|
||||||
token_hash="hashed-token",
|
|
||||||
expires_at=datetime.utcnow() - timedelta(hours=1),
|
|
||||||
)
|
|
||||||
await UserSessionService.create(
|
|
||||||
db_session,
|
|
||||||
session_id="expired-2",
|
|
||||||
token_hash="hashed-token",
|
|
||||||
expires_at=datetime.utcnow() - timedelta(hours=2),
|
|
||||||
)
|
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
# Cleanup
|
|
||||||
count = await UserSessionService.cleanup_expired(db_session)
|
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
assert count == 2
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_session_activity(db_session):
|
|
||||||
"""Test updating session last activity."""
|
|
||||||
expires_at = datetime.utcnow() + timedelta(hours=24)
|
|
||||||
session = await UserSessionService.create(
|
|
||||||
db_session,
|
|
||||||
session_id="test-session-4",
|
|
||||||
token_hash="hashed-token",
|
|
||||||
expires_at=expires_at,
|
|
||||||
)
|
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
original_activity = session.last_activity
|
|
||||||
|
|
||||||
# Wait a bit
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
|
|
||||||
# Update activity
|
|
||||||
updated = await UserSessionService.update_activity(
|
|
||||||
db_session,
|
|
||||||
"test-session-4",
|
|
||||||
)
|
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
assert updated is not None
|
|
||||||
assert updated.last_activity > original_activity
|
|
||||||
@ -1,556 +0,0 @@
|
|||||||
"""
|
|
||||||
Unit tests for enhanced SeriesApp with async callback support.
|
|
||||||
|
|
||||||
Tests the functionality of SeriesApp including:
|
|
||||||
- Initialization and configuration
|
|
||||||
- Search functionality
|
|
||||||
- Download with progress callbacks
|
|
||||||
- Directory scanning with progress reporting
|
|
||||||
- Async versions of operations
|
|
||||||
- Cancellation support
|
|
||||||
- Error handling
|
|
||||||
"""
|
|
||||||
|
|
||||||
from unittest.mock import Mock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from src.core.SeriesApp import OperationResult, OperationStatus, ProgressInfo, SeriesApp
|
|
||||||
|
|
||||||
|
|
||||||
class TestSeriesAppInitialization:
|
|
||||||
"""Test SeriesApp initialization."""
|
|
||||||
|
|
||||||
@patch('src.core.SeriesApp.Loaders')
|
|
||||||
@patch('src.core.SeriesApp.SerieScanner')
|
|
||||||
@patch('src.core.SeriesApp.SerieList')
|
|
||||||
def test_init_success(
|
|
||||||
self, mock_serie_list, mock_scanner, mock_loaders
|
|
||||||
):
|
|
||||||
"""Test successful initialization."""
|
|
||||||
test_dir = "/test/anime"
|
|
||||||
|
|
||||||
# Create app
|
|
||||||
app = SeriesApp(test_dir)
|
|
||||||
|
|
||||||
# Verify initialization
|
|
||||||
assert app.directory_to_search == test_dir
|
|
||||||
assert app._operation_status == OperationStatus.IDLE
|
|
||||||
assert app._cancel_flag is False
|
|
||||||
assert app._current_operation is None
|
|
||||||
mock_loaders.assert_called_once()
|
|
||||||
mock_scanner.assert_called_once()
|
|
||||||
|
|
||||||
@patch('src.core.SeriesApp.Loaders')
|
|
||||||
@patch('src.core.SeriesApp.SerieScanner')
|
|
||||||
@patch('src.core.SeriesApp.SerieList')
|
|
||||||
def test_init_with_callbacks(
|
|
||||||
self, mock_serie_list, mock_scanner, mock_loaders
|
|
||||||
):
|
|
||||||
"""Test initialization with progress and error callbacks."""
|
|
||||||
test_dir = "/test/anime"
|
|
||||||
progress_callback = Mock()
|
|
||||||
error_callback = Mock()
|
|
||||||
|
|
||||||
# Create app with callbacks
|
|
||||||
app = SeriesApp(
|
|
||||||
test_dir,
|
|
||||||
progress_callback=progress_callback,
|
|
||||||
error_callback=error_callback
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify callbacks are stored
|
|
||||||
assert app.progress_callback == progress_callback
|
|
||||||
assert app.error_callback == error_callback
|
|
||||||
|
|
||||||
@patch('src.core.SeriesApp.Loaders')
|
|
||||||
def test_init_failure_calls_error_callback(self, mock_loaders):
|
|
||||||
"""Test that initialization failure triggers error callback."""
|
|
||||||
test_dir = "/test/anime"
|
|
||||||
error_callback = Mock()
|
|
||||||
|
|
||||||
# Make Loaders raise an exception
|
|
||||||
mock_loaders.side_effect = RuntimeError("Init failed")
|
|
||||||
|
|
||||||
# Create app should raise but call error callback
|
|
||||||
with pytest.raises(RuntimeError):
|
|
||||||
SeriesApp(test_dir, error_callback=error_callback)
|
|
||||||
|
|
||||||
# Verify error callback was called
|
|
||||||
error_callback.assert_called_once()
|
|
||||||
assert isinstance(
|
|
||||||
error_callback.call_args[0][0],
|
|
||||||
RuntimeError
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestSeriesAppSearch:
|
|
||||||
"""Test search functionality."""
|
|
||||||
|
|
||||||
@patch('src.core.SeriesApp.Loaders')
|
|
||||||
@patch('src.core.SeriesApp.SerieScanner')
|
|
||||||
@patch('src.core.SeriesApp.SerieList')
|
|
||||||
def test_search_success(
|
|
||||||
self, mock_serie_list, mock_scanner, mock_loaders
|
|
||||||
):
|
|
||||||
"""Test successful search."""
|
|
||||||
test_dir = "/test/anime"
|
|
||||||
app = SeriesApp(test_dir)
|
|
||||||
|
|
||||||
# Mock search results
|
|
||||||
expected_results = [
|
|
||||||
{"key": "anime1", "name": "Anime 1"},
|
|
||||||
{"key": "anime2", "name": "Anime 2"}
|
|
||||||
]
|
|
||||||
app.loader.Search = Mock(return_value=expected_results)
|
|
||||||
|
|
||||||
# Perform search
|
|
||||||
results = app.search("test anime")
|
|
||||||
|
|
||||||
# Verify results
|
|
||||||
assert results == expected_results
|
|
||||||
app.loader.Search.assert_called_once_with("test anime")
|
|
||||||
|
|
||||||
@patch('src.core.SeriesApp.Loaders')
|
|
||||||
@patch('src.core.SeriesApp.SerieScanner')
|
|
||||||
@patch('src.core.SeriesApp.SerieList')
|
|
||||||
def test_search_failure_calls_error_callback(
|
|
||||||
self, mock_serie_list, mock_scanner, mock_loaders
|
|
||||||
):
|
|
||||||
"""Test search failure triggers error callback."""
|
|
||||||
test_dir = "/test/anime"
|
|
||||||
error_callback = Mock()
|
|
||||||
app = SeriesApp(test_dir, error_callback=error_callback)
|
|
||||||
|
|
||||||
# Make search raise an exception
|
|
||||||
app.loader.Search = Mock(
|
|
||||||
side_effect=RuntimeError("Search failed")
|
|
||||||
)
|
|
||||||
|
|
||||||
# Search should raise and call error callback
|
|
||||||
with pytest.raises(RuntimeError):
|
|
||||||
app.search("test")
|
|
||||||
|
|
||||||
error_callback.assert_called_once()
|
|
||||||
|
|
||||||
|
|
||||||
class TestSeriesAppDownload:
|
|
||||||
"""Test download functionality."""
|
|
||||||
|
|
||||||
@patch('src.core.SeriesApp.Loaders')
|
|
||||||
@patch('src.core.SeriesApp.SerieScanner')
|
|
||||||
@patch('src.core.SeriesApp.SerieList')
|
|
||||||
def test_download_success(
|
|
||||||
self, mock_serie_list, mock_scanner, mock_loaders
|
|
||||||
):
|
|
||||||
"""Test successful download."""
|
|
||||||
test_dir = "/test/anime"
|
|
||||||
app = SeriesApp(test_dir)
|
|
||||||
|
|
||||||
# Mock download
|
|
||||||
app.loader.Download = Mock()
|
|
||||||
|
|
||||||
# Perform download
|
|
||||||
result = app.download(
|
|
||||||
"anime_folder",
|
|
||||||
season=1,
|
|
||||||
episode=1,
|
|
||||||
key="anime_key"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify result
|
|
||||||
assert result.success is True
|
|
||||||
assert "Successfully downloaded" in result.message
|
|
||||||
# After successful completion, finally block resets operation
|
|
||||||
assert app._current_operation is None
|
|
||||||
app.loader.Download.assert_called_once()
|
|
||||||
|
|
||||||
@patch('src.core.SeriesApp.Loaders')
|
|
||||||
@patch('src.core.SeriesApp.SerieScanner')
|
|
||||||
@patch('src.core.SeriesApp.SerieList')
|
|
||||||
def test_download_with_progress_callback(
|
|
||||||
self, mock_serie_list, mock_scanner, mock_loaders
|
|
||||||
):
|
|
||||||
"""Test download with progress callback."""
|
|
||||||
test_dir = "/test/anime"
|
|
||||||
app = SeriesApp(test_dir)
|
|
||||||
|
|
||||||
# Mock download that calls progress callback
|
|
||||||
def mock_download(*args, **kwargs):
|
|
||||||
callback = args[-1] if len(args) > 6 else kwargs.get('callback')
|
|
||||||
if callback:
|
|
||||||
callback(0.5)
|
|
||||||
callback(1.0)
|
|
||||||
|
|
||||||
app.loader.Download = Mock(side_effect=mock_download)
|
|
||||||
progress_callback = Mock()
|
|
||||||
|
|
||||||
# Perform download
|
|
||||||
result = app.download(
|
|
||||||
"anime_folder",
|
|
||||||
season=1,
|
|
||||||
episode=1,
|
|
||||||
key="anime_key",
|
|
||||||
callback=progress_callback
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify progress callback was called
|
|
||||||
assert result.success is True
|
|
||||||
assert progress_callback.call_count == 2
|
|
||||||
progress_callback.assert_any_call(0.5)
|
|
||||||
progress_callback.assert_any_call(1.0)
|
|
||||||
|
|
||||||
@patch('src.core.SeriesApp.Loaders')
|
|
||||||
@patch('src.core.SeriesApp.SerieScanner')
|
|
||||||
@patch('src.core.SeriesApp.SerieList')
|
|
||||||
def test_download_cancellation(
|
|
||||||
self, mock_serie_list, mock_scanner, mock_loaders
|
|
||||||
):
|
|
||||||
"""Test download cancellation during operation."""
|
|
||||||
test_dir = "/test/anime"
|
|
||||||
app = SeriesApp(test_dir)
|
|
||||||
|
|
||||||
# Mock download that raises InterruptedError for cancellation
|
|
||||||
def mock_download_cancelled(*args, **kwargs):
|
|
||||||
# Simulate cancellation by raising InterruptedError
|
|
||||||
raise InterruptedError("Download cancelled by user")
|
|
||||||
|
|
||||||
app.loader.Download = Mock(side_effect=mock_download_cancelled)
|
|
||||||
|
|
||||||
# Set cancel flag before calling (will be reset by download())
|
|
||||||
# but the mock will raise InterruptedError anyway
|
|
||||||
app._cancel_flag = True
|
|
||||||
|
|
||||||
# Perform download - should catch InterruptedError
|
|
||||||
result = app.download(
|
|
||||||
"anime_folder",
|
|
||||||
season=1,
|
|
||||||
episode=1,
|
|
||||||
key="anime_key"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify cancellation was handled
|
|
||||||
assert result.success is False
|
|
||||||
assert "cancelled" in result.message.lower()
|
|
||||||
assert app._current_operation is None
|
|
||||||
|
|
||||||
@patch('src.core.SeriesApp.Loaders')
|
|
||||||
@patch('src.core.SeriesApp.SerieScanner')
|
|
||||||
@patch('src.core.SeriesApp.SerieList')
|
|
||||||
def test_download_failure(
|
|
||||||
self, mock_serie_list, mock_scanner, mock_loaders
|
|
||||||
):
|
|
||||||
"""Test download failure handling."""
|
|
||||||
test_dir = "/test/anime"
|
|
||||||
error_callback = Mock()
|
|
||||||
app = SeriesApp(test_dir, error_callback=error_callback)
|
|
||||||
|
|
||||||
# Make download fail
|
|
||||||
app.loader.Download = Mock(
|
|
||||||
side_effect=RuntimeError("Download failed")
|
|
||||||
)
|
|
||||||
|
|
||||||
# Perform download
|
|
||||||
result = app.download(
|
|
||||||
"anime_folder",
|
|
||||||
season=1,
|
|
||||||
episode=1,
|
|
||||||
key="anime_key"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify failure
|
|
||||||
assert result.success is False
|
|
||||||
assert "failed" in result.message.lower()
|
|
||||||
assert result.error is not None
|
|
||||||
# After failure, finally block resets operation
|
|
||||||
assert app._current_operation is None
|
|
||||||
error_callback.assert_called_once()
|
|
||||||
|
|
||||||
|
|
||||||
class TestSeriesAppReScan:
|
|
||||||
"""Test directory scanning functionality."""
|
|
||||||
|
|
||||||
@patch('src.core.SeriesApp.Loaders')
|
|
||||||
@patch('src.core.SeriesApp.SerieScanner')
|
|
||||||
@patch('src.core.SeriesApp.SerieList')
|
|
||||||
def test_rescan_success(
|
|
||||||
self, mock_serie_list, mock_scanner, mock_loaders
|
|
||||||
):
|
|
||||||
"""Test successful directory rescan."""
|
|
||||||
test_dir = "/test/anime"
|
|
||||||
app = SeriesApp(test_dir)
|
|
||||||
|
|
||||||
# Mock scanner
|
|
||||||
app.SerieScanner.GetTotalToScan = Mock(return_value=5)
|
|
||||||
app.SerieScanner.Reinit = Mock()
|
|
||||||
app.SerieScanner.Scan = Mock()
|
|
||||||
|
|
||||||
# Perform rescan
|
|
||||||
result = app.ReScan()
|
|
||||||
|
|
||||||
# Verify result
|
|
||||||
assert result.success is True
|
|
||||||
assert "completed" in result.message.lower()
|
|
||||||
# After successful completion, finally block resets operation
|
|
||||||
assert app._current_operation is None
|
|
||||||
app.SerieScanner.Reinit.assert_called_once()
|
|
||||||
app.SerieScanner.Scan.assert_called_once()
|
|
||||||
|
|
||||||
@patch('src.core.SeriesApp.Loaders')
|
|
||||||
@patch('src.core.SeriesApp.SerieScanner')
|
|
||||||
@patch('src.core.SeriesApp.SerieList')
|
|
||||||
def test_rescan_with_progress_callback(
|
|
||||||
self, mock_serie_list, mock_scanner, mock_loaders
|
|
||||||
):
|
|
||||||
"""Test rescan with progress callbacks."""
|
|
||||||
test_dir = "/test/anime"
|
|
||||||
progress_callback = Mock()
|
|
||||||
app = SeriesApp(test_dir, progress_callback=progress_callback)
|
|
||||||
|
|
||||||
# Mock scanner
|
|
||||||
app.SerieScanner.GetTotalToScan = Mock(return_value=3)
|
|
||||||
app.SerieScanner.Reinit = Mock()
|
|
||||||
|
|
||||||
def mock_scan(callback):
|
|
||||||
callback("folder1", 1)
|
|
||||||
callback("folder2", 2)
|
|
||||||
callback("folder3", 3)
|
|
||||||
|
|
||||||
app.SerieScanner.Scan = Mock(side_effect=mock_scan)
|
|
||||||
|
|
||||||
# Perform rescan
|
|
||||||
result = app.ReScan()
|
|
||||||
|
|
||||||
# Verify progress callbacks were called
|
|
||||||
assert result.success is True
|
|
||||||
assert progress_callback.call_count == 3
|
|
||||||
|
|
||||||
@patch('src.core.SeriesApp.Loaders')
|
|
||||||
@patch('src.core.SeriesApp.SerieScanner')
|
|
||||||
@patch('src.core.SeriesApp.SerieList')
|
|
||||||
def test_rescan_cancellation(
|
|
||||||
self, mock_serie_list, mock_scanner, mock_loaders
|
|
||||||
):
|
|
||||||
"""Test rescan cancellation."""
|
|
||||||
test_dir = "/test/anime"
|
|
||||||
app = SeriesApp(test_dir)
|
|
||||||
|
|
||||||
# Mock scanner
|
|
||||||
app.SerieScanner.GetTotalToScan = Mock(return_value=3)
|
|
||||||
app.SerieScanner.Reinit = Mock()
|
|
||||||
|
|
||||||
def mock_scan(callback):
|
|
||||||
app._cancel_flag = True
|
|
||||||
callback("folder1", 1)
|
|
||||||
|
|
||||||
app.SerieScanner.Scan = Mock(side_effect=mock_scan)
|
|
||||||
|
|
||||||
# Perform rescan
|
|
||||||
result = app.ReScan()
|
|
||||||
|
|
||||||
# Verify cancellation
|
|
||||||
assert result.success is False
|
|
||||||
assert "cancelled" in result.message.lower()
|
|
||||||
|
|
||||||
|
|
||||||
class TestSeriesAppAsync:
|
|
||||||
"""Test async operations."""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@patch('src.core.SeriesApp.Loaders')
|
|
||||||
@patch('src.core.SeriesApp.SerieScanner')
|
|
||||||
@patch('src.core.SeriesApp.SerieList')
|
|
||||||
async def test_async_download(
|
|
||||||
self, mock_serie_list, mock_scanner, mock_loaders
|
|
||||||
):
|
|
||||||
"""Test async download."""
|
|
||||||
test_dir = "/test/anime"
|
|
||||||
app = SeriesApp(test_dir)
|
|
||||||
|
|
||||||
# Mock download
|
|
||||||
app.loader.Download = Mock()
|
|
||||||
|
|
||||||
# Perform async download
|
|
||||||
result = await app.async_download(
|
|
||||||
"anime_folder",
|
|
||||||
season=1,
|
|
||||||
episode=1,
|
|
||||||
key="anime_key"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify result
|
|
||||||
assert isinstance(result, OperationResult)
|
|
||||||
assert result.success is True
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@patch('src.core.SeriesApp.Loaders')
|
|
||||||
@patch('src.core.SeriesApp.SerieScanner')
|
|
||||||
@patch('src.core.SeriesApp.SerieList')
|
|
||||||
async def test_async_rescan(
|
|
||||||
self, mock_serie_list, mock_scanner, mock_loaders
|
|
||||||
):
|
|
||||||
"""Test async rescan."""
|
|
||||||
test_dir = "/test/anime"
|
|
||||||
app = SeriesApp(test_dir)
|
|
||||||
|
|
||||||
# Mock scanner
|
|
||||||
app.SerieScanner.GetTotalToScan = Mock(return_value=5)
|
|
||||||
app.SerieScanner.Reinit = Mock()
|
|
||||||
app.SerieScanner.Scan = Mock()
|
|
||||||
|
|
||||||
# Perform async rescan
|
|
||||||
result = await app.async_rescan()
|
|
||||||
|
|
||||||
# Verify result
|
|
||||||
assert isinstance(result, OperationResult)
|
|
||||||
assert result.success is True
|
|
||||||
|
|
||||||
|
|
||||||
class TestSeriesAppCancellation:
|
|
||||||
"""Test operation cancellation."""
|
|
||||||
|
|
||||||
@patch('src.core.SeriesApp.Loaders')
|
|
||||||
@patch('src.core.SeriesApp.SerieScanner')
|
|
||||||
@patch('src.core.SeriesApp.SerieList')
|
|
||||||
def test_cancel_operation_when_running(
|
|
||||||
self, mock_serie_list, mock_scanner, mock_loaders
|
|
||||||
):
|
|
||||||
"""Test cancelling a running operation."""
|
|
||||||
test_dir = "/test/anime"
|
|
||||||
app = SeriesApp(test_dir)
|
|
||||||
|
|
||||||
# Set operation as running
|
|
||||||
app._current_operation = "test_operation"
|
|
||||||
app._operation_status = OperationStatus.RUNNING
|
|
||||||
|
|
||||||
# Cancel operation
|
|
||||||
result = app.cancel_operation()
|
|
||||||
|
|
||||||
# Verify cancellation
|
|
||||||
assert result is True
|
|
||||||
assert app._cancel_flag is True
|
|
||||||
|
|
||||||
@patch('src.core.SeriesApp.Loaders')
|
|
||||||
@patch('src.core.SeriesApp.SerieScanner')
|
|
||||||
@patch('src.core.SeriesApp.SerieList')
|
|
||||||
def test_cancel_operation_when_idle(
|
|
||||||
self, mock_serie_list, mock_scanner, mock_loaders
|
|
||||||
):
|
|
||||||
"""Test cancelling when no operation is running."""
|
|
||||||
test_dir = "/test/anime"
|
|
||||||
app = SeriesApp(test_dir)
|
|
||||||
|
|
||||||
# Cancel operation (none running)
|
|
||||||
result = app.cancel_operation()
|
|
||||||
|
|
||||||
# Verify no cancellation occurred
|
|
||||||
assert result is False
|
|
||||||
assert app._cancel_flag is False
|
|
||||||
|
|
||||||
|
|
||||||
class TestSeriesAppGetters:
|
|
||||||
"""Test getter methods."""
|
|
||||||
|
|
||||||
@patch('src.core.SeriesApp.Loaders')
|
|
||||||
@patch('src.core.SeriesApp.SerieScanner')
|
|
||||||
@patch('src.core.SeriesApp.SerieList')
|
|
||||||
def test_get_series_list(
|
|
||||||
self, mock_serie_list, mock_scanner, mock_loaders
|
|
||||||
):
|
|
||||||
"""Test getting series list."""
|
|
||||||
test_dir = "/test/anime"
|
|
||||||
app = SeriesApp(test_dir)
|
|
||||||
|
|
||||||
# Get series list
|
|
||||||
series_list = app.get_series_list()
|
|
||||||
|
|
||||||
# Verify
|
|
||||||
assert series_list is not None
|
|
||||||
|
|
||||||
@patch('src.core.SeriesApp.Loaders')
|
|
||||||
@patch('src.core.SeriesApp.SerieScanner')
|
|
||||||
@patch('src.core.SeriesApp.SerieList')
|
|
||||||
def test_get_operation_status(
|
|
||||||
self, mock_serie_list, mock_scanner, mock_loaders
|
|
||||||
):
|
|
||||||
"""Test getting operation status."""
|
|
||||||
test_dir = "/test/anime"
|
|
||||||
app = SeriesApp(test_dir)
|
|
||||||
|
|
||||||
# Get status
|
|
||||||
status = app.get_operation_status()
|
|
||||||
|
|
||||||
# Verify
|
|
||||||
assert status == OperationStatus.IDLE
|
|
||||||
|
|
||||||
@patch('src.core.SeriesApp.Loaders')
|
|
||||||
@patch('src.core.SeriesApp.SerieScanner')
|
|
||||||
@patch('src.core.SeriesApp.SerieList')
|
|
||||||
def test_get_current_operation(
|
|
||||||
self, mock_serie_list, mock_scanner, mock_loaders
|
|
||||||
):
|
|
||||||
"""Test getting current operation."""
|
|
||||||
test_dir = "/test/anime"
|
|
||||||
app = SeriesApp(test_dir)
|
|
||||||
|
|
||||||
# Get current operation
|
|
||||||
operation = app.get_current_operation()
|
|
||||||
|
|
||||||
# Verify
|
|
||||||
assert operation is None
|
|
||||||
|
|
||||||
# Set an operation
|
|
||||||
app._current_operation = "test_op"
|
|
||||||
operation = app.get_current_operation()
|
|
||||||
assert operation == "test_op"
|
|
||||||
|
|
||||||
|
|
||||||
class TestProgressInfo:
|
|
||||||
"""Test ProgressInfo dataclass."""
|
|
||||||
|
|
||||||
def test_progress_info_creation(self):
|
|
||||||
"""Test creating ProgressInfo."""
|
|
||||||
info = ProgressInfo(
|
|
||||||
current=5,
|
|
||||||
total=10,
|
|
||||||
message="Processing...",
|
|
||||||
percentage=50.0,
|
|
||||||
status=OperationStatus.RUNNING
|
|
||||||
)
|
|
||||||
|
|
||||||
assert info.current == 5
|
|
||||||
assert info.total == 10
|
|
||||||
assert info.message == "Processing..."
|
|
||||||
assert info.percentage == 50.0
|
|
||||||
assert info.status == OperationStatus.RUNNING
|
|
||||||
|
|
||||||
|
|
||||||
class TestOperationResult:
|
|
||||||
"""Test OperationResult dataclass."""
|
|
||||||
|
|
||||||
def test_operation_result_success(self):
|
|
||||||
"""Test creating successful OperationResult."""
|
|
||||||
result = OperationResult(
|
|
||||||
success=True,
|
|
||||||
message="Operation completed",
|
|
||||||
data={"key": "value"}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.success is True
|
|
||||||
assert result.message == "Operation completed"
|
|
||||||
assert result.data == {"key": "value"}
|
|
||||||
assert result.error is None
|
|
||||||
|
|
||||||
def test_operation_result_failure(self):
|
|
||||||
"""Test creating failed OperationResult."""
|
|
||||||
error = RuntimeError("Test error")
|
|
||||||
result = OperationResult(
|
|
||||||
success=False,
|
|
||||||
message="Operation failed",
|
|
||||||
error=error
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.success is False
|
|
||||||
assert result.message == "Operation failed"
|
|
||||||
assert result.error == error
|
|
||||||
assert result.data is None
|
|
||||||
@ -1,243 +0,0 @@
|
|||||||
"""
|
|
||||||
Tests for static file serving (CSS, JS).
|
|
||||||
|
|
||||||
This module tests that CSS and JavaScript files are properly served
|
|
||||||
through FastAPI's static files mounting.
|
|
||||||
"""
|
|
||||||
import pytest
|
|
||||||
from httpx import ASGITransport, AsyncClient
|
|
||||||
|
|
||||||
from src.server.fastapi_app import app
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def client():
|
|
||||||
"""Create an async test client for the FastAPI app."""
|
|
||||||
transport = ASGITransport(app=app)
|
|
||||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
|
||||||
yield ac
|
|
||||||
|
|
||||||
|
|
||||||
class TestCSSFileServing:
|
|
||||||
"""Test CSS file serving functionality."""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_styles_css_accessible(self, client):
|
|
||||||
"""Test that styles.css is accessible."""
|
|
||||||
response = await client.get("/static/css/styles.css")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert "text/css" in response.headers.get("content-type", "")
|
|
||||||
assert len(response.text) > 0
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_ux_features_css_accessible(self, client):
|
|
||||||
"""Test that ux_features.css is accessible."""
|
|
||||||
response = await client.get("/static/css/ux_features.css")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert "text/css" in response.headers.get("content-type", "")
|
|
||||||
assert len(response.text) > 0
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_css_contains_expected_variables(self, client):
|
|
||||||
"""Test that styles.css contains expected CSS variables."""
|
|
||||||
response = await client.get("/static/css/styles.css")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
content = response.text
|
|
||||||
|
|
||||||
# Check for Fluent UI design system variables
|
|
||||||
assert "--color-bg-primary:" in content
|
|
||||||
assert "--color-accent:" in content
|
|
||||||
assert "--font-family:" in content
|
|
||||||
assert "--spacing-" in content
|
|
||||||
assert "--border-radius-" in content
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_css_contains_dark_theme_support(self, client):
|
|
||||||
"""Test that styles.css contains dark theme support."""
|
|
||||||
response = await client.get("/static/css/styles.css")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
content = response.text
|
|
||||||
|
|
||||||
# Check for dark theme variables
|
|
||||||
assert '[data-theme="dark"]' in content
|
|
||||||
assert "--color-bg-primary-dark:" in content
|
|
||||||
assert "--color-text-primary-dark:" in content
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_css_contains_responsive_design(self, client):
|
|
||||||
"""Test that CSS files contain responsive design media queries."""
|
|
||||||
# Test styles.css
|
|
||||||
response = await client.get("/static/css/styles.css")
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert "@media" in response.text
|
|
||||||
|
|
||||||
# Test ux_features.css
|
|
||||||
response = await client.get("/static/css/ux_features.css")
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert "@media" in response.text
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_ux_features_css_contains_accessibility(self, client):
|
|
||||||
"""Test that ux_features.css contains accessibility features."""
|
|
||||||
response = await client.get("/static/css/ux_features.css")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
content = response.text
|
|
||||||
|
|
||||||
# Check for accessibility features
|
|
||||||
assert ".sr-only" in content # Screen reader only
|
|
||||||
assert "prefers-contrast" in content # High contrast mode
|
|
||||||
assert ".keyboard-focus" in content # Keyboard navigation
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_nonexistent_css_returns_404(self, client):
|
|
||||||
"""Test that requesting a nonexistent CSS file returns 404."""
|
|
||||||
response = await client.get("/static/css/nonexistent.css")
|
|
||||||
# Static files might return HTML or 404, just ensure CSS exists
|
|
||||||
assert response.status_code in [200, 404]
|
|
||||||
|
|
||||||
|
|
||||||
class TestJavaScriptFileServing:
|
|
||||||
"""Test JavaScript file serving functionality."""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_app_js_accessible(self, client):
|
|
||||||
"""Test that app.js is accessible."""
|
|
||||||
response = await client.get("/static/js/app.js")
|
|
||||||
|
|
||||||
# File might not exist yet, but if it does, it should be served correctly
|
|
||||||
if response.status_code == 200:
|
|
||||||
assert "javascript" in response.headers.get("content-type", "").lower()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_websocket_client_js_accessible(self, client):
|
|
||||||
"""Test that websocket_client.js is accessible."""
|
|
||||||
response = await client.get("/static/js/websocket_client.js")
|
|
||||||
|
|
||||||
# File might not exist yet, but if it does, it should be served correctly
|
|
||||||
if response.status_code == 200:
|
|
||||||
assert "javascript" in response.headers.get("content-type", "").lower()
|
|
||||||
|
|
||||||
|
|
||||||
class TestHTMLTemplatesCSS:
|
|
||||||
"""Test that HTML templates correctly reference CSS files."""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_index_page_references_css(self, client):
|
|
||||||
"""Test that index.html correctly references CSS files."""
|
|
||||||
response = await client.get("/")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
content = response.text
|
|
||||||
|
|
||||||
# Check for CSS references
|
|
||||||
assert '/static/css/styles.css' in content
|
|
||||||
assert '/static/css/ux_features.css' in content
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_login_page_references_css(self, client):
|
|
||||||
"""Test that login.html correctly references CSS files."""
|
|
||||||
response = await client.get("/login")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
content = response.text
|
|
||||||
|
|
||||||
# Check for CSS reference
|
|
||||||
assert '/static/css/styles.css' in content
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_setup_page_references_css(self, client):
|
|
||||||
"""Test that setup.html correctly references CSS files."""
|
|
||||||
response = await client.get("/setup")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
content = response.text
|
|
||||||
|
|
||||||
# Check for CSS reference
|
|
||||||
assert '/static/css/styles.css' in content
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_queue_page_references_css(self, client):
|
|
||||||
"""Test that queue.html correctly references CSS files."""
|
|
||||||
response = await client.get("/queue")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
content = response.text
|
|
||||||
|
|
||||||
# Check for CSS reference
|
|
||||||
assert '/static/css/styles.css' in content
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_css_paths_are_absolute(self, client):
|
|
||||||
"""Test that CSS paths in templates are absolute paths."""
|
|
||||||
pages = ["/", "/login", "/setup", "/queue"]
|
|
||||||
|
|
||||||
for page in pages:
|
|
||||||
response = await client.get(page)
|
|
||||||
assert response.status_code == 200
|
|
||||||
content = response.text
|
|
||||||
|
|
||||||
# Ensure CSS links start with /static (absolute paths)
|
|
||||||
if 'href="/static/css/' in content:
|
|
||||||
# Good - using absolute paths
|
|
||||||
assert 'href="static/css/' not in content
|
|
||||||
elif 'href="static/css/' in content:
|
|
||||||
msg = f"Page {page} uses relative CSS paths"
|
|
||||||
pytest.fail(msg)
|
|
||||||
|
|
||||||
|
|
||||||
class TestCSSContentIntegrity:
|
|
||||||
"""Test CSS content integrity and structure."""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_styles_css_structure(self, client):
|
|
||||||
"""Test that styles.css has proper structure."""
|
|
||||||
response = await client.get("/static/css/styles.css")
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
content = response.text
|
|
||||||
|
|
||||||
# Should have CSS variable definitions
|
|
||||||
assert ":root" in content
|
|
||||||
|
|
||||||
# Should have base element styles
|
|
||||||
assert "body" in content or "html" in content
|
|
||||||
|
|
||||||
# Should not have syntax errors (basic check)
|
|
||||||
# Count braces - should be balanced
|
|
||||||
open_braces = content.count("{")
|
|
||||||
close_braces = content.count("}")
|
|
||||||
assert open_braces == close_braces, "CSS has unbalanced braces"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_ux_features_css_structure(self, client):
|
|
||||||
"""Test that ux_features.css has proper structure."""
|
|
||||||
response = await client.get("/static/css/ux_features.css")
|
|
||||||
assert response.status_code == 200
|
|
||||||
|
|
||||||
content = response.text
|
|
||||||
|
|
||||||
# Should not have syntax errors (basic check)
|
|
||||||
open_braces = content.count("{")
|
|
||||||
close_braces = content.count("}")
|
|
||||||
assert open_braces == close_braces, "CSS has unbalanced braces"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_css_file_sizes_reasonable(self, client):
|
|
||||||
"""Test that CSS files are not empty and have reasonable sizes."""
|
|
||||||
# Test styles.css
|
|
||||||
response = await client.get("/static/css/styles.css")
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert len(response.text) > 1000, "styles.css seems too small"
|
|
||||||
assert len(response.text) < 500000, "styles.css seems unusually large"
|
|
||||||
|
|
||||||
# Test ux_features.css
|
|
||||||
response = await client.get("/static/css/ux_features.css")
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert len(response.text) > 100, "ux_features.css seems too small"
|
|
||||||
msg = "ux_features.css seems unusually large"
|
|
||||||
assert len(response.text) < 100000, msg
|
|
||||||
Loading…
x
Reference in New Issue
Block a user