Compare commits
17 Commits
9096afbace
...
8f7c489bd2
| Author | SHA1 | Date | |
|---|---|---|---|
| 8f7c489bd2 | |||
| 99e24a2fc3 | |||
| 043d8a2877 | |||
| 71207bc935 | |||
| 8c8853d26e | |||
| 94de91ffa0 | |||
| 42a07be4cb | |||
| 577c55f32a | |||
| 028d91283e | |||
| 1ba4336291 | |||
| d0f63063ca | |||
| 9323eb6371 | |||
| 3ffab4e70a | |||
| 5b80824f3a | |||
| 6b979eb57a | |||
| 52b96da8dc | |||
| 4aa7adba3a |
338
FRONTEND_INTEGRATION.md
Normal file
338
FRONTEND_INTEGRATION.md
Normal file
@ -0,0 +1,338 @@
|
||||
# Frontend Integration Changes
|
||||
|
||||
## Overview
|
||||
|
||||
This document details the changes made to integrate the existing frontend JavaScript with the new FastAPI backend and native WebSocket implementation.
|
||||
|
||||
## Key Changes
|
||||
|
||||
### 1. WebSocket Migration (Socket.IO → Native WebSocket)
|
||||
|
||||
**Files Created:**
|
||||
|
||||
- `src/server/web/static/js/websocket_client.js` - Native WebSocket wrapper with Socket.IO-compatible interface
|
||||
|
||||
**Files Modified:**
|
||||
|
||||
- `src/server/web/templates/index.html` - Replace Socket.IO CDN with websocket_client.js
|
||||
- `src/server/web/templates/queue.html` - Replace Socket.IO CDN with websocket_client.js
|
||||
|
||||
**Migration Details:**
|
||||
|
||||
- Created `WebSocketClient` class that provides Socket.IO-style `.on()` and `.emit()` methods
|
||||
- Automatic reconnection with exponential backoff
|
||||
- Room-based subscriptions (join/leave rooms for topic filtering)
|
||||
- Message queueing during disconnection
|
||||
- Native WebSocket URL: `ws://host:port/ws/connect` (or `wss://` for HTTPS)
|
||||
|
||||
### 2. WebSocket Message Format Changes
|
||||
|
||||
**Old Format (Socket.IO custom events):**
|
||||
|
||||
```javascript
|
||||
socket.on('download_progress', (data) => { ... });
|
||||
// data was sent directly
|
||||
```
|
||||
|
||||
**New Format (Structured messages):**
|
||||
|
||||
```javascript
|
||||
{
|
||||
"type": "download_progress",
|
||||
"timestamp": "2025-10-17T12:34:56.789Z",
|
||||
"data": {
|
||||
// Message payload
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Event Mapping:**
|
||||
|
||||
| Old Socket.IO Event | New WebSocket Type | Room | Notes |
|
||||
| ----------------------- | ------------------- | ------------------- | -------------------------- |
|
||||
| `scan_progress` | `scan_progress` | `scan_progress` | Scan updates |
|
||||
| `scan_completed` | `scan_complete` | `scan_progress` | Scan finished |
|
||||
| `scan_error` | `scan_failed` | `scan_progress` | Scan error |
|
||||
| `download_progress` | `download_progress` | `download_progress` | Real-time download updates |
|
||||
| `download_completed` | `download_complete` | `downloads` | Single download finished |
|
||||
| `download_error` | `download_failed` | `downloads` | Download failed |
|
||||
| `download_queue_update` | `queue_status` | `downloads` | Queue state changes |
|
||||
| `queue_started` | `queue_started` | `downloads` | Queue processing started |
|
||||
| `queue_stopped` | `queue_stopped` | `downloads` | Queue processing stopped |
|
||||
| `queue_paused` | `queue_paused` | `downloads` | Queue paused |
|
||||
| `queue_resumed` | `queue_resumed` | `downloads` | Queue resumed |
|
||||
|
||||
### 3. API Endpoint Changes
|
||||
|
||||
**Authentication Endpoints:**
|
||||
|
||||
- ✅ `/api/auth/status` - Check auth status (GET)
|
||||
- ✅ `/api/auth/login` - Login (POST)
|
||||
- ✅ `/api/auth/logout` - Logout (POST)
|
||||
- ✅ `/api/auth/setup` - Initial setup (POST)
|
||||
|
||||
**Anime Endpoints:**
|
||||
|
||||
- ✅ `/api/v1/anime` - List anime with missing episodes (GET)
|
||||
- ✅ `/api/v1/anime/rescan` - Trigger rescan (POST)
|
||||
- ✅ `/api/v1/anime/search` - Search for anime (POST)
|
||||
- ✅ `/api/v1/anime/{anime_id}` - Get anime details (GET)
|
||||
|
||||
**Download Queue Endpoints:**
|
||||
|
||||
- ✅ `/api/queue/status` - Get queue status (GET)
|
||||
- ✅ `/api/queue/add` - Add to queue (POST)
|
||||
- ✅ `/api/queue/{item_id}` - Remove single item (DELETE)
|
||||
- ✅ `/api/queue/` - Remove multiple items (DELETE)
|
||||
- ✅ `/api/queue/start` - Start queue (POST)
|
||||
- ✅ `/api/queue/stop` - Stop queue (POST)
|
||||
- ✅ `/api/queue/pause` - Pause queue (POST)
|
||||
- ✅ `/api/queue/resume` - Resume queue (POST)
|
||||
- ✅ `/api/queue/reorder` - Reorder queue (POST)
|
||||
- ✅ `/api/queue/completed` - Clear completed (DELETE)
|
||||
- ✅ `/api/queue/retry` - Retry failed (POST)
|
||||
|
||||
**WebSocket Endpoint:**
|
||||
|
||||
- ✅ `/ws/connect` - WebSocket connection (WebSocket)
|
||||
- ✅ `/ws/status` - WebSocket status (GET)
|
||||
|
||||
### 4. Required JavaScript Updates
|
||||
|
||||
**app.js Changes Needed:**
|
||||
|
||||
1. **WebSocket Initialization** - Add room subscriptions:
|
||||
|
||||
```javascript
|
||||
initSocket() {
|
||||
this.socket = io();
|
||||
|
||||
// Subscribe to relevant rooms after connection
|
||||
this.socket.on('connected', () => {
|
||||
this.socket.join('scan_progress');
|
||||
this.socket.join('download_progress');
|
||||
this.socket.join('downloads');
|
||||
this.isConnected = true;
|
||||
// ... rest of connect handler
|
||||
});
|
||||
|
||||
// ... rest of event handlers
|
||||
}
|
||||
```
|
||||
|
||||
2. **Event Handler Updates** - Map new message types:
|
||||
|
||||
- `scan_completed` → `scan_complete`
|
||||
- `scan_error` → `scan_failed`
|
||||
- Legacy events that are no longer sent need to be handled differently or removed
|
||||
|
||||
3. **API Call Updates** - Already correct:
|
||||
|
||||
- `/api/v1/anime` for anime list ✅
|
||||
- `/api/auth/*` for authentication ✅
|
||||
|
||||
**queue.js Changes Needed:**
|
||||
|
||||
1. **WebSocket Initialization** - Add room subscriptions:
|
||||
|
||||
```javascript
|
||||
initSocket() {
|
||||
this.socket = io();
|
||||
|
||||
this.socket.on('connected', () => {
|
||||
this.socket.join('downloads');
|
||||
this.socket.join('download_progress');
|
||||
// ... rest of connect handler
|
||||
});
|
||||
|
||||
// ... rest of event handlers
|
||||
}
|
||||
```
|
||||
|
||||
2. **API Calls** - Already mostly correct:
|
||||
|
||||
- `/api/queue/status` ✅
|
||||
- `/api/queue/*` operations ✅
|
||||
|
||||
3. **Event Handlers** - Map to new types:
|
||||
|
||||
- `queue_updated` → `queue_status`
|
||||
- `download_progress_update` → `download_progress`
|
||||
|
||||
### 5. Authentication Flow
|
||||
|
||||
**Current Implementation:**
|
||||
|
||||
- JWT tokens stored in localStorage (via auth service)
|
||||
- Tokens included in Authorization header for API requests
|
||||
- WebSocket connections can optionally authenticate (user_id in session)
|
||||
|
||||
**JavaScript Implementation Needed:**
|
||||
Add helper for authenticated requests:
|
||||
|
||||
```javascript
|
||||
async makeAuthenticatedRequest(url, options = {}) {
|
||||
const token = localStorage.getItem('auth_token');
|
||||
|
||||
if (!token) {
|
||||
window.location.href = '/login';
|
||||
return null;
|
||||
}
|
||||
|
||||
const headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': `Bearer ${token}`,
|
||||
...options.headers
|
||||
};
|
||||
|
||||
const response = await fetch(url, { ...options, headers });
|
||||
|
||||
if (response.status === 401) {
|
||||
// Token expired or invalid
|
||||
localStorage.removeItem('auth_token');
|
||||
window.location.href = '/login';
|
||||
return null;
|
||||
}
|
||||
|
||||
return response;
|
||||
}
|
||||
```
|
||||
|
||||
### 6. Backend Router Registration
|
||||
|
||||
**Fixed in fastapi_app.py:**
|
||||
|
||||
- ✅ Added `anime_router` import
|
||||
- ✅ Registered `app.include_router(anime_router)`
|
||||
|
||||
All routers now properly registered:
|
||||
|
||||
- health_router
|
||||
- page_router
|
||||
- auth_router
|
||||
- anime_router ⭐ (newly added)
|
||||
- download_router
|
||||
- websocket_router
|
||||
|
||||
## Implementation Status
|
||||
|
||||
### ✅ Completed
|
||||
|
||||
1. Created native WebSocket client wrapper
|
||||
2. Updated HTML templates to use new WebSocket client
|
||||
3. Registered anime router in FastAPI app
|
||||
4. Documented API endpoint mappings
|
||||
5. Documented WebSocket message format changes
|
||||
|
||||
### 🔄 In Progress
|
||||
|
||||
1. Update app.js WebSocket initialization and room subscriptions
|
||||
2. Update app.js event handlers for new message types
|
||||
3. Update queue.js WebSocket initialization and room subscriptions
|
||||
4. Update queue.js event handlers for new message types
|
||||
|
||||
### ⏳ Pending
|
||||
|
||||
1. Add authentication token handling to all API requests
|
||||
2. Test complete workflow (auth → scan → download)
|
||||
3. Update other JavaScript modules if they use WebSocket/API
|
||||
4. Integration tests for frontend-backend communication
|
||||
5. Update infrastructure.md documentation
|
||||
|
||||
## Testing Plan
|
||||
|
||||
1. **Authentication Flow:**
|
||||
|
||||
- Test setup page → creates master password
|
||||
- Test login page → authenticates with master password
|
||||
- Test logout → clears session
|
||||
- Test protected pages redirect to login
|
||||
|
||||
2. **Anime Management:**
|
||||
|
||||
- Test loading anime list
|
||||
- Test rescan functionality with progress updates
|
||||
- Test search functionality
|
||||
|
||||
3. **Download Queue:**
|
||||
|
||||
- Test adding items to queue
|
||||
- Test queue operations (start, stop, pause, resume)
|
||||
- Test progress updates via WebSocket
|
||||
- Test retry and clear operations
|
||||
|
||||
4. **WebSocket Communication:**
|
||||
- Test connection/reconnection
|
||||
- Test room subscriptions
|
||||
- Test message routing to correct handlers
|
||||
- Test disconnect handling
|
||||
|
||||
## Known Issues & Limitations
|
||||
|
||||
1. **Legacy Events:** Some Socket.IO events in app.js don't have backend equivalents:
|
||||
|
||||
- `scheduled_rescan_*` events
|
||||
- `auto_download_*` events
|
||||
- `download_episode_update` event
|
||||
- `download_series_completed` event
|
||||
|
||||
**Solution:** Either remove these handlers or implement corresponding backend events
|
||||
|
||||
2. **Configuration Endpoints:** Many config-related API calls in app.js don't have backend implementations:
|
||||
|
||||
- Scheduler configuration
|
||||
- Logging configuration
|
||||
- Advanced configuration
|
||||
- Config backups
|
||||
|
||||
**Solution:** Implement these endpoints or remove the UI features
|
||||
|
||||
3. **Process Status Monitoring:** `checkProcessLocks()` method may not work with new backend
|
||||
|
||||
**Solution:** Implement equivalent status endpoint or remove feature
|
||||
|
||||
## Migration Guide for Developers
|
||||
|
||||
### Adding New WebSocket Events
|
||||
|
||||
1. Define message type in `src/server/models/websocket.py`:
|
||||
|
||||
```python
|
||||
class WebSocketMessageType(str, Enum):
|
||||
MY_NEW_EVENT = "my_new_event"
|
||||
```
|
||||
|
||||
2. Broadcast from service:
|
||||
|
||||
```python
|
||||
await ws_service.broadcast_to_room(
|
||||
{"type": "my_new_event", "data": {...}},
|
||||
"my_room"
|
||||
)
|
||||
```
|
||||
|
||||
3. Subscribe and handle in JavaScript:
|
||||
|
||||
```javascript
|
||||
this.socket.join("my_room");
|
||||
this.socket.on("my_new_event", (data) => {
|
||||
// Handle event
|
||||
});
|
||||
```
|
||||
|
||||
### Adding New API Endpoints
|
||||
|
||||
1. Define Pydantic models in `src/server/models/`
|
||||
2. Create endpoint in appropriate router file in `src/server/api/`
|
||||
3. Add endpoint to this documentation
|
||||
4. Update JavaScript to call new endpoint
|
||||
|
||||
## References
|
||||
|
||||
- FastAPI Application: `src/server/fastapi_app.py`
|
||||
- WebSocket Service: `src/server/services/websocket_service.py`
|
||||
- WebSocket Models: `src/server/models/websocket.py`
|
||||
- Download Service: `src/server/services/download_service.py`
|
||||
- Anime Service: `src/server/services/anime_service.py`
|
||||
- Progress Service: `src/server/services/progress_service.py`
|
||||
- Infrastructure Doc: `infrastructure.md`
|
||||
File diff suppressed because it is too large
Load Diff
123
instructions.md
123
instructions.md
@ -38,127 +38,13 @@ The tasks should be completed in the following order to ensure proper dependenci
|
||||
2. Process the task
|
||||
3. Make Tests.
|
||||
4. Remove task from instructions.md.
|
||||
5. Update infrastructure.md, but only add text that belongs to a infrastructure doc.
|
||||
5. Update infrastructure.md, but only add text that belongs to a infrastructure doc. make sure to summarize text or delete text that do not belog to infrastructure.md. Keep it clear and short.
|
||||
6. Commit in git
|
||||
|
||||
## Core Tasks
|
||||
|
||||
### 3. Configuration Management
|
||||
|
||||
#### [] Implement configuration models
|
||||
|
||||
- []Create `src/server/models/config.py`
|
||||
- []Define ConfigResponse, ConfigUpdate models
|
||||
- []Add SchedulerConfig, LoggingConfig models
|
||||
- []Include ValidationResult model
|
||||
|
||||
#### [] Create configuration service
|
||||
|
||||
- []Create `src/server/services/config_service.py`
|
||||
- []Implement configuration loading/saving
|
||||
- []Add configuration validation
|
||||
- []Include backup/restore functionality
|
||||
- []Add scheduler configuration management
|
||||
|
||||
#### [] Implement configuration API endpoints
|
||||
|
||||
- []Create `src/server/api/config.py`
|
||||
- []Add GET `/api/config` - get configuration
|
||||
- []Add PUT `/api/config` - update configuration
|
||||
- []Add POST `/api/config/validate` - validate config
|
||||
|
||||
### 4. Anime Management Integration
|
||||
|
||||
#### [] Implement anime models
|
||||
|
||||
- []Create `src/server/models/anime.py`
|
||||
- []Define AnimeSeriesResponse, EpisodeInfo models
|
||||
- []Add SearchRequest, SearchResult models
|
||||
- []Include MissingEpisodeInfo model
|
||||
|
||||
#### [] Create anime service wrapper
|
||||
|
||||
- []Create `src/server/services/anime_service.py`
|
||||
- []Wrap SeriesApp functionality for web layer
|
||||
- []Implement async wrappers for blocking operations
|
||||
- []Add caching for frequently accessed data
|
||||
- []Include error handling and logging
|
||||
|
||||
#### [] Implement anime API endpoints
|
||||
|
||||
- []Create `src/server/api/anime.py`
|
||||
- []Add GET `/api/v1/anime` - list series with missing episodes
|
||||
- []Add POST `/api/v1/anime/rescan` - trigger rescan
|
||||
- []Add POST `/api/v1/anime/search` - search for new anime
|
||||
- []Add GET `/api/v1/anime/{id}` - get series details
|
||||
|
||||
### 5. Download Queue Management
|
||||
|
||||
#### [] Implement download queue models
|
||||
|
||||
- []Create `src/server/models/download.py`
|
||||
- []Define DownloadItem, QueueStatus models
|
||||
- []Add DownloadProgress, QueueStats models
|
||||
- []Include DownloadRequest model
|
||||
|
||||
#### [] Create download queue service
|
||||
|
||||
- []Create `src/server/services/download_service.py`
|
||||
- []Implement queue management (add, remove, reorder)
|
||||
- []Add download progress tracking
|
||||
- []Include queue persistence and recovery
|
||||
- []Add concurrent download management
|
||||
|
||||
#### [] Implement download API endpoints
|
||||
|
||||
- []Create `src/server/api/download.py`
|
||||
- []Add GET `/api/queue/status` - get queue status
|
||||
- []Add POST `/api/queue/add` - add to queue
|
||||
- []Add DELETE `/api/queue/{id}` - remove from queue
|
||||
- []Add POST `/api/queue/start` - start downloads
|
||||
- []Add POST `/api/queue/stop` - stop downloads
|
||||
|
||||
### 6. WebSocket Real-time Updates
|
||||
|
||||
#### [] Implement WebSocket manager
|
||||
|
||||
- []Create `src/server/services/websocket_service.py`
|
||||
- []Add connection management
|
||||
- []Implement broadcast functionality
|
||||
- []Include room-based messaging
|
||||
- []Add connection cleanup
|
||||
|
||||
#### [] Add real-time progress updates
|
||||
|
||||
- []Create `src/server/services/progress_service.py`
|
||||
- []Implement download progress broadcasting
|
||||
- []Add scan progress updates
|
||||
- []Include queue status changes
|
||||
- []Add error notifications
|
||||
|
||||
#### [] Integrate WebSocket with core services
|
||||
|
||||
- []Update download service to emit progress
|
||||
- []Add scan progress notifications
|
||||
- []Include queue change broadcasts
|
||||
- []Add error/completion notifications
|
||||
|
||||
### 7. Frontend Integration
|
||||
|
||||
#### [] Integrate existing HTML templates
|
||||
|
||||
- []Review and integrate existing HTML templates in `src/server/web/templates/`
|
||||
- []Ensure templates work with FastAPI Jinja2 setup
|
||||
- []Update template paths and static file references if needed
|
||||
- []Maintain existing responsive layout and theme switching
|
||||
|
||||
#### [] Integrate existing JavaScript functionality
|
||||
|
||||
- []Review existing JavaScript files in `src/server/web/static/js/`
|
||||
- []Update API endpoint URLs to match FastAPI routes
|
||||
- []Ensure WebSocket connections work with new backend
|
||||
- []Maintain existing functionality for app.js and queue.js
|
||||
|
||||
#### [] Integrate existing CSS styling
|
||||
|
||||
- []Review and integrate existing CSS files in `src/server/web/static/css/`
|
||||
@ -251,13 +137,6 @@ The tasks should be completed in the following order to ensure proper dependenci
|
||||
|
||||
### 11. Deployment and Configuration
|
||||
|
||||
#### [] Create Docker configuration
|
||||
|
||||
- []Create `Dockerfile`
|
||||
- []Create `docker-compose.yml`
|
||||
- []Add environment configuration
|
||||
- []Include volume mappings for existing web assets
|
||||
|
||||
#### [] Create production configuration
|
||||
|
||||
- []Create `src/server/config/production.py`
|
||||
|
||||
@ -8,6 +8,7 @@ python-jose[cryptography]==3.3.0
|
||||
passlib[bcrypt]==1.7.4
|
||||
aiofiles==23.2.1
|
||||
websockets==12.0
|
||||
structlog==24.1.0
|
||||
pytest==7.4.3
|
||||
pytest-asyncio==0.21.1
|
||||
httpx==0.25.2
|
||||
117
src/server/api/anime.py
Normal file
117
src/server/api/anime.py
Normal file
@ -0,0 +1,117 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel
|
||||
|
||||
from src.server.utils.dependencies import get_series_app
|
||||
|
||||
router = APIRouter(prefix="/api/v1/anime", tags=["anime"])
|
||||
|
||||
|
||||
class AnimeSummary(BaseModel):
|
||||
id: str
|
||||
title: str
|
||||
missing_episodes: int
|
||||
|
||||
|
||||
class AnimeDetail(BaseModel):
|
||||
id: str
|
||||
title: str
|
||||
episodes: List[str]
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
@router.get("/", response_model=List[AnimeSummary])
|
||||
async def list_anime(series_app=Depends(get_series_app)):
|
||||
"""List series with missing episodes using the core SeriesApp."""
|
||||
try:
|
||||
series = series_app.List.GetMissingEpisode()
|
||||
result = []
|
||||
for s in series:
|
||||
missing = 0
|
||||
try:
|
||||
missing = len(s.episodeDict) if getattr(s, "episodeDict", None) is not None else 0
|
||||
except Exception:
|
||||
missing = 0
|
||||
result.append(AnimeSummary(id=getattr(s, "key", getattr(s, "folder", "")), title=getattr(s, "name", ""), missing_episodes=missing))
|
||||
return result
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to retrieve anime list")
|
||||
|
||||
|
||||
@router.post("/rescan")
|
||||
async def trigger_rescan(series_app=Depends(get_series_app)):
|
||||
"""Trigger a rescan of local series data using SeriesApp.ReScan."""
|
||||
try:
|
||||
# SeriesApp.ReScan expects a callback; pass a no-op
|
||||
if hasattr(series_app, "ReScan"):
|
||||
series_app.ReScan(lambda *args, **kwargs: None)
|
||||
return {"success": True, "message": "Rescan started"}
|
||||
else:
|
||||
raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED, detail="Rescan not available")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to start rescan")
|
||||
|
||||
|
||||
class SearchRequest(BaseModel):
|
||||
query: str
|
||||
|
||||
|
||||
@router.post("/search", response_model=List[AnimeSummary])
|
||||
async def search_anime(request: SearchRequest, series_app=Depends(get_series_app)):
|
||||
"""Search for new anime by query text using the SeriesApp loader."""
|
||||
try:
|
||||
matches = []
|
||||
if hasattr(series_app, "search"):
|
||||
# SeriesApp.search is synchronous in core; call directly
|
||||
matches = series_app.search(request.query)
|
||||
|
||||
result = []
|
||||
for m in matches:
|
||||
# matches may be dicts or objects
|
||||
if isinstance(m, dict):
|
||||
mid = m.get("key") or m.get("id") or ""
|
||||
title = m.get("title") or m.get("name") or ""
|
||||
missing = int(m.get("missing", 0)) if m.get("missing") is not None else 0
|
||||
else:
|
||||
mid = getattr(m, "key", getattr(m, "id", ""))
|
||||
title = getattr(m, "title", getattr(m, "name", ""))
|
||||
missing = int(getattr(m, "missing", 0))
|
||||
result.append(AnimeSummary(id=mid, title=title, missing_episodes=missing))
|
||||
|
||||
return result
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Search failed")
|
||||
|
||||
|
||||
@router.get("/{anime_id}", response_model=AnimeDetail)
|
||||
async def get_anime(anime_id: str, series_app=Depends(get_series_app)):
|
||||
"""Return detailed info about a series from SeriesApp.List."""
|
||||
try:
|
||||
series = series_app.List.GetList()
|
||||
found = None
|
||||
for s in series:
|
||||
if getattr(s, "key", None) == anime_id or getattr(s, "folder", None) == anime_id:
|
||||
found = s
|
||||
break
|
||||
|
||||
if not found:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Series not found")
|
||||
|
||||
episodes = []
|
||||
epdict = getattr(found, "episodeDict", {}) or {}
|
||||
for season, eps in epdict.items():
|
||||
for e in eps:
|
||||
episodes.append(f"{season}-{e}")
|
||||
|
||||
return AnimeDetail(id=getattr(found, "key", getattr(found, "folder", "")), title=getattr(found, "name", ""), episodes=episodes, description=getattr(found, "description", None))
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to retrieve series details")
|
||||
68
src/server/api/config.py
Normal file
68
src/server/api/config.py
Normal file
@ -0,0 +1,68 @@
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
|
||||
from src.config.settings import settings
|
||||
from src.server.models.config import AppConfig, ConfigUpdate, ValidationResult
|
||||
from src.server.utils.dependencies import require_auth
|
||||
|
||||
router = APIRouter(prefix="/api/config", tags=["config"])
|
||||
|
||||
|
||||
@router.get("", response_model=AppConfig)
|
||||
def get_config(auth: Optional[dict] = Depends(require_auth)) -> AppConfig:
|
||||
"""Return current application configuration (read-only)."""
|
||||
# Construct AppConfig from pydantic-settings where possible
|
||||
cfg_data = {
|
||||
"name": getattr(settings, "app_name", "Aniworld"),
|
||||
"data_dir": getattr(settings, "data_dir", "data"),
|
||||
"scheduler": getattr(settings, "scheduler", {}),
|
||||
"logging": getattr(settings, "logging", {}),
|
||||
"backup": getattr(settings, "backup", {}),
|
||||
"other": getattr(settings, "other", {}),
|
||||
}
|
||||
try:
|
||||
return AppConfig(**cfg_data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to read config: {e}")
|
||||
|
||||
|
||||
@router.put("", response_model=AppConfig)
|
||||
def update_config(update: ConfigUpdate, auth: dict = Depends(require_auth)) -> AppConfig:
|
||||
"""Apply an update to the configuration and return the new config.
|
||||
|
||||
Note: persistence strategy for settings is out-of-scope for this task.
|
||||
This endpoint updates the in-memory Settings where possible and returns
|
||||
the merged result as an AppConfig.
|
||||
"""
|
||||
# Build current AppConfig from settings then apply update
|
||||
current = get_config(auth)
|
||||
new_cfg = update.apply_to(current)
|
||||
|
||||
# Mirror some fields back into pydantic-settings 'settings' where safe.
|
||||
# Avoid writing secrets or unsupported fields.
|
||||
try:
|
||||
if new_cfg.data_dir:
|
||||
setattr(settings, "data_dir", new_cfg.data_dir)
|
||||
# scheduler/logging/backup/other kept in memory only for now
|
||||
setattr(settings, "scheduler", new_cfg.scheduler.model_dump())
|
||||
setattr(settings, "logging", new_cfg.logging.model_dump())
|
||||
setattr(settings, "backup", new_cfg.backup.model_dump())
|
||||
setattr(settings, "other", new_cfg.other)
|
||||
except Exception:
|
||||
# Best-effort; do not fail the request if persistence is not available
|
||||
pass
|
||||
|
||||
return new_cfg
|
||||
|
||||
|
||||
@router.post("/validate", response_model=ValidationResult)
|
||||
def validate_config(cfg: AppConfig, auth: dict = Depends(require_auth)) -> ValidationResult:
|
||||
"""Validate a provided AppConfig without applying it.
|
||||
|
||||
Returns ValidationResult with any validation errors.
|
||||
"""
|
||||
try:
|
||||
return cfg.validate()
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
474
src/server/api/download.py
Normal file
474
src/server/api/download.py
Normal file
@ -0,0 +1,474 @@
|
||||
"""Download queue API endpoints for Aniworld web application.
|
||||
|
||||
This module provides REST API endpoints for managing the anime download queue,
|
||||
including adding episodes, removing items, controlling queue processing, and
|
||||
retrieving queue status and statistics.
|
||||
"""
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path, status
|
||||
|
||||
from src.server.models.download import (
|
||||
DownloadRequest,
|
||||
DownloadResponse,
|
||||
QueueOperationRequest,
|
||||
QueueReorderRequest,
|
||||
QueueStatusResponse,
|
||||
)
|
||||
from src.server.services.download_service import DownloadService, DownloadServiceError
|
||||
from src.server.utils.dependencies import get_download_service, require_auth
|
||||
|
||||
router = APIRouter(prefix="/api/queue", tags=["download"])
|
||||
|
||||
|
||||
@router.get("/status", response_model=QueueStatusResponse)
|
||||
async def get_queue_status(
|
||||
download_service: DownloadService = Depends(get_download_service),
|
||||
_: dict = Depends(require_auth),
|
||||
):
|
||||
"""Get current download queue status and statistics.
|
||||
|
||||
Returns comprehensive information about all queue items including:
|
||||
- Active downloads with progress
|
||||
- Pending items waiting to be processed
|
||||
- Recently completed downloads
|
||||
- Failed downloads
|
||||
|
||||
Requires authentication.
|
||||
|
||||
Returns:
|
||||
QueueStatusResponse: Complete queue status and statistics
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if not authenticated, 500 on service error
|
||||
"""
|
||||
try:
|
||||
queue_status = await download_service.get_queue_status()
|
||||
queue_stats = await download_service.get_queue_stats()
|
||||
|
||||
return QueueStatusResponse(status=queue_status, statistics=queue_stats)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to retrieve queue status: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/add",
|
||||
response_model=DownloadResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
async def add_to_queue(
|
||||
request: DownloadRequest,
|
||||
download_service: DownloadService = Depends(get_download_service),
|
||||
_: dict = Depends(require_auth),
|
||||
):
|
||||
"""Add episodes to the download queue.
|
||||
|
||||
Adds one or more episodes to the download queue with specified priority.
|
||||
Episodes are validated and queued for processing based on priority level:
|
||||
- HIGH priority items are processed first
|
||||
- NORMAL and LOW priority items follow FIFO order
|
||||
|
||||
Requires authentication.
|
||||
|
||||
Args:
|
||||
request: Download request with serie info, episodes, and priority
|
||||
|
||||
Returns:
|
||||
DownloadResponse: Status and list of created download item IDs
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if not authenticated, 400 for invalid request,
|
||||
500 on service error
|
||||
"""
|
||||
try:
|
||||
# Validate request
|
||||
if not request.episodes:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="At least one episode must be specified",
|
||||
)
|
||||
|
||||
# Add to queue
|
||||
added_ids = await download_service.add_to_queue(
|
||||
serie_id=request.serie_id,
|
||||
serie_name=request.serie_name,
|
||||
episodes=request.episodes,
|
||||
priority=request.priority,
|
||||
)
|
||||
|
||||
return DownloadResponse(
|
||||
status="success",
|
||||
message=f"Added {len(added_ids)} episode(s) to download queue",
|
||||
added_items=added_ids,
|
||||
failed_items=[],
|
||||
)
|
||||
|
||||
except DownloadServiceError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to add episodes to queue: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{item_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def remove_from_queue(
|
||||
item_id: str = Path(..., description="Download item ID to remove"),
|
||||
download_service: DownloadService = Depends(get_download_service),
|
||||
_: dict = Depends(require_auth),
|
||||
):
|
||||
"""Remove a specific item from the download queue.
|
||||
|
||||
Removes a download item from the queue. If the item is currently
|
||||
downloading, it will be cancelled and marked as cancelled. If it's
|
||||
pending, it will simply be removed from the queue.
|
||||
|
||||
Requires authentication.
|
||||
|
||||
Args:
|
||||
item_id: Unique identifier of the download item to remove
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if not authenticated, 404 if item not found,
|
||||
500 on service error
|
||||
"""
|
||||
try:
|
||||
removed_ids = await download_service.remove_from_queue([item_id])
|
||||
|
||||
if not removed_ids:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Download item {item_id} not found in queue",
|
||||
)
|
||||
|
||||
except DownloadServiceError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to remove item from queue: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def remove_multiple_from_queue(
|
||||
request: QueueOperationRequest,
|
||||
download_service: DownloadService = Depends(get_download_service),
|
||||
_: dict = Depends(require_auth),
|
||||
):
|
||||
"""Remove multiple items from the download queue.
|
||||
|
||||
Batch removal of multiple download items. Each item is processed
|
||||
individually, and the operation continues even if some items are not
|
||||
found.
|
||||
|
||||
Requires authentication.
|
||||
|
||||
Args:
|
||||
request: List of download item IDs to remove
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if not authenticated, 400 for invalid request,
|
||||
500 on service error
|
||||
"""
|
||||
try:
|
||||
if not request.item_ids:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="At least one item ID must be specified",
|
||||
)
|
||||
|
||||
await download_service.remove_from_queue(request.item_ids)
|
||||
|
||||
# Note: We don't raise 404 if some items weren't found, as this is
|
||||
# a batch operation and partial success is acceptable
|
||||
|
||||
except DownloadServiceError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to remove items from queue: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/start", status_code=status.HTTP_200_OK)
|
||||
async def start_queue(
|
||||
download_service: DownloadService = Depends(get_download_service),
|
||||
_: dict = Depends(require_auth),
|
||||
):
|
||||
"""Start the download queue processor.
|
||||
|
||||
Starts processing the download queue. Downloads will be processed according
|
||||
to priority and concurrency limits. If the queue is already running, this
|
||||
operation is idempotent.
|
||||
|
||||
Requires authentication.
|
||||
|
||||
Returns:
|
||||
dict: Status message indicating queue has been started
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if not authenticated, 500 on service error
|
||||
"""
|
||||
try:
|
||||
await download_service.start()
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "Download queue processing started",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to start download queue: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/stop", status_code=status.HTTP_200_OK)
|
||||
async def stop_queue(
|
||||
download_service: DownloadService = Depends(get_download_service),
|
||||
_: dict = Depends(require_auth),
|
||||
):
|
||||
"""Stop the download queue processor.
|
||||
|
||||
Stops processing the download queue. Active downloads will be allowed to
|
||||
complete (with a timeout), then the queue processor will shut down.
|
||||
Queue state is persisted before shutdown.
|
||||
|
||||
Requires authentication.
|
||||
|
||||
Returns:
|
||||
dict: Status message indicating queue has been stopped
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if not authenticated, 500 on service error
|
||||
"""
|
||||
try:
|
||||
await download_service.stop()
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "Download queue processing stopped",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to stop download queue: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/pause", status_code=status.HTTP_200_OK)
|
||||
async def pause_queue(
|
||||
download_service: DownloadService = Depends(get_download_service),
|
||||
_: dict = Depends(require_auth),
|
||||
):
|
||||
"""Pause the download queue processor.
|
||||
|
||||
Pauses download processing. Active downloads will continue, but no new
|
||||
downloads will be started until the queue is resumed.
|
||||
|
||||
Requires authentication.
|
||||
|
||||
Returns:
|
||||
dict: Status message indicating queue has been paused
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if not authenticated, 500 on service error
|
||||
"""
|
||||
try:
|
||||
await download_service.pause_queue()
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "Download queue paused",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to pause download queue: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/resume", status_code=status.HTTP_200_OK)
|
||||
async def resume_queue(
|
||||
download_service: DownloadService = Depends(get_download_service),
|
||||
_: dict = Depends(require_auth),
|
||||
):
|
||||
"""Resume the download queue processor.
|
||||
|
||||
Resumes download processing after being paused. The queue will continue
|
||||
processing pending items according to priority.
|
||||
|
||||
Requires authentication.
|
||||
|
||||
Returns:
|
||||
dict: Status message indicating queue has been resumed
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if not authenticated, 500 on service error
|
||||
"""
|
||||
try:
|
||||
await download_service.resume_queue()
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "Download queue resumed",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to resume download queue: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/reorder", status_code=status.HTTP_200_OK)
|
||||
async def reorder_queue(
|
||||
request: QueueReorderRequest,
|
||||
download_service: DownloadService = Depends(get_download_service),
|
||||
_: dict = Depends(require_auth),
|
||||
):
|
||||
"""Reorder an item in the pending queue.
|
||||
|
||||
Changes the position of a pending download item in the queue. This only
|
||||
affects items that haven't started downloading yet. The position is
|
||||
0-based.
|
||||
|
||||
Requires authentication.
|
||||
|
||||
Args:
|
||||
request: Item ID and new position in queue
|
||||
|
||||
Returns:
|
||||
dict: Status message indicating item has been reordered
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if not authenticated, 404 if item not found,
|
||||
400 for invalid request, 500 on service error
|
||||
"""
|
||||
try:
|
||||
success = await download_service.reorder_queue(
|
||||
item_id=request.item_id,
|
||||
new_position=request.new_position,
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Item {request.item_id} not found in pending queue",
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "Queue item reordered successfully",
|
||||
}
|
||||
|
||||
except DownloadServiceError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to reorder queue item: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/completed", status_code=status.HTTP_200_OK)
|
||||
async def clear_completed(
|
||||
download_service: DownloadService = Depends(get_download_service),
|
||||
_: dict = Depends(require_auth),
|
||||
):
|
||||
"""Clear completed downloads from history.
|
||||
|
||||
Removes all completed download items from the queue history. This helps
|
||||
keep the queue display clean and manageable.
|
||||
|
||||
Requires authentication.
|
||||
|
||||
Returns:
|
||||
dict: Status message with count of cleared items
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if not authenticated, 500 on service error
|
||||
"""
|
||||
try:
|
||||
cleared_count = await download_service.clear_completed()
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Cleared {cleared_count} completed item(s)",
|
||||
"count": cleared_count,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to clear completed items: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/retry", status_code=status.HTTP_200_OK)
|
||||
async def retry_failed(
|
||||
request: QueueOperationRequest,
|
||||
download_service: DownloadService = Depends(get_download_service),
|
||||
_: dict = Depends(require_auth),
|
||||
):
|
||||
"""Retry failed downloads.
|
||||
|
||||
Moves failed download items back to the pending queue for retry. Only items
|
||||
that haven't exceeded the maximum retry count will be retried.
|
||||
|
||||
Requires authentication.
|
||||
|
||||
Args:
|
||||
request: List of download item IDs to retry (empty list retries all)
|
||||
|
||||
Returns:
|
||||
dict: Status message with count of retried items
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if not authenticated, 500 on service error
|
||||
"""
|
||||
try:
|
||||
# If no specific IDs provided, retry all failed items
|
||||
item_ids = request.item_ids if request.item_ids else None
|
||||
|
||||
retried_ids = await download_service.retry_failed(item_ids)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Retrying {len(retried_ids)} failed item(s)",
|
||||
"retried_ids": retried_ids,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to retry downloads: {str(e)}",
|
||||
)
|
||||
236
src/server/api/websocket.py
Normal file
236
src/server/api/websocket.py
Normal file
@ -0,0 +1,236 @@
|
||||
"""WebSocket API endpoints for real-time communication.
|
||||
|
||||
This module provides WebSocket endpoints for clients to connect and receive
|
||||
real-time updates about downloads, queue status, and system events.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
import structlog
|
||||
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect, status
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from src.server.models.websocket import (
|
||||
ClientMessage,
|
||||
RoomSubscriptionRequest,
|
||||
WebSocketMessageType,
|
||||
)
|
||||
from src.server.services.websocket_service import (
|
||||
WebSocketService,
|
||||
get_websocket_service,
|
||||
)
|
||||
from src.server.utils.dependencies import get_current_user_optional
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/ws", tags=["websocket"])
|
||||
|
||||
|
||||
@router.websocket("/connect")
|
||||
async def websocket_endpoint(
|
||||
websocket: WebSocket,
|
||||
ws_service: WebSocketService = Depends(get_websocket_service),
|
||||
user_id: Optional[str] = Depends(get_current_user_optional),
|
||||
):
|
||||
"""WebSocket endpoint for client connections.
|
||||
|
||||
Clients connect to this endpoint to receive real-time updates.
|
||||
The connection is maintained until the client disconnects or
|
||||
an error occurs.
|
||||
|
||||
Message flow:
|
||||
1. Client connects
|
||||
2. Server sends "connected" message
|
||||
3. Client can send subscription requests (join/leave rooms)
|
||||
4. Server broadcasts updates to subscribed rooms
|
||||
5. Client disconnects
|
||||
|
||||
Example client subscription:
|
||||
```json
|
||||
{
|
||||
"action": "join",
|
||||
"room": "downloads"
|
||||
}
|
||||
```
|
||||
|
||||
Server message format:
|
||||
```json
|
||||
{
|
||||
"type": "download_progress",
|
||||
"timestamp": "2025-10-17T10:30:00.000Z",
|
||||
"data": {
|
||||
"download_id": "abc123",
|
||||
"percent": 45.2,
|
||||
"speed_mbps": 2.5,
|
||||
"eta_seconds": 180
|
||||
}
|
||||
}
|
||||
```
|
||||
"""
|
||||
connection_id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
# Accept connection and register with service
|
||||
await ws_service.connect(websocket, connection_id, user_id=user_id)
|
||||
|
||||
# Send connection confirmation
|
||||
await ws_service.manager.send_personal_message(
|
||||
{
|
||||
"type": WebSocketMessageType.CONNECTED,
|
||||
"data": {
|
||||
"connection_id": connection_id,
|
||||
"message": "Connected to Aniworld WebSocket",
|
||||
},
|
||||
},
|
||||
connection_id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"WebSocket client connected",
|
||||
connection_id=connection_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Handle incoming messages
|
||||
while True:
|
||||
try:
|
||||
# Receive message from client
|
||||
data = await websocket.receive_json()
|
||||
|
||||
# Parse client message
|
||||
try:
|
||||
client_msg = ClientMessage(**data)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Invalid client message format",
|
||||
connection_id=connection_id,
|
||||
error=str(e),
|
||||
)
|
||||
await ws_service.send_error(
|
||||
connection_id,
|
||||
"Invalid message format",
|
||||
"INVALID_MESSAGE",
|
||||
)
|
||||
continue
|
||||
|
||||
# Handle room subscription requests
|
||||
if client_msg.action in ["join", "leave"]:
|
||||
try:
|
||||
room_req = RoomSubscriptionRequest(
|
||||
action=client_msg.action,
|
||||
room=client_msg.data.get("room", ""),
|
||||
)
|
||||
|
||||
if room_req.action == "join":
|
||||
await ws_service.manager.join_room(
|
||||
connection_id, room_req.room
|
||||
)
|
||||
await ws_service.manager.send_personal_message(
|
||||
{
|
||||
"type": WebSocketMessageType.SYSTEM_INFO,
|
||||
"data": {
|
||||
"message": (
|
||||
f"Joined room: {room_req.room}"
|
||||
)
|
||||
},
|
||||
},
|
||||
connection_id,
|
||||
)
|
||||
elif room_req.action == "leave":
|
||||
await ws_service.manager.leave_room(
|
||||
connection_id, room_req.room
|
||||
)
|
||||
await ws_service.manager.send_personal_message(
|
||||
{
|
||||
"type": WebSocketMessageType.SYSTEM_INFO,
|
||||
"data": {
|
||||
"message": (
|
||||
f"Left room: {room_req.room}"
|
||||
)
|
||||
},
|
||||
},
|
||||
connection_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Invalid room subscription request",
|
||||
connection_id=connection_id,
|
||||
error=str(e),
|
||||
)
|
||||
await ws_service.send_error(
|
||||
connection_id,
|
||||
"Invalid room subscription",
|
||||
"INVALID_SUBSCRIPTION",
|
||||
)
|
||||
|
||||
# Handle ping/pong for keepalive
|
||||
elif client_msg.action == "ping":
|
||||
await ws_service.manager.send_personal_message(
|
||||
{"type": WebSocketMessageType.PONG, "data": {}},
|
||||
connection_id,
|
||||
)
|
||||
|
||||
else:
|
||||
logger.debug(
|
||||
"Unknown action from client",
|
||||
connection_id=connection_id,
|
||||
action=client_msg.action,
|
||||
)
|
||||
await ws_service.send_error(
|
||||
connection_id,
|
||||
f"Unknown action: {client_msg.action}",
|
||||
"UNKNOWN_ACTION",
|
||||
)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(
|
||||
"WebSocket client disconnected",
|
||||
connection_id=connection_id,
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error handling WebSocket message",
|
||||
connection_id=connection_id,
|
||||
error=str(e),
|
||||
)
|
||||
await ws_service.send_error(
|
||||
connection_id,
|
||||
"Internal server error",
|
||||
"SERVER_ERROR",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"WebSocket connection error",
|
||||
connection_id=connection_id,
|
||||
error=str(e),
|
||||
)
|
||||
finally:
|
||||
# Cleanup connection
|
||||
await ws_service.disconnect(connection_id)
|
||||
logger.info("WebSocket connection closed", connection_id=connection_id)
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
async def websocket_status(
|
||||
ws_service: WebSocketService = Depends(get_websocket_service),
|
||||
):
|
||||
"""Get WebSocket service status and statistics.
|
||||
|
||||
Returns information about active connections and rooms.
|
||||
Useful for monitoring and debugging.
|
||||
"""
|
||||
connection_count = await ws_service.manager.get_connection_count()
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content={
|
||||
"status": "operational",
|
||||
"active_connections": connection_count,
|
||||
"supported_message_types": [t.value for t in WebSocketMessageType],
|
||||
},
|
||||
)
|
||||
@ -6,7 +6,7 @@ This module provides custom error handlers for different HTTP status codes.
|
||||
from fastapi import HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from src.server.utils.templates import templates
|
||||
from src.server.utils.template_helpers import render_template
|
||||
|
||||
|
||||
async def not_found_handler(request: Request, exc: HTTPException):
|
||||
@ -16,9 +16,11 @@ async def not_found_handler(request: Request, exc: HTTPException):
|
||||
status_code=404,
|
||||
content={"detail": "API endpoint not found"}
|
||||
)
|
||||
return templates.TemplateResponse(
|
||||
return render_template(
|
||||
"error.html",
|
||||
{"request": request, "error": "Page not found", "status_code": 404}
|
||||
request,
|
||||
context={"error": "Page not found", "status_code": 404},
|
||||
title="404 - Not Found"
|
||||
)
|
||||
|
||||
|
||||
@ -29,11 +31,9 @@ async def server_error_handler(request: Request, exc: Exception):
|
||||
status_code=500,
|
||||
content={"detail": "Internal server error"}
|
||||
)
|
||||
return templates.TemplateResponse(
|
||||
return render_template(
|
||||
"error.html",
|
||||
{
|
||||
"request": request,
|
||||
"error": "Internal server error",
|
||||
"status_code": 500
|
||||
}
|
||||
)
|
||||
request,
|
||||
context={"error": "Internal server error", "status_code": 500},
|
||||
title="500 - Server Error"
|
||||
)
|
||||
|
||||
@ -6,7 +6,7 @@ This module provides endpoints for serving HTML pages using Jinja2 templates.
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
from src.server.utils.templates import templates
|
||||
from src.server.utils.template_helpers import render_template
|
||||
|
||||
router = APIRouter(tags=["pages"])
|
||||
|
||||
@ -14,34 +14,38 @@ router = APIRouter(tags=["pages"])
|
||||
@router.get("/", response_class=HTMLResponse)
|
||||
async def root(request: Request):
|
||||
"""Serve the main application page."""
|
||||
return templates.TemplateResponse(
|
||||
return render_template(
|
||||
"index.html",
|
||||
{"request": request, "title": "Aniworld Download Manager"}
|
||||
request,
|
||||
title="Aniworld Download Manager"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/setup", response_class=HTMLResponse)
|
||||
async def setup_page(request: Request):
|
||||
"""Serve the setup page."""
|
||||
return templates.TemplateResponse(
|
||||
return render_template(
|
||||
"setup.html",
|
||||
{"request": request, "title": "Setup - Aniworld"}
|
||||
request,
|
||||
title="Setup - Aniworld"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/login", response_class=HTMLResponse)
|
||||
async def login_page(request: Request):
|
||||
"""Serve the login page."""
|
||||
return templates.TemplateResponse(
|
||||
return render_template(
|
||||
"login.html",
|
||||
{"request": request, "title": "Login - Aniworld"}
|
||||
request,
|
||||
title="Login - Aniworld"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/queue", response_class=HTMLResponse)
|
||||
async def queue_page(request: Request):
|
||||
"""Serve the download queue page."""
|
||||
return templates.TemplateResponse(
|
||||
return render_template(
|
||||
"queue.html",
|
||||
{"request": request, "title": "Download Queue - Aniworld"}
|
||||
)
|
||||
request,
|
||||
title="Download Queue - Aniworld"
|
||||
)
|
||||
|
||||
@ -17,7 +17,10 @@ from src.config.settings import settings
|
||||
|
||||
# Import core functionality
|
||||
from src.core.SeriesApp import SeriesApp
|
||||
from src.server.api.anime import router as anime_router
|
||||
from src.server.api.auth import router as auth_router
|
||||
from src.server.api.download import router as download_router
|
||||
from src.server.api.websocket import router as websocket_router
|
||||
from src.server.controllers.error_controller import (
|
||||
not_found_handler,
|
||||
server_error_handler,
|
||||
@ -27,6 +30,8 @@ from src.server.controllers.error_controller import (
|
||||
from src.server.controllers.health_controller import router as health_router
|
||||
from src.server.controllers.page_controller import router as page_router
|
||||
from src.server.middleware.auth import AuthMiddleware
|
||||
from src.server.services.progress_service import get_progress_service
|
||||
from src.server.services.websocket_service import get_websocket_service
|
||||
|
||||
# Initialize FastAPI app
|
||||
app = FastAPI(
|
||||
@ -57,6 +62,9 @@ app.add_middleware(AuthMiddleware, rate_limit_per_minute=5)
|
||||
app.include_router(health_router)
|
||||
app.include_router(page_router)
|
||||
app.include_router(auth_router)
|
||||
app.include_router(anime_router)
|
||||
app.include_router(download_router)
|
||||
app.include_router(websocket_router)
|
||||
|
||||
# Global variables for application state
|
||||
series_app: Optional[SeriesApp] = None
|
||||
@ -70,6 +78,23 @@ async def startup_event():
|
||||
# Initialize SeriesApp with configured directory
|
||||
if settings.anime_directory:
|
||||
series_app = SeriesApp(settings.anime_directory)
|
||||
|
||||
# Initialize progress service with websocket callback
|
||||
progress_service = get_progress_service()
|
||||
ws_service = get_websocket_service()
|
||||
|
||||
async def broadcast_callback(
|
||||
message_type: str, data: dict, room: str
|
||||
):
|
||||
"""Broadcast progress updates via WebSocket."""
|
||||
message = {
|
||||
"type": message_type,
|
||||
"data": data,
|
||||
}
|
||||
await ws_service.manager.broadcast_to_room(message, room)
|
||||
|
||||
progress_service.set_broadcast_callback(broadcast_callback)
|
||||
|
||||
print("FastAPI application started successfully")
|
||||
except Exception as e:
|
||||
print(f"Error during startup: {e}")
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
"""Models package for server-side Pydantic models."""
|
||||
|
||||
__all__ = ["auth"]
|
||||
__all__ = ["auth", "anime", "config", "download"]
|
||||
|
||||
122
src/server/models/anime.py
Normal file
122
src/server/models/anime.py
Normal file
@ -0,0 +1,122 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, HttpUrl
|
||||
|
||||
|
||||
class EpisodeInfo(BaseModel):
|
||||
"""Information about a single episode."""
|
||||
|
||||
episode_number: int = Field(..., ge=1, description="Episode index (1-based)")
|
||||
title: Optional[str] = Field(None, description="Optional episode title")
|
||||
aired_at: Optional[datetime] = Field(None, description="Air date/time if known")
|
||||
duration_seconds: Optional[int] = Field(None, ge=0, description="Duration in seconds")
|
||||
available: bool = Field(True, description="Whether the episode is available for download")
|
||||
sources: List[HttpUrl] = Field(default_factory=list, description="List of known streaming/download source URLs")
|
||||
|
||||
|
||||
class MissingEpisodeInfo(BaseModel):
|
||||
"""Represents a gap in the episode list for a series."""
|
||||
|
||||
from_episode: int = Field(..., ge=1, description="Starting missing episode number")
|
||||
to_episode: int = Field(..., ge=1, description="Ending missing episode number (inclusive)")
|
||||
reason: Optional[str] = Field(None, description="Optional explanation why episodes are missing")
|
||||
|
||||
@property
|
||||
def count(self) -> int:
|
||||
"""Number of missing episodes in the range."""
|
||||
return max(0, self.to_episode - self.from_episode + 1)
|
||||
|
||||
|
||||
class AnimeSeriesResponse(BaseModel):
|
||||
"""Response model for a series with metadata and episodes."""
|
||||
|
||||
id: str = Field(..., description="Unique series identifier")
|
||||
title: str = Field(..., description="Series title")
|
||||
alt_titles: List[str] = Field(default_factory=list, description="Alternative titles")
|
||||
description: Optional[str] = Field(None, description="Short series description")
|
||||
total_episodes: Optional[int] = Field(None, ge=0, description="Declared total episode count if known")
|
||||
episodes: List[EpisodeInfo] = Field(default_factory=list, description="Known episodes information")
|
||||
missing_episodes: List[MissingEpisodeInfo] = Field(default_factory=list, description="Detected missing episode ranges")
|
||||
thumbnail: Optional[HttpUrl] = Field(None, description="Optional thumbnail image URL")
|
||||
|
||||
|
||||
class SearchRequest(BaseModel):
|
||||
"""Request payload for searching series."""
|
||||
|
||||
query: str = Field(..., min_length=1)
|
||||
limit: int = Field(10, ge=1, le=100)
|
||||
include_adult: bool = Field(False)
|
||||
|
||||
|
||||
class SearchResult(BaseModel):
|
||||
"""Search result item for a series discovery endpoint."""
|
||||
|
||||
id: str
|
||||
title: str
|
||||
snippet: Optional[str] = None
|
||||
thumbnail: Optional[HttpUrl] = None
|
||||
score: Optional[float] = None
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, HttpUrl
|
||||
|
||||
|
||||
class EpisodeInfo(BaseModel):
|
||||
"""Information about a single episode."""
|
||||
|
||||
episode_number: int = Field(..., ge=1, description="Episode index (1-based)")
|
||||
title: Optional[str] = Field(None, description="Optional episode title")
|
||||
aired_at: Optional[datetime] = Field(None, description="Air date/time if known")
|
||||
duration_seconds: Optional[int] = Field(None, ge=0, description="Duration in seconds")
|
||||
available: bool = Field(True, description="Whether the episode is available for download")
|
||||
sources: List[HttpUrl] = Field(default_factory=list, description="List of known streaming/download source URLs")
|
||||
|
||||
|
||||
class MissingEpisodeInfo(BaseModel):
|
||||
"""Represents a gap in the episode list for a series."""
|
||||
|
||||
from_episode: int = Field(..., ge=1, description="Starting missing episode number")
|
||||
to_episode: int = Field(..., ge=1, description="Ending missing episode number (inclusive)")
|
||||
reason: Optional[str] = Field(None, description="Optional explanation why episodes are missing")
|
||||
|
||||
@property
|
||||
def count(self) -> int:
|
||||
"""Number of missing episodes in the range."""
|
||||
return max(0, self.to_episode - self.from_episode + 1)
|
||||
|
||||
|
||||
class AnimeSeriesResponse(BaseModel):
|
||||
"""Response model for a series with metadata and episodes."""
|
||||
|
||||
id: str = Field(..., description="Unique series identifier")
|
||||
title: str = Field(..., description="Series title")
|
||||
alt_titles: List[str] = Field(default_factory=list, description="Alternative titles")
|
||||
description: Optional[str] = Field(None, description="Short series description")
|
||||
total_episodes: Optional[int] = Field(None, ge=0, description="Declared total episode count if known")
|
||||
episodes: List[EpisodeInfo] = Field(default_factory=list, description="Known episodes information")
|
||||
missing_episodes: List[MissingEpisodeInfo] = Field(default_factory=list, description="Detected missing episode ranges")
|
||||
thumbnail: Optional[HttpUrl] = Field(None, description="Optional thumbnail image URL")
|
||||
|
||||
|
||||
class SearchRequest(BaseModel):
|
||||
"""Request payload for searching series."""
|
||||
|
||||
query: str = Field(..., min_length=1)
|
||||
limit: int = Field(10, ge=1, le=100)
|
||||
include_adult: bool = Field(False)
|
||||
|
||||
|
||||
class SearchResult(BaseModel):
|
||||
"""Search result item for a series discovery endpoint."""
|
||||
|
||||
id: str
|
||||
title: str
|
||||
snippet: Optional[str] = None
|
||||
thumbnail: Optional[HttpUrl] = None
|
||||
score: Optional[float] = None
|
||||
130
src/server/models/config.py
Normal file
130
src/server/models/config.py
Normal file
@ -0,0 +1,130 @@
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationError, validator
|
||||
|
||||
|
||||
class SchedulerConfig(BaseModel):
|
||||
"""Scheduler related configuration."""
|
||||
|
||||
enabled: bool = Field(
|
||||
default=True, description="Whether the scheduler is enabled"
|
||||
)
|
||||
interval_minutes: int = Field(
|
||||
default=60, ge=1, description="Scheduler interval in minutes"
|
||||
)
|
||||
|
||||
|
||||
class BackupConfig(BaseModel):
|
||||
"""Configuration for automatic backups of application data."""
|
||||
|
||||
enabled: bool = Field(
|
||||
default=False, description="Whether backups are enabled"
|
||||
)
|
||||
path: Optional[str] = Field(
|
||||
default="data/backups", description="Path to store backups"
|
||||
)
|
||||
keep_days: int = Field(
|
||||
default=30, ge=0, description="How many days to keep backups"
|
||||
)
|
||||
|
||||
|
||||
class LoggingConfig(BaseModel):
|
||||
"""Logging configuration with basic validation for level."""
|
||||
|
||||
level: str = Field(
|
||||
default="INFO", description="Logging level"
|
||||
)
|
||||
file: Optional[str] = Field(
|
||||
default=None, description="Optional file path for log output"
|
||||
)
|
||||
max_bytes: Optional[int] = Field(
|
||||
default=None, ge=0, description="Max bytes per log file for rotation"
|
||||
)
|
||||
backup_count: Optional[int] = Field(
|
||||
default=3, ge=0, description="Number of rotated log files to keep"
|
||||
)
|
||||
|
||||
@validator("level")
|
||||
def validate_level(cls, v: str) -> str:
|
||||
allowed = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"}
|
||||
lvl = (v or "").upper()
|
||||
if lvl not in allowed:
|
||||
raise ValueError(f"invalid logging level: {v}")
|
||||
return lvl
|
||||
|
||||
|
||||
class ValidationResult(BaseModel):
|
||||
"""Result of a configuration validation attempt."""
|
||||
|
||||
valid: bool = Field(..., description="Whether the configuration is valid")
|
||||
errors: Optional[List[str]] = Field(
|
||||
default_factory=list, description="List of validation error messages"
|
||||
)
|
||||
|
||||
|
||||
class AppConfig(BaseModel):
|
||||
"""Top-level application configuration model used by the web layer.
|
||||
|
||||
This model intentionally keeps things small and serializable to JSON.
|
||||
"""
|
||||
|
||||
name: str = Field(default="Aniworld", description="Application name")
|
||||
data_dir: str = Field(default="data", description="Base data directory")
|
||||
scheduler: SchedulerConfig = Field(default_factory=SchedulerConfig)
|
||||
logging: LoggingConfig = Field(default_factory=LoggingConfig)
|
||||
backup: BackupConfig = Field(default_factory=BackupConfig)
|
||||
other: Dict[str, object] = Field(
|
||||
default_factory=dict, description="Arbitrary other settings"
|
||||
)
|
||||
|
||||
def validate(self) -> ValidationResult:
|
||||
"""Perform light-weight validation and return a ValidationResult.
|
||||
|
||||
This method intentionally avoids performing IO (no filesystem checks)
|
||||
so it remains fast and side-effect free for unit tests and API use.
|
||||
"""
|
||||
errors: List[str] = []
|
||||
|
||||
# Pydantic field validators already run on construction; re-run a
|
||||
# quick check for common constraints and collect messages.
|
||||
try:
|
||||
# Reconstruct to ensure nested validators are executed
|
||||
AppConfig(**self.model_dump())
|
||||
except ValidationError as exc:
|
||||
for e in exc.errors():
|
||||
loc = ".".join(str(x) for x in e.get("loc", []))
|
||||
msg = f"{loc}: {e.get('msg')}"
|
||||
errors.append(msg)
|
||||
|
||||
# backup.path must be set when backups are enabled
|
||||
if self.backup.enabled and (not self.backup.path):
|
||||
errors.append(
|
||||
"backup.path must be set when backups.enabled is true"
|
||||
)
|
||||
|
||||
return ValidationResult(valid=(len(errors) == 0), errors=errors)
|
||||
|
||||
|
||||
class ConfigUpdate(BaseModel):
|
||||
scheduler: Optional[SchedulerConfig] = None
|
||||
logging: Optional[LoggingConfig] = None
|
||||
backup: Optional[BackupConfig] = None
|
||||
other: Optional[Dict[str, object]] = None
|
||||
|
||||
def apply_to(self, current: AppConfig) -> AppConfig:
|
||||
"""Return a new AppConfig with updates applied to the current config.
|
||||
|
||||
Performs a shallow merge for `other`.
|
||||
"""
|
||||
data = current.model_dump()
|
||||
if self.scheduler is not None:
|
||||
data["scheduler"] = self.scheduler.model_dump()
|
||||
if self.logging is not None:
|
||||
data["logging"] = self.logging.model_dump()
|
||||
if self.backup is not None:
|
||||
data["backup"] = self.backup.model_dump()
|
||||
if self.other is not None:
|
||||
merged = dict(current.other or {})
|
||||
merged.update(self.other)
|
||||
data["other"] = merged
|
||||
return AppConfig(**data)
|
||||
207
src/server/models/download.py
Normal file
207
src/server/models/download.py
Normal file
@ -0,0 +1,207 @@
|
||||
"""Download queue Pydantic models for the Aniworld web application.
|
||||
|
||||
This module defines request/response models used by the download queue API
|
||||
and the download service. Models are intentionally lightweight and focused
|
||||
on serialization, validation, and OpenAPI documentation.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, HttpUrl
|
||||
|
||||
|
||||
class DownloadStatus(str, Enum):
|
||||
"""Status of a download item in the queue."""
|
||||
|
||||
PENDING = "pending"
|
||||
DOWNLOADING = "downloading"
|
||||
PAUSED = "paused"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class DownloadPriority(str, Enum):
|
||||
"""Priority level for download queue items."""
|
||||
|
||||
LOW = "low"
|
||||
NORMAL = "normal"
|
||||
HIGH = "high"
|
||||
|
||||
|
||||
class EpisodeIdentifier(BaseModel):
|
||||
"""Episode identification information for a download item."""
|
||||
|
||||
season: int = Field(..., ge=1, description="Season number (1-based)")
|
||||
episode: int = Field(
|
||||
..., ge=1, description="Episode number within season (1-based)"
|
||||
)
|
||||
title: Optional[str] = Field(None, description="Episode title if known")
|
||||
|
||||
|
||||
class DownloadProgress(BaseModel):
|
||||
"""Real-time progress information for an active download."""
|
||||
|
||||
percent: float = Field(
|
||||
0.0, ge=0.0, le=100.0, description="Download progress percentage"
|
||||
)
|
||||
downloaded_mb: float = Field(
|
||||
0.0, ge=0.0, description="Downloaded size in megabytes"
|
||||
)
|
||||
total_mb: Optional[float] = Field(
|
||||
None, ge=0.0, description="Total size in megabytes if known"
|
||||
)
|
||||
speed_mbps: Optional[float] = Field(
|
||||
None, ge=0.0, description="Download speed in MB/s"
|
||||
)
|
||||
eta_seconds: Optional[int] = Field(
|
||||
None, ge=0, description="Estimated time remaining in seconds"
|
||||
)
|
||||
|
||||
|
||||
class DownloadItem(BaseModel):
|
||||
"""Represents a single download item in the queue."""
|
||||
|
||||
id: str = Field(..., description="Unique download item identifier")
|
||||
serie_id: str = Field(..., description="Series identifier")
|
||||
serie_name: str = Field(..., min_length=1, description="Series name")
|
||||
episode: EpisodeIdentifier = Field(
|
||||
..., description="Episode identification"
|
||||
)
|
||||
status: DownloadStatus = Field(
|
||||
DownloadStatus.PENDING, description="Current download status"
|
||||
)
|
||||
priority: DownloadPriority = Field(
|
||||
DownloadPriority.NORMAL, description="Queue priority"
|
||||
)
|
||||
|
||||
# Timestamps
|
||||
added_at: datetime = Field(
|
||||
default_factory=datetime.utcnow,
|
||||
description="When item was added to queue"
|
||||
)
|
||||
started_at: Optional[datetime] = Field(
|
||||
None, description="When download started"
|
||||
)
|
||||
completed_at: Optional[datetime] = Field(
|
||||
None, description="When download completed/failed"
|
||||
)
|
||||
|
||||
# Progress tracking
|
||||
progress: Optional[DownloadProgress] = Field(
|
||||
None, description="Current progress if downloading"
|
||||
)
|
||||
|
||||
# Error handling
|
||||
error: Optional[str] = Field(None, description="Error message if failed")
|
||||
retry_count: int = Field(0, ge=0, description="Number of retry attempts")
|
||||
|
||||
# Download source
|
||||
source_url: Optional[HttpUrl] = Field(
|
||||
None, description="Source URL for download"
|
||||
)
|
||||
|
||||
|
||||
class QueueStatus(BaseModel):
|
||||
"""Overall status of the download queue system."""
|
||||
|
||||
is_running: bool = Field(
|
||||
False, description="Whether the queue processor is running"
|
||||
)
|
||||
is_paused: bool = Field(False, description="Whether downloads are paused")
|
||||
active_downloads: List[DownloadItem] = Field(
|
||||
default_factory=list, description="Currently downloading items"
|
||||
)
|
||||
pending_queue: List[DownloadItem] = Field(
|
||||
default_factory=list, description="Items waiting to be downloaded"
|
||||
)
|
||||
completed_downloads: List[DownloadItem] = Field(
|
||||
default_factory=list, description="Recently completed downloads"
|
||||
)
|
||||
failed_downloads: List[DownloadItem] = Field(
|
||||
default_factory=list, description="Failed download items"
|
||||
)
|
||||
|
||||
|
||||
class QueueStats(BaseModel):
|
||||
"""Statistics about the download queue."""
|
||||
|
||||
total_items: int = Field(
|
||||
0, ge=0, description="Total number of items in all queues"
|
||||
)
|
||||
pending_count: int = Field(0, ge=0, description="Number of pending items")
|
||||
active_count: int = Field(
|
||||
0, ge=0, description="Number of active downloads"
|
||||
)
|
||||
completed_count: int = Field(
|
||||
0, ge=0, description="Number of completed downloads"
|
||||
)
|
||||
failed_count: int = Field(
|
||||
0, ge=0, description="Number of failed downloads"
|
||||
)
|
||||
|
||||
total_downloaded_mb: float = Field(
|
||||
0.0, ge=0.0, description="Total megabytes downloaded"
|
||||
)
|
||||
average_speed_mbps: Optional[float] = Field(
|
||||
None, ge=0.0, description="Average download speed in MB/s"
|
||||
)
|
||||
estimated_time_remaining: Optional[int] = Field(
|
||||
None, ge=0, description="Estimated time to complete queue in seconds"
|
||||
)
|
||||
|
||||
|
||||
class DownloadRequest(BaseModel):
|
||||
"""Request to add episode(s) to the download queue."""
|
||||
|
||||
serie_id: str = Field(..., description="Series identifier")
|
||||
serie_name: str = Field(
|
||||
..., min_length=1, description="Series name for display"
|
||||
)
|
||||
episodes: List[EpisodeIdentifier] = Field(
|
||||
..., min_length=1, description="List of episodes to download"
|
||||
)
|
||||
priority: DownloadPriority = Field(
|
||||
DownloadPriority.NORMAL, description="Priority level for queue items"
|
||||
)
|
||||
|
||||
|
||||
class DownloadResponse(BaseModel):
|
||||
"""Response after adding items to the download queue."""
|
||||
|
||||
status: str = Field(..., description="Status of the request")
|
||||
message: str = Field(..., description="Human-readable status message")
|
||||
added_items: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="IDs of successfully added download items"
|
||||
)
|
||||
failed_items: List[str] = Field(
|
||||
default_factory=list, description="Episodes that failed to be added"
|
||||
)
|
||||
|
||||
|
||||
class QueueOperationRequest(BaseModel):
|
||||
"""Request to perform operations on queue items."""
|
||||
|
||||
item_ids: List[str] = Field(
|
||||
..., min_length=1, description="List of download item IDs"
|
||||
)
|
||||
|
||||
|
||||
class QueueReorderRequest(BaseModel):
|
||||
"""Request to reorder items in the pending queue."""
|
||||
|
||||
item_id: str = Field(..., description="Download item ID to move")
|
||||
new_position: int = Field(
|
||||
..., ge=0, description="New position in queue (0-based)"
|
||||
)
|
||||
|
||||
|
||||
class QueueStatusResponse(BaseModel):
|
||||
"""Complete response for queue status endpoint."""
|
||||
|
||||
status: QueueStatus = Field(..., description="Current queue status")
|
||||
statistics: QueueStats = Field(..., description="Queue statistics")
|
||||
285
src/server/models/websocket.py
Normal file
285
src/server/models/websocket.py
Normal file
@ -0,0 +1,285 @@
|
||||
"""WebSocket message Pydantic models for the Aniworld web application.
|
||||
|
||||
This module defines message models for WebSocket communication between
|
||||
the server and clients. Models ensure type safety and provide validation
|
||||
for real-time updates.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class WebSocketMessageType(str, Enum):
|
||||
"""Types of WebSocket messages."""
|
||||
|
||||
# Download-related messages
|
||||
DOWNLOAD_PROGRESS = "download_progress"
|
||||
DOWNLOAD_COMPLETE = "download_complete"
|
||||
DOWNLOAD_FAILED = "download_failed"
|
||||
DOWNLOAD_ADDED = "download_added"
|
||||
DOWNLOAD_REMOVED = "download_removed"
|
||||
|
||||
# Queue-related messages
|
||||
QUEUE_STATUS = "queue_status"
|
||||
QUEUE_STARTED = "queue_started"
|
||||
QUEUE_STOPPED = "queue_stopped"
|
||||
QUEUE_PAUSED = "queue_paused"
|
||||
QUEUE_RESUMED = "queue_resumed"
|
||||
|
||||
# Progress-related messages
|
||||
SCAN_PROGRESS = "scan_progress"
|
||||
SCAN_COMPLETE = "scan_complete"
|
||||
SCAN_FAILED = "scan_failed"
|
||||
|
||||
# System messages
|
||||
SYSTEM_INFO = "system_info"
|
||||
SYSTEM_WARNING = "system_warning"
|
||||
SYSTEM_ERROR = "system_error"
|
||||
|
||||
# Error messages
|
||||
ERROR = "error"
|
||||
|
||||
# Connection messages
|
||||
CONNECTED = "connected"
|
||||
PING = "ping"
|
||||
PONG = "pong"
|
||||
|
||||
|
||||
class WebSocketMessage(BaseModel):
|
||||
"""Base WebSocket message structure."""
|
||||
|
||||
type: WebSocketMessageType = Field(
|
||||
..., description="Type of the message"
|
||||
)
|
||||
timestamp: str = Field(
|
||||
default_factory=lambda: datetime.utcnow().isoformat(),
|
||||
description="ISO 8601 timestamp when message was created",
|
||||
)
|
||||
data: Dict[str, Any] = Field(
|
||||
default_factory=dict, description="Message payload"
|
||||
)
|
||||
|
||||
|
||||
class DownloadProgressMessage(BaseModel):
|
||||
"""Download progress update message."""
|
||||
|
||||
type: WebSocketMessageType = Field(
|
||||
default=WebSocketMessageType.DOWNLOAD_PROGRESS,
|
||||
description="Message type",
|
||||
)
|
||||
timestamp: str = Field(
|
||||
default_factory=lambda: datetime.utcnow().isoformat(),
|
||||
description="ISO 8601 timestamp",
|
||||
)
|
||||
data: Dict[str, Any] = Field(
|
||||
...,
|
||||
description="Progress data including download_id, percent, speed, eta",
|
||||
)
|
||||
|
||||
|
||||
class DownloadCompleteMessage(BaseModel):
|
||||
"""Download completion message."""
|
||||
|
||||
type: WebSocketMessageType = Field(
|
||||
default=WebSocketMessageType.DOWNLOAD_COMPLETE,
|
||||
description="Message type",
|
||||
)
|
||||
timestamp: str = Field(
|
||||
default_factory=lambda: datetime.utcnow().isoformat(),
|
||||
description="ISO 8601 timestamp",
|
||||
)
|
||||
data: Dict[str, Any] = Field(
|
||||
..., description="Completion data including download_id, file_path"
|
||||
)
|
||||
|
||||
|
||||
class DownloadFailedMessage(BaseModel):
|
||||
"""Download failure message."""
|
||||
|
||||
type: WebSocketMessageType = Field(
|
||||
default=WebSocketMessageType.DOWNLOAD_FAILED,
|
||||
description="Message type",
|
||||
)
|
||||
timestamp: str = Field(
|
||||
default_factory=lambda: datetime.utcnow().isoformat(),
|
||||
description="ISO 8601 timestamp",
|
||||
)
|
||||
data: Dict[str, Any] = Field(
|
||||
..., description="Error data including download_id, error_message"
|
||||
)
|
||||
|
||||
|
||||
class QueueStatusMessage(BaseModel):
|
||||
"""Queue status update message."""
|
||||
|
||||
type: WebSocketMessageType = Field(
|
||||
default=WebSocketMessageType.QUEUE_STATUS,
|
||||
description="Message type",
|
||||
)
|
||||
timestamp: str = Field(
|
||||
default_factory=lambda: datetime.utcnow().isoformat(),
|
||||
description="ISO 8601 timestamp",
|
||||
)
|
||||
data: Dict[str, Any] = Field(
|
||||
...,
|
||||
description="Queue status including active, pending, completed counts",
|
||||
)
|
||||
|
||||
|
||||
class SystemMessage(BaseModel):
|
||||
"""System-level message (info, warning, error)."""
|
||||
|
||||
type: WebSocketMessageType = Field(
|
||||
..., description="System message type"
|
||||
)
|
||||
timestamp: str = Field(
|
||||
default_factory=lambda: datetime.utcnow().isoformat(),
|
||||
description="ISO 8601 timestamp",
|
||||
)
|
||||
data: Dict[str, Any] = Field(
|
||||
..., description="System message data"
|
||||
)
|
||||
|
||||
|
||||
class ErrorMessage(BaseModel):
|
||||
"""Error message to client."""
|
||||
|
||||
type: WebSocketMessageType = Field(
|
||||
default=WebSocketMessageType.ERROR, description="Message type"
|
||||
)
|
||||
timestamp: str = Field(
|
||||
default_factory=lambda: datetime.utcnow().isoformat(),
|
||||
description="ISO 8601 timestamp",
|
||||
)
|
||||
data: Dict[str, Any] = Field(
|
||||
..., description="Error data including code and message"
|
||||
)
|
||||
|
||||
|
||||
class ConnectionMessage(BaseModel):
|
||||
"""Connection-related message (connected, ping, pong)."""
|
||||
|
||||
type: WebSocketMessageType = Field(
|
||||
..., description="Connection message type"
|
||||
)
|
||||
timestamp: str = Field(
|
||||
default_factory=lambda: datetime.utcnow().isoformat(),
|
||||
description="ISO 8601 timestamp",
|
||||
)
|
||||
data: Dict[str, Any] = Field(
|
||||
default_factory=dict, description="Connection message data"
|
||||
)
|
||||
|
||||
|
||||
class ClientMessage(BaseModel):
|
||||
"""Inbound message from client to server."""
|
||||
|
||||
action: str = Field(..., description="Action requested by client")
|
||||
data: Optional[Dict[str, Any]] = Field(
|
||||
default_factory=dict, description="Action payload"
|
||||
)
|
||||
|
||||
|
||||
class RoomSubscriptionRequest(BaseModel):
|
||||
"""Request to join or leave a room."""
|
||||
|
||||
action: str = Field(
|
||||
..., description="Action: 'join' or 'leave'"
|
||||
)
|
||||
room: str = Field(
|
||||
..., min_length=1, description="Room name to join or leave"
|
||||
)
|
||||
|
||||
|
||||
class ScanProgressMessage(BaseModel):
|
||||
"""Scan progress update message."""
|
||||
|
||||
type: WebSocketMessageType = Field(
|
||||
default=WebSocketMessageType.SCAN_PROGRESS,
|
||||
description="Message type",
|
||||
)
|
||||
timestamp: str = Field(
|
||||
default_factory=lambda: datetime.utcnow().isoformat(),
|
||||
description="ISO 8601 timestamp",
|
||||
)
|
||||
data: Dict[str, Any] = Field(
|
||||
...,
|
||||
description="Scan progress data including current, total, percent",
|
||||
)
|
||||
|
||||
|
||||
class ScanCompleteMessage(BaseModel):
|
||||
"""Scan completion message."""
|
||||
|
||||
type: WebSocketMessageType = Field(
|
||||
default=WebSocketMessageType.SCAN_COMPLETE,
|
||||
description="Message type",
|
||||
)
|
||||
timestamp: str = Field(
|
||||
default_factory=lambda: datetime.utcnow().isoformat(),
|
||||
description="ISO 8601 timestamp",
|
||||
)
|
||||
data: Dict[str, Any] = Field(
|
||||
...,
|
||||
description="Scan completion data including series_found, duration",
|
||||
)
|
||||
|
||||
|
||||
class ScanFailedMessage(BaseModel):
|
||||
"""Scan failure message."""
|
||||
|
||||
type: WebSocketMessageType = Field(
|
||||
default=WebSocketMessageType.SCAN_FAILED,
|
||||
description="Message type",
|
||||
)
|
||||
timestamp: str = Field(
|
||||
default_factory=lambda: datetime.utcnow().isoformat(),
|
||||
description="ISO 8601 timestamp",
|
||||
)
|
||||
data: Dict[str, Any] = Field(
|
||||
..., description="Scan error data including error_message"
|
||||
)
|
||||
|
||||
|
||||
class ErrorNotificationMessage(BaseModel):
|
||||
"""Error notification message for critical errors."""
|
||||
|
||||
type: WebSocketMessageType = Field(
|
||||
default=WebSocketMessageType.SYSTEM_ERROR,
|
||||
description="Message type",
|
||||
)
|
||||
timestamp: str = Field(
|
||||
default_factory=lambda: datetime.utcnow().isoformat(),
|
||||
description="ISO 8601 timestamp",
|
||||
)
|
||||
data: Dict[str, Any] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Error notification data including severity, message, details"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ProgressUpdateMessage(BaseModel):
|
||||
"""Generic progress update message.
|
||||
|
||||
Can be used for any type of progress (download, scan, queue, etc.)
|
||||
"""
|
||||
|
||||
type: WebSocketMessageType = Field(
|
||||
..., description="Type of progress message"
|
||||
)
|
||||
timestamp: str = Field(
|
||||
default_factory=lambda: datetime.utcnow().isoformat(),
|
||||
description="ISO 8601 timestamp",
|
||||
)
|
||||
data: Dict[str, Any] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Progress data including id, status, percent, current, total"
|
||||
),
|
||||
)
|
||||
170
src/server/services/anime_service.py
Normal file
170
src/server/services/anime_service.py
Normal file
@ -0,0 +1,170 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import lru_cache
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import structlog
|
||||
|
||||
from src.core.SeriesApp import SeriesApp
|
||||
from src.server.services.progress_service import (
|
||||
ProgressService,
|
||||
ProgressType,
|
||||
get_progress_service,
|
||||
)
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
class AnimeServiceError(Exception):
|
||||
"""Service-level exception for anime operations."""
|
||||
|
||||
|
||||
class AnimeService:
|
||||
"""Wraps the blocking SeriesApp for use in the FastAPI web layer.
|
||||
|
||||
- Runs blocking operations in a threadpool
|
||||
- Exposes async methods
|
||||
- Adds simple in-memory caching for read operations
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
directory: str,
|
||||
max_workers: int = 4,
|
||||
progress_service: Optional[ProgressService] = None,
|
||||
):
|
||||
self._directory = directory
|
||||
self._executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
self._progress_service = progress_service or get_progress_service()
|
||||
# SeriesApp is blocking; instantiate per-service
|
||||
try:
|
||||
self._app = SeriesApp(directory)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to initialize SeriesApp")
|
||||
raise AnimeServiceError("Initialization failed") from e
|
||||
|
||||
async def _run_in_executor(self, func, *args, **kwargs):
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
return await loop.run_in_executor(self._executor, lambda: func(*args, **kwargs))
|
||||
except Exception as e:
|
||||
logger.exception("Executor task failed")
|
||||
raise AnimeServiceError(str(e)) from e
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def _cached_list_missing(self) -> List[dict]:
|
||||
# Synchronous cached call used by async wrapper
|
||||
try:
|
||||
series = self._app.series_list
|
||||
# normalize to simple dicts
|
||||
return [s.to_dict() if hasattr(s, "to_dict") else s for s in series]
|
||||
except Exception as e:
|
||||
logger.exception("Failed to get missing episodes list")
|
||||
raise
|
||||
|
||||
async def list_missing(self) -> List[dict]:
|
||||
"""Return list of series with missing episodes."""
|
||||
try:
|
||||
return await self._run_in_executor(self._cached_list_missing)
|
||||
except AnimeServiceError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("list_missing failed")
|
||||
raise AnimeServiceError("Failed to list missing series") from e
|
||||
|
||||
async def search(self, query: str) -> List[dict]:
|
||||
"""Search for series using underlying loader.Search."""
|
||||
if not query:
|
||||
return []
|
||||
try:
|
||||
result = await self._run_in_executor(self._app.search, query)
|
||||
# result may already be list of dicts or objects
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.exception("search failed")
|
||||
raise AnimeServiceError("Search failed") from e
|
||||
|
||||
async def rescan(self, callback: Optional[Callable] = None) -> None:
|
||||
"""Trigger a re-scan. Accepts an optional callback function.
|
||||
|
||||
The callback is executed in the threadpool by SeriesApp.
|
||||
Progress updates are tracked and broadcasted via ProgressService.
|
||||
"""
|
||||
scan_id = "library_scan"
|
||||
|
||||
try:
|
||||
# Start progress tracking
|
||||
await self._progress_service.start_progress(
|
||||
progress_id=scan_id,
|
||||
progress_type=ProgressType.SCAN,
|
||||
title="Scanning anime library",
|
||||
message="Initializing scan...",
|
||||
)
|
||||
|
||||
# Create wrapped callback for progress updates
|
||||
def progress_callback(progress_data: dict) -> None:
|
||||
"""Update progress during scan."""
|
||||
try:
|
||||
if callback:
|
||||
callback(progress_data)
|
||||
|
||||
# Update progress service
|
||||
current = progress_data.get("current", 0)
|
||||
total = progress_data.get("total", 0)
|
||||
message = progress_data.get("message", "Scanning...")
|
||||
|
||||
asyncio.create_task(
|
||||
self._progress_service.update_progress(
|
||||
progress_id=scan_id,
|
||||
current=current,
|
||||
total=total,
|
||||
message=message,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Scan progress callback error", error=str(e))
|
||||
|
||||
# Run scan
|
||||
await self._run_in_executor(self._app.ReScan, progress_callback)
|
||||
|
||||
# invalidate cache
|
||||
try:
|
||||
self._cached_list_missing.cache_clear()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Complete progress tracking
|
||||
await self._progress_service.complete_progress(
|
||||
progress_id=scan_id,
|
||||
message="Scan completed successfully",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("rescan failed")
|
||||
|
||||
# Fail progress tracking
|
||||
await self._progress_service.fail_progress(
|
||||
progress_id=scan_id,
|
||||
error_message=str(e),
|
||||
)
|
||||
|
||||
raise AnimeServiceError("Rescan failed") from e
|
||||
|
||||
async def download(self, serie_folder: str, season: int, episode: int, key: str, callback=None) -> bool:
|
||||
"""Start a download via the underlying loader.
|
||||
|
||||
Returns True on success or raises AnimeServiceError on failure.
|
||||
"""
|
||||
try:
|
||||
result = await self._run_in_executor(self._app.download, serie_folder, season, episode, key, callback)
|
||||
return bool(result)
|
||||
except Exception as e:
|
||||
logger.exception("download failed")
|
||||
raise AnimeServiceError("Download failed") from e
|
||||
|
||||
|
||||
def get_anime_service(directory: str = "./") -> AnimeService:
|
||||
"""Factory used by FastAPI dependency injection."""
|
||||
return AnimeService(directory)
|
||||
859
src/server/services/download_service.py
Normal file
859
src/server/services/download_service.py
Normal file
@ -0,0 +1,859 @@
|
||||
"""Download queue service for managing anime episode downloads.
|
||||
|
||||
This module provides a comprehensive queue management system for handling
|
||||
concurrent anime episode downloads with priority-based scheduling, progress
|
||||
tracking, persistence, and automatic retry functionality.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from collections import deque
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
import structlog
|
||||
|
||||
from src.server.models.download import (
|
||||
DownloadItem,
|
||||
DownloadPriority,
|
||||
DownloadProgress,
|
||||
DownloadStatus,
|
||||
EpisodeIdentifier,
|
||||
QueueStats,
|
||||
QueueStatus,
|
||||
)
|
||||
from src.server.services.anime_service import AnimeService, AnimeServiceError
|
||||
from src.server.services.progress_service import (
|
||||
ProgressService,
|
||||
ProgressType,
|
||||
get_progress_service,
|
||||
)
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
class DownloadServiceError(Exception):
|
||||
"""Service-level exception for download queue operations."""
|
||||
|
||||
|
||||
class DownloadService:
|
||||
"""Manages the download queue with concurrent processing and persistence.
|
||||
|
||||
Features:
|
||||
- Priority-based queue management
|
||||
- Concurrent download processing
|
||||
- Real-time progress tracking
|
||||
- Queue persistence and recovery
|
||||
- Automatic retry logic
|
||||
- WebSocket broadcast support
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
anime_service: AnimeService,
|
||||
max_concurrent_downloads: int = 2,
|
||||
max_retries: int = 3,
|
||||
persistence_path: str = "./data/download_queue.json",
|
||||
progress_service: Optional[ProgressService] = None,
|
||||
):
|
||||
"""Initialize the download service.
|
||||
|
||||
Args:
|
||||
anime_service: Service for anime operations
|
||||
max_concurrent_downloads: Maximum simultaneous downloads
|
||||
max_retries: Maximum retry attempts for failed downloads
|
||||
persistence_path: Path to persist queue state
|
||||
progress_service: Optional progress service for tracking
|
||||
"""
|
||||
self._anime_service = anime_service
|
||||
self._max_concurrent = max_concurrent_downloads
|
||||
self._max_retries = max_retries
|
||||
self._persistence_path = Path(persistence_path)
|
||||
self._progress_service = progress_service or get_progress_service()
|
||||
|
||||
# Queue storage by status
|
||||
self._pending_queue: deque[DownloadItem] = deque()
|
||||
self._active_downloads: Dict[str, DownloadItem] = {}
|
||||
self._completed_items: deque[DownloadItem] = deque(maxlen=100)
|
||||
self._failed_items: deque[DownloadItem] = deque(maxlen=50)
|
||||
|
||||
# Control flags
|
||||
self._is_running = False
|
||||
self._is_paused = False
|
||||
self._shutdown_event = asyncio.Event()
|
||||
|
||||
# Executor for blocking operations
|
||||
self._executor = ThreadPoolExecutor(
|
||||
max_workers=max_concurrent_downloads
|
||||
)
|
||||
|
||||
# WebSocket broadcast callback
|
||||
self._broadcast_callback: Optional[Callable] = None
|
||||
|
||||
# Statistics tracking
|
||||
self._total_downloaded_mb: float = 0.0
|
||||
self._download_speeds: deque[float] = deque(maxlen=10)
|
||||
|
||||
# Load persisted queue
|
||||
self._load_queue()
|
||||
|
||||
logger.info(
|
||||
"DownloadService initialized",
|
||||
max_concurrent=max_concurrent_downloads,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
|
||||
def set_broadcast_callback(self, callback: Callable) -> None:
|
||||
"""Set callback for broadcasting status updates via WebSocket."""
|
||||
self._broadcast_callback = callback
|
||||
logger.debug("Broadcast callback registered")
|
||||
|
||||
async def _broadcast_update(self, update_type: str, data: dict) -> None:
|
||||
"""Broadcast update to connected WebSocket clients.
|
||||
|
||||
Args:
|
||||
update_type: Type of update (download_progress, queue_status, etc.)
|
||||
data: Update data to broadcast
|
||||
"""
|
||||
if self._broadcast_callback:
|
||||
try:
|
||||
await self._broadcast_callback(update_type, data)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to broadcast update",
|
||||
update_type=update_type,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
def _generate_item_id(self) -> str:
|
||||
"""Generate unique identifier for download items."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
def _load_queue(self) -> None:
|
||||
"""Load persisted queue from disk."""
|
||||
try:
|
||||
if self._persistence_path.exists():
|
||||
with open(self._persistence_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Restore pending items
|
||||
for item_dict in data.get("pending", []):
|
||||
item = DownloadItem(**item_dict)
|
||||
# Reset status if was downloading when saved
|
||||
if item.status == DownloadStatus.DOWNLOADING:
|
||||
item.status = DownloadStatus.PENDING
|
||||
self._pending_queue.append(item)
|
||||
|
||||
# Restore failed items that can be retried
|
||||
for item_dict in data.get("failed", []):
|
||||
item = DownloadItem(**item_dict)
|
||||
if item.retry_count < self._max_retries:
|
||||
item.status = DownloadStatus.PENDING
|
||||
self._pending_queue.append(item)
|
||||
else:
|
||||
self._failed_items.append(item)
|
||||
|
||||
logger.info(
|
||||
"Queue restored from disk",
|
||||
pending_count=len(self._pending_queue),
|
||||
failed_count=len(self._failed_items),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to load persisted queue", error=str(e))
|
||||
|
||||
def _save_queue(self) -> None:
|
||||
"""Persist current queue state to disk."""
|
||||
try:
|
||||
self._persistence_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
data = {
|
||||
"pending": [
|
||||
item.model_dump(mode="json")
|
||||
for item in self._pending_queue
|
||||
],
|
||||
"active": [
|
||||
item.model_dump(mode="json")
|
||||
for item in self._active_downloads.values()
|
||||
],
|
||||
"failed": [
|
||||
item.model_dump(mode="json")
|
||||
for item in self._failed_items
|
||||
],
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
}
|
||||
|
||||
with open(self._persistence_path, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
logger.debug("Queue persisted to disk")
|
||||
except Exception as e:
|
||||
logger.error("Failed to persist queue", error=str(e))
|
||||
|
||||
async def add_to_queue(
|
||||
self,
|
||||
serie_id: str,
|
||||
serie_name: str,
|
||||
episodes: List[EpisodeIdentifier],
|
||||
priority: DownloadPriority = DownloadPriority.NORMAL,
|
||||
) -> List[str]:
|
||||
"""Add episodes to the download queue.
|
||||
|
||||
Args:
|
||||
serie_id: Series identifier
|
||||
serie_name: Series display name
|
||||
episodes: List of episodes to download
|
||||
priority: Queue priority level
|
||||
|
||||
Returns:
|
||||
List of created download item IDs
|
||||
|
||||
Raises:
|
||||
DownloadServiceError: If adding items fails
|
||||
"""
|
||||
created_ids = []
|
||||
|
||||
try:
|
||||
for episode in episodes:
|
||||
item = DownloadItem(
|
||||
id=self._generate_item_id(),
|
||||
serie_id=serie_id,
|
||||
serie_name=serie_name,
|
||||
episode=episode,
|
||||
status=DownloadStatus.PENDING,
|
||||
priority=priority,
|
||||
added_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
# Insert based on priority
|
||||
if priority == DownloadPriority.HIGH:
|
||||
self._pending_queue.appendleft(item)
|
||||
else:
|
||||
self._pending_queue.append(item)
|
||||
|
||||
created_ids.append(item.id)
|
||||
|
||||
logger.info(
|
||||
"Item added to queue",
|
||||
item_id=item.id,
|
||||
serie=serie_name,
|
||||
season=episode.season,
|
||||
episode=episode.episode,
|
||||
priority=priority.value,
|
||||
)
|
||||
|
||||
self._save_queue()
|
||||
|
||||
# Broadcast queue status update
|
||||
queue_status = await self.get_queue_status()
|
||||
await self._broadcast_update(
|
||||
"queue_status",
|
||||
{
|
||||
"action": "items_added",
|
||||
"added_ids": created_ids,
|
||||
"queue_status": queue_status.model_dump(mode="json"),
|
||||
},
|
||||
)
|
||||
|
||||
return created_ids
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to add items to queue", error=str(e))
|
||||
raise DownloadServiceError(f"Failed to add items: {str(e)}") from e
|
||||
|
||||
async def remove_from_queue(self, item_ids: List[str]) -> List[str]:
|
||||
"""Remove items from the queue.
|
||||
|
||||
Args:
|
||||
item_ids: List of download item IDs to remove
|
||||
|
||||
Returns:
|
||||
List of successfully removed item IDs
|
||||
|
||||
Raises:
|
||||
DownloadServiceError: If removal fails
|
||||
"""
|
||||
removed_ids = []
|
||||
|
||||
try:
|
||||
for item_id in item_ids:
|
||||
# Check if item is currently downloading
|
||||
if item_id in self._active_downloads:
|
||||
item = self._active_downloads[item_id]
|
||||
item.status = DownloadStatus.CANCELLED
|
||||
item.completed_at = datetime.utcnow()
|
||||
self._failed_items.append(item)
|
||||
del self._active_downloads[item_id]
|
||||
removed_ids.append(item_id)
|
||||
logger.info("Cancelled active download", item_id=item_id)
|
||||
continue
|
||||
|
||||
# Check pending queue
|
||||
for item in list(self._pending_queue):
|
||||
if item.id == item_id:
|
||||
self._pending_queue.remove(item)
|
||||
removed_ids.append(item_id)
|
||||
logger.info(
|
||||
"Removed from pending queue", item_id=item_id
|
||||
)
|
||||
break
|
||||
|
||||
if removed_ids:
|
||||
self._save_queue()
|
||||
# Broadcast queue status update
|
||||
queue_status = await self.get_queue_status()
|
||||
await self._broadcast_update(
|
||||
"queue_status",
|
||||
{
|
||||
"action": "items_removed",
|
||||
"removed_ids": removed_ids,
|
||||
"queue_status": queue_status.model_dump(mode="json"),
|
||||
},
|
||||
)
|
||||
|
||||
return removed_ids
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to remove items", error=str(e))
|
||||
raise DownloadServiceError(
|
||||
f"Failed to remove items: {str(e)}"
|
||||
) from e
|
||||
|
||||
async def reorder_queue(self, item_id: str, new_position: int) -> bool:
|
||||
"""Reorder an item in the pending queue.
|
||||
|
||||
Args:
|
||||
item_id: Download item ID to reorder
|
||||
new_position: New position in queue (0-based)
|
||||
|
||||
Returns:
|
||||
True if reordering was successful
|
||||
|
||||
Raises:
|
||||
DownloadServiceError: If reordering fails
|
||||
"""
|
||||
try:
|
||||
# Find and remove item
|
||||
item_to_move = None
|
||||
for item in list(self._pending_queue):
|
||||
if item.id == item_id:
|
||||
self._pending_queue.remove(item)
|
||||
item_to_move = item
|
||||
break
|
||||
|
||||
if not item_to_move:
|
||||
raise DownloadServiceError(
|
||||
f"Item {item_id} not found in pending queue"
|
||||
)
|
||||
|
||||
# Insert at new position
|
||||
queue_list = list(self._pending_queue)
|
||||
new_position = max(0, min(new_position, len(queue_list)))
|
||||
queue_list.insert(new_position, item_to_move)
|
||||
self._pending_queue = deque(queue_list)
|
||||
|
||||
self._save_queue()
|
||||
|
||||
# Broadcast queue status update
|
||||
queue_status = await self.get_queue_status()
|
||||
await self._broadcast_update(
|
||||
"queue_status",
|
||||
{
|
||||
"action": "queue_reordered",
|
||||
"item_id": item_id,
|
||||
"new_position": new_position,
|
||||
"queue_status": queue_status.model_dump(mode="json"),
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Queue item reordered",
|
||||
item_id=item_id,
|
||||
new_position=new_position
|
||||
)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to reorder queue", error=str(e))
|
||||
raise DownloadServiceError(
|
||||
f"Failed to reorder: {str(e)}"
|
||||
) from e
|
||||
|
||||
async def get_queue_status(self) -> QueueStatus:
|
||||
"""Get current status of all queues.
|
||||
|
||||
Returns:
|
||||
Complete queue status with all items
|
||||
"""
|
||||
return QueueStatus(
|
||||
is_running=self._is_running,
|
||||
is_paused=self._is_paused,
|
||||
active_downloads=list(self._active_downloads.values()),
|
||||
pending_queue=list(self._pending_queue),
|
||||
completed_downloads=list(self._completed_items),
|
||||
failed_downloads=list(self._failed_items),
|
||||
)
|
||||
|
||||
async def get_queue_stats(self) -> QueueStats:
|
||||
"""Calculate queue statistics.
|
||||
|
||||
Returns:
|
||||
Statistics about the download queue
|
||||
"""
|
||||
active_count = len(self._active_downloads)
|
||||
pending_count = len(self._pending_queue)
|
||||
completed_count = len(self._completed_items)
|
||||
failed_count = len(self._failed_items)
|
||||
|
||||
# Calculate average speed
|
||||
avg_speed = None
|
||||
if self._download_speeds:
|
||||
avg_speed = (
|
||||
sum(self._download_speeds) / len(self._download_speeds)
|
||||
)
|
||||
|
||||
# Estimate remaining time
|
||||
eta_seconds = None
|
||||
if avg_speed and avg_speed > 0 and pending_count > 0:
|
||||
# Rough estimation based on average file size
|
||||
estimated_size_per_episode = 500 # MB
|
||||
remaining_mb = pending_count * estimated_size_per_episode
|
||||
eta_seconds = int(remaining_mb / avg_speed)
|
||||
|
||||
return QueueStats(
|
||||
total_items=(
|
||||
active_count + pending_count + completed_count + failed_count
|
||||
),
|
||||
pending_count=pending_count,
|
||||
active_count=active_count,
|
||||
completed_count=completed_count,
|
||||
failed_count=failed_count,
|
||||
total_downloaded_mb=self._total_downloaded_mb,
|
||||
average_speed_mbps=avg_speed,
|
||||
estimated_time_remaining=eta_seconds,
|
||||
)
|
||||
|
||||
async def pause_queue(self) -> None:
|
||||
"""Pause download processing."""
|
||||
self._is_paused = True
|
||||
logger.info("Download queue paused")
|
||||
|
||||
# Broadcast queue status update
|
||||
queue_status = await self.get_queue_status()
|
||||
await self._broadcast_update(
|
||||
"queue_paused",
|
||||
{
|
||||
"is_paused": True,
|
||||
"queue_status": queue_status.model_dump(mode="json"),
|
||||
},
|
||||
)
|
||||
|
||||
async def resume_queue(self) -> None:
|
||||
"""Resume download processing."""
|
||||
self._is_paused = False
|
||||
logger.info("Download queue resumed")
|
||||
|
||||
# Broadcast queue status update
|
||||
queue_status = await self.get_queue_status()
|
||||
await self._broadcast_update(
|
||||
"queue_resumed",
|
||||
{
|
||||
"is_paused": False,
|
||||
"queue_status": queue_status.model_dump(mode="json"),
|
||||
},
|
||||
)
|
||||
|
||||
async def clear_completed(self) -> int:
|
||||
"""Clear completed downloads from history.
|
||||
|
||||
Returns:
|
||||
Number of items cleared
|
||||
"""
|
||||
count = len(self._completed_items)
|
||||
self._completed_items.clear()
|
||||
logger.info("Cleared completed items", count=count)
|
||||
|
||||
# Broadcast queue status update
|
||||
if count > 0:
|
||||
queue_status = await self.get_queue_status()
|
||||
await self._broadcast_update(
|
||||
"queue_status",
|
||||
{
|
||||
"action": "completed_cleared",
|
||||
"cleared_count": count,
|
||||
"queue_status": queue_status.model_dump(mode="json"),
|
||||
},
|
||||
)
|
||||
|
||||
return count
|
||||
|
||||
async def retry_failed(
|
||||
self, item_ids: Optional[List[str]] = None
|
||||
) -> List[str]:
|
||||
"""Retry failed downloads.
|
||||
|
||||
Args:
|
||||
item_ids: Specific item IDs to retry, or None for all failed items
|
||||
|
||||
Returns:
|
||||
List of item IDs moved back to pending queue
|
||||
"""
|
||||
retried_ids = []
|
||||
|
||||
try:
|
||||
failed_list = list(self._failed_items)
|
||||
|
||||
for item in failed_list:
|
||||
# Skip if specific IDs requested and this isn't one
|
||||
if item_ids and item.id not in item_ids:
|
||||
continue
|
||||
|
||||
# Skip if max retries reached
|
||||
if item.retry_count >= self._max_retries:
|
||||
continue
|
||||
|
||||
# Move back to pending
|
||||
self._failed_items.remove(item)
|
||||
item.status = DownloadStatus.PENDING
|
||||
item.retry_count += 1
|
||||
item.error = None
|
||||
item.progress = None
|
||||
self._pending_queue.append(item)
|
||||
retried_ids.append(item.id)
|
||||
|
||||
logger.info(
|
||||
"Retrying failed item",
|
||||
item_id=item.id,
|
||||
retry_count=item.retry_count
|
||||
)
|
||||
|
||||
if retried_ids:
|
||||
self._save_queue()
|
||||
# Broadcast queue status update
|
||||
queue_status = await self.get_queue_status()
|
||||
await self._broadcast_update(
|
||||
"queue_status",
|
||||
{
|
||||
"action": "items_retried",
|
||||
"retried_ids": retried_ids,
|
||||
"queue_status": queue_status.model_dump(mode="json"),
|
||||
},
|
||||
)
|
||||
|
||||
return retried_ids
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to retry items", error=str(e))
|
||||
raise DownloadServiceError(
|
||||
f"Failed to retry: {str(e)}"
|
||||
) from e
|
||||
|
||||
def _create_progress_callback(self, item: DownloadItem) -> Callable:
|
||||
"""Create a progress callback for a download item.
|
||||
|
||||
Args:
|
||||
item: Download item to track progress for
|
||||
|
||||
Returns:
|
||||
Callback function for progress updates
|
||||
"""
|
||||
def progress_callback(progress_data: dict) -> None:
|
||||
"""Update progress and broadcast to clients."""
|
||||
try:
|
||||
# Update item progress
|
||||
item.progress = DownloadProgress(
|
||||
percent=progress_data.get("percent", 0.0),
|
||||
downloaded_mb=progress_data.get("downloaded_mb", 0.0),
|
||||
total_mb=progress_data.get("total_mb"),
|
||||
speed_mbps=progress_data.get("speed_mbps"),
|
||||
eta_seconds=progress_data.get("eta_seconds"),
|
||||
)
|
||||
|
||||
# Track speed for statistics
|
||||
if item.progress.speed_mbps:
|
||||
self._download_speeds.append(item.progress.speed_mbps)
|
||||
|
||||
# Update progress service
|
||||
if item.progress.total_mb and item.progress.total_mb > 0:
|
||||
current_mb = int(item.progress.downloaded_mb)
|
||||
total_mb = int(item.progress.total_mb)
|
||||
|
||||
asyncio.create_task(
|
||||
self._progress_service.update_progress(
|
||||
progress_id=f"download_{item.id}",
|
||||
current=current_mb,
|
||||
total=total_mb,
|
||||
metadata={
|
||||
"speed_mbps": item.progress.speed_mbps,
|
||||
"eta_seconds": item.progress.eta_seconds,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Broadcast update (fire and forget)
|
||||
asyncio.create_task(
|
||||
self._broadcast_update(
|
||||
"download_progress",
|
||||
{
|
||||
"download_id": item.id,
|
||||
"item_id": item.id,
|
||||
"serie_name": item.serie_name,
|
||||
"season": item.episode.season,
|
||||
"episode": item.episode.episode,
|
||||
"progress": item.progress.model_dump(mode="json"),
|
||||
},
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Progress callback error", error=str(e))
|
||||
|
||||
return progress_callback
|
||||
|
||||
async def _process_download(self, item: DownloadItem) -> None:
|
||||
"""Process a single download item.
|
||||
|
||||
Args:
|
||||
item: Download item to process
|
||||
"""
|
||||
try:
|
||||
# Update status
|
||||
item.status = DownloadStatus.DOWNLOADING
|
||||
item.started_at = datetime.utcnow()
|
||||
self._active_downloads[item.id] = item
|
||||
|
||||
logger.info(
|
||||
"Starting download",
|
||||
item_id=item.id,
|
||||
serie=item.serie_name,
|
||||
season=item.episode.season,
|
||||
episode=item.episode.episode,
|
||||
)
|
||||
|
||||
# Start progress tracking
|
||||
await self._progress_service.start_progress(
|
||||
progress_id=f"download_{item.id}",
|
||||
progress_type=ProgressType.DOWNLOAD,
|
||||
title=f"Downloading {item.serie_name}",
|
||||
message=(
|
||||
f"S{item.episode.season:02d}E{item.episode.episode:02d}"
|
||||
),
|
||||
metadata={
|
||||
"item_id": item.id,
|
||||
"serie_name": item.serie_name,
|
||||
"season": item.episode.season,
|
||||
"episode": item.episode.episode,
|
||||
},
|
||||
)
|
||||
|
||||
# Create progress callback
|
||||
progress_callback = self._create_progress_callback(item)
|
||||
|
||||
# Execute download via anime service
|
||||
success = await self._anime_service.download(
|
||||
serie_folder=item.serie_id,
|
||||
season=item.episode.season,
|
||||
episode=item.episode.episode,
|
||||
key=item.serie_id, # Assuming serie_id is the provider key
|
||||
callback=progress_callback,
|
||||
)
|
||||
|
||||
# Handle result
|
||||
if success:
|
||||
item.status = DownloadStatus.COMPLETED
|
||||
item.completed_at = datetime.utcnow()
|
||||
|
||||
# Track downloaded size
|
||||
if item.progress and item.progress.downloaded_mb:
|
||||
self._total_downloaded_mb += item.progress.downloaded_mb
|
||||
|
||||
self._completed_items.append(item)
|
||||
|
||||
logger.info(
|
||||
"Download completed successfully", item_id=item.id
|
||||
)
|
||||
|
||||
# Complete progress tracking
|
||||
await self._progress_service.complete_progress(
|
||||
progress_id=f"download_{item.id}",
|
||||
message="Download completed successfully",
|
||||
metadata={
|
||||
"downloaded_mb": item.progress.downloaded_mb
|
||||
if item.progress
|
||||
else 0,
|
||||
},
|
||||
)
|
||||
|
||||
await self._broadcast_update(
|
||||
"download_complete",
|
||||
{
|
||||
"download_id": item.id,
|
||||
"item_id": item.id,
|
||||
"serie_name": item.serie_name,
|
||||
"season": item.episode.season,
|
||||
"episode": item.episode.episode,
|
||||
"downloaded_mb": item.progress.downloaded_mb
|
||||
if item.progress
|
||||
else 0,
|
||||
},
|
||||
)
|
||||
else:
|
||||
raise AnimeServiceError("Download returned False")
|
||||
|
||||
except Exception as e:
|
||||
# Handle failure
|
||||
item.status = DownloadStatus.FAILED
|
||||
item.completed_at = datetime.utcnow()
|
||||
item.error = str(e)
|
||||
self._failed_items.append(item)
|
||||
|
||||
logger.error(
|
||||
"Download failed",
|
||||
item_id=item.id,
|
||||
error=str(e),
|
||||
retry_count=item.retry_count,
|
||||
)
|
||||
|
||||
# Fail progress tracking
|
||||
await self._progress_service.fail_progress(
|
||||
progress_id=f"download_{item.id}",
|
||||
error_message=str(e),
|
||||
metadata={"retry_count": item.retry_count},
|
||||
)
|
||||
|
||||
await self._broadcast_update(
|
||||
"download_failed",
|
||||
{
|
||||
"download_id": item.id,
|
||||
"item_id": item.id,
|
||||
"serie_name": item.serie_name,
|
||||
"season": item.episode.season,
|
||||
"episode": item.episode.episode,
|
||||
"error": item.error,
|
||||
"retry_count": item.retry_count,
|
||||
},
|
||||
)
|
||||
|
||||
finally:
|
||||
# Remove from active downloads
|
||||
if item.id in self._active_downloads:
|
||||
del self._active_downloads[item.id]
|
||||
|
||||
self._save_queue()
|
||||
|
||||
async def _queue_processor(self) -> None:
|
||||
"""Main queue processing loop."""
|
||||
logger.info("Queue processor started")
|
||||
|
||||
while not self._shutdown_event.is_set():
|
||||
try:
|
||||
# Wait if paused
|
||||
if self._is_paused:
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
# Check if we can start more downloads
|
||||
if len(self._active_downloads) >= self._max_concurrent:
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
# Get next item from queue
|
||||
if not self._pending_queue:
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
item = self._pending_queue.popleft()
|
||||
|
||||
# Process download in background
|
||||
asyncio.create_task(self._process_download(item))
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Queue processor error", error=str(e))
|
||||
await asyncio.sleep(5)
|
||||
|
||||
logger.info("Queue processor stopped")
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the download queue processor."""
|
||||
if self._is_running:
|
||||
logger.warning("Queue processor already running")
|
||||
return
|
||||
|
||||
self._is_running = True
|
||||
self._shutdown_event.clear()
|
||||
|
||||
# Start processor task
|
||||
asyncio.create_task(self._queue_processor())
|
||||
|
||||
logger.info("Download queue service started")
|
||||
|
||||
# Broadcast queue started event
|
||||
queue_status = await self.get_queue_status()
|
||||
await self._broadcast_update(
|
||||
"queue_started",
|
||||
{
|
||||
"is_running": True,
|
||||
"queue_status": queue_status.model_dump(mode="json"),
|
||||
},
|
||||
)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the download queue processor."""
|
||||
if not self._is_running:
|
||||
return
|
||||
|
||||
logger.info("Stopping download queue service...")
|
||||
|
||||
self._is_running = False
|
||||
self._shutdown_event.set()
|
||||
|
||||
# Wait for active downloads to complete (with timeout)
|
||||
timeout = 30 # seconds
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
|
||||
while (
|
||||
self._active_downloads
|
||||
and (asyncio.get_event_loop().time() - start_time) < timeout
|
||||
):
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Save final state
|
||||
self._save_queue()
|
||||
|
||||
# Shutdown executor
|
||||
self._executor.shutdown(wait=True)
|
||||
|
||||
logger.info("Download queue service stopped")
|
||||
|
||||
# Broadcast queue stopped event
|
||||
queue_status = await self.get_queue_status()
|
||||
await self._broadcast_update(
|
||||
"queue_stopped",
|
||||
{
|
||||
"is_running": False,
|
||||
"queue_status": queue_status.model_dump(mode="json"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_download_service_instance: Optional[DownloadService] = None
|
||||
|
||||
|
||||
def get_download_service(anime_service: AnimeService) -> DownloadService:
|
||||
"""Factory function for FastAPI dependency injection.
|
||||
|
||||
Args:
|
||||
anime_service: AnimeService instance
|
||||
|
||||
Returns:
|
||||
Singleton DownloadService instance
|
||||
"""
|
||||
global _download_service_instance
|
||||
|
||||
if _download_service_instance is None:
|
||||
_download_service_instance = DownloadService(anime_service)
|
||||
|
||||
return _download_service_instance
|
||||
485
src/server/services/progress_service.py
Normal file
485
src/server/services/progress_service.py
Normal file
@ -0,0 +1,485 @@
|
||||
"""Progress service for managing real-time progress updates.
|
||||
|
||||
This module provides a centralized service for tracking and broadcasting
|
||||
real-time progress updates for downloads, scans, queue changes, and
|
||||
system events. It integrates with the WebSocket service to push updates
|
||||
to connected clients.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
class ProgressType(str, Enum):
|
||||
"""Types of progress updates."""
|
||||
|
||||
DOWNLOAD = "download"
|
||||
SCAN = "scan"
|
||||
QUEUE = "queue"
|
||||
SYSTEM = "system"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class ProgressStatus(str, Enum):
|
||||
"""Status of a progress operation."""
|
||||
|
||||
STARTED = "started"
|
||||
IN_PROGRESS = "in_progress"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProgressUpdate:
|
||||
"""Represents a progress update event.
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier for this progress operation
|
||||
type: Type of progress (download, scan, etc.)
|
||||
status: Current status of the operation
|
||||
title: Human-readable title
|
||||
message: Detailed message
|
||||
percent: Completion percentage (0-100)
|
||||
current: Current progress value
|
||||
total: Total progress value
|
||||
metadata: Additional metadata
|
||||
started_at: When operation started
|
||||
updated_at: When last updated
|
||||
"""
|
||||
|
||||
id: str
|
||||
type: ProgressType
|
||||
status: ProgressStatus
|
||||
title: str
|
||||
message: str = ""
|
||||
percent: float = 0.0
|
||||
current: int = 0
|
||||
total: int = 0
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
started_at: datetime = field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = field(default_factory=datetime.utcnow)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert progress update to dictionary."""
|
||||
return {
|
||||
"id": self.id,
|
||||
"type": self.type.value,
|
||||
"status": self.status.value,
|
||||
"title": self.title,
|
||||
"message": self.message,
|
||||
"percent": round(self.percent, 2),
|
||||
"current": self.current,
|
||||
"total": self.total,
|
||||
"metadata": self.metadata,
|
||||
"started_at": self.started_at.isoformat(),
|
||||
"updated_at": self.updated_at.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
class ProgressServiceError(Exception):
|
||||
"""Service-level exception for progress operations."""
|
||||
|
||||
|
||||
class ProgressService:
|
||||
"""Manages real-time progress updates and broadcasting.
|
||||
|
||||
Features:
|
||||
- Track multiple concurrent progress operations
|
||||
- Calculate progress percentages and rates
|
||||
- Broadcast updates via WebSocket
|
||||
- Manage progress lifecycle (start, update, complete, fail)
|
||||
- Support for different progress types (download, scan, queue)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the progress service."""
|
||||
# Active progress operations: id -> ProgressUpdate
|
||||
self._active_progress: Dict[str, ProgressUpdate] = {}
|
||||
|
||||
# Completed progress history (limited size)
|
||||
self._history: Dict[str, ProgressUpdate] = {}
|
||||
self._max_history_size = 50
|
||||
|
||||
# WebSocket broadcast callback
|
||||
self._broadcast_callback: Optional[Callable] = None
|
||||
|
||||
# Lock for thread-safe operations
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
logger.info("ProgressService initialized")
|
||||
|
||||
def set_broadcast_callback(self, callback: Callable) -> None:
|
||||
"""Set callback for broadcasting progress updates via WebSocket.
|
||||
|
||||
Args:
|
||||
callback: Async function to call for broadcasting updates
|
||||
"""
|
||||
self._broadcast_callback = callback
|
||||
logger.debug("Progress broadcast callback registered")
|
||||
|
||||
async def _broadcast(self, update: ProgressUpdate, room: str) -> None:
|
||||
"""Broadcast progress update to WebSocket clients.
|
||||
|
||||
Args:
|
||||
update: Progress update to broadcast
|
||||
room: WebSocket room to broadcast to
|
||||
"""
|
||||
if self._broadcast_callback:
|
||||
try:
|
||||
await self._broadcast_callback(
|
||||
message_type=f"{update.type.value}_progress",
|
||||
data=update.to_dict(),
|
||||
room=room,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to broadcast progress update",
|
||||
error=str(e),
|
||||
progress_id=update.id,
|
||||
)
|
||||
|
||||
async def start_progress(
|
||||
self,
|
||||
progress_id: str,
|
||||
progress_type: ProgressType,
|
||||
title: str,
|
||||
total: int = 0,
|
||||
message: str = "",
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> ProgressUpdate:
|
||||
"""Start a new progress operation.
|
||||
|
||||
Args:
|
||||
progress_id: Unique identifier for this progress
|
||||
progress_type: Type of progress operation
|
||||
title: Human-readable title
|
||||
total: Total items/bytes to process
|
||||
message: Initial message
|
||||
metadata: Additional metadata
|
||||
|
||||
Returns:
|
||||
Created progress update object
|
||||
|
||||
Raises:
|
||||
ProgressServiceError: If progress already exists
|
||||
"""
|
||||
async with self._lock:
|
||||
if progress_id in self._active_progress:
|
||||
raise ProgressServiceError(
|
||||
f"Progress with id '{progress_id}' already exists"
|
||||
)
|
||||
|
||||
update = ProgressUpdate(
|
||||
id=progress_id,
|
||||
type=progress_type,
|
||||
status=ProgressStatus.STARTED,
|
||||
title=title,
|
||||
message=message,
|
||||
total=total,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
self._active_progress[progress_id] = update
|
||||
|
||||
logger.info(
|
||||
"Progress started",
|
||||
progress_id=progress_id,
|
||||
type=progress_type.value,
|
||||
title=title,
|
||||
)
|
||||
|
||||
# Broadcast to appropriate room
|
||||
room = f"{progress_type.value}_progress"
|
||||
await self._broadcast(update, room)
|
||||
|
||||
return update
|
||||
|
||||
async def update_progress(
|
||||
self,
|
||||
progress_id: str,
|
||||
current: Optional[int] = None,
|
||||
total: Optional[int] = None,
|
||||
message: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
force_broadcast: bool = False,
|
||||
) -> ProgressUpdate:
|
||||
"""Update an existing progress operation.
|
||||
|
||||
Args:
|
||||
progress_id: Progress identifier
|
||||
current: Current progress value
|
||||
total: Updated total value
|
||||
message: Updated message
|
||||
metadata: Additional metadata to merge
|
||||
force_broadcast: Force broadcasting even for small changes
|
||||
|
||||
Returns:
|
||||
Updated progress object
|
||||
|
||||
Raises:
|
||||
ProgressServiceError: If progress not found
|
||||
"""
|
||||
async with self._lock:
|
||||
if progress_id not in self._active_progress:
|
||||
raise ProgressServiceError(
|
||||
f"Progress with id '{progress_id}' not found"
|
||||
)
|
||||
|
||||
update = self._active_progress[progress_id]
|
||||
old_percent = update.percent
|
||||
|
||||
# Update fields
|
||||
if current is not None:
|
||||
update.current = current
|
||||
if total is not None:
|
||||
update.total = total
|
||||
if message is not None:
|
||||
update.message = message
|
||||
if metadata:
|
||||
update.metadata.update(metadata)
|
||||
|
||||
# Calculate percentage
|
||||
if update.total > 0:
|
||||
update.percent = (update.current / update.total) * 100
|
||||
else:
|
||||
update.percent = 0.0
|
||||
|
||||
update.status = ProgressStatus.IN_PROGRESS
|
||||
update.updated_at = datetime.utcnow()
|
||||
|
||||
# Only broadcast if significant change or forced
|
||||
percent_change = abs(update.percent - old_percent)
|
||||
should_broadcast = force_broadcast or percent_change >= 1.0
|
||||
|
||||
if should_broadcast:
|
||||
room = f"{update.type.value}_progress"
|
||||
await self._broadcast(update, room)
|
||||
|
||||
return update
|
||||
|
||||
async def complete_progress(
|
||||
self,
|
||||
progress_id: str,
|
||||
message: str = "Completed successfully",
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> ProgressUpdate:
|
||||
"""Mark a progress operation as completed.
|
||||
|
||||
Args:
|
||||
progress_id: Progress identifier
|
||||
message: Completion message
|
||||
metadata: Additional metadata
|
||||
|
||||
Returns:
|
||||
Completed progress object
|
||||
|
||||
Raises:
|
||||
ProgressServiceError: If progress not found
|
||||
"""
|
||||
async with self._lock:
|
||||
if progress_id not in self._active_progress:
|
||||
raise ProgressServiceError(
|
||||
f"Progress with id '{progress_id}' not found"
|
||||
)
|
||||
|
||||
update = self._active_progress[progress_id]
|
||||
update.status = ProgressStatus.COMPLETED
|
||||
update.message = message
|
||||
update.percent = 100.0
|
||||
update.current = update.total
|
||||
update.updated_at = datetime.utcnow()
|
||||
|
||||
if metadata:
|
||||
update.metadata.update(metadata)
|
||||
|
||||
# Move to history
|
||||
del self._active_progress[progress_id]
|
||||
self._add_to_history(update)
|
||||
|
||||
logger.info(
|
||||
"Progress completed",
|
||||
progress_id=progress_id,
|
||||
type=update.type.value,
|
||||
)
|
||||
|
||||
# Broadcast completion
|
||||
room = f"{update.type.value}_progress"
|
||||
await self._broadcast(update, room)
|
||||
|
||||
return update
|
||||
|
||||
async def fail_progress(
|
||||
self,
|
||||
progress_id: str,
|
||||
error_message: str,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> ProgressUpdate:
|
||||
"""Mark a progress operation as failed.
|
||||
|
||||
Args:
|
||||
progress_id: Progress identifier
|
||||
error_message: Error description
|
||||
metadata: Additional error metadata
|
||||
|
||||
Returns:
|
||||
Failed progress object
|
||||
|
||||
Raises:
|
||||
ProgressServiceError: If progress not found
|
||||
"""
|
||||
async with self._lock:
|
||||
if progress_id not in self._active_progress:
|
||||
raise ProgressServiceError(
|
||||
f"Progress with id '{progress_id}' not found"
|
||||
)
|
||||
|
||||
update = self._active_progress[progress_id]
|
||||
update.status = ProgressStatus.FAILED
|
||||
update.message = error_message
|
||||
update.updated_at = datetime.utcnow()
|
||||
|
||||
if metadata:
|
||||
update.metadata.update(metadata)
|
||||
|
||||
# Move to history
|
||||
del self._active_progress[progress_id]
|
||||
self._add_to_history(update)
|
||||
|
||||
logger.error(
|
||||
"Progress failed",
|
||||
progress_id=progress_id,
|
||||
type=update.type.value,
|
||||
error=error_message,
|
||||
)
|
||||
|
||||
# Broadcast failure
|
||||
room = f"{update.type.value}_progress"
|
||||
await self._broadcast(update, room)
|
||||
|
||||
return update
|
||||
|
||||
async def cancel_progress(
|
||||
self,
|
||||
progress_id: str,
|
||||
message: str = "Cancelled by user",
|
||||
) -> ProgressUpdate:
|
||||
"""Cancel a progress operation.
|
||||
|
||||
Args:
|
||||
progress_id: Progress identifier
|
||||
message: Cancellation message
|
||||
|
||||
Returns:
|
||||
Cancelled progress object
|
||||
|
||||
Raises:
|
||||
ProgressServiceError: If progress not found
|
||||
"""
|
||||
async with self._lock:
|
||||
if progress_id not in self._active_progress:
|
||||
raise ProgressServiceError(
|
||||
f"Progress with id '{progress_id}' not found"
|
||||
)
|
||||
|
||||
update = self._active_progress[progress_id]
|
||||
update.status = ProgressStatus.CANCELLED
|
||||
update.message = message
|
||||
update.updated_at = datetime.utcnow()
|
||||
|
||||
# Move to history
|
||||
del self._active_progress[progress_id]
|
||||
self._add_to_history(update)
|
||||
|
||||
logger.info(
|
||||
"Progress cancelled",
|
||||
progress_id=progress_id,
|
||||
type=update.type.value,
|
||||
)
|
||||
|
||||
# Broadcast cancellation
|
||||
room = f"{update.type.value}_progress"
|
||||
await self._broadcast(update, room)
|
||||
|
||||
return update
|
||||
|
||||
def _add_to_history(self, update: ProgressUpdate) -> None:
|
||||
"""Add completed progress to history with size limit."""
|
||||
self._history[update.id] = update
|
||||
|
||||
# Maintain history size limit
|
||||
if len(self._history) > self._max_history_size:
|
||||
# Remove oldest entries
|
||||
oldest_keys = sorted(
|
||||
self._history.keys(),
|
||||
key=lambda k: self._history[k].updated_at,
|
||||
)[: len(self._history) - self._max_history_size]
|
||||
|
||||
for key in oldest_keys:
|
||||
del self._history[key]
|
||||
|
||||
async def get_progress(self, progress_id: str) -> Optional[ProgressUpdate]:
|
||||
"""Get current progress state.
|
||||
|
||||
Args:
|
||||
progress_id: Progress identifier
|
||||
|
||||
Returns:
|
||||
Progress update object or None if not found
|
||||
"""
|
||||
async with self._lock:
|
||||
if progress_id in self._active_progress:
|
||||
return self._active_progress[progress_id]
|
||||
if progress_id in self._history:
|
||||
return self._history[progress_id]
|
||||
return None
|
||||
|
||||
async def get_all_active_progress(
|
||||
self, progress_type: Optional[ProgressType] = None
|
||||
) -> Dict[str, ProgressUpdate]:
|
||||
"""Get all active progress operations.
|
||||
|
||||
Args:
|
||||
progress_type: Optional filter by progress type
|
||||
|
||||
Returns:
|
||||
Dictionary of progress_id -> ProgressUpdate
|
||||
"""
|
||||
async with self._lock:
|
||||
if progress_type:
|
||||
return {
|
||||
pid: update
|
||||
for pid, update in self._active_progress.items()
|
||||
if update.type == progress_type
|
||||
}
|
||||
return self._active_progress.copy()
|
||||
|
||||
async def clear_history(self) -> None:
|
||||
"""Clear progress history."""
|
||||
async with self._lock:
|
||||
self._history.clear()
|
||||
logger.info("Progress history cleared")
|
||||
|
||||
|
||||
# Global singleton instance
|
||||
_progress_service: Optional[ProgressService] = None
|
||||
|
||||
|
||||
def get_progress_service() -> ProgressService:
|
||||
"""Get or create the global progress service instance.
|
||||
|
||||
Returns:
|
||||
Global ProgressService instance
|
||||
"""
|
||||
global _progress_service
|
||||
if _progress_service is None:
|
||||
_progress_service = ProgressService()
|
||||
return _progress_service
|
||||
461
src/server/services/websocket_service.py
Normal file
461
src/server/services/websocket_service.py
Normal file
@ -0,0 +1,461 @@
|
||||
"""WebSocket service for real-time communication with clients.
|
||||
|
||||
This module provides a comprehensive WebSocket manager for handling
|
||||
real-time updates, connection management, room-based messaging, and
|
||||
broadcast functionality for the Aniworld web application.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
import structlog
|
||||
from fastapi import WebSocket, WebSocketDisconnect
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
class WebSocketServiceError(Exception):
|
||||
"""Service-level exception for WebSocket operations."""
|
||||
|
||||
|
||||
class ConnectionManager:
|
||||
"""Manages WebSocket connections with room-based messaging support.
|
||||
|
||||
Features:
|
||||
- Connection lifecycle management
|
||||
- Room-based messaging (rooms for specific topics)
|
||||
- Broadcast to all connections or specific rooms
|
||||
- Connection health monitoring
|
||||
- Automatic cleanup on disconnect
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the connection manager."""
|
||||
# Active connections: connection_id -> WebSocket
|
||||
self._active_connections: Dict[str, WebSocket] = {}
|
||||
|
||||
# Room memberships: room_name -> set of connection_ids
|
||||
self._rooms: Dict[str, Set[str]] = defaultdict(set)
|
||||
|
||||
# Connection metadata: connection_id -> metadata dict
|
||||
self._connection_metadata: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
# Lock for thread-safe operations
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
logger.info("ConnectionManager initialized")
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
websocket: WebSocket,
|
||||
connection_id: str,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Accept and register a new WebSocket connection.
|
||||
|
||||
Args:
|
||||
websocket: The WebSocket connection to accept
|
||||
connection_id: Unique identifier for this connection
|
||||
metadata: Optional metadata to associate with the connection
|
||||
"""
|
||||
await websocket.accept()
|
||||
|
||||
async with self._lock:
|
||||
self._active_connections[connection_id] = websocket
|
||||
self._connection_metadata[connection_id] = metadata or {}
|
||||
|
||||
logger.info(
|
||||
"WebSocket connected",
|
||||
connection_id=connection_id,
|
||||
total_connections=len(self._active_connections),
|
||||
)
|
||||
|
||||
async def disconnect(self, connection_id: str) -> None:
|
||||
"""Remove a WebSocket connection and cleanup associated resources.
|
||||
|
||||
Args:
|
||||
connection_id: The connection to remove
|
||||
"""
|
||||
async with self._lock:
|
||||
# Remove from all rooms
|
||||
for room_members in self._rooms.values():
|
||||
room_members.discard(connection_id)
|
||||
|
||||
# Remove empty rooms
|
||||
self._rooms = {
|
||||
room: members
|
||||
for room, members in self._rooms.items()
|
||||
if members
|
||||
}
|
||||
|
||||
# Remove connection and metadata
|
||||
self._active_connections.pop(connection_id, None)
|
||||
self._connection_metadata.pop(connection_id, None)
|
||||
|
||||
logger.info(
|
||||
"WebSocket disconnected",
|
||||
connection_id=connection_id,
|
||||
total_connections=len(self._active_connections),
|
||||
)
|
||||
|
||||
async def join_room(self, connection_id: str, room: str) -> None:
|
||||
"""Add a connection to a room.
|
||||
|
||||
Args:
|
||||
connection_id: The connection to add
|
||||
room: The room name to join
|
||||
"""
|
||||
async with self._lock:
|
||||
if connection_id in self._active_connections:
|
||||
self._rooms[room].add(connection_id)
|
||||
logger.debug(
|
||||
"Connection joined room",
|
||||
connection_id=connection_id,
|
||||
room=room,
|
||||
room_size=len(self._rooms[room]),
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Attempted to join room with inactive connection",
|
||||
connection_id=connection_id,
|
||||
room=room,
|
||||
)
|
||||
|
||||
async def leave_room(self, connection_id: str, room: str) -> None:
|
||||
"""Remove a connection from a room.
|
||||
|
||||
Args:
|
||||
connection_id: The connection to remove
|
||||
room: The room name to leave
|
||||
"""
|
||||
async with self._lock:
|
||||
if room in self._rooms:
|
||||
self._rooms[room].discard(connection_id)
|
||||
|
||||
# Remove empty room
|
||||
if not self._rooms[room]:
|
||||
del self._rooms[room]
|
||||
|
||||
logger.debug(
|
||||
"Connection left room",
|
||||
connection_id=connection_id,
|
||||
room=room,
|
||||
)
|
||||
|
||||
async def send_personal_message(
|
||||
self, message: Dict[str, Any], connection_id: str
|
||||
) -> None:
|
||||
"""Send a message to a specific connection.
|
||||
|
||||
Args:
|
||||
message: The message to send (will be JSON serialized)
|
||||
connection_id: Target connection identifier
|
||||
"""
|
||||
websocket = self._active_connections.get(connection_id)
|
||||
if websocket:
|
||||
try:
|
||||
await websocket.send_json(message)
|
||||
logger.debug(
|
||||
"Personal message sent",
|
||||
connection_id=connection_id,
|
||||
message_type=message.get("type", "unknown"),
|
||||
)
|
||||
except WebSocketDisconnect:
|
||||
logger.warning(
|
||||
"Connection disconnected during send",
|
||||
connection_id=connection_id,
|
||||
)
|
||||
await self.disconnect(connection_id)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to send personal message",
|
||||
connection_id=connection_id,
|
||||
error=str(e),
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Attempted to send message to inactive connection",
|
||||
connection_id=connection_id,
|
||||
)
|
||||
|
||||
async def broadcast(
|
||||
self, message: Dict[str, Any], exclude: Optional[Set[str]] = None
|
||||
) -> None:
|
||||
"""Broadcast a message to all active connections.
|
||||
|
||||
Args:
|
||||
message: The message to broadcast (will be JSON serialized)
|
||||
exclude: Optional set of connection IDs to exclude from broadcast
|
||||
"""
|
||||
exclude = exclude or set()
|
||||
disconnected = []
|
||||
|
||||
for connection_id, websocket in self._active_connections.items():
|
||||
if connection_id in exclude:
|
||||
continue
|
||||
|
||||
try:
|
||||
await websocket.send_json(message)
|
||||
except WebSocketDisconnect:
|
||||
logger.warning(
|
||||
"Connection disconnected during broadcast",
|
||||
connection_id=connection_id,
|
||||
)
|
||||
disconnected.append(connection_id)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to broadcast to connection",
|
||||
connection_id=connection_id,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
# Cleanup disconnected connections
|
||||
for connection_id in disconnected:
|
||||
await self.disconnect(connection_id)
|
||||
|
||||
logger.debug(
|
||||
"Message broadcast",
|
||||
message_type=message.get("type", "unknown"),
|
||||
recipient_count=len(self._active_connections) - len(exclude),
|
||||
failed_count=len(disconnected),
|
||||
)
|
||||
|
||||
async def broadcast_to_room(
|
||||
self, message: Dict[str, Any], room: str
|
||||
) -> None:
|
||||
"""Broadcast a message to all connections in a specific room.
|
||||
|
||||
Args:
|
||||
message: The message to broadcast (will be JSON serialized)
|
||||
room: The room to broadcast to
|
||||
"""
|
||||
room_members = self._rooms.get(room, set()).copy()
|
||||
disconnected = []
|
||||
|
||||
for connection_id in room_members:
|
||||
websocket = self._active_connections.get(connection_id)
|
||||
if not websocket:
|
||||
continue
|
||||
|
||||
try:
|
||||
await websocket.send_json(message)
|
||||
except WebSocketDisconnect:
|
||||
logger.warning(
|
||||
"Connection disconnected during room broadcast",
|
||||
connection_id=connection_id,
|
||||
room=room,
|
||||
)
|
||||
disconnected.append(connection_id)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to broadcast to room member",
|
||||
connection_id=connection_id,
|
||||
room=room,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
# Cleanup disconnected connections
|
||||
for connection_id in disconnected:
|
||||
await self.disconnect(connection_id)
|
||||
|
||||
logger.debug(
|
||||
"Message broadcast to room",
|
||||
room=room,
|
||||
message_type=message.get("type", "unknown"),
|
||||
recipient_count=len(room_members),
|
||||
failed_count=len(disconnected),
|
||||
)
|
||||
|
||||
async def get_connection_count(self) -> int:
|
||||
"""Get the total number of active connections."""
|
||||
return len(self._active_connections)
|
||||
|
||||
async def get_room_members(self, room: str) -> List[str]:
|
||||
"""Get list of connection IDs in a specific room."""
|
||||
return list(self._rooms.get(room, set()))
|
||||
|
||||
async def get_connection_metadata(
|
||||
self, connection_id: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Get metadata associated with a connection."""
|
||||
return self._connection_metadata.get(connection_id)
|
||||
|
||||
async def update_connection_metadata(
|
||||
self, connection_id: str, metadata: Dict[str, Any]
|
||||
) -> None:
|
||||
"""Update metadata for a connection."""
|
||||
if connection_id in self._active_connections:
|
||||
async with self._lock:
|
||||
self._connection_metadata[connection_id].update(metadata)
|
||||
else:
|
||||
logger.warning(
|
||||
"Attempted to update metadata for inactive connection",
|
||||
connection_id=connection_id,
|
||||
)
|
||||
|
||||
|
||||
class WebSocketService:
|
||||
"""High-level WebSocket service for application-wide messaging.
|
||||
|
||||
This service provides a convenient interface for broadcasting
|
||||
application events and managing WebSocket connections. It wraps
|
||||
the ConnectionManager with application-specific message types.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the WebSocket service."""
|
||||
self._manager = ConnectionManager()
|
||||
logger.info("WebSocketService initialized")
|
||||
|
||||
@property
|
||||
def manager(self) -> ConnectionManager:
|
||||
"""Access the underlying connection manager."""
|
||||
return self._manager
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
websocket: WebSocket,
|
||||
connection_id: str,
|
||||
user_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Connect a new WebSocket client.
|
||||
|
||||
Args:
|
||||
websocket: The WebSocket connection
|
||||
connection_id: Unique connection identifier
|
||||
user_id: Optional user identifier for authentication
|
||||
"""
|
||||
metadata = {
|
||||
"connected_at": datetime.utcnow().isoformat(),
|
||||
"user_id": user_id,
|
||||
}
|
||||
await self._manager.connect(websocket, connection_id, metadata)
|
||||
|
||||
async def disconnect(self, connection_id: str) -> None:
|
||||
"""Disconnect a WebSocket client."""
|
||||
await self._manager.disconnect(connection_id)
|
||||
|
||||
async def broadcast_download_progress(
|
||||
self, download_id: str, progress_data: Dict[str, Any]
|
||||
) -> None:
|
||||
"""Broadcast download progress update to all clients.
|
||||
|
||||
Args:
|
||||
download_id: The download item identifier
|
||||
progress_data: Progress information (percent, speed, etc.)
|
||||
"""
|
||||
message = {
|
||||
"type": "download_progress",
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"data": {
|
||||
"download_id": download_id,
|
||||
**progress_data,
|
||||
},
|
||||
}
|
||||
await self._manager.broadcast_to_room(message, "downloads")
|
||||
|
||||
async def broadcast_download_complete(
|
||||
self, download_id: str, result_data: Dict[str, Any]
|
||||
) -> None:
|
||||
"""Broadcast download completion to all clients.
|
||||
|
||||
Args:
|
||||
download_id: The download item identifier
|
||||
result_data: Download result information
|
||||
"""
|
||||
message = {
|
||||
"type": "download_complete",
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"data": {
|
||||
"download_id": download_id,
|
||||
**result_data,
|
||||
},
|
||||
}
|
||||
await self._manager.broadcast_to_room(message, "downloads")
|
||||
|
||||
async def broadcast_download_failed(
|
||||
self, download_id: str, error_data: Dict[str, Any]
|
||||
) -> None:
|
||||
"""Broadcast download failure to all clients.
|
||||
|
||||
Args:
|
||||
download_id: The download item identifier
|
||||
error_data: Error information
|
||||
"""
|
||||
message = {
|
||||
"type": "download_failed",
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"data": {
|
||||
"download_id": download_id,
|
||||
**error_data,
|
||||
},
|
||||
}
|
||||
await self._manager.broadcast_to_room(message, "downloads")
|
||||
|
||||
async def broadcast_queue_status(self, status_data: Dict[str, Any]) -> None:
|
||||
"""Broadcast queue status update to all clients.
|
||||
|
||||
Args:
|
||||
status_data: Queue status information
|
||||
"""
|
||||
message = {
|
||||
"type": "queue_status",
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"data": status_data,
|
||||
}
|
||||
await self._manager.broadcast_to_room(message, "downloads")
|
||||
|
||||
async def broadcast_system_message(
|
||||
self, message_type: str, data: Dict[str, Any]
|
||||
) -> None:
|
||||
"""Broadcast a system message to all clients.
|
||||
|
||||
Args:
|
||||
message_type: Type of system message
|
||||
data: Message data
|
||||
"""
|
||||
message = {
|
||||
"type": f"system_{message_type}",
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"data": data,
|
||||
}
|
||||
await self._manager.broadcast(message)
|
||||
|
||||
async def send_error(
|
||||
self, connection_id: str, error_message: str, error_code: str = "ERROR"
|
||||
) -> None:
|
||||
"""Send an error message to a specific connection.
|
||||
|
||||
Args:
|
||||
connection_id: Target connection
|
||||
error_message: Error description
|
||||
error_code: Error code for client handling
|
||||
"""
|
||||
message = {
|
||||
"type": "error",
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"data": {
|
||||
"code": error_code,
|
||||
"message": error_message,
|
||||
},
|
||||
}
|
||||
await self._manager.send_personal_message(message, connection_id)
|
||||
|
||||
|
||||
# Singleton instance for application-wide access
|
||||
_websocket_service: Optional[WebSocketService] = None
|
||||
|
||||
|
||||
def get_websocket_service() -> WebSocketService:
|
||||
"""Get or create the singleton WebSocket service instance.
|
||||
|
||||
Returns:
|
||||
The WebSocket service instance
|
||||
"""
|
||||
global _websocket_service
|
||||
if _websocket_service is None:
|
||||
_websocket_service = WebSocketService()
|
||||
return _websocket_service
|
||||
@ -2,14 +2,18 @@
|
||||
Dependency injection utilities for FastAPI.
|
||||
|
||||
This module provides dependency injection functions for the FastAPI
|
||||
application, including SeriesApp instances, database sessions, and
|
||||
authentication dependencies.
|
||||
application, including SeriesApp instances, AnimeService, DownloadService,
|
||||
database sessions, and authentication dependencies.
|
||||
"""
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
try:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
except Exception: # pragma: no cover - optional dependency
|
||||
AsyncSession = object
|
||||
|
||||
from src.config.settings import settings
|
||||
from src.core.SeriesApp import SeriesApp
|
||||
@ -22,6 +26,10 @@ security = HTTPBearer()
|
||||
# Global SeriesApp instance
|
||||
_series_app: Optional[SeriesApp] = None
|
||||
|
||||
# Global service instances
|
||||
_anime_service: Optional[object] = None
|
||||
_download_service: Optional[object] = None
|
||||
|
||||
|
||||
def get_series_app() -> SeriesApp:
|
||||
"""
|
||||
@ -146,6 +154,26 @@ def optional_auth(
|
||||
return None
|
||||
|
||||
|
||||
def get_current_user_optional(
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = Depends(
|
||||
HTTPBearer(auto_error=False)
|
||||
)
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Dependency to get optional current user ID.
|
||||
|
||||
Args:
|
||||
credentials: Optional JWT token from Authorization header
|
||||
|
||||
Returns:
|
||||
Optional[str]: User ID if authenticated, None otherwise
|
||||
"""
|
||||
user_dict = optional_auth(credentials)
|
||||
if user_dict:
|
||||
return user_dict.get("user_id")
|
||||
return None
|
||||
|
||||
|
||||
class CommonQueryParams:
|
||||
"""Common query parameters for API endpoints."""
|
||||
|
||||
@ -189,3 +217,106 @@ async def log_request_dependency():
|
||||
TODO: Implement request logging logic
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def get_anime_service() -> object:
|
||||
"""
|
||||
Dependency to get AnimeService instance.
|
||||
|
||||
Returns:
|
||||
AnimeService: The anime service for async operations
|
||||
|
||||
Raises:
|
||||
HTTPException: If anime directory is not configured or
|
||||
AnimeService initialization fails
|
||||
"""
|
||||
global _anime_service
|
||||
|
||||
if not settings.anime_directory:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Anime directory not configured. Please complete setup.",
|
||||
)
|
||||
|
||||
if _anime_service is None:
|
||||
try:
|
||||
from src.server.services.anime_service import AnimeService
|
||||
_anime_service = AnimeService(settings.anime_directory)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to initialize AnimeService: {str(e)}",
|
||||
) from e
|
||||
|
||||
return _anime_service
|
||||
|
||||
|
||||
def get_download_service() -> object:
|
||||
"""
|
||||
Dependency to get DownloadService instance.
|
||||
|
||||
Returns:
|
||||
DownloadService: The download queue service
|
||||
|
||||
Raises:
|
||||
HTTPException: If DownloadService initialization fails
|
||||
"""
|
||||
global _download_service
|
||||
|
||||
if _download_service is None:
|
||||
try:
|
||||
from src.server.services.download_service import DownloadService
|
||||
from src.server.services.websocket_service import get_websocket_service
|
||||
|
||||
# Get anime service first (required dependency)
|
||||
anime_service = get_anime_service()
|
||||
|
||||
# Initialize download service with anime service
|
||||
_download_service = DownloadService(anime_service)
|
||||
|
||||
# Setup WebSocket broadcast callback
|
||||
ws_service = get_websocket_service()
|
||||
|
||||
async def broadcast_callback(update_type: str, data: dict):
|
||||
"""Broadcast download updates via WebSocket."""
|
||||
if update_type == "download_progress":
|
||||
await ws_service.broadcast_download_progress(
|
||||
data.get("download_id", ""), data
|
||||
)
|
||||
elif update_type == "download_complete":
|
||||
await ws_service.broadcast_download_complete(
|
||||
data.get("download_id", ""), data
|
||||
)
|
||||
elif update_type == "download_failed":
|
||||
await ws_service.broadcast_download_failed(
|
||||
data.get("download_id", ""), data
|
||||
)
|
||||
elif update_type == "queue_status":
|
||||
await ws_service.broadcast_queue_status(data)
|
||||
else:
|
||||
# Generic queue update
|
||||
await ws_service.broadcast_queue_status(data)
|
||||
|
||||
_download_service.set_broadcast_callback(broadcast_callback)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to initialize DownloadService: {str(e)}",
|
||||
) from e
|
||||
|
||||
return _download_service
|
||||
|
||||
|
||||
def reset_anime_service() -> None:
|
||||
"""Reset global AnimeService instance (for testing/config changes)."""
|
||||
global _anime_service
|
||||
_anime_service = None
|
||||
|
||||
|
||||
def reset_download_service() -> None:
|
||||
"""Reset global DownloadService instance (for testing/config changes)."""
|
||||
global _download_service
|
||||
_download_service = None
|
||||
|
||||
96
src/server/utils/template_helpers.py
Normal file
96
src/server/utils/template_helpers.py
Normal file
@ -0,0 +1,96 @@
|
||||
"""
|
||||
Template integration utilities for FastAPI application.
|
||||
|
||||
This module provides utilities for template rendering with common context
|
||||
and helper functions.
|
||||
"""
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.templating import Jinja2Templates
|
||||
|
||||
# Configure templates directory
|
||||
TEMPLATES_DIR = Path(__file__).parent.parent / "web" / "templates"
|
||||
templates = Jinja2Templates(directory=str(TEMPLATES_DIR))
|
||||
|
||||
|
||||
def get_base_context(
|
||||
request: Request, title: str = "Aniworld"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get base context for all templates.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object
|
||||
title: Page title
|
||||
|
||||
Returns:
|
||||
Dictionary with base context variables
|
||||
"""
|
||||
return {
|
||||
"request": request,
|
||||
"title": title,
|
||||
"app_name": "Aniworld Download Manager",
|
||||
"version": "1.0.0"
|
||||
}
|
||||
|
||||
|
||||
def render_template(
|
||||
template_name: str,
|
||||
request: Request,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
title: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Render a template with base context.
|
||||
|
||||
Args:
|
||||
template_name: Name of the template file
|
||||
request: FastAPI request object
|
||||
context: Additional context variables
|
||||
title: Page title (optional)
|
||||
|
||||
Returns:
|
||||
TemplateResponse object
|
||||
"""
|
||||
base_context = get_base_context(
|
||||
request,
|
||||
title or template_name.replace('.html', '').replace('_', ' ').title()
|
||||
)
|
||||
|
||||
if context:
|
||||
base_context.update(context)
|
||||
|
||||
return templates.TemplateResponse(template_name, base_context)
|
||||
|
||||
|
||||
def validate_template_exists(template_name: str) -> bool:
|
||||
"""
|
||||
Check if a template file exists.
|
||||
|
||||
Args:
|
||||
template_name: Name of the template file
|
||||
|
||||
Returns:
|
||||
True if template exists, False otherwise
|
||||
"""
|
||||
template_path = TEMPLATES_DIR / template_name
|
||||
return template_path.exists()
|
||||
|
||||
|
||||
def list_available_templates() -> list[str]:
|
||||
"""
|
||||
Get list of all available template files.
|
||||
|
||||
Returns:
|
||||
List of template file names
|
||||
"""
|
||||
if not TEMPLATES_DIR.exists():
|
||||
return []
|
||||
|
||||
return [
|
||||
f.name
|
||||
for f in TEMPLATES_DIR.glob("*.html")
|
||||
if f.is_file()
|
||||
]
|
||||
202
src/server/web/static/css/ux_features.css
Normal file
202
src/server/web/static/css/ux_features.css
Normal file
@ -0,0 +1,202 @@
|
||||
/**
|
||||
* UX Features CSS
|
||||
* Additional styling for enhanced user experience features
|
||||
*/
|
||||
|
||||
/* Drag and drop indicators */
|
||||
.drag-over {
|
||||
border: 2px dashed var(--color-accent);
|
||||
background-color: var(--color-bg-tertiary);
|
||||
opacity: 0.8;
|
||||
}
|
||||
|
||||
.dragging {
|
||||
opacity: 0.5;
|
||||
cursor: move;
|
||||
}
|
||||
|
||||
/* Bulk operation selection */
|
||||
.bulk-select-mode .series-card {
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.bulk-select-mode .series-card.selected {
|
||||
border: 2px solid var(--color-accent);
|
||||
background-color: var(--color-surface-hover);
|
||||
}
|
||||
|
||||
/* Keyboard navigation focus indicators */
|
||||
.keyboard-focus {
|
||||
outline: 2px solid var(--color-accent);
|
||||
outline-offset: 2px;
|
||||
}
|
||||
|
||||
/* Touch gestures feedback */
|
||||
.touch-feedback {
|
||||
animation: touchPulse 0.3s ease-out;
|
||||
}
|
||||
|
||||
@keyframes touchPulse {
|
||||
0% {
|
||||
transform: scale(1);
|
||||
}
|
||||
50% {
|
||||
transform: scale(0.95);
|
||||
}
|
||||
100% {
|
||||
transform: scale(1);
|
||||
}
|
||||
}
|
||||
|
||||
/* Mobile responsive enhancements */
|
||||
@media (max-width: 768px) {
|
||||
.mobile-hide {
|
||||
display: none !important;
|
||||
}
|
||||
|
||||
.mobile-full-width {
|
||||
width: 100% !important;
|
||||
}
|
||||
}
|
||||
|
||||
/* Accessibility high contrast mode */
|
||||
@media (prefers-contrast: high) {
|
||||
:root {
|
||||
--color-border: #000000;
|
||||
--color-text-primary: #000000;
|
||||
--color-bg-primary: #ffffff;
|
||||
}
|
||||
|
||||
[data-theme="dark"] {
|
||||
--color-border: #ffffff;
|
||||
--color-text-primary: #ffffff;
|
||||
--color-bg-primary: #000000;
|
||||
}
|
||||
}
|
||||
|
||||
/* Screen reader only content */
|
||||
.sr-only {
|
||||
position: absolute;
|
||||
width: 1px;
|
||||
height: 1px;
|
||||
padding: 0;
|
||||
margin: -1px;
|
||||
overflow: hidden;
|
||||
clip: rect(0, 0, 0, 0);
|
||||
white-space: nowrap;
|
||||
border-width: 0;
|
||||
}
|
||||
|
||||
/* Multi-screen support */
|
||||
.window-controls {
|
||||
display: flex;
|
||||
gap: var(--spacing-sm);
|
||||
padding: var(--spacing-sm);
|
||||
}
|
||||
|
||||
.window-control-btn {
|
||||
width: 32px;
|
||||
height: 32px;
|
||||
border-radius: 4px;
|
||||
border: 1px solid var(--color-border);
|
||||
background: var(--color-surface);
|
||||
cursor: pointer;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
transition: all 0.2s ease;
|
||||
}
|
||||
|
||||
.window-control-btn:hover {
|
||||
background: var(--color-surface-hover);
|
||||
}
|
||||
|
||||
/* Undo/Redo notification */
|
||||
.undo-notification {
|
||||
position: fixed;
|
||||
bottom: 20px;
|
||||
right: 20px;
|
||||
background: var(--color-surface);
|
||||
border: 1px solid var(--color-border);
|
||||
border-radius: 8px;
|
||||
padding: var(--spacing-md);
|
||||
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.15);
|
||||
z-index: 1000;
|
||||
animation: slideInUp 0.3s ease-out;
|
||||
}
|
||||
|
||||
@keyframes slideInUp {
|
||||
from {
|
||||
transform: translateY(100%);
|
||||
opacity: 0;
|
||||
}
|
||||
to {
|
||||
transform: translateY(0);
|
||||
opacity: 1;
|
||||
}
|
||||
}
|
||||
|
||||
/* Advanced search panel */
|
||||
.advanced-search-panel {
|
||||
background: var(--color-surface);
|
||||
border: 1px solid var(--color-border);
|
||||
border-radius: 8px;
|
||||
padding: var(--spacing-lg);
|
||||
margin-top: var(--spacing-md);
|
||||
display: none;
|
||||
}
|
||||
|
||||
.advanced-search-panel.active {
|
||||
display: block;
|
||||
}
|
||||
|
||||
/* Loading states */
|
||||
.loading-skeleton {
|
||||
background: linear-gradient(
|
||||
90deg,
|
||||
var(--color-bg-tertiary) 25%,
|
||||
var(--color-surface-hover) 50%,
|
||||
var(--color-bg-tertiary) 75%
|
||||
);
|
||||
background-size: 200% 100%;
|
||||
animation: loading 1.5s ease-in-out infinite;
|
||||
}
|
||||
|
||||
@keyframes loading {
|
||||
0% {
|
||||
background-position: 200% 0;
|
||||
}
|
||||
100% {
|
||||
background-position: -200% 0;
|
||||
}
|
||||
}
|
||||
|
||||
/* Tooltip enhancements */
|
||||
.tooltip {
|
||||
position: absolute;
|
||||
background: var(--color-surface);
|
||||
border: 1px solid var(--color-border);
|
||||
border-radius: 4px;
|
||||
padding: var(--spacing-sm);
|
||||
font-size: var(--font-size-caption);
|
||||
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.15);
|
||||
z-index: 1000;
|
||||
pointer-events: none;
|
||||
opacity: 0;
|
||||
transition: opacity 0.2s ease;
|
||||
}
|
||||
|
||||
.tooltip.show {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
/* Reduced motion support */
|
||||
@media (prefers-reduced-motion: reduce) {
|
||||
*,
|
||||
*::before,
|
||||
*::after {
|
||||
animation-duration: 0.01ms !important;
|
||||
animation-iteration-count: 1 !important;
|
||||
transition-duration: 0.01ms !important;
|
||||
}
|
||||
}
|
||||
77
src/server/web/static/js/accessibility_features.js
Normal file
77
src/server/web/static/js/accessibility_features.js
Normal file
@ -0,0 +1,77 @@
|
||||
/**
|
||||
* Accessibility Features Module
|
||||
* Enhances accessibility for all users
|
||||
*/
|
||||
|
||||
(function() {
|
||||
'use strict';
|
||||
|
||||
/**
|
||||
* Initialize accessibility features
|
||||
*/
|
||||
function initAccessibilityFeatures() {
|
||||
setupFocusManagement();
|
||||
setupAriaLabels();
|
||||
console.log('[Accessibility Features] Initialized');
|
||||
}
|
||||
|
||||
/**
|
||||
* Setup focus management
|
||||
*/
|
||||
function setupFocusManagement() {
|
||||
// Add focus visible class for keyboard navigation
|
||||
document.addEventListener('keydown', (e) => {
|
||||
if (e.key === 'Tab') {
|
||||
document.body.classList.add('keyboard-navigation');
|
||||
}
|
||||
});
|
||||
|
||||
document.addEventListener('mousedown', () => {
|
||||
document.body.classList.remove('keyboard-navigation');
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Setup ARIA labels for dynamic content
|
||||
*/
|
||||
function setupAriaLabels() {
|
||||
// Ensure all interactive elements have proper ARIA labels
|
||||
const buttons = document.querySelectorAll('button:not([aria-label])');
|
||||
buttons.forEach(button => {
|
||||
if (!button.getAttribute('aria-label') && button.title) {
|
||||
button.setAttribute('aria-label', button.title);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Announce message to screen readers
|
||||
*/
|
||||
function announceToScreenReader(message, priority = 'polite') {
|
||||
const announcement = document.createElement('div');
|
||||
announcement.setAttribute('role', 'status');
|
||||
announcement.setAttribute('aria-live', priority);
|
||||
announcement.setAttribute('aria-atomic', 'true');
|
||||
announcement.className = 'sr-only';
|
||||
announcement.textContent = message;
|
||||
|
||||
document.body.appendChild(announcement);
|
||||
|
||||
setTimeout(() => {
|
||||
announcement.remove();
|
||||
}, 1000);
|
||||
}
|
||||
|
||||
// Export functions
|
||||
window.Accessibility = {
|
||||
announce: announceToScreenReader
|
||||
};
|
||||
|
||||
// Initialize on DOM ready
|
||||
if (document.readyState === 'loading') {
|
||||
document.addEventListener('DOMContentLoaded', initAccessibilityFeatures);
|
||||
} else {
|
||||
initAccessibilityFeatures();
|
||||
}
|
||||
|
||||
})();
|
||||
29
src/server/web/static/js/advanced_search.js
Normal file
29
src/server/web/static/js/advanced_search.js
Normal file
@ -0,0 +1,29 @@
|
||||
/**
|
||||
* Advanced Search Module
|
||||
* Provides advanced search and filtering capabilities
|
||||
*/
|
||||
|
||||
(function() {
|
||||
'use strict';
|
||||
|
||||
/**
|
||||
* Initialize advanced search
|
||||
*/
|
||||
function initAdvancedSearch() {
|
||||
console.log('[Advanced Search] Module loaded (functionality to be implemented)');
|
||||
|
||||
// TODO: Implement advanced search features
|
||||
// - Filter by genre
|
||||
// - Filter by year
|
||||
// - Filter by status
|
||||
// - Sort options
|
||||
}
|
||||
|
||||
// Initialize on DOM ready
|
||||
if (document.readyState === 'loading') {
|
||||
document.addEventListener('DOMContentLoaded', initAdvancedSearch);
|
||||
} else {
|
||||
initAdvancedSearch();
|
||||
}
|
||||
|
||||
})();
|
||||
@ -133,9 +133,20 @@ class AniWorldApp {
|
||||
initSocket() {
|
||||
this.socket = io();
|
||||
|
||||
// Handle initial connection message from server
|
||||
this.socket.on('connected', (data) => {
|
||||
console.log('WebSocket connection confirmed', data);
|
||||
});
|
||||
|
||||
this.socket.on('connect', () => {
|
||||
this.isConnected = true;
|
||||
console.log('Connected to server');
|
||||
|
||||
// Subscribe to rooms for targeted updates
|
||||
this.socket.join('scan_progress');
|
||||
this.socket.join('download_progress');
|
||||
this.socket.join('downloads');
|
||||
|
||||
this.showToast(this.localization.getText('connected-server'), 'success');
|
||||
this.updateConnectionStatus();
|
||||
this.checkProcessLocks();
|
||||
@ -158,18 +169,24 @@ class AniWorldApp {
|
||||
this.updateStatus(`Scanning: ${data.folder} (${data.counter})`);
|
||||
});
|
||||
|
||||
this.socket.on('scan_completed', () => {
|
||||
// Handle both 'scan_completed' (legacy) and 'scan_complete' (new backend)
|
||||
const handleScanComplete = () => {
|
||||
this.hideStatus();
|
||||
this.showToast('Scan completed successfully', 'success');
|
||||
this.updateProcessStatus('rescan', false);
|
||||
this.loadSeries();
|
||||
});
|
||||
};
|
||||
this.socket.on('scan_completed', handleScanComplete);
|
||||
this.socket.on('scan_complete', handleScanComplete);
|
||||
|
||||
this.socket.on('scan_error', (data) => {
|
||||
// Handle both 'scan_error' (legacy) and 'scan_failed' (new backend)
|
||||
const handleScanError = (data) => {
|
||||
this.hideStatus();
|
||||
this.showToast(`Scan error: ${data.message}`, 'error');
|
||||
this.showToast(`Scan error: ${data.message || data.error}`, 'error');
|
||||
this.updateProcessStatus('rescan', false, true);
|
||||
});
|
||||
};
|
||||
this.socket.on('scan_error', handleScanError);
|
||||
this.socket.on('scan_failed', handleScanError);
|
||||
|
||||
// Scheduled scan events
|
||||
this.socket.on('scheduled_rescan_started', () => {
|
||||
|
||||
29
src/server/web/static/js/bulk_operations.js
Normal file
29
src/server/web/static/js/bulk_operations.js
Normal file
@ -0,0 +1,29 @@
|
||||
/**
|
||||
* Bulk Operations Module
|
||||
* Handles bulk selection and operations on multiple series
|
||||
*/
|
||||
|
||||
(function() {
|
||||
'use strict';
|
||||
|
||||
/**
|
||||
* Initialize bulk operations
|
||||
*/
|
||||
function initBulkOperations() {
|
||||
console.log('[Bulk Operations] Module loaded (functionality to be implemented)');
|
||||
|
||||
// TODO: Implement bulk operations
|
||||
// - Select multiple series
|
||||
// - Bulk download
|
||||
// - Bulk mark as watched
|
||||
// - Bulk delete
|
||||
}
|
||||
|
||||
// Initialize on DOM ready
|
||||
if (document.readyState === 'loading') {
|
||||
document.addEventListener('DOMContentLoaded', initBulkOperations);
|
||||
} else {
|
||||
initBulkOperations();
|
||||
}
|
||||
|
||||
})();
|
||||
42
src/server/web/static/js/color_contrast_compliance.js
Normal file
42
src/server/web/static/js/color_contrast_compliance.js
Normal file
@ -0,0 +1,42 @@
|
||||
/**
|
||||
* Color Contrast Compliance Module
|
||||
* Ensures WCAG color contrast compliance
|
||||
*/
|
||||
|
||||
(function() {
|
||||
'use strict';
|
||||
|
||||
/**
|
||||
* Initialize color contrast compliance
|
||||
*/
|
||||
function initColorContrastCompliance() {
|
||||
checkContrastCompliance();
|
||||
console.log('[Color Contrast Compliance] Initialized');
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if color contrast meets WCAG standards
|
||||
*/
|
||||
function checkContrastCompliance() {
|
||||
// This would typically check computed styles
|
||||
// For now, we rely on CSS variables defined in styles.css
|
||||
console.log('[Color Contrast] Relying on predefined WCAG-compliant color scheme');
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate contrast ratio between two colors
|
||||
*/
|
||||
function calculateContrastRatio(color1, color2) {
|
||||
// Simplified contrast calculation
|
||||
// Real implementation would use relative luminance
|
||||
return 4.5; // Placeholder
|
||||
}
|
||||
|
||||
// Initialize on DOM ready
|
||||
if (document.readyState === 'loading') {
|
||||
document.addEventListener('DOMContentLoaded', initColorContrastCompliance);
|
||||
} else {
|
||||
initColorContrastCompliance();
|
||||
}
|
||||
|
||||
})();
|
||||
26
src/server/web/static/js/drag_drop.js
Normal file
26
src/server/web/static/js/drag_drop.js
Normal file
@ -0,0 +1,26 @@
|
||||
/**
|
||||
* Drag and Drop Module
|
||||
* Handles drag-and-drop functionality for series cards
|
||||
*/
|
||||
|
||||
(function() {
|
||||
'use strict';
|
||||
|
||||
/**
|
||||
* Initialize drag and drop
|
||||
*/
|
||||
function initDragDrop() {
|
||||
console.log('[Drag & Drop] Module loaded (functionality to be implemented)');
|
||||
|
||||
// TODO: Implement drag-and-drop for series cards
|
||||
// This will allow users to reorder series or add to queue via drag-and-drop
|
||||
}
|
||||
|
||||
// Initialize on DOM ready
|
||||
if (document.readyState === 'loading') {
|
||||
document.addEventListener('DOMContentLoaded', initDragDrop);
|
||||
} else {
|
||||
initDragDrop();
|
||||
}
|
||||
|
||||
})();
|
||||
144
src/server/web/static/js/keyboard_shortcuts.js
Normal file
144
src/server/web/static/js/keyboard_shortcuts.js
Normal file
@ -0,0 +1,144 @@
|
||||
/**
|
||||
* Keyboard Shortcuts Module
|
||||
* Handles keyboard navigation and shortcuts for improved accessibility
|
||||
*/
|
||||
|
||||
(function() {
|
||||
'use strict';
|
||||
|
||||
// Keyboard shortcuts configuration
|
||||
const shortcuts = {
|
||||
'ctrl+k': 'focusSearch',
|
||||
'ctrl+r': 'triggerRescan',
|
||||
'ctrl+q': 'openQueue',
|
||||
'escape': 'closeModals',
|
||||
'tab': 'navigationMode',
|
||||
'/': 'focusSearch'
|
||||
};
|
||||
|
||||
/**
|
||||
* Initialize keyboard shortcuts
|
||||
*/
|
||||
function initKeyboardShortcuts() {
|
||||
document.addEventListener('keydown', handleKeydown);
|
||||
console.log('[Keyboard Shortcuts] Initialized');
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle keydown events
|
||||
*/
|
||||
function handleKeydown(event) {
|
||||
const key = getKeyCombo(event);
|
||||
|
||||
if (shortcuts[key]) {
|
||||
const action = shortcuts[key];
|
||||
handleShortcut(action, event);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get key combination string
|
||||
*/
|
||||
function getKeyCombo(event) {
|
||||
const parts = [];
|
||||
|
||||
if (event.ctrlKey) parts.push('ctrl');
|
||||
if (event.altKey) parts.push('alt');
|
||||
if (event.shiftKey) parts.push('shift');
|
||||
|
||||
const key = event.key.toLowerCase();
|
||||
parts.push(key);
|
||||
|
||||
return parts.join('+');
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle keyboard shortcut action
|
||||
*/
|
||||
function handleShortcut(action, event) {
|
||||
switch(action) {
|
||||
case 'focusSearch':
|
||||
event.preventDefault();
|
||||
focusSearchInput();
|
||||
break;
|
||||
case 'triggerRescan':
|
||||
event.preventDefault();
|
||||
triggerRescan();
|
||||
break;
|
||||
case 'openQueue':
|
||||
event.preventDefault();
|
||||
openQueue();
|
||||
break;
|
||||
case 'closeModals':
|
||||
closeAllModals();
|
||||
break;
|
||||
case 'navigationMode':
|
||||
handleTabNavigation(event);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Focus search input
|
||||
*/
|
||||
function focusSearchInput() {
|
||||
const searchInput = document.getElementById('search-input');
|
||||
if (searchInput) {
|
||||
searchInput.focus();
|
||||
searchInput.select();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Trigger rescan
|
||||
*/
|
||||
function triggerRescan() {
|
||||
const rescanBtn = document.getElementById('rescan-btn');
|
||||
if (rescanBtn && !rescanBtn.disabled) {
|
||||
rescanBtn.click();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Open queue page
|
||||
*/
|
||||
function openQueue() {
|
||||
window.location.href = '/queue';
|
||||
}
|
||||
|
||||
/**
|
||||
* Close all open modals
|
||||
*/
|
||||
function closeAllModals() {
|
||||
const modals = document.querySelectorAll('.modal.active');
|
||||
modals.forEach(modal => {
|
||||
modal.classList.remove('active');
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle tab navigation with visual indicators
|
||||
*/
|
||||
function handleTabNavigation(event) {
|
||||
// Add keyboard-focus class to focused element
|
||||
const previousFocus = document.querySelector('.keyboard-focus');
|
||||
if (previousFocus) {
|
||||
previousFocus.classList.remove('keyboard-focus');
|
||||
}
|
||||
|
||||
// Will be applied after tab completes
|
||||
setTimeout(() => {
|
||||
if (document.activeElement) {
|
||||
document.activeElement.classList.add('keyboard-focus');
|
||||
}
|
||||
}, 0);
|
||||
}
|
||||
|
||||
// Initialize on DOM ready
|
||||
if (document.readyState === 'loading') {
|
||||
document.addEventListener('DOMContentLoaded', initKeyboardShortcuts);
|
||||
} else {
|
||||
initKeyboardShortcuts();
|
||||
}
|
||||
|
||||
})();
|
||||
80
src/server/web/static/js/mobile_responsive.js
Normal file
80
src/server/web/static/js/mobile_responsive.js
Normal file
@ -0,0 +1,80 @@
|
||||
/**
|
||||
* Mobile Responsive Module
|
||||
* Handles mobile-specific functionality and responsive behavior
|
||||
*/
|
||||
|
||||
(function() {
|
||||
'use strict';
|
||||
|
||||
let isMobile = false;
|
||||
|
||||
/**
|
||||
* Initialize mobile responsive features
|
||||
*/
|
||||
function initMobileResponsive() {
|
||||
detectMobile();
|
||||
setupResponsiveHandlers();
|
||||
console.log('[Mobile Responsive] Initialized');
|
||||
}
|
||||
|
||||
/**
|
||||
* Detect if device is mobile
|
||||
*/
|
||||
function detectMobile() {
|
||||
isMobile = /Android|webOS|iPhone|iPad|iPod|BlackBerry|IEMobile|Opera Mini/i.test(navigator.userAgent);
|
||||
|
||||
if (isMobile) {
|
||||
document.body.classList.add('mobile-device');
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Setup responsive event handlers
|
||||
*/
|
||||
function setupResponsiveHandlers() {
|
||||
window.addEventListener('resize', handleResize);
|
||||
handleResize(); // Initial call
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle window resize
|
||||
*/
|
||||
function handleResize() {
|
||||
const width = window.innerWidth;
|
||||
|
||||
if (width < 768) {
|
||||
applyMobileLayout();
|
||||
} else {
|
||||
applyDesktopLayout();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Apply mobile-specific layout
|
||||
*/
|
||||
function applyMobileLayout() {
|
||||
document.body.classList.add('mobile-layout');
|
||||
document.body.classList.remove('desktop-layout');
|
||||
}
|
||||
|
||||
/**
|
||||
* Apply desktop-specific layout
|
||||
*/
|
||||
function applyDesktopLayout() {
|
||||
document.body.classList.add('desktop-layout');
|
||||
document.body.classList.remove('mobile-layout');
|
||||
}
|
||||
|
||||
// Export functions
|
||||
window.MobileResponsive = {
|
||||
isMobile: () => isMobile
|
||||
};
|
||||
|
||||
// Initialize on DOM ready
|
||||
if (document.readyState === 'loading') {
|
||||
document.addEventListener('DOMContentLoaded', initMobileResponsive);
|
||||
} else {
|
||||
initMobileResponsive();
|
||||
}
|
||||
|
||||
})();
|
||||
76
src/server/web/static/js/multi_screen_support.js
Normal file
76
src/server/web/static/js/multi_screen_support.js
Normal file
@ -0,0 +1,76 @@
|
||||
/**
|
||||
* Multi-Screen Support Module
|
||||
* Handles multi-monitor and window management
|
||||
*/
|
||||
|
||||
(function() {
|
||||
'use strict';
|
||||
|
||||
/**
|
||||
* Initialize multi-screen support
|
||||
*/
|
||||
function initMultiScreenSupport() {
|
||||
if ('screen' in window) {
|
||||
detectScreens();
|
||||
console.log('[Multi-Screen Support] Initialized');
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Detect available screens
|
||||
*/
|
||||
function detectScreens() {
|
||||
// Modern browsers support window.screen
|
||||
const screenInfo = {
|
||||
width: window.screen.width,
|
||||
height: window.screen.height,
|
||||
availWidth: window.screen.availWidth,
|
||||
availHeight: window.screen.availHeight,
|
||||
colorDepth: window.screen.colorDepth,
|
||||
pixelDepth: window.screen.pixelDepth
|
||||
};
|
||||
|
||||
console.log('[Multi-Screen] Screen info:', screenInfo);
|
||||
}
|
||||
|
||||
/**
|
||||
* Request fullscreen
|
||||
*/
|
||||
function requestFullscreen() {
|
||||
const elem = document.documentElement;
|
||||
if (elem.requestFullscreen) {
|
||||
elem.requestFullscreen();
|
||||
} else if (elem.webkitRequestFullscreen) {
|
||||
elem.webkitRequestFullscreen();
|
||||
} else if (elem.msRequestFullscreen) {
|
||||
elem.msRequestFullscreen();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Exit fullscreen
|
||||
*/
|
||||
function exitFullscreen() {
|
||||
if (document.exitFullscreen) {
|
||||
document.exitFullscreen();
|
||||
} else if (document.webkitExitFullscreen) {
|
||||
document.webkitExitFullscreen();
|
||||
} else if (document.msExitFullscreen) {
|
||||
document.msExitFullscreen();
|
||||
}
|
||||
}
|
||||
|
||||
// Export functions
|
||||
window.MultiScreen = {
|
||||
requestFullscreen: requestFullscreen,
|
||||
exitFullscreen: exitFullscreen
|
||||
};
|
||||
|
||||
// Initialize on DOM ready
|
||||
if (document.readyState === 'loading') {
|
||||
document.addEventListener('DOMContentLoaded', initMultiScreenSupport);
|
||||
} else {
|
||||
initMultiScreenSupport();
|
||||
}
|
||||
|
||||
})();
|
||||
@ -22,8 +22,18 @@ class QueueManager {
|
||||
initSocket() {
|
||||
this.socket = io();
|
||||
|
||||
// Handle initial connection message from server
|
||||
this.socket.on('connected', (data) => {
|
||||
console.log('WebSocket connection confirmed', data);
|
||||
});
|
||||
|
||||
this.socket.on('connect', () => {
|
||||
console.log('Connected to server');
|
||||
|
||||
// Subscribe to rooms for targeted updates
|
||||
this.socket.join('downloads');
|
||||
this.socket.join('download_progress');
|
||||
|
||||
this.showToast('Connected to server', 'success');
|
||||
});
|
||||
|
||||
@ -32,10 +42,18 @@ class QueueManager {
|
||||
this.showToast('Disconnected from server', 'warning');
|
||||
});
|
||||
|
||||
// Queue update events
|
||||
// Queue update events - handle both old and new message types
|
||||
this.socket.on('queue_updated', (data) => {
|
||||
this.updateQueueDisplay(data);
|
||||
});
|
||||
this.socket.on('queue_status', (data) => {
|
||||
// New backend sends queue_status messages
|
||||
if (data.queue_status) {
|
||||
this.updateQueueDisplay(data.queue_status);
|
||||
} else {
|
||||
this.updateQueueDisplay(data);
|
||||
}
|
||||
});
|
||||
|
||||
this.socket.on('download_progress_update', (data) => {
|
||||
this.updateDownloadProgress(data);
|
||||
@ -46,21 +64,33 @@ class QueueManager {
|
||||
this.showToast('Download queue started', 'success');
|
||||
this.loadQueueData(); // Refresh data
|
||||
});
|
||||
this.socket.on('queue_started', () => {
|
||||
this.showToast('Download queue started', 'success');
|
||||
this.loadQueueData(); // Refresh data
|
||||
});
|
||||
|
||||
this.socket.on('download_progress', (data) => {
|
||||
this.updateDownloadProgress(data);
|
||||
});
|
||||
|
||||
this.socket.on('download_completed', (data) => {
|
||||
this.showToast(`Completed: ${data.serie} - Episode ${data.episode}`, 'success');
|
||||
// Handle both old and new download completion events
|
||||
const handleDownloadComplete = (data) => {
|
||||
const serieName = data.serie_name || data.serie || 'Unknown';
|
||||
const episode = data.episode || '';
|
||||
this.showToast(`Completed: ${serieName}${episode ? ' - Episode ' + episode : ''}`, 'success');
|
||||
this.loadQueueData(); // Refresh data
|
||||
});
|
||||
};
|
||||
this.socket.on('download_completed', handleDownloadComplete);
|
||||
this.socket.on('download_complete', handleDownloadComplete);
|
||||
|
||||
this.socket.on('download_error', (data) => {
|
||||
// Handle both old and new download error events
|
||||
const handleDownloadError = (data) => {
|
||||
const message = data.error || data.message || 'Unknown error';
|
||||
this.showToast(`Download failed: ${message}`, 'error');
|
||||
this.loadQueueData(); // Refresh data
|
||||
});
|
||||
};
|
||||
this.socket.on('download_error', handleDownloadError);
|
||||
this.socket.on('download_failed', handleDownloadError);
|
||||
|
||||
this.socket.on('download_queue_completed', () => {
|
||||
this.showToast('All downloads completed!', 'success');
|
||||
@ -71,9 +101,23 @@ class QueueManager {
|
||||
this.showToast('Stopping downloads...', 'info');
|
||||
});
|
||||
|
||||
this.socket.on('download_stopped', () => {
|
||||
// Handle both old and new queue stopped events
|
||||
const handleQueueStopped = () => {
|
||||
this.showToast('Download queue stopped', 'success');
|
||||
this.loadQueueData(); // Refresh data
|
||||
};
|
||||
this.socket.on('download_stopped', handleQueueStopped);
|
||||
this.socket.on('queue_stopped', handleQueueStopped);
|
||||
|
||||
// Handle queue paused/resumed
|
||||
this.socket.on('queue_paused', () => {
|
||||
this.showToast('Queue paused', 'info');
|
||||
this.loadQueueData();
|
||||
});
|
||||
|
||||
this.socket.on('queue_resumed', () => {
|
||||
this.showToast('Queue resumed', 'success');
|
||||
this.loadQueueData();
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
65
src/server/web/static/js/screen_reader_support.js
Normal file
65
src/server/web/static/js/screen_reader_support.js
Normal file
@ -0,0 +1,65 @@
|
||||
/**
|
||||
* Screen Reader Support Module
|
||||
* Provides enhanced screen reader support
|
||||
*/
|
||||
|
||||
(function() {
|
||||
'use strict';
|
||||
|
||||
/**
|
||||
* Initialize screen reader support
|
||||
*/
|
||||
function initScreenReaderSupport() {
|
||||
setupLiveRegions();
|
||||
setupNavigationAnnouncements();
|
||||
console.log('[Screen Reader Support] Initialized');
|
||||
}
|
||||
|
||||
/**
|
||||
* Setup live regions for dynamic content
|
||||
*/
|
||||
function setupLiveRegions() {
|
||||
// Create global live region if it doesn't exist
|
||||
if (!document.getElementById('sr-live-region')) {
|
||||
const liveRegion = document.createElement('div');
|
||||
liveRegion.id = 'sr-live-region';
|
||||
liveRegion.className = 'sr-only';
|
||||
liveRegion.setAttribute('role', 'status');
|
||||
liveRegion.setAttribute('aria-live', 'polite');
|
||||
liveRegion.setAttribute('aria-atomic', 'true');
|
||||
document.body.appendChild(liveRegion);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Setup navigation announcements
|
||||
*/
|
||||
function setupNavigationAnnouncements() {
|
||||
// Announce page navigation
|
||||
const pageTitle = document.title;
|
||||
announceToScreenReader(`Page loaded: ${pageTitle}`);
|
||||
}
|
||||
|
||||
/**
|
||||
* Announce message to screen readers
|
||||
*/
|
||||
function announceToScreenReader(message) {
|
||||
const liveRegion = document.getElementById('sr-live-region');
|
||||
if (liveRegion) {
|
||||
liveRegion.textContent = message;
|
||||
}
|
||||
}
|
||||
|
||||
// Export functions
|
||||
window.ScreenReader = {
|
||||
announce: announceToScreenReader
|
||||
};
|
||||
|
||||
// Initialize on DOM ready
|
||||
if (document.readyState === 'loading') {
|
||||
document.addEventListener('DOMContentLoaded', initScreenReaderSupport);
|
||||
} else {
|
||||
initScreenReaderSupport();
|
||||
}
|
||||
|
||||
})();
|
||||
66
src/server/web/static/js/touch_gestures.js
Normal file
66
src/server/web/static/js/touch_gestures.js
Normal file
@ -0,0 +1,66 @@
|
||||
/**
|
||||
* Touch Gestures Module
|
||||
* Handles touch gestures for mobile devices
|
||||
*/
|
||||
|
||||
(function() {
|
||||
'use strict';
|
||||
|
||||
/**
|
||||
* Initialize touch gestures
|
||||
*/
|
||||
function initTouchGestures() {
|
||||
if ('ontouchstart' in window) {
|
||||
setupSwipeGestures();
|
||||
console.log('[Touch Gestures] Initialized');
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Setup swipe gesture handlers
|
||||
*/
|
||||
function setupSwipeGestures() {
|
||||
let touchStartX = 0;
|
||||
let touchStartY = 0;
|
||||
let touchEndX = 0;
|
||||
let touchEndY = 0;
|
||||
|
||||
document.addEventListener('touchstart', (e) => {
|
||||
touchStartX = e.changedTouches[0].screenX;
|
||||
touchStartY = e.changedTouches[0].screenY;
|
||||
}, { passive: true });
|
||||
|
||||
document.addEventListener('touchend', (e) => {
|
||||
touchEndX = e.changedTouches[0].screenX;
|
||||
touchEndY = e.changedTouches[0].screenY;
|
||||
handleSwipe();
|
||||
}, { passive: true });
|
||||
|
||||
function handleSwipe() {
|
||||
const deltaX = touchEndX - touchStartX;
|
||||
const deltaY = touchEndY - touchStartY;
|
||||
const minSwipeDistance = 50;
|
||||
|
||||
if (Math.abs(deltaX) > Math.abs(deltaY)) {
|
||||
// Horizontal swipe
|
||||
if (Math.abs(deltaX) > minSwipeDistance) {
|
||||
if (deltaX > 0) {
|
||||
// Swipe right
|
||||
console.log('[Touch Gestures] Swipe right detected');
|
||||
} else {
|
||||
// Swipe left
|
||||
console.log('[Touch Gestures] Swipe left detected');
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize on DOM ready
|
||||
if (document.readyState === 'loading') {
|
||||
document.addEventListener('DOMContentLoaded', initTouchGestures);
|
||||
} else {
|
||||
initTouchGestures();
|
||||
}
|
||||
|
||||
})();
|
||||
111
src/server/web/static/js/undo_redo.js
Normal file
111
src/server/web/static/js/undo_redo.js
Normal file
@ -0,0 +1,111 @@
|
||||
/**
|
||||
* Undo/Redo Module
|
||||
* Provides undo/redo functionality for user actions
|
||||
*/
|
||||
|
||||
(function() {
|
||||
'use strict';
|
||||
|
||||
const actionHistory = [];
|
||||
let currentIndex = -1;
|
||||
|
||||
/**
|
||||
* Initialize undo/redo system
|
||||
*/
|
||||
function initUndoRedo() {
|
||||
setupKeyboardShortcuts();
|
||||
console.log('[Undo/Redo] Initialized');
|
||||
}
|
||||
|
||||
/**
|
||||
* Setup keyboard shortcuts for undo/redo
|
||||
*/
|
||||
function setupKeyboardShortcuts() {
|
||||
document.addEventListener('keydown', (event) => {
|
||||
if (event.ctrlKey || event.metaKey) {
|
||||
if (event.key === 'z' && !event.shiftKey) {
|
||||
event.preventDefault();
|
||||
undo();
|
||||
} else if (event.key === 'z' && event.shiftKey || event.key === 'y') {
|
||||
event.preventDefault();
|
||||
redo();
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Add action to history
|
||||
*/
|
||||
function addAction(action) {
|
||||
// Remove any actions after current index
|
||||
actionHistory.splice(currentIndex + 1);
|
||||
|
||||
// Add new action
|
||||
actionHistory.push(action);
|
||||
currentIndex++;
|
||||
|
||||
// Limit history size
|
||||
if (actionHistory.length > 50) {
|
||||
actionHistory.shift();
|
||||
currentIndex--;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Undo last action
|
||||
*/
|
||||
function undo() {
|
||||
if (currentIndex >= 0) {
|
||||
const action = actionHistory[currentIndex];
|
||||
if (action && action.undo) {
|
||||
action.undo();
|
||||
currentIndex--;
|
||||
showNotification('Action undone');
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Redo last undone action
|
||||
*/
|
||||
function redo() {
|
||||
if (currentIndex < actionHistory.length - 1) {
|
||||
currentIndex++;
|
||||
const action = actionHistory[currentIndex];
|
||||
if (action && action.redo) {
|
||||
action.redo();
|
||||
showNotification('Action redone');
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Show undo/redo notification
|
||||
*/
|
||||
function showNotification(message) {
|
||||
const notification = document.createElement('div');
|
||||
notification.className = 'undo-notification';
|
||||
notification.textContent = message;
|
||||
document.body.appendChild(notification);
|
||||
|
||||
setTimeout(() => {
|
||||
notification.remove();
|
||||
}, 2000);
|
||||
}
|
||||
|
||||
// Export functions
|
||||
window.UndoRedo = {
|
||||
add: addAction,
|
||||
undo: undo,
|
||||
redo: redo
|
||||
};
|
||||
|
||||
// Initialize on DOM ready
|
||||
if (document.readyState === 'loading') {
|
||||
document.addEventListener('DOMContentLoaded', initUndoRedo);
|
||||
} else {
|
||||
initUndoRedo();
|
||||
}
|
||||
|
||||
})();
|
||||
94
src/server/web/static/js/user_preferences.js
Normal file
94
src/server/web/static/js/user_preferences.js
Normal file
@ -0,0 +1,94 @@
|
||||
/**
|
||||
* User Preferences Module
|
||||
* Manages user preferences and settings persistence
|
||||
*/
|
||||
|
||||
(function() {
|
||||
'use strict';
|
||||
|
||||
const STORAGE_KEY = 'aniworld_preferences';
|
||||
|
||||
/**
|
||||
* Initialize user preferences
|
||||
*/
|
||||
function initUserPreferences() {
|
||||
loadPreferences();
|
||||
console.log('[User Preferences] Initialized');
|
||||
}
|
||||
|
||||
/**
|
||||
* Load preferences from localStorage
|
||||
*/
|
||||
function loadPreferences() {
|
||||
try {
|
||||
const stored = localStorage.getItem(STORAGE_KEY);
|
||||
if (stored) {
|
||||
const preferences = JSON.parse(stored);
|
||||
applyPreferences(preferences);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('[User Preferences] Error loading:', error);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Save preferences to localStorage
|
||||
*/
|
||||
function savePreferences(preferences) {
|
||||
try {
|
||||
localStorage.setItem(STORAGE_KEY, JSON.stringify(preferences));
|
||||
} catch (error) {
|
||||
console.error('[User Preferences] Error saving:', error);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Apply preferences to the application
|
||||
*/
|
||||
function applyPreferences(preferences) {
|
||||
if (preferences.theme) {
|
||||
document.documentElement.setAttribute('data-theme', preferences.theme);
|
||||
}
|
||||
if (preferences.language) {
|
||||
// Language preference would be applied here
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get current preferences
|
||||
*/
|
||||
function getPreferences() {
|
||||
try {
|
||||
const stored = localStorage.getItem(STORAGE_KEY);
|
||||
return stored ? JSON.parse(stored) : {};
|
||||
} catch (error) {
|
||||
console.error('[User Preferences] Error getting preferences:', error);
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Update specific preference
|
||||
*/
|
||||
function updatePreference(key, value) {
|
||||
const preferences = getPreferences();
|
||||
preferences[key] = value;
|
||||
savePreferences(preferences);
|
||||
}
|
||||
|
||||
// Export functions
|
||||
window.UserPreferences = {
|
||||
load: loadPreferences,
|
||||
save: savePreferences,
|
||||
get: getPreferences,
|
||||
update: updatePreference
|
||||
};
|
||||
|
||||
// Initialize on DOM ready
|
||||
if (document.readyState === 'loading') {
|
||||
document.addEventListener('DOMContentLoaded', initUserPreferences);
|
||||
} else {
|
||||
initUserPreferences();
|
||||
}
|
||||
|
||||
})();
|
||||
233
src/server/web/static/js/websocket_client.js
Normal file
233
src/server/web/static/js/websocket_client.js
Normal file
@ -0,0 +1,233 @@
|
||||
/**
|
||||
* Native WebSocket Client Wrapper
|
||||
* Provides Socket.IO-like interface using native WebSocket API
|
||||
*
|
||||
* This wrapper maintains compatibility with existing Socket.IO-style
|
||||
* event handlers while using the modern WebSocket API underneath.
|
||||
*/
|
||||
|
||||
class WebSocketClient {
|
||||
constructor(url = null) {
|
||||
this.ws = null;
|
||||
this.url = url || this.getWebSocketUrl();
|
||||
this.eventHandlers = new Map();
|
||||
this.reconnectAttempts = 0;
|
||||
this.maxReconnectAttempts = 5;
|
||||
this.reconnectDelay = 1000;
|
||||
this.isConnected = false;
|
||||
this.rooms = new Set();
|
||||
this.messageQueue = [];
|
||||
this.autoReconnect = true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get WebSocket URL based on current page URL
|
||||
*/
|
||||
getWebSocketUrl() {
|
||||
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
|
||||
const host = window.location.host;
|
||||
return `${protocol}//${host}/ws/connect`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Connect to WebSocket server
|
||||
*/
|
||||
connect() {
|
||||
if (this.ws && this.ws.readyState === WebSocket.OPEN) {
|
||||
console.log('WebSocket already connected');
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
this.ws = new WebSocket(this.url);
|
||||
|
||||
this.ws.onopen = () => {
|
||||
console.log('WebSocket connected');
|
||||
this.isConnected = true;
|
||||
this.reconnectAttempts = 0;
|
||||
|
||||
// Emit connect event
|
||||
this.emit('connect');
|
||||
|
||||
// Rejoin rooms
|
||||
this.rejoinRooms();
|
||||
|
||||
// Process queued messages
|
||||
this.processMessageQueue();
|
||||
};
|
||||
|
||||
this.ws.onmessage = (event) => {
|
||||
this.handleMessage(event.data);
|
||||
};
|
||||
|
||||
this.ws.onerror = (error) => {
|
||||
console.error('WebSocket error:', error);
|
||||
this.emit('error', { error: 'WebSocket connection error' });
|
||||
};
|
||||
|
||||
this.ws.onclose = (event) => {
|
||||
console.log('WebSocket disconnected', event.code, event.reason);
|
||||
this.isConnected = false;
|
||||
this.emit('disconnect', { code: event.code, reason: event.reason });
|
||||
|
||||
// Attempt reconnection
|
||||
if (this.autoReconnect && this.reconnectAttempts < this.maxReconnectAttempts) {
|
||||
this.reconnectAttempts++;
|
||||
const delay = this.reconnectDelay * this.reconnectAttempts;
|
||||
console.log(`Attempting reconnection in ${delay}ms (attempt ${this.reconnectAttempts})`);
|
||||
setTimeout(() => this.connect(), delay);
|
||||
}
|
||||
};
|
||||
} catch (error) {
|
||||
console.error('Failed to create WebSocket connection:', error);
|
||||
this.emit('error', { error: 'Failed to connect' });
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Disconnect from WebSocket server
|
||||
*/
|
||||
disconnect() {
|
||||
this.autoReconnect = false;
|
||||
if (this.ws) {
|
||||
this.ws.close(1000, 'Client disconnected');
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle incoming WebSocket message
|
||||
*/
|
||||
handleMessage(data) {
|
||||
try {
|
||||
const message = JSON.parse(data);
|
||||
const { type, data: payload, timestamp } = message;
|
||||
|
||||
// Emit event with payload
|
||||
if (type) {
|
||||
this.emit(type, payload || {});
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to parse WebSocket message:', error, data);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Register event handler (Socket.IO-style)
|
||||
*/
|
||||
on(event, handler) {
|
||||
if (!this.eventHandlers.has(event)) {
|
||||
this.eventHandlers.set(event, []);
|
||||
}
|
||||
this.eventHandlers.get(event).push(handler);
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove event handler
|
||||
*/
|
||||
off(event, handler) {
|
||||
if (!this.eventHandlers.has(event)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const handlers = this.eventHandlers.get(event);
|
||||
const index = handlers.indexOf(handler);
|
||||
if (index !== -1) {
|
||||
handlers.splice(index, 1);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Emit event to registered handlers
|
||||
*/
|
||||
emit(event, data = null) {
|
||||
if (!this.eventHandlers.has(event)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const handlers = this.eventHandlers.get(event);
|
||||
handlers.forEach(handler => {
|
||||
try {
|
||||
if (data !== null) {
|
||||
handler(data);
|
||||
} else {
|
||||
handler();
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(`Error in event handler for ${event}:`, error);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Send message to server
|
||||
*/
|
||||
send(type, data = {}) {
|
||||
const message = JSON.stringify({
|
||||
type,
|
||||
data,
|
||||
timestamp: new Date().toISOString()
|
||||
});
|
||||
|
||||
if (this.isConnected && this.ws.readyState === WebSocket.OPEN) {
|
||||
this.ws.send(message);
|
||||
} else {
|
||||
console.warn('WebSocket not connected, queueing message');
|
||||
this.messageQueue.push(message);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Join a room (subscribe to topic)
|
||||
*/
|
||||
join(room) {
|
||||
this.rooms.add(room);
|
||||
if (this.isConnected) {
|
||||
this.send('join', { room });
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Leave a room (unsubscribe from topic)
|
||||
*/
|
||||
leave(room) {
|
||||
this.rooms.delete(room);
|
||||
if (this.isConnected) {
|
||||
this.send('leave', { room });
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Rejoin all rooms after reconnection
|
||||
*/
|
||||
rejoinRooms() {
|
||||
this.rooms.forEach(room => {
|
||||
this.send('join', { room });
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Process queued messages after connection
|
||||
*/
|
||||
processMessageQueue() {
|
||||
while (this.messageQueue.length > 0 && this.isConnected) {
|
||||
const message = this.messageQueue.shift();
|
||||
this.ws.send(message);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if connected
|
||||
*/
|
||||
connected() {
|
||||
return this.isConnected && this.ws && this.ws.readyState === WebSocket.OPEN;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create global io() function for Socket.IO compatibility
|
||||
*/
|
||||
function io(url = null) {
|
||||
const client = new WebSocketClient(url);
|
||||
client.connect();
|
||||
return client;
|
||||
}
|
||||
@ -455,7 +455,7 @@
|
||||
</div>
|
||||
|
||||
<!-- Scripts -->
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/socket.io/4.0.1/socket.io.js"></script>
|
||||
<script src="/static/js/websocket_client.js"></script>
|
||||
<script src="/static/js/localization.js"></script>
|
||||
|
||||
<!-- UX Enhancement Scripts -->
|
||||
|
||||
@ -245,7 +245,7 @@
|
||||
</div>
|
||||
|
||||
<!-- Scripts -->
|
||||
<script src="https://cdnjs.cloudflare.com/ajax/libs/socket.io/4.0.1/socket.io.js"></script>
|
||||
<script src="/static/js/websocket_client.js"></script>
|
||||
<script src="/static/js/queue.js"></script>
|
||||
</body>
|
||||
|
||||
|
||||
49
tests/api/test_anime_endpoints.py
Normal file
49
tests/api/test_anime_endpoints.py
Normal file
@ -0,0 +1,49 @@
|
||||
import asyncio
|
||||
|
||||
from src.server.api import anime as anime_module
|
||||
|
||||
|
||||
class FakeSerie:
|
||||
def __init__(self, key, name, folder, episodeDict=None):
|
||||
self.key = key
|
||||
self.name = name
|
||||
self.folder = folder
|
||||
self.episodeDict = episodeDict or {}
|
||||
|
||||
|
||||
class FakeSeriesApp:
|
||||
def __init__(self):
|
||||
self.List = self
|
||||
self._items = [
|
||||
FakeSerie("1", "Test Show", "test_show", {1: [1, 2]}),
|
||||
FakeSerie("2", "Complete Show", "complete_show", {}),
|
||||
]
|
||||
|
||||
def GetMissingEpisode(self):
|
||||
return [s for s in self._items if s.episodeDict]
|
||||
|
||||
def GetList(self):
|
||||
return self._items
|
||||
|
||||
def ReScan(self, callback):
|
||||
callback()
|
||||
|
||||
|
||||
def test_list_anime_direct_call():
|
||||
fake = FakeSeriesApp()
|
||||
result = asyncio.run(anime_module.list_anime(series_app=fake))
|
||||
assert isinstance(result, list)
|
||||
assert any(item.title == "Test Show" for item in result)
|
||||
|
||||
|
||||
def test_get_anime_detail_direct_call():
|
||||
fake = FakeSeriesApp()
|
||||
result = asyncio.run(anime_module.get_anime("1", series_app=fake))
|
||||
assert result.title == "Test Show"
|
||||
assert "1-1" in result.episodes
|
||||
|
||||
|
||||
def test_rescan_direct_call():
|
||||
fake = FakeSeriesApp()
|
||||
result = asyncio.run(anime_module.trigger_rescan(series_app=fake))
|
||||
assert result["success"] is True
|
||||
36
tests/api/test_config_endpoints.py
Normal file
36
tests/api/test_config_endpoints.py
Normal file
@ -0,0 +1,36 @@
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from src.server.fastapi_app import app
|
||||
from src.server.models.config import AppConfig, SchedulerConfig
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
def test_get_config_public():
|
||||
resp = client.get("/api/config")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "name" in data
|
||||
assert "data_dir" in data
|
||||
|
||||
|
||||
def test_validate_config():
|
||||
cfg = {
|
||||
"name": "Aniworld",
|
||||
"data_dir": "data",
|
||||
"scheduler": {"enabled": True, "interval_minutes": 30},
|
||||
"logging": {"level": "INFO"},
|
||||
"backup": {"enabled": False},
|
||||
"other": {},
|
||||
}
|
||||
resp = client.post("/api/config/validate", json=cfg)
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body.get("valid") is True
|
||||
|
||||
|
||||
def test_update_config_unauthorized():
|
||||
# update requires auth; without auth should be 401
|
||||
update = {"scheduler": {"enabled": False}}
|
||||
resp = client.put("/api/config", json=update)
|
||||
assert resp.status_code in (401, 422)
|
||||
443
tests/api/test_download_endpoints.py
Normal file
443
tests/api/test_download_endpoints.py
Normal file
@ -0,0 +1,443 @@
|
||||
"""Tests for download queue API endpoints."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from src.server.fastapi_app import app
|
||||
from src.server.models.download import DownloadPriority, QueueStats, QueueStatus
|
||||
from src.server.services.auth_service import auth_service
|
||||
from src.server.services.download_service import DownloadServiceError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def authenticated_client():
|
||||
"""Create authenticated async client."""
|
||||
# Ensure auth is configured for test
|
||||
if not auth_service.is_configured():
|
||||
auth_service.setup_master_password("TestPass123!")
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(
|
||||
transport=transport, base_url="http://test"
|
||||
) as client:
|
||||
# Login to get token
|
||||
r = await client.post(
|
||||
"/api/auth/login", json={"password": "TestPass123!"}
|
||||
)
|
||||
assert r.status_code == 200
|
||||
token = r.json()["access_token"]
|
||||
|
||||
# Set authorization header for all requests
|
||||
client.headers["Authorization"] = f"Bearer {token}"
|
||||
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_download_service():
|
||||
"""Mock DownloadService for testing."""
|
||||
with patch(
|
||||
"src.server.utils.dependencies.get_download_service"
|
||||
) as mock:
|
||||
service = MagicMock()
|
||||
|
||||
# Mock queue status
|
||||
service.get_queue_status = AsyncMock(
|
||||
return_value=QueueStatus(
|
||||
is_running=True,
|
||||
is_paused=False,
|
||||
active_downloads=[],
|
||||
pending_queue=[],
|
||||
completed_downloads=[],
|
||||
failed_downloads=[],
|
||||
)
|
||||
)
|
||||
|
||||
# Mock queue stats
|
||||
service.get_queue_stats = AsyncMock(
|
||||
return_value=QueueStats(
|
||||
total_items=0,
|
||||
pending_count=0,
|
||||
active_count=0,
|
||||
completed_count=0,
|
||||
failed_count=0,
|
||||
total_downloaded_mb=0.0,
|
||||
)
|
||||
)
|
||||
|
||||
# Mock add_to_queue
|
||||
service.add_to_queue = AsyncMock(
|
||||
return_value=["item-id-1", "item-id-2"]
|
||||
)
|
||||
|
||||
# Mock remove_from_queue
|
||||
service.remove_from_queue = AsyncMock(return_value=["item-id-1"])
|
||||
|
||||
# Mock reorder_queue
|
||||
service.reorder_queue = AsyncMock(return_value=True)
|
||||
|
||||
# Mock start/stop/pause/resume
|
||||
service.start = AsyncMock()
|
||||
service.stop = AsyncMock()
|
||||
service.pause_queue = AsyncMock()
|
||||
service.resume_queue = AsyncMock()
|
||||
|
||||
# Mock clear_completed and retry_failed
|
||||
service.clear_completed = AsyncMock(return_value=5)
|
||||
service.retry_failed = AsyncMock(return_value=["item-id-3"])
|
||||
|
||||
mock.return_value = service
|
||||
yield service
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_queue_status(authenticated_client, mock_download_service):
|
||||
"""Test GET /api/queue/status endpoint."""
|
||||
response = await authenticated_client.get("/api/queue/status")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert "status" in data
|
||||
assert "statistics" in data
|
||||
assert data["status"]["is_running"] is True
|
||||
assert data["status"]["is_paused"] is False
|
||||
|
||||
mock_download_service.get_queue_status.assert_called_once()
|
||||
mock_download_service.get_queue_stats.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_queue_status_unauthorized():
|
||||
"""Test GET /api/queue/status without authentication."""
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(
|
||||
transport=transport, base_url="http://test"
|
||||
) as client:
|
||||
response = await client.get("/api/queue/status")
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_add_to_queue(authenticated_client, mock_download_service):
|
||||
"""Test POST /api/queue/add endpoint."""
|
||||
request_data = {
|
||||
"serie_id": "series-1",
|
||||
"serie_name": "Test Anime",
|
||||
"episodes": [
|
||||
{"season": 1, "episode": 1},
|
||||
{"season": 1, "episode": 2},
|
||||
],
|
||||
"priority": "normal",
|
||||
}
|
||||
|
||||
response = await authenticated_client.post(
|
||||
"/api/queue/add", json=request_data
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
|
||||
assert data["status"] == "success"
|
||||
assert len(data["added_items"]) == 2
|
||||
assert data["added_items"] == ["item-id-1", "item-id-2"]
|
||||
|
||||
mock_download_service.add_to_queue.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_add_to_queue_with_high_priority(
|
||||
authenticated_client, mock_download_service
|
||||
):
|
||||
"""Test adding items with HIGH priority."""
|
||||
request_data = {
|
||||
"serie_id": "series-1",
|
||||
"serie_name": "Test Anime",
|
||||
"episodes": [{"season": 1, "episode": 1}],
|
||||
"priority": "high",
|
||||
}
|
||||
|
||||
response = await authenticated_client.post(
|
||||
"/api/queue/add", json=request_data
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
|
||||
# Verify priority was passed correctly
|
||||
call_args = mock_download_service.add_to_queue.call_args
|
||||
assert call_args[1]["priority"] == DownloadPriority.HIGH
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_add_to_queue_empty_episodes(
|
||||
authenticated_client, mock_download_service
|
||||
):
|
||||
"""Test adding empty episodes list returns 400."""
|
||||
request_data = {
|
||||
"serie_id": "series-1",
|
||||
"serie_name": "Test Anime",
|
||||
"episodes": [],
|
||||
"priority": "normal",
|
||||
}
|
||||
|
||||
response = await authenticated_client.post(
|
||||
"/api/queue/add", json=request_data
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_add_to_queue_service_error(
|
||||
authenticated_client, mock_download_service
|
||||
):
|
||||
"""Test adding to queue when service raises error."""
|
||||
mock_download_service.add_to_queue.side_effect = DownloadServiceError(
|
||||
"Queue full"
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"serie_id": "series-1",
|
||||
"serie_name": "Test Anime",
|
||||
"episodes": [{"season": 1, "episode": 1}],
|
||||
"priority": "normal",
|
||||
}
|
||||
|
||||
response = await authenticated_client.post(
|
||||
"/api/queue/add", json=request_data
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "Queue full" in response.json()["detail"]
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_remove_from_queue_single(
|
||||
authenticated_client, mock_download_service
|
||||
):
|
||||
"""Test DELETE /api/queue/{item_id} endpoint."""
|
||||
response = await authenticated_client.delete("/api/queue/item-id-1")
|
||||
|
||||
assert response.status_code == 204
|
||||
|
||||
mock_download_service.remove_from_queue.assert_called_once_with(
|
||||
["item-id-1"]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_remove_from_queue_not_found(
|
||||
authenticated_client, mock_download_service
|
||||
):
|
||||
"""Test removing non-existent item returns 404."""
|
||||
mock_download_service.remove_from_queue.return_value = []
|
||||
|
||||
response = await authenticated_client.delete(
|
||||
"/api/queue/non-existent-id"
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_remove_multiple_from_queue(
|
||||
authenticated_client, mock_download_service
|
||||
):
|
||||
"""Test DELETE /api/queue/ with multiple items."""
|
||||
request_data = {"item_ids": ["item-id-1", "item-id-2"]}
|
||||
|
||||
response = await authenticated_client.request(
|
||||
"DELETE", "/api/queue/", json=request_data
|
||||
)
|
||||
|
||||
assert response.status_code == 204
|
||||
|
||||
mock_download_service.remove_from_queue.assert_called_once_with(
|
||||
["item-id-1", "item-id-2"]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_remove_multiple_empty_list(
|
||||
authenticated_client, mock_download_service
|
||||
):
|
||||
"""Test removing with empty item list returns 400."""
|
||||
request_data = {"item_ids": []}
|
||||
|
||||
response = await authenticated_client.request(
|
||||
"DELETE", "/api/queue/", json=request_data
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_start_queue(authenticated_client, mock_download_service):
|
||||
"""Test POST /api/queue/start endpoint."""
|
||||
response = await authenticated_client.post("/api/queue/start")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["status"] == "success"
|
||||
assert "started" in data["message"].lower()
|
||||
|
||||
mock_download_service.start.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_stop_queue(authenticated_client, mock_download_service):
|
||||
"""Test POST /api/queue/stop endpoint."""
|
||||
response = await authenticated_client.post("/api/queue/stop")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["status"] == "success"
|
||||
assert "stopped" in data["message"].lower()
|
||||
|
||||
mock_download_service.stop.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_pause_queue(authenticated_client, mock_download_service):
|
||||
"""Test POST /api/queue/pause endpoint."""
|
||||
response = await authenticated_client.post("/api/queue/pause")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["status"] == "success"
|
||||
assert "paused" in data["message"].lower()
|
||||
|
||||
mock_download_service.pause_queue.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_resume_queue(authenticated_client, mock_download_service):
|
||||
"""Test POST /api/queue/resume endpoint."""
|
||||
response = await authenticated_client.post("/api/queue/resume")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["status"] == "success"
|
||||
assert "resumed" in data["message"].lower()
|
||||
|
||||
mock_download_service.resume_queue.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_reorder_queue(authenticated_client, mock_download_service):
|
||||
"""Test POST /api/queue/reorder endpoint."""
|
||||
request_data = {"item_id": "item-id-1", "new_position": 0}
|
||||
|
||||
response = await authenticated_client.post(
|
||||
"/api/queue/reorder", json=request_data
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["status"] == "success"
|
||||
|
||||
mock_download_service.reorder_queue.assert_called_once_with(
|
||||
item_id="item-id-1", new_position=0
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_reorder_queue_not_found(
|
||||
authenticated_client, mock_download_service
|
||||
):
|
||||
"""Test reordering non-existent item returns 404."""
|
||||
mock_download_service.reorder_queue.return_value = False
|
||||
|
||||
request_data = {"item_id": "non-existent", "new_position": 0}
|
||||
|
||||
response = await authenticated_client.post(
|
||||
"/api/queue/reorder", json=request_data
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_clear_completed(authenticated_client, mock_download_service):
|
||||
"""Test DELETE /api/queue/completed endpoint."""
|
||||
response = await authenticated_client.delete("/api/queue/completed")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["status"] == "success"
|
||||
assert data["count"] == 5
|
||||
|
||||
mock_download_service.clear_completed.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_retry_failed(authenticated_client, mock_download_service):
|
||||
"""Test POST /api/queue/retry endpoint."""
|
||||
request_data = {"item_ids": ["item-id-3"]}
|
||||
|
||||
response = await authenticated_client.post(
|
||||
"/api/queue/retry", json=request_data
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["status"] == "success"
|
||||
assert len(data["retried_ids"]) == 1
|
||||
|
||||
mock_download_service.retry_failed.assert_called_once_with(
|
||||
["item-id-3"]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_retry_all_failed(authenticated_client, mock_download_service):
|
||||
"""Test retrying all failed items with empty list."""
|
||||
request_data = {"item_ids": []}
|
||||
|
||||
response = await authenticated_client.post(
|
||||
"/api/queue/retry", json=request_data
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
# Should call retry_failed with None to retry all
|
||||
mock_download_service.retry_failed.assert_called_once_with(None)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_queue_endpoints_require_auth():
|
||||
"""Test that all queue endpoints require authentication."""
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(
|
||||
transport=transport, base_url="http://test"
|
||||
) as client:
|
||||
# Test all endpoints without auth
|
||||
endpoints = [
|
||||
("GET", "/api/queue/status"),
|
||||
("POST", "/api/queue/add"),
|
||||
("DELETE", "/api/queue/item-1"),
|
||||
("POST", "/api/queue/start"),
|
||||
("POST", "/api/queue/stop"),
|
||||
("POST", "/api/queue/pause"),
|
||||
("POST", "/api/queue/resume"),
|
||||
]
|
||||
|
||||
for method, url in endpoints:
|
||||
if method == "GET":
|
||||
response = await client.get(url)
|
||||
elif method == "POST":
|
||||
response = await client.post(url, json={})
|
||||
elif method == "DELETE":
|
||||
response = await client.delete(url)
|
||||
|
||||
assert response.status_code == 401, (
|
||||
f"{method} {url} should require auth"
|
||||
)
|
||||
470
tests/integration/test_websocket_integration.py
Normal file
470
tests/integration/test_websocket_integration.py
Normal file
@ -0,0 +1,470 @@
|
||||
"""Integration tests for WebSocket integration with core services.
|
||||
|
||||
This module tests the integration between WebSocket broadcasting and
|
||||
core services (DownloadService, AnimeService, ProgressService) to ensure
|
||||
real-time updates are properly broadcasted to connected clients.
|
||||
"""
|
||||
import asyncio
|
||||
from typing import Any, Dict, List
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src.server.models.download import (
|
||||
DownloadPriority,
|
||||
DownloadStatus,
|
||||
EpisodeIdentifier,
|
||||
)
|
||||
from src.server.services.anime_service import AnimeService
|
||||
from src.server.services.download_service import DownloadService
|
||||
from src.server.services.progress_service import ProgressService, ProgressType
|
||||
from src.server.services.websocket_service import WebSocketService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_series_app():
|
||||
"""Mock SeriesApp for testing."""
|
||||
app = Mock()
|
||||
app.series_list = []
|
||||
app.search = Mock(return_value=[])
|
||||
app.ReScan = Mock()
|
||||
app.download = Mock(return_value=True)
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def progress_service():
|
||||
"""Create a ProgressService instance for testing."""
|
||||
return ProgressService()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def websocket_service():
|
||||
"""Create a WebSocketService instance for testing."""
|
||||
return WebSocketService()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def anime_service(mock_series_app, progress_service):
|
||||
"""Create an AnimeService with mocked dependencies."""
|
||||
with patch("src.server.services.anime_service.SeriesApp", return_value=mock_series_app):
|
||||
service = AnimeService(
|
||||
directory="/test/anime",
|
||||
progress_service=progress_service,
|
||||
)
|
||||
yield service
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def download_service(anime_service, progress_service):
|
||||
"""Create a DownloadService with dependencies."""
|
||||
service = DownloadService(
|
||||
anime_service=anime_service,
|
||||
max_concurrent_downloads=2,
|
||||
progress_service=progress_service,
|
||||
persistence_path="/tmp/test_queue.json",
|
||||
)
|
||||
yield service
|
||||
await service.stop()
|
||||
|
||||
|
||||
class TestWebSocketDownloadIntegration:
|
||||
"""Test WebSocket integration with DownloadService."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_progress_broadcast(
|
||||
self, download_service, websocket_service
|
||||
):
|
||||
"""Test that download progress updates are broadcasted."""
|
||||
broadcasts: List[Dict[str, Any]] = []
|
||||
|
||||
async def mock_broadcast(update_type: str, data: dict):
|
||||
"""Capture broadcast calls."""
|
||||
broadcasts.append({"type": update_type, "data": data})
|
||||
|
||||
download_service.set_broadcast_callback(mock_broadcast)
|
||||
|
||||
# Add item to queue
|
||||
item_ids = await download_service.add_to_queue(
|
||||
serie_id="test_serie",
|
||||
serie_name="Test Anime",
|
||||
episodes=[EpisodeIdentifier(season=1, episode=1)],
|
||||
priority=DownloadPriority.HIGH,
|
||||
)
|
||||
|
||||
assert len(item_ids) == 1
|
||||
assert len(broadcasts) == 1
|
||||
assert broadcasts[0]["type"] == "queue_status"
|
||||
assert broadcasts[0]["data"]["action"] == "items_added"
|
||||
assert item_ids[0] in broadcasts[0]["data"]["added_ids"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_queue_operations_broadcast(
|
||||
self, download_service
|
||||
):
|
||||
"""Test that queue operations broadcast status updates."""
|
||||
broadcasts: List[Dict[str, Any]] = []
|
||||
|
||||
async def mock_broadcast(update_type: str, data: dict):
|
||||
broadcasts.append({"type": update_type, "data": data})
|
||||
|
||||
download_service.set_broadcast_callback(mock_broadcast)
|
||||
|
||||
# Add items
|
||||
item_ids = await download_service.add_to_queue(
|
||||
serie_id="test",
|
||||
serie_name="Test",
|
||||
episodes=[EpisodeIdentifier(season=1, episode=i) for i in range(1, 4)],
|
||||
priority=DownloadPriority.NORMAL,
|
||||
)
|
||||
|
||||
# Remove items
|
||||
removed = await download_service.remove_from_queue([item_ids[0]])
|
||||
assert len(removed) == 1
|
||||
|
||||
# Check broadcasts
|
||||
add_broadcast = next(
|
||||
b for b in broadcasts
|
||||
if b["data"].get("action") == "items_added"
|
||||
)
|
||||
remove_broadcast = next(
|
||||
b for b in broadcasts
|
||||
if b["data"].get("action") == "items_removed"
|
||||
)
|
||||
|
||||
assert add_broadcast["type"] == "queue_status"
|
||||
assert len(add_broadcast["data"]["added_ids"]) == 3
|
||||
|
||||
assert remove_broadcast["type"] == "queue_status"
|
||||
assert item_ids[0] in remove_broadcast["data"]["removed_ids"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_queue_start_stop_broadcast(
|
||||
self, download_service
|
||||
):
|
||||
"""Test that start/stop operations broadcast updates."""
|
||||
broadcasts: List[Dict[str, Any]] = []
|
||||
|
||||
async def mock_broadcast(update_type: str, data: dict):
|
||||
broadcasts.append({"type": update_type, "data": data})
|
||||
|
||||
download_service.set_broadcast_callback(mock_broadcast)
|
||||
|
||||
# Start queue
|
||||
await download_service.start()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Stop queue
|
||||
await download_service.stop()
|
||||
|
||||
# Find start/stop broadcasts
|
||||
start_broadcast = next(
|
||||
(b for b in broadcasts if b["type"] == "queue_started"),
|
||||
None,
|
||||
)
|
||||
stop_broadcast = next(
|
||||
(b for b in broadcasts if b["type"] == "queue_stopped"),
|
||||
None,
|
||||
)
|
||||
|
||||
assert start_broadcast is not None
|
||||
assert start_broadcast["data"]["is_running"] is True
|
||||
|
||||
assert stop_broadcast is not None
|
||||
assert stop_broadcast["data"]["is_running"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_queue_pause_resume_broadcast(
|
||||
self, download_service
|
||||
):
|
||||
"""Test that pause/resume operations broadcast updates."""
|
||||
broadcasts: List[Dict[str, Any]] = []
|
||||
|
||||
async def mock_broadcast(update_type: str, data: dict):
|
||||
broadcasts.append({"type": update_type, "data": data})
|
||||
|
||||
download_service.set_broadcast_callback(mock_broadcast)
|
||||
|
||||
# Pause queue
|
||||
await download_service.pause_queue()
|
||||
|
||||
# Resume queue
|
||||
await download_service.resume_queue()
|
||||
|
||||
# Find pause/resume broadcasts
|
||||
pause_broadcast = next(
|
||||
(b for b in broadcasts if b["type"] == "queue_paused"),
|
||||
None,
|
||||
)
|
||||
resume_broadcast = next(
|
||||
(b for b in broadcasts if b["type"] == "queue_resumed"),
|
||||
None,
|
||||
)
|
||||
|
||||
assert pause_broadcast is not None
|
||||
assert pause_broadcast["data"]["is_paused"] is True
|
||||
|
||||
assert resume_broadcast is not None
|
||||
assert resume_broadcast["data"]["is_paused"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_completed_broadcast(
|
||||
self, download_service
|
||||
):
|
||||
"""Test that clearing completed items broadcasts update."""
|
||||
broadcasts: List[Dict[str, Any]] = []
|
||||
|
||||
async def mock_broadcast(update_type: str, data: dict):
|
||||
broadcasts.append({"type": update_type, "data": data})
|
||||
|
||||
download_service.set_broadcast_callback(mock_broadcast)
|
||||
|
||||
# Manually add a completed item to test
|
||||
from datetime import datetime
|
||||
|
||||
from src.server.models.download import DownloadItem
|
||||
|
||||
completed_item = DownloadItem(
|
||||
id="test_completed",
|
||||
serie_id="test",
|
||||
serie_name="Test",
|
||||
episode=EpisodeIdentifier(season=1, episode=1),
|
||||
status=DownloadStatus.COMPLETED,
|
||||
priority=DownloadPriority.NORMAL,
|
||||
added_at=datetime.utcnow(),
|
||||
)
|
||||
download_service._completed_items.append(completed_item)
|
||||
|
||||
# Clear completed
|
||||
count = await download_service.clear_completed()
|
||||
|
||||
assert count == 1
|
||||
|
||||
# Find clear broadcast
|
||||
clear_broadcast = next(
|
||||
(
|
||||
b for b in broadcasts
|
||||
if b["data"].get("action") == "completed_cleared"
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
assert clear_broadcast is not None
|
||||
assert clear_broadcast["data"]["cleared_count"] == 1
|
||||
|
||||
|
||||
class TestWebSocketScanIntegration:
|
||||
"""Test WebSocket integration with AnimeService scan operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_progress_broadcast(
|
||||
self, anime_service, progress_service, mock_series_app
|
||||
):
|
||||
"""Test that scan progress updates are broadcasted."""
|
||||
broadcasts: List[Dict[str, Any]] = []
|
||||
|
||||
async def mock_broadcast(message_type: str, data: dict, room: str):
|
||||
"""Capture broadcast calls."""
|
||||
broadcasts.append({
|
||||
"type": message_type,
|
||||
"data": data,
|
||||
"room": room,
|
||||
})
|
||||
|
||||
progress_service.set_broadcast_callback(mock_broadcast)
|
||||
|
||||
# Mock scan callback to simulate progress
|
||||
def mock_scan_callback(callback):
|
||||
"""Simulate scan progress."""
|
||||
if callback:
|
||||
callback({"current": 5, "total": 10, "message": "Scanning..."})
|
||||
callback({"current": 10, "total": 10, "message": "Complete"})
|
||||
|
||||
mock_series_app.ReScan = mock_scan_callback
|
||||
|
||||
# Run scan
|
||||
await anime_service.rescan()
|
||||
|
||||
# Verify broadcasts were made
|
||||
assert len(broadcasts) >= 2 # At least start and complete
|
||||
|
||||
# Check for scan progress broadcasts
|
||||
scan_broadcasts = [
|
||||
b for b in broadcasts if b["room"] == "scan_progress"
|
||||
]
|
||||
assert len(scan_broadcasts) >= 2
|
||||
|
||||
# Verify start broadcast
|
||||
start_broadcast = scan_broadcasts[0]
|
||||
assert start_broadcast["data"]["status"] == "started"
|
||||
assert start_broadcast["data"]["type"] == ProgressType.SCAN.value
|
||||
|
||||
# Verify completion broadcast
|
||||
complete_broadcast = scan_broadcasts[-1]
|
||||
assert complete_broadcast["data"]["status"] == "completed"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_failure_broadcast(
|
||||
self, anime_service, progress_service, mock_series_app
|
||||
):
|
||||
"""Test that scan failures are broadcasted."""
|
||||
broadcasts: List[Dict[str, Any]] = []
|
||||
|
||||
async def mock_broadcast(message_type: str, data: dict, room: str):
|
||||
broadcasts.append({
|
||||
"type": message_type,
|
||||
"data": data,
|
||||
"room": room,
|
||||
})
|
||||
|
||||
progress_service.set_broadcast_callback(mock_broadcast)
|
||||
|
||||
# Mock scan to raise error
|
||||
def mock_scan_error(callback):
|
||||
raise RuntimeError("Scan failed")
|
||||
|
||||
mock_series_app.ReScan = mock_scan_error
|
||||
|
||||
# Run scan (should fail)
|
||||
with pytest.raises(Exception):
|
||||
await anime_service.rescan()
|
||||
|
||||
# Verify failure broadcast
|
||||
scan_broadcasts = [
|
||||
b for b in broadcasts if b["room"] == "scan_progress"
|
||||
]
|
||||
assert len(scan_broadcasts) >= 2 # Start and fail
|
||||
|
||||
# Verify failure broadcast
|
||||
fail_broadcast = scan_broadcasts[-1]
|
||||
assert fail_broadcast["data"]["status"] == "failed"
|
||||
# Verify error message or failed status
|
||||
is_error = "error" in fail_broadcast["data"]["message"].lower()
|
||||
is_failed = fail_broadcast["data"]["status"] == "failed"
|
||||
assert is_error or is_failed
|
||||
|
||||
|
||||
class TestWebSocketProgressIntegration:
|
||||
"""Test WebSocket integration with ProgressService."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_progress_lifecycle_broadcast(
|
||||
self, progress_service
|
||||
):
|
||||
"""Test that progress lifecycle events are broadcasted."""
|
||||
broadcasts: List[Dict[str, Any]] = []
|
||||
|
||||
async def mock_broadcast(message_type: str, data: dict, room: str):
|
||||
broadcasts.append({
|
||||
"type": message_type,
|
||||
"data": data,
|
||||
"room": room,
|
||||
})
|
||||
|
||||
progress_service.set_broadcast_callback(mock_broadcast)
|
||||
|
||||
# Start progress
|
||||
await progress_service.start_progress(
|
||||
progress_id="test_progress",
|
||||
progress_type=ProgressType.DOWNLOAD,
|
||||
title="Test Download",
|
||||
total=100,
|
||||
)
|
||||
|
||||
# Update progress
|
||||
await progress_service.update_progress(
|
||||
progress_id="test_progress",
|
||||
current=50,
|
||||
force_broadcast=True,
|
||||
)
|
||||
|
||||
# Complete progress
|
||||
await progress_service.complete_progress(
|
||||
progress_id="test_progress",
|
||||
message="Download complete",
|
||||
)
|
||||
|
||||
# Verify broadcasts
|
||||
assert len(broadcasts) == 3
|
||||
|
||||
start_broadcast = broadcasts[0]
|
||||
assert start_broadcast["data"]["status"] == "started"
|
||||
assert start_broadcast["room"] == "download_progress"
|
||||
|
||||
update_broadcast = broadcasts[1]
|
||||
assert update_broadcast["data"]["status"] == "in_progress"
|
||||
assert update_broadcast["data"]["percent"] == 50.0
|
||||
|
||||
complete_broadcast = broadcasts[2]
|
||||
assert complete_broadcast["data"]["status"] == "completed"
|
||||
assert complete_broadcast["data"]["percent"] == 100.0
|
||||
|
||||
|
||||
class TestWebSocketEndToEnd:
|
||||
"""End-to-end integration tests with all services."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_download_flow_with_broadcasts(
|
||||
self, download_service, anime_service, progress_service
|
||||
):
|
||||
"""Test complete download flow with all broadcasts."""
|
||||
all_broadcasts: List[Dict[str, Any]] = []
|
||||
|
||||
async def capture_download_broadcast(update_type: str, data: dict):
|
||||
all_broadcasts.append({
|
||||
"source": "download",
|
||||
"type": update_type,
|
||||
"data": data,
|
||||
})
|
||||
|
||||
async def capture_progress_broadcast(
|
||||
message_type: str, data: dict, room: str
|
||||
):
|
||||
all_broadcasts.append({
|
||||
"source": "progress",
|
||||
"type": message_type,
|
||||
"data": data,
|
||||
"room": room,
|
||||
})
|
||||
|
||||
download_service.set_broadcast_callback(capture_download_broadcast)
|
||||
progress_service.set_broadcast_callback(capture_progress_broadcast)
|
||||
|
||||
# Add items to queue
|
||||
item_ids = await download_service.add_to_queue(
|
||||
serie_id="test",
|
||||
serie_name="Test Anime",
|
||||
episodes=[EpisodeIdentifier(season=1, episode=1)],
|
||||
priority=DownloadPriority.HIGH,
|
||||
)
|
||||
|
||||
# Start queue
|
||||
await download_service.start()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Pause queue
|
||||
await download_service.pause_queue()
|
||||
|
||||
# Resume queue
|
||||
await download_service.resume_queue()
|
||||
|
||||
# Stop queue
|
||||
await download_service.stop()
|
||||
|
||||
# Verify we received broadcasts from both services
|
||||
download_broadcasts = [
|
||||
b for b in all_broadcasts if b["source"] == "download"
|
||||
]
|
||||
|
||||
assert len(download_broadcasts) >= 4 # add, start, pause, resume, stop
|
||||
assert len(item_ids) == 1
|
||||
|
||||
# Verify queue status broadcasts
|
||||
queue_status_broadcasts = [
|
||||
b for b in download_broadcasts if b["type"] == "queue_status"
|
||||
]
|
||||
assert len(queue_status_broadcasts) >= 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
109
tests/unit/test_anime_models.py
Normal file
109
tests/unit/test_anime_models.py
Normal file
@ -0,0 +1,109 @@
|
||||
from pydantic import ValidationError
|
||||
|
||||
from src.server.models.anime import (
|
||||
AnimeSeriesResponse,
|
||||
EpisodeInfo,
|
||||
MissingEpisodeInfo,
|
||||
SearchRequest,
|
||||
SearchResult,
|
||||
)
|
||||
|
||||
|
||||
def test_episode_info_basic():
|
||||
ep = EpisodeInfo(episode_number=1, title="Pilot", duration_seconds=1500)
|
||||
assert ep.episode_number == 1
|
||||
assert ep.title == "Pilot"
|
||||
assert ep.duration_seconds == 1500
|
||||
assert ep.available is True
|
||||
|
||||
|
||||
def test_missing_episode_count():
|
||||
m = MissingEpisodeInfo(from_episode=5, to_episode=7)
|
||||
assert m.count == 3
|
||||
|
||||
|
||||
def test_anime_series_response():
|
||||
ep = EpisodeInfo(episode_number=1, title="Ep1")
|
||||
series = AnimeSeriesResponse(
|
||||
id="series-123",
|
||||
title="My Anime",
|
||||
episodes=[ep],
|
||||
total_episodes=12,
|
||||
)
|
||||
|
||||
assert series.id == "series-123"
|
||||
assert series.episodes[0].title == "Ep1"
|
||||
|
||||
|
||||
def test_search_request_validation():
|
||||
# valid
|
||||
req = SearchRequest(query="naruto", limit=5)
|
||||
assert req.query == "naruto"
|
||||
|
||||
# invalid: empty query
|
||||
try:
|
||||
SearchRequest(query="", limit=5)
|
||||
raised = False
|
||||
except ValidationError:
|
||||
raised = True
|
||||
assert raised
|
||||
|
||||
|
||||
def test_search_result_optional_fields():
|
||||
res = SearchResult(id="s1", title="T1", snippet="snip", score=0.9)
|
||||
assert res.score == 0.9
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from src.server.models.anime import (
|
||||
AnimeSeriesResponse,
|
||||
EpisodeInfo,
|
||||
MissingEpisodeInfo,
|
||||
SearchRequest,
|
||||
SearchResult,
|
||||
)
|
||||
|
||||
|
||||
def test_episode_info_basic():
|
||||
ep = EpisodeInfo(episode_number=1, title="Pilot", duration_seconds=1500)
|
||||
assert ep.episode_number == 1
|
||||
assert ep.title == "Pilot"
|
||||
assert ep.duration_seconds == 1500
|
||||
assert ep.available is True
|
||||
|
||||
|
||||
def test_missing_episode_count():
|
||||
m = MissingEpisodeInfo(from_episode=5, to_episode=7)
|
||||
assert m.count == 3
|
||||
|
||||
|
||||
def test_anime_series_response():
|
||||
ep = EpisodeInfo(episode_number=1, title="Ep1")
|
||||
series = AnimeSeriesResponse(
|
||||
id="series-123",
|
||||
title="My Anime",
|
||||
episodes=[ep],
|
||||
total_episodes=12,
|
||||
)
|
||||
|
||||
assert series.id == "series-123"
|
||||
assert series.episodes[0].title == "Ep1"
|
||||
|
||||
|
||||
def test_search_request_validation():
|
||||
# valid
|
||||
req = SearchRequest(query="naruto", limit=5)
|
||||
assert req.query == "naruto"
|
||||
|
||||
# invalid: empty query
|
||||
try:
|
||||
SearchRequest(query="", limit=5)
|
||||
raised = False
|
||||
except ValidationError:
|
||||
raised = True
|
||||
assert raised
|
||||
|
||||
|
||||
def test_search_result_optional_fields():
|
||||
res = SearchResult(id="s1", title="T1", snippet="snip", score=0.9)
|
||||
assert res.score == 0.9
|
||||
27
tests/unit/test_anime_service.py
Normal file
27
tests/unit/test_anime_service.py
Normal file
@ -0,0 +1,27 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from src.server.services.anime_service import AnimeService, AnimeServiceError
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_missing_empty(tmp_path):
|
||||
svc = AnimeService(directory=str(tmp_path))
|
||||
# SeriesApp may return empty list depending on filesystem; ensure it returns a list
|
||||
result = await svc.list_missing()
|
||||
assert isinstance(result, list)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_empty_query(tmp_path):
|
||||
svc = AnimeService(directory=str(tmp_path))
|
||||
res = await svc.search("")
|
||||
assert res == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rescan_and_cache_clear(tmp_path):
|
||||
svc = AnimeService(directory=str(tmp_path))
|
||||
# calling rescan should not raise
|
||||
await svc.rescan()
|
||||
55
tests/unit/test_config_models.py
Normal file
55
tests/unit/test_config_models.py
Normal file
@ -0,0 +1,55 @@
|
||||
import pytest
|
||||
|
||||
from src.server.models.config import (
|
||||
AppConfig,
|
||||
ConfigUpdate,
|
||||
LoggingConfig,
|
||||
SchedulerConfig,
|
||||
ValidationResult,
|
||||
)
|
||||
|
||||
|
||||
def test_scheduler_defaults_and_validation():
|
||||
sched = SchedulerConfig()
|
||||
assert sched.enabled is True
|
||||
assert sched.interval_minutes == 60
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
SchedulerConfig(interval_minutes=0)
|
||||
|
||||
|
||||
def test_logging_config_defaults_and_values():
|
||||
log = LoggingConfig()
|
||||
assert log.level == "INFO"
|
||||
assert log.file is None
|
||||
assert log.backup_count == 3
|
||||
|
||||
|
||||
def test_appconfig_and_config_update_apply_to():
|
||||
base = AppConfig()
|
||||
|
||||
upd = ConfigUpdate(
|
||||
scheduler=SchedulerConfig(enabled=False, interval_minutes=30)
|
||||
)
|
||||
new = upd.apply_to(base)
|
||||
assert isinstance(new, AppConfig)
|
||||
assert new.scheduler.enabled is False
|
||||
assert new.scheduler.interval_minutes == 30
|
||||
|
||||
upd2 = ConfigUpdate(other={"b": 2})
|
||||
new2 = upd2.apply_to(base)
|
||||
assert new2.other.get("b") == 2
|
||||
|
||||
|
||||
def test_backup_and_validation():
|
||||
cfg = AppConfig()
|
||||
# default backups disabled -> valid
|
||||
res: ValidationResult = cfg.validate()
|
||||
assert res.valid is True
|
||||
|
||||
# enable backups but leave path empty -> invalid
|
||||
cfg.backup.enabled = True
|
||||
cfg.backup.path = ""
|
||||
res2 = cfg.validate()
|
||||
assert res2.valid is False
|
||||
assert any("backup.path" in e for e in res2.errors)
|
||||
550
tests/unit/test_download_models.py
Normal file
550
tests/unit/test_download_models.py
Normal file
@ -0,0 +1,550 @@
|
||||
"""Unit tests for download queue Pydantic models.
|
||||
|
||||
This module tests all download-related models including validation,
|
||||
serialization, and field constraints.
|
||||
"""
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from src.server.models.download import (
|
||||
DownloadItem,
|
||||
DownloadPriority,
|
||||
DownloadProgress,
|
||||
DownloadRequest,
|
||||
DownloadResponse,
|
||||
DownloadStatus,
|
||||
EpisodeIdentifier,
|
||||
QueueOperationRequest,
|
||||
QueueReorderRequest,
|
||||
QueueStats,
|
||||
QueueStatus,
|
||||
QueueStatusResponse,
|
||||
)
|
||||
|
||||
|
||||
class TestDownloadStatus:
|
||||
"""Test DownloadStatus enum."""
|
||||
|
||||
def test_all_statuses_exist(self):
|
||||
"""Test that all expected statuses are defined."""
|
||||
assert DownloadStatus.PENDING == "pending"
|
||||
assert DownloadStatus.DOWNLOADING == "downloading"
|
||||
assert DownloadStatus.PAUSED == "paused"
|
||||
assert DownloadStatus.COMPLETED == "completed"
|
||||
assert DownloadStatus.FAILED == "failed"
|
||||
assert DownloadStatus.CANCELLED == "cancelled"
|
||||
|
||||
def test_status_values(self):
|
||||
"""Test that status values are lowercase strings."""
|
||||
for status in DownloadStatus:
|
||||
assert isinstance(status.value, str)
|
||||
assert status.value.islower()
|
||||
|
||||
|
||||
class TestDownloadPriority:
|
||||
"""Test DownloadPriority enum."""
|
||||
|
||||
def test_all_priorities_exist(self):
|
||||
"""Test that all expected priorities are defined."""
|
||||
assert DownloadPriority.LOW == "low"
|
||||
assert DownloadPriority.NORMAL == "normal"
|
||||
assert DownloadPriority.HIGH == "high"
|
||||
|
||||
def test_priority_values(self):
|
||||
"""Test that priority values are lowercase strings."""
|
||||
for priority in DownloadPriority:
|
||||
assert isinstance(priority.value, str)
|
||||
assert priority.value.islower()
|
||||
|
||||
|
||||
class TestEpisodeIdentifier:
|
||||
"""Test EpisodeIdentifier model."""
|
||||
|
||||
def test_valid_episode_identifier(self):
|
||||
"""Test creating a valid episode identifier."""
|
||||
episode = EpisodeIdentifier(
|
||||
season=1,
|
||||
episode=5,
|
||||
title="Test Episode"
|
||||
)
|
||||
assert episode.season == 1
|
||||
assert episode.episode == 5
|
||||
assert episode.title == "Test Episode"
|
||||
|
||||
def test_episode_identifier_without_title(self):
|
||||
"""Test creating episode identifier without title."""
|
||||
episode = EpisodeIdentifier(season=2, episode=10)
|
||||
assert episode.season == 2
|
||||
assert episode.episode == 10
|
||||
assert episode.title is None
|
||||
|
||||
def test_invalid_season_number(self):
|
||||
"""Test that season must be positive."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
EpisodeIdentifier(season=0, episode=1)
|
||||
errors = exc_info.value.errors()
|
||||
assert any("season" in str(e["loc"]) for e in errors)
|
||||
|
||||
def test_invalid_episode_number(self):
|
||||
"""Test that episode must be positive."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
EpisodeIdentifier(season=1, episode=0)
|
||||
errors = exc_info.value.errors()
|
||||
assert any("episode" in str(e["loc"]) for e in errors)
|
||||
|
||||
def test_negative_season_rejected(self):
|
||||
"""Test that negative season is rejected."""
|
||||
with pytest.raises(ValidationError):
|
||||
EpisodeIdentifier(season=-1, episode=1)
|
||||
|
||||
def test_negative_episode_rejected(self):
|
||||
"""Test that negative episode is rejected."""
|
||||
with pytest.raises(ValidationError):
|
||||
EpisodeIdentifier(season=1, episode=-1)
|
||||
|
||||
|
||||
class TestDownloadProgress:
|
||||
"""Test DownloadProgress model."""
|
||||
|
||||
def test_valid_progress(self):
|
||||
"""Test creating valid progress information."""
|
||||
progress = DownloadProgress(
|
||||
percent=45.5,
|
||||
downloaded_mb=100.0,
|
||||
total_mb=220.0,
|
||||
speed_mbps=5.5,
|
||||
eta_seconds=120
|
||||
)
|
||||
assert progress.percent == 45.5
|
||||
assert progress.downloaded_mb == 100.0
|
||||
assert progress.total_mb == 220.0
|
||||
assert progress.speed_mbps == 5.5
|
||||
assert progress.eta_seconds == 120
|
||||
|
||||
def test_progress_defaults(self):
|
||||
"""Test default values for progress."""
|
||||
progress = DownloadProgress()
|
||||
assert progress.percent == 0.0
|
||||
assert progress.downloaded_mb == 0.0
|
||||
assert progress.total_mb is None
|
||||
assert progress.speed_mbps is None
|
||||
assert progress.eta_seconds is None
|
||||
|
||||
def test_percent_range_validation(self):
|
||||
"""Test that percent must be between 0 and 100."""
|
||||
# Valid boundary values
|
||||
DownloadProgress(percent=0.0)
|
||||
DownloadProgress(percent=100.0)
|
||||
|
||||
# Invalid values
|
||||
with pytest.raises(ValidationError):
|
||||
DownloadProgress(percent=-0.1)
|
||||
with pytest.raises(ValidationError):
|
||||
DownloadProgress(percent=100.1)
|
||||
|
||||
def test_negative_downloaded_mb_rejected(self):
|
||||
"""Test that negative downloaded_mb is rejected."""
|
||||
with pytest.raises(ValidationError):
|
||||
DownloadProgress(downloaded_mb=-1.0)
|
||||
|
||||
def test_negative_total_mb_rejected(self):
|
||||
"""Test that negative total_mb is rejected."""
|
||||
with pytest.raises(ValidationError):
|
||||
DownloadProgress(total_mb=-1.0)
|
||||
|
||||
def test_negative_speed_rejected(self):
|
||||
"""Test that negative speed is rejected."""
|
||||
with pytest.raises(ValidationError):
|
||||
DownloadProgress(speed_mbps=-1.0)
|
||||
|
||||
def test_negative_eta_rejected(self):
|
||||
"""Test that negative ETA is rejected."""
|
||||
with pytest.raises(ValidationError):
|
||||
DownloadProgress(eta_seconds=-1)
|
||||
|
||||
|
||||
class TestDownloadItem:
|
||||
"""Test DownloadItem model."""
|
||||
|
||||
def test_valid_download_item(self):
|
||||
"""Test creating a valid download item."""
|
||||
episode = EpisodeIdentifier(season=1, episode=5)
|
||||
item = DownloadItem(
|
||||
id="download_123",
|
||||
serie_id="serie_456",
|
||||
serie_name="Test Series",
|
||||
episode=episode,
|
||||
status=DownloadStatus.PENDING,
|
||||
priority=DownloadPriority.HIGH
|
||||
)
|
||||
assert item.id == "download_123"
|
||||
assert item.serie_id == "serie_456"
|
||||
assert item.serie_name == "Test Series"
|
||||
assert item.episode == episode
|
||||
assert item.status == DownloadStatus.PENDING
|
||||
assert item.priority == DownloadPriority.HIGH
|
||||
|
||||
def test_download_item_defaults(self):
|
||||
"""Test default values for download item."""
|
||||
episode = EpisodeIdentifier(season=1, episode=1)
|
||||
item = DownloadItem(
|
||||
id="test_id",
|
||||
serie_id="serie_id",
|
||||
serie_name="Test",
|
||||
episode=episode
|
||||
)
|
||||
assert item.status == DownloadStatus.PENDING
|
||||
assert item.priority == DownloadPriority.NORMAL
|
||||
assert item.started_at is None
|
||||
assert item.completed_at is None
|
||||
assert item.progress is None
|
||||
assert item.error is None
|
||||
assert item.retry_count == 0
|
||||
assert item.source_url is None
|
||||
|
||||
def test_download_item_with_progress(self):
|
||||
"""Test download item with progress information."""
|
||||
episode = EpisodeIdentifier(season=1, episode=1)
|
||||
progress = DownloadProgress(percent=50.0, downloaded_mb=100.0)
|
||||
item = DownloadItem(
|
||||
id="test_id",
|
||||
serie_id="serie_id",
|
||||
serie_name="Test",
|
||||
episode=episode,
|
||||
progress=progress
|
||||
)
|
||||
assert item.progress is not None
|
||||
assert item.progress.percent == 50.0
|
||||
|
||||
def test_download_item_with_timestamps(self):
|
||||
"""Test download item with timestamp fields."""
|
||||
episode = EpisodeIdentifier(season=1, episode=1)
|
||||
now = datetime.utcnow()
|
||||
item = DownloadItem(
|
||||
id="test_id",
|
||||
serie_id="serie_id",
|
||||
serie_name="Test",
|
||||
episode=episode,
|
||||
started_at=now,
|
||||
completed_at=now + timedelta(minutes=5)
|
||||
)
|
||||
assert item.started_at == now
|
||||
assert item.completed_at == now + timedelta(minutes=5)
|
||||
|
||||
def test_empty_serie_name_rejected(self):
|
||||
"""Test that empty serie name is rejected."""
|
||||
episode = EpisodeIdentifier(season=1, episode=1)
|
||||
with pytest.raises(ValidationError):
|
||||
DownloadItem(
|
||||
id="test_id",
|
||||
serie_id="serie_id",
|
||||
serie_name="",
|
||||
episode=episode
|
||||
)
|
||||
|
||||
def test_negative_retry_count_rejected(self):
|
||||
"""Test that negative retry count is rejected."""
|
||||
episode = EpisodeIdentifier(season=1, episode=1)
|
||||
with pytest.raises(ValidationError):
|
||||
DownloadItem(
|
||||
id="test_id",
|
||||
serie_id="serie_id",
|
||||
serie_name="Test",
|
||||
episode=episode,
|
||||
retry_count=-1
|
||||
)
|
||||
|
||||
def test_added_at_auto_generated(self):
|
||||
"""Test that added_at is automatically set."""
|
||||
episode = EpisodeIdentifier(season=1, episode=1)
|
||||
before = datetime.utcnow()
|
||||
item = DownloadItem(
|
||||
id="test_id",
|
||||
serie_id="serie_id",
|
||||
serie_name="Test",
|
||||
episode=episode
|
||||
)
|
||||
after = datetime.utcnow()
|
||||
assert before <= item.added_at <= after
|
||||
|
||||
|
||||
class TestQueueStatus:
|
||||
"""Test QueueStatus model."""
|
||||
|
||||
def test_valid_queue_status(self):
|
||||
"""Test creating valid queue status."""
|
||||
episode = EpisodeIdentifier(season=1, episode=1)
|
||||
item = DownloadItem(
|
||||
id="test_id",
|
||||
serie_id="serie_id",
|
||||
serie_name="Test",
|
||||
episode=episode
|
||||
)
|
||||
status = QueueStatus(
|
||||
is_running=True,
|
||||
is_paused=False,
|
||||
active_downloads=[item],
|
||||
pending_queue=[item],
|
||||
completed_downloads=[],
|
||||
failed_downloads=[]
|
||||
)
|
||||
assert status.is_running is True
|
||||
assert status.is_paused is False
|
||||
assert len(status.active_downloads) == 1
|
||||
assert len(status.pending_queue) == 1
|
||||
|
||||
def test_queue_status_defaults(self):
|
||||
"""Test default values for queue status."""
|
||||
status = QueueStatus()
|
||||
assert status.is_running is False
|
||||
assert status.is_paused is False
|
||||
assert status.active_downloads == []
|
||||
assert status.pending_queue == []
|
||||
assert status.completed_downloads == []
|
||||
assert status.failed_downloads == []
|
||||
|
||||
|
||||
class TestQueueStats:
|
||||
"""Test QueueStats model."""
|
||||
|
||||
def test_valid_queue_stats(self):
|
||||
"""Test creating valid queue statistics."""
|
||||
stats = QueueStats(
|
||||
total_items=10,
|
||||
pending_count=3,
|
||||
active_count=2,
|
||||
completed_count=4,
|
||||
failed_count=1,
|
||||
total_downloaded_mb=500.5,
|
||||
average_speed_mbps=5.0,
|
||||
estimated_time_remaining=120
|
||||
)
|
||||
assert stats.total_items == 10
|
||||
assert stats.pending_count == 3
|
||||
assert stats.active_count == 2
|
||||
assert stats.completed_count == 4
|
||||
assert stats.failed_count == 1
|
||||
assert stats.total_downloaded_mb == 500.5
|
||||
assert stats.average_speed_mbps == 5.0
|
||||
assert stats.estimated_time_remaining == 120
|
||||
|
||||
def test_queue_stats_defaults(self):
|
||||
"""Test default values for queue stats."""
|
||||
stats = QueueStats()
|
||||
assert stats.total_items == 0
|
||||
assert stats.pending_count == 0
|
||||
assert stats.active_count == 0
|
||||
assert stats.completed_count == 0
|
||||
assert stats.failed_count == 0
|
||||
assert stats.total_downloaded_mb == 0.0
|
||||
assert stats.average_speed_mbps is None
|
||||
assert stats.estimated_time_remaining is None
|
||||
|
||||
def test_negative_counts_rejected(self):
|
||||
"""Test that negative counts are rejected."""
|
||||
with pytest.raises(ValidationError):
|
||||
QueueStats(total_items=-1)
|
||||
with pytest.raises(ValidationError):
|
||||
QueueStats(pending_count=-1)
|
||||
with pytest.raises(ValidationError):
|
||||
QueueStats(active_count=-1)
|
||||
with pytest.raises(ValidationError):
|
||||
QueueStats(completed_count=-1)
|
||||
with pytest.raises(ValidationError):
|
||||
QueueStats(failed_count=-1)
|
||||
|
||||
def test_negative_speed_rejected(self):
|
||||
"""Test that negative speed is rejected."""
|
||||
with pytest.raises(ValidationError):
|
||||
QueueStats(average_speed_mbps=-1.0)
|
||||
|
||||
def test_negative_eta_rejected(self):
|
||||
"""Test that negative ETA is rejected."""
|
||||
with pytest.raises(ValidationError):
|
||||
QueueStats(estimated_time_remaining=-1)
|
||||
|
||||
|
||||
class TestDownloadRequest:
|
||||
"""Test DownloadRequest model."""
|
||||
|
||||
def test_valid_download_request(self):
|
||||
"""Test creating a valid download request."""
|
||||
episode1 = EpisodeIdentifier(season=1, episode=1)
|
||||
episode2 = EpisodeIdentifier(season=1, episode=2)
|
||||
request = DownloadRequest(
|
||||
serie_id="serie_123",
|
||||
serie_name="Test Series",
|
||||
episodes=[episode1, episode2],
|
||||
priority=DownloadPriority.HIGH
|
||||
)
|
||||
assert request.serie_id == "serie_123"
|
||||
assert request.serie_name == "Test Series"
|
||||
assert len(request.episodes) == 2
|
||||
assert request.priority == DownloadPriority.HIGH
|
||||
|
||||
def test_download_request_default_priority(self):
|
||||
"""Test default priority for download request."""
|
||||
episode = EpisodeIdentifier(season=1, episode=1)
|
||||
request = DownloadRequest(
|
||||
serie_id="serie_123",
|
||||
serie_name="Test Series",
|
||||
episodes=[episode]
|
||||
)
|
||||
assert request.priority == DownloadPriority.NORMAL
|
||||
|
||||
def test_empty_episodes_list_rejected(self):
|
||||
"""Test that empty episodes list is rejected."""
|
||||
with pytest.raises(ValidationError):
|
||||
DownloadRequest(
|
||||
serie_id="serie_123",
|
||||
serie_name="Test Series",
|
||||
episodes=[]
|
||||
)
|
||||
|
||||
def test_empty_serie_name_rejected(self):
|
||||
"""Test that empty serie name is rejected."""
|
||||
episode = EpisodeIdentifier(season=1, episode=1)
|
||||
with pytest.raises(ValidationError):
|
||||
DownloadRequest(
|
||||
serie_id="serie_123",
|
||||
serie_name="",
|
||||
episodes=[episode]
|
||||
)
|
||||
|
||||
|
||||
class TestDownloadResponse:
|
||||
"""Test DownloadResponse model."""
|
||||
|
||||
def test_valid_download_response(self):
|
||||
"""Test creating a valid download response."""
|
||||
response = DownloadResponse(
|
||||
status="success",
|
||||
message="Added 2 episodes to queue",
|
||||
added_items=["item1", "item2"],
|
||||
failed_items=[]
|
||||
)
|
||||
assert response.status == "success"
|
||||
assert response.message == "Added 2 episodes to queue"
|
||||
assert len(response.added_items) == 2
|
||||
assert response.failed_items == []
|
||||
|
||||
def test_download_response_defaults(self):
|
||||
"""Test default values for download response."""
|
||||
response = DownloadResponse(
|
||||
status="success",
|
||||
message="Test message"
|
||||
)
|
||||
assert response.added_items == []
|
||||
assert response.failed_items == []
|
||||
|
||||
|
||||
class TestQueueOperationRequest:
|
||||
"""Test QueueOperationRequest model."""
|
||||
|
||||
def test_valid_operation_request(self):
|
||||
"""Test creating a valid operation request."""
|
||||
request = QueueOperationRequest(
|
||||
item_ids=["item1", "item2", "item3"]
|
||||
)
|
||||
assert len(request.item_ids) == 3
|
||||
assert "item1" in request.item_ids
|
||||
|
||||
def test_empty_item_ids_rejected(self):
|
||||
"""Test that empty item_ids list is rejected."""
|
||||
with pytest.raises(ValidationError):
|
||||
QueueOperationRequest(item_ids=[])
|
||||
|
||||
|
||||
class TestQueueReorderRequest:
|
||||
"""Test QueueReorderRequest model."""
|
||||
|
||||
def test_valid_reorder_request(self):
|
||||
"""Test creating a valid reorder request."""
|
||||
request = QueueReorderRequest(
|
||||
item_id="item_123",
|
||||
new_position=5
|
||||
)
|
||||
assert request.item_id == "item_123"
|
||||
assert request.new_position == 5
|
||||
|
||||
def test_zero_position_allowed(self):
|
||||
"""Test that position zero is allowed."""
|
||||
request = QueueReorderRequest(
|
||||
item_id="item_123",
|
||||
new_position=0
|
||||
)
|
||||
assert request.new_position == 0
|
||||
|
||||
def test_negative_position_rejected(self):
|
||||
"""Test that negative position is rejected."""
|
||||
with pytest.raises(ValidationError):
|
||||
QueueReorderRequest(
|
||||
item_id="item_123",
|
||||
new_position=-1
|
||||
)
|
||||
|
||||
|
||||
class TestQueueStatusResponse:
|
||||
"""Test QueueStatusResponse model."""
|
||||
|
||||
def test_valid_status_response(self):
|
||||
"""Test creating a valid status response."""
|
||||
status = QueueStatus()
|
||||
stats = QueueStats()
|
||||
response = QueueStatusResponse(
|
||||
status=status,
|
||||
statistics=stats
|
||||
)
|
||||
assert response.status is not None
|
||||
assert response.statistics is not None
|
||||
|
||||
|
||||
class TestModelSerialization:
|
||||
"""Test model serialization and deserialization."""
|
||||
|
||||
def test_download_item_to_dict(self):
|
||||
"""Test serializing download item to dict."""
|
||||
episode = EpisodeIdentifier(season=1, episode=5, title="Test")
|
||||
item = DownloadItem(
|
||||
id="test_id",
|
||||
serie_id="serie_id",
|
||||
serie_name="Test Series",
|
||||
episode=episode
|
||||
)
|
||||
data = item.model_dump()
|
||||
assert data["id"] == "test_id"
|
||||
assert data["serie_name"] == "Test Series"
|
||||
assert data["episode"]["season"] == 1
|
||||
assert data["episode"]["episode"] == 5
|
||||
|
||||
def test_download_item_from_dict(self):
|
||||
"""Test deserializing download item from dict."""
|
||||
data = {
|
||||
"id": "test_id",
|
||||
"serie_id": "serie_id",
|
||||
"serie_name": "Test Series",
|
||||
"episode": {
|
||||
"season": 1,
|
||||
"episode": 5,
|
||||
"title": "Test Episode"
|
||||
}
|
||||
}
|
||||
item = DownloadItem(**data)
|
||||
assert item.id == "test_id"
|
||||
assert item.serie_name == "Test Series"
|
||||
assert item.episode.season == 1
|
||||
|
||||
def test_queue_status_to_json(self):
|
||||
"""Test serializing queue status to JSON."""
|
||||
status = QueueStatus(is_running=True)
|
||||
json_str = status.model_dump_json()
|
||||
assert '"is_running":true' in json_str.lower()
|
||||
|
||||
def test_queue_stats_from_json(self):
|
||||
"""Test deserializing queue stats from JSON."""
|
||||
json_str = '{"total_items": 5, "pending_count": 3}'
|
||||
stats = QueueStats.model_validate_json(json_str)
|
||||
assert stats.total_items == 5
|
||||
assert stats.pending_count == 3
|
||||
491
tests/unit/test_download_service.py
Normal file
491
tests/unit/test_download_service.py
Normal file
@ -0,0 +1,491 @@
|
||||
"""Unit tests for the download queue service.
|
||||
|
||||
Tests cover queue management, priority handling, persistence,
|
||||
concurrent downloads, and error scenarios.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from src.server.models.download import (
|
||||
DownloadItem,
|
||||
DownloadPriority,
|
||||
DownloadStatus,
|
||||
EpisodeIdentifier,
|
||||
)
|
||||
from src.server.services.anime_service import AnimeService
|
||||
from src.server.services.download_service import DownloadService, DownloadServiceError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_anime_service():
|
||||
"""Create a mock AnimeService."""
|
||||
service = MagicMock(spec=AnimeService)
|
||||
service.download = AsyncMock(return_value=True)
|
||||
return service
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_persistence_path(tmp_path):
|
||||
"""Create a temporary persistence path."""
|
||||
return str(tmp_path / "test_queue.json")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def download_service(mock_anime_service, temp_persistence_path):
|
||||
"""Create a DownloadService instance for testing."""
|
||||
return DownloadService(
|
||||
anime_service=mock_anime_service,
|
||||
max_concurrent_downloads=2,
|
||||
max_retries=3,
|
||||
persistence_path=temp_persistence_path,
|
||||
)
|
||||
|
||||
|
||||
class TestDownloadServiceInitialization:
|
||||
"""Test download service initialization."""
|
||||
|
||||
def test_initialization_creates_queues(
|
||||
self, mock_anime_service, temp_persistence_path
|
||||
):
|
||||
"""Test that initialization creates empty queues."""
|
||||
service = DownloadService(
|
||||
anime_service=mock_anime_service,
|
||||
persistence_path=temp_persistence_path,
|
||||
)
|
||||
|
||||
assert len(service._pending_queue) == 0
|
||||
assert len(service._active_downloads) == 0
|
||||
assert len(service._completed_items) == 0
|
||||
assert len(service._failed_items) == 0
|
||||
assert service._is_running is False
|
||||
assert service._is_paused is False
|
||||
|
||||
def test_initialization_loads_persisted_queue(
|
||||
self, mock_anime_service, temp_persistence_path
|
||||
):
|
||||
"""Test that initialization loads persisted queue state."""
|
||||
# Create a persisted queue file
|
||||
persistence_file = Path(temp_persistence_path)
|
||||
persistence_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
test_data = {
|
||||
"pending": [
|
||||
{
|
||||
"id": "test-id-1",
|
||||
"serie_id": "series-1",
|
||||
"serie_name": "Test Series",
|
||||
"episode": {"season": 1, "episode": 1, "title": None},
|
||||
"status": "pending",
|
||||
"priority": "normal",
|
||||
"added_at": datetime.utcnow().isoformat(),
|
||||
"started_at": None,
|
||||
"completed_at": None,
|
||||
"progress": None,
|
||||
"error": None,
|
||||
"retry_count": 0,
|
||||
"source_url": None,
|
||||
}
|
||||
],
|
||||
"active": [],
|
||||
"failed": [],
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
}
|
||||
|
||||
with open(persistence_file, "w", encoding="utf-8") as f:
|
||||
json.dump(test_data, f)
|
||||
|
||||
service = DownloadService(
|
||||
anime_service=mock_anime_service,
|
||||
persistence_path=temp_persistence_path,
|
||||
)
|
||||
|
||||
assert len(service._pending_queue) == 1
|
||||
assert service._pending_queue[0].id == "test-id-1"
|
||||
|
||||
|
||||
class TestQueueManagement:
|
||||
"""Test queue management operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_to_queue_single_episode(self, download_service):
|
||||
"""Test adding a single episode to queue."""
|
||||
episodes = [EpisodeIdentifier(season=1, episode=1)]
|
||||
|
||||
item_ids = await download_service.add_to_queue(
|
||||
serie_id="series-1",
|
||||
serie_name="Test Series",
|
||||
episodes=episodes,
|
||||
priority=DownloadPriority.NORMAL,
|
||||
)
|
||||
|
||||
assert len(item_ids) == 1
|
||||
assert len(download_service._pending_queue) == 1
|
||||
assert download_service._pending_queue[0].serie_id == "series-1"
|
||||
assert (
|
||||
download_service._pending_queue[0].status
|
||||
== DownloadStatus.PENDING
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_to_queue_multiple_episodes(self, download_service):
|
||||
"""Test adding multiple episodes to queue."""
|
||||
episodes = [
|
||||
EpisodeIdentifier(season=1, episode=1),
|
||||
EpisodeIdentifier(season=1, episode=2),
|
||||
EpisodeIdentifier(season=1, episode=3),
|
||||
]
|
||||
|
||||
item_ids = await download_service.add_to_queue(
|
||||
serie_id="series-1",
|
||||
serie_name="Test Series",
|
||||
episodes=episodes,
|
||||
priority=DownloadPriority.NORMAL,
|
||||
)
|
||||
|
||||
assert len(item_ids) == 3
|
||||
assert len(download_service._pending_queue) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_high_priority_to_front(self, download_service):
|
||||
"""Test that high priority items are added to front of queue."""
|
||||
# Add normal priority item
|
||||
await download_service.add_to_queue(
|
||||
serie_id="series-1",
|
||||
serie_name="Test Series",
|
||||
episodes=[EpisodeIdentifier(season=1, episode=1)],
|
||||
priority=DownloadPriority.NORMAL,
|
||||
)
|
||||
|
||||
# Add high priority item
|
||||
await download_service.add_to_queue(
|
||||
serie_id="series-2",
|
||||
serie_name="Priority Series",
|
||||
episodes=[EpisodeIdentifier(season=1, episode=1)],
|
||||
priority=DownloadPriority.HIGH,
|
||||
)
|
||||
|
||||
# High priority should be at front
|
||||
assert download_service._pending_queue[0].serie_id == "series-2"
|
||||
assert download_service._pending_queue[1].serie_id == "series-1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_from_pending_queue(self, download_service):
|
||||
"""Test removing items from pending queue."""
|
||||
item_ids = await download_service.add_to_queue(
|
||||
serie_id="series-1",
|
||||
serie_name="Test Series",
|
||||
episodes=[EpisodeIdentifier(season=1, episode=1)],
|
||||
)
|
||||
|
||||
removed_ids = await download_service.remove_from_queue(item_ids)
|
||||
|
||||
assert len(removed_ids) == 1
|
||||
assert removed_ids[0] == item_ids[0]
|
||||
assert len(download_service._pending_queue) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reorder_queue(self, download_service):
|
||||
"""Test reordering items in queue."""
|
||||
# Add three items
|
||||
await download_service.add_to_queue(
|
||||
serie_id="series-1",
|
||||
serie_name="Series 1",
|
||||
episodes=[EpisodeIdentifier(season=1, episode=1)],
|
||||
)
|
||||
await download_service.add_to_queue(
|
||||
serie_id="series-2",
|
||||
serie_name="Series 2",
|
||||
episodes=[EpisodeIdentifier(season=1, episode=1)],
|
||||
)
|
||||
await download_service.add_to_queue(
|
||||
serie_id="series-3",
|
||||
serie_name="Series 3",
|
||||
episodes=[EpisodeIdentifier(season=1, episode=1)],
|
||||
)
|
||||
|
||||
# Move last item to position 0
|
||||
item_to_move = download_service._pending_queue[2].id
|
||||
success = await download_service.reorder_queue(item_to_move, 0)
|
||||
|
||||
assert success is True
|
||||
assert download_service._pending_queue[0].id == item_to_move
|
||||
assert download_service._pending_queue[0].serie_id == "series-3"
|
||||
|
||||
|
||||
class TestQueueStatus:
|
||||
"""Test queue status reporting."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_queue_status(self, download_service):
|
||||
"""Test getting queue status."""
|
||||
# Add items to queue
|
||||
await download_service.add_to_queue(
|
||||
serie_id="series-1",
|
||||
serie_name="Test Series",
|
||||
episodes=[
|
||||
EpisodeIdentifier(season=1, episode=1),
|
||||
EpisodeIdentifier(season=1, episode=2),
|
||||
],
|
||||
)
|
||||
|
||||
status = await download_service.get_queue_status()
|
||||
|
||||
assert status.is_running is False
|
||||
assert status.is_paused is False
|
||||
assert len(status.pending_queue) == 2
|
||||
assert len(status.active_downloads) == 0
|
||||
assert len(status.completed_downloads) == 0
|
||||
assert len(status.failed_downloads) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_queue_stats(self, download_service):
|
||||
"""Test getting queue statistics."""
|
||||
# Add items
|
||||
await download_service.add_to_queue(
|
||||
serie_id="series-1",
|
||||
serie_name="Test Series",
|
||||
episodes=[
|
||||
EpisodeIdentifier(season=1, episode=1),
|
||||
EpisodeIdentifier(season=1, episode=2),
|
||||
],
|
||||
)
|
||||
|
||||
stats = await download_service.get_queue_stats()
|
||||
|
||||
assert stats.total_items == 2
|
||||
assert stats.pending_count == 2
|
||||
assert stats.active_count == 0
|
||||
assert stats.completed_count == 0
|
||||
assert stats.failed_count == 0
|
||||
assert stats.total_downloaded_mb == 0.0
|
||||
|
||||
|
||||
class TestQueueControl:
|
||||
"""Test queue control operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_queue(self, download_service):
|
||||
"""Test pausing the queue."""
|
||||
await download_service.pause_queue()
|
||||
assert download_service._is_paused is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_queue(self, download_service):
|
||||
"""Test resuming the queue."""
|
||||
await download_service.pause_queue()
|
||||
await download_service.resume_queue()
|
||||
assert download_service._is_paused is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_completed(self, download_service):
|
||||
"""Test clearing completed downloads."""
|
||||
# Manually add completed item
|
||||
completed_item = DownloadItem(
|
||||
id="completed-1",
|
||||
serie_id="series-1",
|
||||
serie_name="Test Series",
|
||||
episode=EpisodeIdentifier(season=1, episode=1),
|
||||
status=DownloadStatus.COMPLETED,
|
||||
)
|
||||
download_service._completed_items.append(completed_item)
|
||||
|
||||
count = await download_service.clear_completed()
|
||||
|
||||
assert count == 1
|
||||
assert len(download_service._completed_items) == 0
|
||||
|
||||
|
||||
class TestPersistence:
|
||||
"""Test queue persistence functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_queue_persistence(self, download_service):
|
||||
"""Test that queue state is persisted to disk."""
|
||||
await download_service.add_to_queue(
|
||||
serie_id="series-1",
|
||||
serie_name="Test Series",
|
||||
episodes=[EpisodeIdentifier(season=1, episode=1)],
|
||||
)
|
||||
|
||||
# Persistence file should exist
|
||||
persistence_path = Path(download_service._persistence_path)
|
||||
assert persistence_path.exists()
|
||||
|
||||
# Check file contents
|
||||
with open(persistence_path, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
assert len(data["pending"]) == 1
|
||||
assert data["pending"][0]["serie_id"] == "series-1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_queue_recovery_after_restart(
|
||||
self, mock_anime_service, temp_persistence_path
|
||||
):
|
||||
"""Test that queue is recovered after service restart."""
|
||||
# Create and populate first service
|
||||
service1 = DownloadService(
|
||||
anime_service=mock_anime_service,
|
||||
persistence_path=temp_persistence_path,
|
||||
)
|
||||
|
||||
await service1.add_to_queue(
|
||||
serie_id="series-1",
|
||||
serie_name="Test Series",
|
||||
episodes=[
|
||||
EpisodeIdentifier(season=1, episode=1),
|
||||
EpisodeIdentifier(season=1, episode=2),
|
||||
],
|
||||
)
|
||||
|
||||
# Create new service with same persistence path
|
||||
service2 = DownloadService(
|
||||
anime_service=mock_anime_service,
|
||||
persistence_path=temp_persistence_path,
|
||||
)
|
||||
|
||||
# Should recover pending items
|
||||
assert len(service2._pending_queue) == 2
|
||||
|
||||
|
||||
class TestRetryLogic:
|
||||
"""Test retry logic for failed downloads."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_failed_items(self, download_service):
|
||||
"""Test retrying failed downloads."""
|
||||
# Manually add failed item
|
||||
failed_item = DownloadItem(
|
||||
id="failed-1",
|
||||
serie_id="series-1",
|
||||
serie_name="Test Series",
|
||||
episode=EpisodeIdentifier(season=1, episode=1),
|
||||
status=DownloadStatus.FAILED,
|
||||
retry_count=0,
|
||||
error="Test error",
|
||||
)
|
||||
download_service._failed_items.append(failed_item)
|
||||
|
||||
retried_ids = await download_service.retry_failed()
|
||||
|
||||
assert len(retried_ids) == 1
|
||||
assert len(download_service._failed_items) == 0
|
||||
assert len(download_service._pending_queue) == 1
|
||||
assert download_service._pending_queue[0].retry_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_retries_not_exceeded(self, download_service):
|
||||
"""Test that items with max retries are not retried."""
|
||||
# Create item with max retries
|
||||
failed_item = DownloadItem(
|
||||
id="failed-1",
|
||||
serie_id="series-1",
|
||||
serie_name="Test Series",
|
||||
episode=EpisodeIdentifier(season=1, episode=1),
|
||||
status=DownloadStatus.FAILED,
|
||||
retry_count=3, # Max retries
|
||||
error="Test error",
|
||||
)
|
||||
download_service._failed_items.append(failed_item)
|
||||
|
||||
retried_ids = await download_service.retry_failed()
|
||||
|
||||
assert len(retried_ids) == 0
|
||||
assert len(download_service._failed_items) == 1
|
||||
assert len(download_service._pending_queue) == 0
|
||||
|
||||
|
||||
class TestBroadcastCallbacks:
|
||||
"""Test WebSocket broadcast functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_broadcast_callback(self, download_service):
|
||||
"""Test setting broadcast callback."""
|
||||
mock_callback = AsyncMock()
|
||||
download_service.set_broadcast_callback(mock_callback)
|
||||
|
||||
assert download_service._broadcast_callback == mock_callback
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_on_queue_update(self, download_service):
|
||||
"""Test that broadcasts are sent on queue updates."""
|
||||
mock_callback = AsyncMock()
|
||||
download_service.set_broadcast_callback(mock_callback)
|
||||
|
||||
await download_service.add_to_queue(
|
||||
serie_id="series-1",
|
||||
serie_name="Test Series",
|
||||
episodes=[EpisodeIdentifier(season=1, episode=1)],
|
||||
)
|
||||
|
||||
# Allow async callback to execute
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Verify callback was called
|
||||
mock_callback.assert_called()
|
||||
|
||||
|
||||
class TestServiceLifecycle:
|
||||
"""Test service start and stop operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_service(self, download_service):
|
||||
"""Test starting the service."""
|
||||
await download_service.start()
|
||||
assert download_service._is_running is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_service(self, download_service):
|
||||
"""Test stopping the service."""
|
||||
await download_service.start()
|
||||
await download_service.stop()
|
||||
assert download_service._is_running is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_already_running(self, download_service):
|
||||
"""Test starting service when already running."""
|
||||
await download_service.start()
|
||||
await download_service.start() # Should not raise error
|
||||
assert download_service._is_running is True
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
"""Test error handling in download service."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reorder_nonexistent_item(self, download_service):
|
||||
"""Test reordering non-existent item raises error."""
|
||||
with pytest.raises(DownloadServiceError):
|
||||
await download_service.reorder_queue("nonexistent-id", 0)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_failure_moves_to_failed(self, download_service):
|
||||
"""Test that download failures are handled correctly."""
|
||||
# Mock download to fail
|
||||
download_service._anime_service.download = AsyncMock(
|
||||
side_effect=Exception("Download failed")
|
||||
)
|
||||
|
||||
await download_service.add_to_queue(
|
||||
serie_id="series-1",
|
||||
serie_name="Test Series",
|
||||
episodes=[EpisodeIdentifier(season=1, episode=1)],
|
||||
)
|
||||
|
||||
# Process the download
|
||||
item = download_service._pending_queue.popleft()
|
||||
await download_service._process_download(item)
|
||||
|
||||
# Item should be in failed queue
|
||||
assert len(download_service._failed_items) == 1
|
||||
assert (
|
||||
download_service._failed_items[0].status == DownloadStatus.FAILED
|
||||
)
|
||||
assert download_service._failed_items[0].error is not None
|
||||
499
tests/unit/test_progress_service.py
Normal file
499
tests/unit/test_progress_service.py
Normal file
@ -0,0 +1,499 @@
|
||||
"""Unit tests for ProgressService.
|
||||
|
||||
This module contains comprehensive tests for the progress tracking service,
|
||||
including progress lifecycle, broadcasting, error handling, and concurrency.
|
||||
"""
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from src.server.services.progress_service import (
|
||||
ProgressService,
|
||||
ProgressServiceError,
|
||||
ProgressStatus,
|
||||
ProgressType,
|
||||
ProgressUpdate,
|
||||
)
|
||||
|
||||
|
||||
class TestProgressUpdate:
|
||||
"""Test ProgressUpdate dataclass."""
|
||||
|
||||
def test_progress_update_creation(self):
|
||||
"""Test creating a progress update."""
|
||||
update = ProgressUpdate(
|
||||
id="test-1",
|
||||
type=ProgressType.DOWNLOAD,
|
||||
status=ProgressStatus.STARTED,
|
||||
title="Test Download",
|
||||
message="Starting download",
|
||||
total=100,
|
||||
)
|
||||
|
||||
assert update.id == "test-1"
|
||||
assert update.type == ProgressType.DOWNLOAD
|
||||
assert update.status == ProgressStatus.STARTED
|
||||
assert update.title == "Test Download"
|
||||
assert update.message == "Starting download"
|
||||
assert update.total == 100
|
||||
assert update.current == 0
|
||||
assert update.percent == 0.0
|
||||
|
||||
def test_progress_update_to_dict(self):
|
||||
"""Test converting progress update to dictionary."""
|
||||
update = ProgressUpdate(
|
||||
id="test-1",
|
||||
type=ProgressType.SCAN,
|
||||
status=ProgressStatus.IN_PROGRESS,
|
||||
title="Test Scan",
|
||||
message="Scanning files",
|
||||
current=50,
|
||||
total=100,
|
||||
metadata={"test_key": "test_value"},
|
||||
)
|
||||
|
||||
result = update.to_dict()
|
||||
|
||||
assert result["id"] == "test-1"
|
||||
assert result["type"] == "scan"
|
||||
assert result["status"] == "in_progress"
|
||||
assert result["title"] == "Test Scan"
|
||||
assert result["message"] == "Scanning files"
|
||||
assert result["current"] == 50
|
||||
assert result["total"] == 100
|
||||
assert result["percent"] == 0.0
|
||||
assert result["metadata"]["test_key"] == "test_value"
|
||||
assert "started_at" in result
|
||||
assert "updated_at" in result
|
||||
|
||||
|
||||
class TestProgressService:
|
||||
"""Test ProgressService class."""
|
||||
|
||||
@pytest.fixture
|
||||
def service(self):
|
||||
"""Create a fresh ProgressService instance for each test."""
|
||||
return ProgressService()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_broadcast(self):
|
||||
"""Create a mock broadcast callback."""
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_progress(self, service):
|
||||
"""Test starting a new progress operation."""
|
||||
update = await service.start_progress(
|
||||
progress_id="download-1",
|
||||
progress_type=ProgressType.DOWNLOAD,
|
||||
title="Downloading episode",
|
||||
total=1000,
|
||||
message="Starting...",
|
||||
metadata={"episode": "S01E01"},
|
||||
)
|
||||
|
||||
assert update.id == "download-1"
|
||||
assert update.type == ProgressType.DOWNLOAD
|
||||
assert update.status == ProgressStatus.STARTED
|
||||
assert update.title == "Downloading episode"
|
||||
assert update.total == 1000
|
||||
assert update.message == "Starting..."
|
||||
assert update.metadata["episode"] == "S01E01"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_progress_duplicate_id(self, service):
|
||||
"""Test starting progress with duplicate ID raises error."""
|
||||
await service.start_progress(
|
||||
progress_id="test-1",
|
||||
progress_type=ProgressType.DOWNLOAD,
|
||||
title="Test",
|
||||
)
|
||||
|
||||
with pytest.raises(ProgressServiceError, match="already exists"):
|
||||
await service.start_progress(
|
||||
progress_id="test-1",
|
||||
progress_type=ProgressType.DOWNLOAD,
|
||||
title="Test Duplicate",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_progress(self, service):
|
||||
"""Test updating an existing progress operation."""
|
||||
await service.start_progress(
|
||||
progress_id="test-1",
|
||||
progress_type=ProgressType.DOWNLOAD,
|
||||
title="Test",
|
||||
total=100,
|
||||
)
|
||||
|
||||
update = await service.update_progress(
|
||||
progress_id="test-1",
|
||||
current=50,
|
||||
message="Half way",
|
||||
)
|
||||
|
||||
assert update.current == 50
|
||||
assert update.total == 100
|
||||
assert update.percent == 50.0
|
||||
assert update.message == "Half way"
|
||||
assert update.status == ProgressStatus.IN_PROGRESS
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_progress_not_found(self, service):
|
||||
"""Test updating non-existent progress raises error."""
|
||||
with pytest.raises(ProgressServiceError, match="not found"):
|
||||
await service.update_progress(
|
||||
progress_id="nonexistent",
|
||||
current=50,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_progress_percentage_calculation(self, service):
|
||||
"""Test progress percentage is calculated correctly."""
|
||||
await service.start_progress(
|
||||
progress_id="test-1",
|
||||
progress_type=ProgressType.DOWNLOAD,
|
||||
title="Test",
|
||||
total=200,
|
||||
)
|
||||
|
||||
await service.update_progress(progress_id="test-1", current=50)
|
||||
update = await service.get_progress("test-1")
|
||||
assert update.percent == 25.0
|
||||
|
||||
await service.update_progress(progress_id="test-1", current=100)
|
||||
update = await service.get_progress("test-1")
|
||||
assert update.percent == 50.0
|
||||
|
||||
await service.update_progress(progress_id="test-1", current=200)
|
||||
update = await service.get_progress("test-1")
|
||||
assert update.percent == 100.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_progress(self, service):
|
||||
"""Test completing a progress operation."""
|
||||
await service.start_progress(
|
||||
progress_id="test-1",
|
||||
progress_type=ProgressType.SCAN,
|
||||
title="Test Scan",
|
||||
total=100,
|
||||
)
|
||||
|
||||
await service.update_progress(progress_id="test-1", current=50)
|
||||
|
||||
update = await service.complete_progress(
|
||||
progress_id="test-1",
|
||||
message="Scan completed successfully",
|
||||
metadata={"items_found": 42},
|
||||
)
|
||||
|
||||
assert update.status == ProgressStatus.COMPLETED
|
||||
assert update.percent == 100.0
|
||||
assert update.current == update.total
|
||||
assert update.message == "Scan completed successfully"
|
||||
assert update.metadata["items_found"] == 42
|
||||
|
||||
# Should be moved to history
|
||||
active_progress = await service.get_all_active_progress()
|
||||
assert "test-1" not in active_progress
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fail_progress(self, service):
|
||||
"""Test failing a progress operation."""
|
||||
await service.start_progress(
|
||||
progress_id="test-1",
|
||||
progress_type=ProgressType.DOWNLOAD,
|
||||
title="Test Download",
|
||||
)
|
||||
|
||||
update = await service.fail_progress(
|
||||
progress_id="test-1",
|
||||
error_message="Network timeout",
|
||||
metadata={"retry_count": 3},
|
||||
)
|
||||
|
||||
assert update.status == ProgressStatus.FAILED
|
||||
assert update.message == "Network timeout"
|
||||
assert update.metadata["retry_count"] == 3
|
||||
|
||||
# Should be moved to history
|
||||
active_progress = await service.get_all_active_progress()
|
||||
assert "test-1" not in active_progress
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_progress(self, service):
|
||||
"""Test cancelling a progress operation."""
|
||||
await service.start_progress(
|
||||
progress_id="test-1",
|
||||
progress_type=ProgressType.DOWNLOAD,
|
||||
title="Test Download",
|
||||
)
|
||||
|
||||
update = await service.cancel_progress(
|
||||
progress_id="test-1",
|
||||
message="Cancelled by user",
|
||||
)
|
||||
|
||||
assert update.status == ProgressStatus.CANCELLED
|
||||
assert update.message == "Cancelled by user"
|
||||
|
||||
# Should be moved to history
|
||||
active_progress = await service.get_all_active_progress()
|
||||
assert "test-1" not in active_progress
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_progress(self, service):
|
||||
"""Test retrieving progress by ID."""
|
||||
await service.start_progress(
|
||||
progress_id="test-1",
|
||||
progress_type=ProgressType.SCAN,
|
||||
title="Test",
|
||||
)
|
||||
|
||||
progress = await service.get_progress("test-1")
|
||||
assert progress is not None
|
||||
assert progress.id == "test-1"
|
||||
|
||||
# Test non-existent progress
|
||||
progress = await service.get_progress("nonexistent")
|
||||
assert progress is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_active_progress(self, service):
|
||||
"""Test retrieving all active progress operations."""
|
||||
await service.start_progress(
|
||||
progress_id="download-1",
|
||||
progress_type=ProgressType.DOWNLOAD,
|
||||
title="Download 1",
|
||||
)
|
||||
await service.start_progress(
|
||||
progress_id="download-2",
|
||||
progress_type=ProgressType.DOWNLOAD,
|
||||
title="Download 2",
|
||||
)
|
||||
await service.start_progress(
|
||||
progress_id="scan-1",
|
||||
progress_type=ProgressType.SCAN,
|
||||
title="Scan 1",
|
||||
)
|
||||
|
||||
all_progress = await service.get_all_active_progress()
|
||||
assert len(all_progress) == 3
|
||||
assert "download-1" in all_progress
|
||||
assert "download-2" in all_progress
|
||||
assert "scan-1" in all_progress
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_active_progress_filtered(self, service):
|
||||
"""Test retrieving active progress filtered by type."""
|
||||
await service.start_progress(
|
||||
progress_id="download-1",
|
||||
progress_type=ProgressType.DOWNLOAD,
|
||||
title="Download 1",
|
||||
)
|
||||
await service.start_progress(
|
||||
progress_id="download-2",
|
||||
progress_type=ProgressType.DOWNLOAD,
|
||||
title="Download 2",
|
||||
)
|
||||
await service.start_progress(
|
||||
progress_id="scan-1",
|
||||
progress_type=ProgressType.SCAN,
|
||||
title="Scan 1",
|
||||
)
|
||||
|
||||
download_progress = await service.get_all_active_progress(
|
||||
progress_type=ProgressType.DOWNLOAD
|
||||
)
|
||||
assert len(download_progress) == 2
|
||||
assert "download-1" in download_progress
|
||||
assert "download-2" in download_progress
|
||||
assert "scan-1" not in download_progress
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_history_management(self, service):
|
||||
"""Test progress history is maintained with size limit."""
|
||||
# Start and complete multiple progress operations
|
||||
for i in range(60): # More than max_history_size (50)
|
||||
await service.start_progress(
|
||||
progress_id=f"test-{i}",
|
||||
progress_type=ProgressType.DOWNLOAD,
|
||||
title=f"Test {i}",
|
||||
)
|
||||
await service.complete_progress(
|
||||
progress_id=f"test-{i}",
|
||||
message="Completed",
|
||||
)
|
||||
|
||||
# Check that oldest entries were removed
|
||||
history = service._history
|
||||
assert len(history) <= 50
|
||||
|
||||
# Most recent should be in history
|
||||
recent_progress = await service.get_progress("test-59")
|
||||
assert recent_progress is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_callback(self, service, mock_broadcast):
|
||||
"""Test broadcast callback is invoked correctly."""
|
||||
service.set_broadcast_callback(mock_broadcast)
|
||||
|
||||
await service.start_progress(
|
||||
progress_id="test-1",
|
||||
progress_type=ProgressType.DOWNLOAD,
|
||||
title="Test",
|
||||
)
|
||||
|
||||
# Verify callback was called for start
|
||||
mock_broadcast.assert_called_once()
|
||||
call_args = mock_broadcast.call_args
|
||||
assert call_args[1]["message_type"] == "download_progress"
|
||||
assert call_args[1]["room"] == "download_progress"
|
||||
assert "test-1" in str(call_args[1]["data"])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_on_update(self, service, mock_broadcast):
|
||||
"""Test broadcast on progress update."""
|
||||
service.set_broadcast_callback(mock_broadcast)
|
||||
|
||||
await service.start_progress(
|
||||
progress_id="test-1",
|
||||
progress_type=ProgressType.DOWNLOAD,
|
||||
title="Test",
|
||||
total=100,
|
||||
)
|
||||
mock_broadcast.reset_mock()
|
||||
|
||||
# Update with significant change (>1%)
|
||||
await service.update_progress(
|
||||
progress_id="test-1",
|
||||
current=50,
|
||||
force_broadcast=True,
|
||||
)
|
||||
|
||||
# Should have been called
|
||||
assert mock_broadcast.call_count >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_on_complete(self, service, mock_broadcast):
|
||||
"""Test broadcast on progress completion."""
|
||||
service.set_broadcast_callback(mock_broadcast)
|
||||
|
||||
await service.start_progress(
|
||||
progress_id="test-1",
|
||||
progress_type=ProgressType.SCAN,
|
||||
title="Test",
|
||||
)
|
||||
mock_broadcast.reset_mock()
|
||||
|
||||
await service.complete_progress(
|
||||
progress_id="test-1",
|
||||
message="Done",
|
||||
)
|
||||
|
||||
# Should have been called
|
||||
mock_broadcast.assert_called_once()
|
||||
call_args = mock_broadcast.call_args
|
||||
assert "completed" in str(call_args[1]["data"]).lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_on_failure(self, service, mock_broadcast):
|
||||
"""Test broadcast on progress failure."""
|
||||
service.set_broadcast_callback(mock_broadcast)
|
||||
|
||||
await service.start_progress(
|
||||
progress_id="test-1",
|
||||
progress_type=ProgressType.DOWNLOAD,
|
||||
title="Test",
|
||||
)
|
||||
mock_broadcast.reset_mock()
|
||||
|
||||
await service.fail_progress(
|
||||
progress_id="test-1",
|
||||
error_message="Test error",
|
||||
)
|
||||
|
||||
# Should have been called
|
||||
mock_broadcast.assert_called_once()
|
||||
call_args = mock_broadcast.call_args
|
||||
assert "failed" in str(call_args[1]["data"]).lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_history(self, service):
|
||||
"""Test clearing progress history."""
|
||||
# Create and complete some progress
|
||||
for i in range(5):
|
||||
await service.start_progress(
|
||||
progress_id=f"test-{i}",
|
||||
progress_type=ProgressType.DOWNLOAD,
|
||||
title=f"Test {i}",
|
||||
)
|
||||
await service.complete_progress(
|
||||
progress_id=f"test-{i}",
|
||||
message="Done",
|
||||
)
|
||||
|
||||
# History should not be empty
|
||||
assert len(service._history) > 0
|
||||
|
||||
# Clear history
|
||||
await service.clear_history()
|
||||
|
||||
# History should now be empty
|
||||
assert len(service._history) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_progress_operations(self, service):
|
||||
"""Test handling multiple concurrent progress operations."""
|
||||
|
||||
async def create_and_complete_progress(id_num: int):
|
||||
"""Helper to create and complete a progress."""
|
||||
await service.start_progress(
|
||||
progress_id=f"test-{id_num}",
|
||||
progress_type=ProgressType.DOWNLOAD,
|
||||
title=f"Test {id_num}",
|
||||
total=100,
|
||||
)
|
||||
for i in range(0, 101, 10):
|
||||
await service.update_progress(
|
||||
progress_id=f"test-{id_num}",
|
||||
current=i,
|
||||
)
|
||||
await asyncio.sleep(0.01)
|
||||
await service.complete_progress(
|
||||
progress_id=f"test-{id_num}",
|
||||
message="Done",
|
||||
)
|
||||
|
||||
# Run multiple concurrent operations
|
||||
tasks = [create_and_complete_progress(i) for i in range(10)]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
# All should be in history
|
||||
for i in range(10):
|
||||
progress = await service.get_progress(f"test-{i}")
|
||||
assert progress is not None
|
||||
assert progress.status == ProgressStatus.COMPLETED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_with_metadata(self, service):
|
||||
"""Test updating progress with metadata."""
|
||||
await service.start_progress(
|
||||
progress_id="test-1",
|
||||
progress_type=ProgressType.DOWNLOAD,
|
||||
title="Test",
|
||||
metadata={"initial": "value"},
|
||||
)
|
||||
|
||||
await service.update_progress(
|
||||
progress_id="test-1",
|
||||
current=50,
|
||||
metadata={"additional": "data", "speed": 1.5},
|
||||
)
|
||||
|
||||
progress = await service.get_progress("test-1")
|
||||
assert progress.metadata["initial"] == "value"
|
||||
assert progress.metadata["additional"] == "data"
|
||||
assert progress.metadata["speed"] == 1.5
|
||||
86
tests/unit/test_template_helpers.py
Normal file
86
tests/unit/test_template_helpers.py
Normal file
@ -0,0 +1,86 @@
|
||||
"""
|
||||
Tests for template helper utilities.
|
||||
|
||||
This module tests the template helper functions.
|
||||
"""
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from src.server.utils.template_helpers import (
|
||||
get_base_context,
|
||||
list_available_templates,
|
||||
validate_template_exists,
|
||||
)
|
||||
|
||||
|
||||
class TestTemplateHelpers:
|
||||
"""Test template helper utilities."""
|
||||
|
||||
def test_get_base_context(self):
|
||||
"""Test that base context is created correctly."""
|
||||
request = Mock()
|
||||
context = get_base_context(request, "Test Title")
|
||||
|
||||
assert "request" in context
|
||||
assert context["request"] == request
|
||||
assert context["title"] == "Test Title"
|
||||
assert context["app_name"] == "Aniworld Download Manager"
|
||||
assert context["version"] == "1.0.0"
|
||||
|
||||
def test_get_base_context_default_title(self):
|
||||
"""Test that default title is used."""
|
||||
request = Mock()
|
||||
context = get_base_context(request)
|
||||
|
||||
assert context["title"] == "Aniworld"
|
||||
|
||||
def test_validate_template_exists_true(self):
|
||||
"""Test template validation for existing template."""
|
||||
# index.html should exist
|
||||
exists = validate_template_exists("index.html")
|
||||
assert exists is True
|
||||
|
||||
def test_validate_template_exists_false(self):
|
||||
"""Test template validation for non-existing template."""
|
||||
exists = validate_template_exists("nonexistent.html")
|
||||
assert exists is False
|
||||
|
||||
def test_list_available_templates(self):
|
||||
"""Test listing available templates."""
|
||||
templates = list_available_templates()
|
||||
|
||||
# Should be a list
|
||||
assert isinstance(templates, list)
|
||||
|
||||
# Should contain at least the main templates
|
||||
expected_templates = [
|
||||
"index.html",
|
||||
"login.html",
|
||||
"setup.html",
|
||||
"queue.html",
|
||||
"error.html"
|
||||
]
|
||||
for expected in expected_templates:
|
||||
assert expected in templates, (
|
||||
f"{expected} not found in templates list"
|
||||
)
|
||||
|
||||
def test_list_available_templates_only_html(self):
|
||||
"""Test that only HTML files are listed."""
|
||||
templates = list_available_templates()
|
||||
|
||||
for template in templates:
|
||||
assert template.endswith(".html")
|
||||
|
||||
@pytest.mark.parametrize("template_name", [
|
||||
"index.html",
|
||||
"login.html",
|
||||
"setup.html",
|
||||
"queue.html",
|
||||
"error.html"
|
||||
])
|
||||
def test_all_required_templates_exist(self, template_name):
|
||||
"""Test that all required templates exist."""
|
||||
assert validate_template_exists(template_name), \
|
||||
f"Required template {template_name} does not exist"
|
||||
153
tests/unit/test_template_integration.py
Normal file
153
tests/unit/test_template_integration.py
Normal file
@ -0,0 +1,153 @@
|
||||
"""
|
||||
Tests for template integration and rendering.
|
||||
|
||||
This module tests that all HTML templates are properly integrated with FastAPI
|
||||
and can be rendered correctly.
|
||||
"""
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from src.server.fastapi_app import app
|
||||
|
||||
|
||||
class TestTemplateIntegration:
|
||||
"""Test template integration with FastAPI."""
|
||||
|
||||
@pytest.fixture
|
||||
def client(self):
|
||||
"""Create test client."""
|
||||
return TestClient(app)
|
||||
|
||||
def test_index_template_renders(self, client):
|
||||
"""Test that index.html renders successfully."""
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"].startswith("text/html")
|
||||
assert b"AniWorld Manager" in response.content
|
||||
assert b"/static/css/styles.css" in response.content
|
||||
|
||||
def test_login_template_renders(self, client):
|
||||
"""Test that login.html renders successfully."""
|
||||
response = client.get("/login")
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"].startswith("text/html")
|
||||
assert b"Login" in response.content
|
||||
assert b"/static/css/styles.css" in response.content
|
||||
|
||||
def test_setup_template_renders(self, client):
|
||||
"""Test that setup.html renders successfully."""
|
||||
response = client.get("/setup")
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"].startswith("text/html")
|
||||
assert b"Setup" in response.content
|
||||
assert b"/static/css/styles.css" in response.content
|
||||
|
||||
def test_queue_template_renders(self, client):
|
||||
"""Test that queue.html renders successfully."""
|
||||
response = client.get("/queue")
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"].startswith("text/html")
|
||||
assert b"Download Queue" in response.content
|
||||
assert b"/static/css/styles.css" in response.content
|
||||
|
||||
def test_error_template_404(self, client):
|
||||
"""Test that 404 error page renders correctly."""
|
||||
response = client.get("/nonexistent-page")
|
||||
assert response.status_code == 404
|
||||
assert response.headers["content-type"].startswith("text/html")
|
||||
assert b"Error 404" in response.content or b"404" in response.content
|
||||
|
||||
def test_static_css_accessible(self, client):
|
||||
"""Test that static CSS files are accessible."""
|
||||
response = client.get("/static/css/styles.css")
|
||||
assert response.status_code == 200
|
||||
assert "text/css" in response.headers.get("content-type", "")
|
||||
|
||||
def test_static_js_accessible(self, client):
|
||||
"""Test that static JavaScript files are accessible."""
|
||||
response = client.get("/static/js/app.js")
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_templates_include_theme_switching(self, client):
|
||||
"""Test that templates include theme switching functionality."""
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
# Check for theme toggle button
|
||||
assert b"theme-toggle" in response.content
|
||||
# Check for data-theme attribute
|
||||
assert b'data-theme="light"' in response.content
|
||||
|
||||
def test_templates_include_responsive_meta(self, client):
|
||||
"""Test that templates include responsive viewport meta tag."""
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
assert b'name="viewport"' in response.content
|
||||
assert b"width=device-width" in response.content
|
||||
|
||||
def test_templates_include_font_awesome(self, client):
|
||||
"""Test that templates include Font Awesome icons."""
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
assert b"font-awesome" in response.content.lower()
|
||||
|
||||
def test_all_templates_have_correct_structure(self, client):
|
||||
"""Test that all templates have correct HTML structure."""
|
||||
pages = ["/", "/login", "/setup", "/queue"]
|
||||
|
||||
for page in pages:
|
||||
response = client.get(page)
|
||||
assert response.status_code == 200
|
||||
content = response.content
|
||||
|
||||
# Check for essential HTML elements
|
||||
assert b"<!DOCTYPE html>" in content
|
||||
assert b"<html" in content
|
||||
assert b"<head>" in content
|
||||
assert b"<body>" in content
|
||||
assert b"</html>" in content
|
||||
|
||||
def test_templates_load_required_javascript(self, client):
|
||||
"""Test that index template loads all required JavaScript files."""
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
content = response.content
|
||||
|
||||
# Check for main app.js
|
||||
assert b"/static/js/app.js" in content
|
||||
|
||||
# Check for localization.js
|
||||
assert b"/static/js/localization.js" in content
|
||||
|
||||
def test_templates_load_ux_features_css(self, client):
|
||||
"""Test that templates load UX features CSS."""
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
assert b"/static/css/ux_features.css" in response.content
|
||||
|
||||
def test_queue_template_has_websocket_script(self, client):
|
||||
"""Test that queue template includes WebSocket support."""
|
||||
response = client.get("/queue")
|
||||
assert response.status_code == 200
|
||||
# Check for socket.io or WebSocket implementation
|
||||
assert (
|
||||
b"socket.io" in response.content or
|
||||
b"WebSocket" in response.content
|
||||
)
|
||||
|
||||
def test_index_includes_search_functionality(self, client):
|
||||
"""Test that index page includes search functionality."""
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
content = response.content
|
||||
|
||||
assert b"search-input" in content
|
||||
assert b"search-btn" in content
|
||||
|
||||
def test_templates_accessibility_features(self, client):
|
||||
"""Test that templates include accessibility features."""
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
content = response.content
|
||||
|
||||
# Check for ARIA labels or roles
|
||||
assert b"aria-" in content or b"role=" in content
|
||||
423
tests/unit/test_websocket_service.py
Normal file
423
tests/unit/test_websocket_service.py
Normal file
@ -0,0 +1,423 @@
|
||||
"""Unit tests for WebSocket service."""
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from fastapi import WebSocket
|
||||
|
||||
from src.server.services.websocket_service import (
|
||||
ConnectionManager,
|
||||
WebSocketService,
|
||||
get_websocket_service,
|
||||
)
|
||||
|
||||
|
||||
class TestConnectionManager:
|
||||
"""Test cases for ConnectionManager class."""
|
||||
|
||||
@pytest.fixture
|
||||
def manager(self):
|
||||
"""Create a ConnectionManager instance for testing."""
|
||||
return ConnectionManager()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_websocket(self):
|
||||
"""Create a mock WebSocket instance."""
|
||||
ws = AsyncMock(spec=WebSocket)
|
||||
ws.accept = AsyncMock()
|
||||
ws.send_json = AsyncMock()
|
||||
return ws
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect(self, manager, mock_websocket):
|
||||
"""Test connecting a WebSocket client."""
|
||||
connection_id = "test-conn-1"
|
||||
metadata = {"user_id": "user123"}
|
||||
|
||||
await manager.connect(mock_websocket, connection_id, metadata)
|
||||
|
||||
mock_websocket.accept.assert_called_once()
|
||||
assert connection_id in manager._active_connections
|
||||
assert manager._connection_metadata[connection_id] == metadata
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_without_metadata(self, manager, mock_websocket):
|
||||
"""Test connecting without metadata."""
|
||||
connection_id = "test-conn-2"
|
||||
|
||||
await manager.connect(mock_websocket, connection_id)
|
||||
|
||||
assert connection_id in manager._active_connections
|
||||
assert manager._connection_metadata[connection_id] == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect(self, manager, mock_websocket):
|
||||
"""Test disconnecting a WebSocket client."""
|
||||
connection_id = "test-conn-3"
|
||||
await manager.connect(mock_websocket, connection_id)
|
||||
|
||||
await manager.disconnect(connection_id)
|
||||
|
||||
assert connection_id not in manager._active_connections
|
||||
assert connection_id not in manager._connection_metadata
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_join_room(self, manager, mock_websocket):
|
||||
"""Test joining a room."""
|
||||
connection_id = "test-conn-4"
|
||||
room = "downloads"
|
||||
|
||||
await manager.connect(mock_websocket, connection_id)
|
||||
await manager.join_room(connection_id, room)
|
||||
|
||||
assert connection_id in manager._rooms[room]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_join_room_inactive_connection(self, manager):
|
||||
"""Test joining a room with inactive connection."""
|
||||
connection_id = "inactive-conn"
|
||||
room = "downloads"
|
||||
|
||||
# Should not raise error, just log warning
|
||||
await manager.join_room(connection_id, room)
|
||||
|
||||
assert connection_id not in manager._rooms.get(room, set())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_leave_room(self, manager, mock_websocket):
|
||||
"""Test leaving a room."""
|
||||
connection_id = "test-conn-5"
|
||||
room = "downloads"
|
||||
|
||||
await manager.connect(mock_websocket, connection_id)
|
||||
await manager.join_room(connection_id, room)
|
||||
await manager.leave_room(connection_id, room)
|
||||
|
||||
assert connection_id not in manager._rooms.get(room, set())
|
||||
assert room not in manager._rooms # Empty room should be removed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_removes_from_all_rooms(
|
||||
self, manager, mock_websocket
|
||||
):
|
||||
"""Test that disconnect removes connection from all rooms."""
|
||||
connection_id = "test-conn-6"
|
||||
rooms = ["room1", "room2", "room3"]
|
||||
|
||||
await manager.connect(mock_websocket, connection_id)
|
||||
for room in rooms:
|
||||
await manager.join_room(connection_id, room)
|
||||
|
||||
await manager.disconnect(connection_id)
|
||||
|
||||
for room in rooms:
|
||||
assert connection_id not in manager._rooms.get(room, set())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_personal_message(self, manager, mock_websocket):
|
||||
"""Test sending a personal message to a connection."""
|
||||
connection_id = "test-conn-7"
|
||||
message = {"type": "test", "data": {"value": 123}}
|
||||
|
||||
await manager.connect(mock_websocket, connection_id)
|
||||
await manager.send_personal_message(message, connection_id)
|
||||
|
||||
mock_websocket.send_json.assert_called_once_with(message)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_personal_message_inactive_connection(
|
||||
self, manager, mock_websocket
|
||||
):
|
||||
"""Test sending message to inactive connection."""
|
||||
connection_id = "inactive-conn"
|
||||
message = {"type": "test", "data": {}}
|
||||
|
||||
# Should not raise error, just log warning
|
||||
await manager.send_personal_message(message, connection_id)
|
||||
|
||||
mock_websocket.send_json.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast(self, manager):
|
||||
"""Test broadcasting to all connections."""
|
||||
connections = {}
|
||||
for i in range(3):
|
||||
ws = AsyncMock(spec=WebSocket)
|
||||
ws.accept = AsyncMock()
|
||||
ws.send_json = AsyncMock()
|
||||
conn_id = f"conn-{i}"
|
||||
await manager.connect(ws, conn_id)
|
||||
connections[conn_id] = ws
|
||||
|
||||
message = {"type": "broadcast", "data": {"value": 456}}
|
||||
await manager.broadcast(message)
|
||||
|
||||
for ws in connections.values():
|
||||
ws.send_json.assert_called_once_with(message)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_with_exclusion(self, manager):
|
||||
"""Test broadcasting with excluded connections."""
|
||||
connections = {}
|
||||
for i in range(3):
|
||||
ws = AsyncMock(spec=WebSocket)
|
||||
ws.accept = AsyncMock()
|
||||
ws.send_json = AsyncMock()
|
||||
conn_id = f"conn-{i}"
|
||||
await manager.connect(ws, conn_id)
|
||||
connections[conn_id] = ws
|
||||
|
||||
exclude = {"conn-1"}
|
||||
message = {"type": "broadcast", "data": {"value": 789}}
|
||||
await manager.broadcast(message, exclude=exclude)
|
||||
|
||||
connections["conn-0"].send_json.assert_called_once_with(message)
|
||||
connections["conn-1"].send_json.assert_not_called()
|
||||
connections["conn-2"].send_json.assert_called_once_with(message)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_to_room(self, manager):
|
||||
"""Test broadcasting to a specific room."""
|
||||
# Setup connections
|
||||
room_members = {}
|
||||
non_members = {}
|
||||
|
||||
for i in range(2):
|
||||
ws = AsyncMock(spec=WebSocket)
|
||||
ws.accept = AsyncMock()
|
||||
ws.send_json = AsyncMock()
|
||||
conn_id = f"member-{i}"
|
||||
await manager.connect(ws, conn_id)
|
||||
await manager.join_room(conn_id, "downloads")
|
||||
room_members[conn_id] = ws
|
||||
|
||||
for i in range(2):
|
||||
ws = AsyncMock(spec=WebSocket)
|
||||
ws.accept = AsyncMock()
|
||||
ws.send_json = AsyncMock()
|
||||
conn_id = f"non-member-{i}"
|
||||
await manager.connect(ws, conn_id)
|
||||
non_members[conn_id] = ws
|
||||
|
||||
message = {"type": "room_broadcast", "data": {"room": "downloads"}}
|
||||
await manager.broadcast_to_room(message, "downloads")
|
||||
|
||||
# Room members should receive message
|
||||
for ws in room_members.values():
|
||||
ws.send_json.assert_called_once_with(message)
|
||||
|
||||
# Non-members should not receive message
|
||||
for ws in non_members.values():
|
||||
ws.send_json.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_connection_count(self, manager, mock_websocket):
|
||||
"""Test getting connection count."""
|
||||
assert await manager.get_connection_count() == 0
|
||||
|
||||
await manager.connect(mock_websocket, "conn-1")
|
||||
assert await manager.get_connection_count() == 1
|
||||
|
||||
ws2 = AsyncMock(spec=WebSocket)
|
||||
ws2.accept = AsyncMock()
|
||||
await manager.connect(ws2, "conn-2")
|
||||
assert await manager.get_connection_count() == 2
|
||||
|
||||
await manager.disconnect("conn-1")
|
||||
assert await manager.get_connection_count() == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_room_members(self, manager, mock_websocket):
|
||||
"""Test getting room members."""
|
||||
room = "test-room"
|
||||
assert await manager.get_room_members(room) == []
|
||||
|
||||
await manager.connect(mock_websocket, "conn-1")
|
||||
await manager.join_room("conn-1", room)
|
||||
|
||||
members = await manager.get_room_members(room)
|
||||
assert "conn-1" in members
|
||||
assert len(members) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_connection_metadata(self, manager, mock_websocket):
|
||||
"""Test getting connection metadata."""
|
||||
connection_id = "test-conn"
|
||||
metadata = {"user_id": "user123", "ip": "127.0.0.1"}
|
||||
|
||||
await manager.connect(mock_websocket, connection_id, metadata)
|
||||
|
||||
result = await manager.get_connection_metadata(connection_id)
|
||||
assert result == metadata
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_connection_metadata(self, manager, mock_websocket):
|
||||
"""Test updating connection metadata."""
|
||||
connection_id = "test-conn"
|
||||
initial_metadata = {"user_id": "user123"}
|
||||
update = {"session_id": "session456"}
|
||||
|
||||
await manager.connect(mock_websocket, connection_id, initial_metadata)
|
||||
await manager.update_connection_metadata(connection_id, update)
|
||||
|
||||
result = await manager.get_connection_metadata(connection_id)
|
||||
assert result["user_id"] == "user123"
|
||||
assert result["session_id"] == "session456"
|
||||
|
||||
|
||||
class TestWebSocketService:
|
||||
"""Test cases for WebSocketService class."""
|
||||
|
||||
@pytest.fixture
|
||||
def service(self):
|
||||
"""Create a WebSocketService instance for testing."""
|
||||
return WebSocketService()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_websocket(self):
|
||||
"""Create a mock WebSocket instance."""
|
||||
ws = AsyncMock(spec=WebSocket)
|
||||
ws.accept = AsyncMock()
|
||||
ws.send_json = AsyncMock()
|
||||
return ws
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect(self, service, mock_websocket):
|
||||
"""Test connecting a client."""
|
||||
connection_id = "test-conn"
|
||||
user_id = "user123"
|
||||
|
||||
await service.connect(mock_websocket, connection_id, user_id)
|
||||
|
||||
mock_websocket.accept.assert_called_once()
|
||||
assert connection_id in service._manager._active_connections
|
||||
metadata = await service._manager.get_connection_metadata(
|
||||
connection_id
|
||||
)
|
||||
assert metadata["user_id"] == user_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect(self, service, mock_websocket):
|
||||
"""Test disconnecting a client."""
|
||||
connection_id = "test-conn"
|
||||
|
||||
await service.connect(mock_websocket, connection_id)
|
||||
await service.disconnect(connection_id)
|
||||
|
||||
assert connection_id not in service._manager._active_connections
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_download_progress(self, service, mock_websocket):
|
||||
"""Test broadcasting download progress."""
|
||||
connection_id = "test-conn"
|
||||
download_id = "download123"
|
||||
progress_data = {
|
||||
"percent": 50.0,
|
||||
"speed_mbps": 2.5,
|
||||
"eta_seconds": 120,
|
||||
}
|
||||
|
||||
await service.connect(mock_websocket, connection_id)
|
||||
await service._manager.join_room(connection_id, "downloads")
|
||||
await service.broadcast_download_progress(download_id, progress_data)
|
||||
|
||||
# Verify message was sent
|
||||
assert mock_websocket.send_json.called
|
||||
call_args = mock_websocket.send_json.call_args[0][0]
|
||||
assert call_args["type"] == "download_progress"
|
||||
assert call_args["data"]["download_id"] == download_id
|
||||
assert call_args["data"]["percent"] == 50.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_download_complete(self, service, mock_websocket):
|
||||
"""Test broadcasting download completion."""
|
||||
connection_id = "test-conn"
|
||||
download_id = "download123"
|
||||
result_data = {"file_path": "/path/to/file.mp4"}
|
||||
|
||||
await service.connect(mock_websocket, connection_id)
|
||||
await service._manager.join_room(connection_id, "downloads")
|
||||
await service.broadcast_download_complete(download_id, result_data)
|
||||
|
||||
assert mock_websocket.send_json.called
|
||||
call_args = mock_websocket.send_json.call_args[0][0]
|
||||
assert call_args["type"] == "download_complete"
|
||||
assert call_args["data"]["download_id"] == download_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_download_failed(self, service, mock_websocket):
|
||||
"""Test broadcasting download failure."""
|
||||
connection_id = "test-conn"
|
||||
download_id = "download123"
|
||||
error_data = {"error_message": "Network error"}
|
||||
|
||||
await service.connect(mock_websocket, connection_id)
|
||||
await service._manager.join_room(connection_id, "downloads")
|
||||
await service.broadcast_download_failed(download_id, error_data)
|
||||
|
||||
assert mock_websocket.send_json.called
|
||||
call_args = mock_websocket.send_json.call_args[0][0]
|
||||
assert call_args["type"] == "download_failed"
|
||||
assert call_args["data"]["download_id"] == download_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_queue_status(self, service, mock_websocket):
|
||||
"""Test broadcasting queue status."""
|
||||
connection_id = "test-conn"
|
||||
status_data = {"active": 2, "pending": 5, "completed": 10}
|
||||
|
||||
await service.connect(mock_websocket, connection_id)
|
||||
await service._manager.join_room(connection_id, "downloads")
|
||||
await service.broadcast_queue_status(status_data)
|
||||
|
||||
assert mock_websocket.send_json.called
|
||||
call_args = mock_websocket.send_json.call_args[0][0]
|
||||
assert call_args["type"] == "queue_status"
|
||||
assert call_args["data"] == status_data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_system_message(self, service, mock_websocket):
|
||||
"""Test broadcasting system message."""
|
||||
connection_id = "test-conn"
|
||||
message_type = "maintenance"
|
||||
data = {"message": "System will be down for maintenance"}
|
||||
|
||||
await service.connect(mock_websocket, connection_id)
|
||||
await service.broadcast_system_message(message_type, data)
|
||||
|
||||
assert mock_websocket.send_json.called
|
||||
call_args = mock_websocket.send_json.call_args[0][0]
|
||||
assert call_args["type"] == f"system_{message_type}"
|
||||
assert call_args["data"] == data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_error(self, service, mock_websocket):
|
||||
"""Test sending error message."""
|
||||
connection_id = "test-conn"
|
||||
error_message = "Invalid request"
|
||||
error_code = "INVALID_REQUEST"
|
||||
|
||||
await service.connect(mock_websocket, connection_id)
|
||||
await service.send_error(connection_id, error_message, error_code)
|
||||
|
||||
assert mock_websocket.send_json.called
|
||||
call_args = mock_websocket.send_json.call_args[0][0]
|
||||
assert call_args["type"] == "error"
|
||||
assert call_args["data"]["code"] == error_code
|
||||
assert call_args["data"]["message"] == error_message
|
||||
|
||||
|
||||
class TestGetWebSocketService:
|
||||
"""Test cases for get_websocket_service factory function."""
|
||||
|
||||
def test_singleton_pattern(self):
|
||||
"""Test that get_websocket_service returns singleton instance."""
|
||||
service1 = get_websocket_service()
|
||||
service2 = get_websocket_service()
|
||||
|
||||
assert service1 is service2
|
||||
|
||||
def test_returns_websocket_service(self):
|
||||
"""Test that factory returns WebSocketService instance."""
|
||||
service = get_websocket_service()
|
||||
|
||||
assert isinstance(service, WebSocketService)
|
||||
Loading…
x
Reference in New Issue
Block a user