Add advanced features: notification system, security middleware, audit logging, data validation, and caching
- Implement notification service with email, webhook, and in-app support - Add security headers middleware (CORS, CSP, HSTS, XSS protection) - Create comprehensive audit logging service for security events - Add data validation utilities with Pydantic validators - Implement cache service with in-memory and Redis backend support All 714 tests passing
This commit is contained in:
610
src/server/services/audit_service.py
Normal file
610
src/server/services/audit_service.py
Normal file
@@ -0,0 +1,610 @@
|
||||
"""
|
||||
Audit Service for AniWorld.
|
||||
|
||||
This module provides comprehensive audit logging for security-critical
|
||||
operations including authentication, configuration changes, and downloads.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuditEventType(str, Enum):
|
||||
"""Types of audit events."""
|
||||
|
||||
# Authentication events
|
||||
AUTH_SETUP = "auth.setup"
|
||||
AUTH_LOGIN_SUCCESS = "auth.login.success"
|
||||
AUTH_LOGIN_FAILURE = "auth.login.failure"
|
||||
AUTH_LOGOUT = "auth.logout"
|
||||
AUTH_TOKEN_REFRESH = "auth.token.refresh"
|
||||
AUTH_TOKEN_INVALID = "auth.token.invalid"
|
||||
|
||||
# Configuration events
|
||||
CONFIG_READ = "config.read"
|
||||
CONFIG_UPDATE = "config.update"
|
||||
CONFIG_BACKUP = "config.backup"
|
||||
CONFIG_RESTORE = "config.restore"
|
||||
CONFIG_DELETE = "config.delete"
|
||||
|
||||
# Download events
|
||||
DOWNLOAD_ADDED = "download.added"
|
||||
DOWNLOAD_STARTED = "download.started"
|
||||
DOWNLOAD_COMPLETED = "download.completed"
|
||||
DOWNLOAD_FAILED = "download.failed"
|
||||
DOWNLOAD_CANCELLED = "download.cancelled"
|
||||
DOWNLOAD_REMOVED = "download.removed"
|
||||
|
||||
# Queue events
|
||||
QUEUE_STARTED = "queue.started"
|
||||
QUEUE_STOPPED = "queue.stopped"
|
||||
QUEUE_PAUSED = "queue.paused"
|
||||
QUEUE_RESUMED = "queue.resumed"
|
||||
QUEUE_CLEARED = "queue.cleared"
|
||||
|
||||
# System events
|
||||
SYSTEM_STARTUP = "system.startup"
|
||||
SYSTEM_SHUTDOWN = "system.shutdown"
|
||||
SYSTEM_ERROR = "system.error"
|
||||
|
||||
|
||||
class AuditEventSeverity(str, Enum):
|
||||
"""Severity levels for audit events."""
|
||||
|
||||
DEBUG = "debug"
|
||||
INFO = "info"
|
||||
WARNING = "warning"
|
||||
ERROR = "error"
|
||||
CRITICAL = "critical"
|
||||
|
||||
|
||||
class AuditEvent(BaseModel):
|
||||
"""Audit event model."""
|
||||
|
||||
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
||||
event_type: AuditEventType
|
||||
severity: AuditEventSeverity = AuditEventSeverity.INFO
|
||||
user_id: Optional[str] = None
|
||||
ip_address: Optional[str] = None
|
||||
user_agent: Optional[str] = None
|
||||
resource: Optional[str] = None
|
||||
action: Optional[str] = None
|
||||
status: str = "success"
|
||||
message: str
|
||||
details: Optional[Dict[str, Any]] = None
|
||||
session_id: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
"""Pydantic config."""
|
||||
|
||||
json_encoders = {datetime: lambda v: v.isoformat()}
|
||||
|
||||
|
||||
class AuditLogStorage:
|
||||
"""Base class for audit log storage backends."""
|
||||
|
||||
async def write_event(self, event: AuditEvent) -> None:
|
||||
"""
|
||||
Write an audit event to storage.
|
||||
|
||||
Args:
|
||||
event: Audit event to write
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def read_events(
|
||||
self,
|
||||
start_time: Optional[datetime] = None,
|
||||
end_time: Optional[datetime] = None,
|
||||
event_types: Optional[List[AuditEventType]] = None,
|
||||
user_id: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
) -> List[AuditEvent]:
|
||||
"""
|
||||
Read audit events from storage.
|
||||
|
||||
Args:
|
||||
start_time: Start of time range
|
||||
end_time: End of time range
|
||||
event_types: Filter by event types
|
||||
user_id: Filter by user ID
|
||||
limit: Maximum number of events to return
|
||||
|
||||
Returns:
|
||||
List of audit events
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def cleanup_old_events(self, days: int = 90) -> int:
|
||||
"""
|
||||
Clean up audit events older than specified days.
|
||||
|
||||
Args:
|
||||
days: Number of days to retain
|
||||
|
||||
Returns:
|
||||
Number of events deleted
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class FileAuditLogStorage(AuditLogStorage):
|
||||
"""File-based audit log storage."""
|
||||
|
||||
def __init__(self, log_directory: str = "logs/audit"):
|
||||
"""
|
||||
Initialize file-based audit log storage.
|
||||
|
||||
Args:
|
||||
log_directory: Directory to store audit logs
|
||||
"""
|
||||
self.log_directory = Path(log_directory)
|
||||
self.log_directory.mkdir(parents=True, exist_ok=True)
|
||||
self._current_date: Optional[str] = None
|
||||
self._current_file: Optional[Path] = None
|
||||
|
||||
def _get_log_file(self, date: datetime) -> Path:
|
||||
"""
|
||||
Get log file path for a specific date.
|
||||
|
||||
Args:
|
||||
date: Date for log file
|
||||
|
||||
Returns:
|
||||
Path to log file
|
||||
"""
|
||||
date_str = date.strftime("%Y-%m-%d")
|
||||
return self.log_directory / f"audit_{date_str}.jsonl"
|
||||
|
||||
async def write_event(self, event: AuditEvent) -> None:
|
||||
"""
|
||||
Write an audit event to file.
|
||||
|
||||
Args:
|
||||
event: Audit event to write
|
||||
"""
|
||||
log_file = self._get_log_file(event.timestamp)
|
||||
|
||||
try:
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write(event.model_dump_json() + "\n")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to write audit event to file: {e}")
|
||||
|
||||
async def read_events(
|
||||
self,
|
||||
start_time: Optional[datetime] = None,
|
||||
end_time: Optional[datetime] = None,
|
||||
event_types: Optional[List[AuditEventType]] = None,
|
||||
user_id: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
) -> List[AuditEvent]:
|
||||
"""
|
||||
Read audit events from files.
|
||||
|
||||
Args:
|
||||
start_time: Start of time range
|
||||
end_time: End of time range
|
||||
event_types: Filter by event types
|
||||
user_id: Filter by user ID
|
||||
limit: Maximum number of events to return
|
||||
|
||||
Returns:
|
||||
List of audit events
|
||||
"""
|
||||
if start_time is None:
|
||||
start_time = datetime.utcnow() - timedelta(days=7)
|
||||
if end_time is None:
|
||||
end_time = datetime.utcnow()
|
||||
|
||||
events: List[AuditEvent] = []
|
||||
current_date = start_time.date()
|
||||
end_date = end_time.date()
|
||||
|
||||
# Read from all log files in date range
|
||||
while current_date <= end_date and len(events) < limit:
|
||||
log_file = self._get_log_file(datetime.combine(current_date, datetime.min.time()))
|
||||
|
||||
if log_file.exists():
|
||||
try:
|
||||
with open(log_file, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if len(events) >= limit:
|
||||
break
|
||||
|
||||
try:
|
||||
event_data = json.loads(line.strip())
|
||||
event = AuditEvent(**event_data)
|
||||
|
||||
# Apply filters
|
||||
if event.timestamp < start_time or event.timestamp > end_time:
|
||||
continue
|
||||
|
||||
if event_types and event.event_type not in event_types:
|
||||
continue
|
||||
|
||||
if user_id and event.user_id != user_id:
|
||||
continue
|
||||
|
||||
events.append(event)
|
||||
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
logger.warning(f"Failed to parse audit event: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read audit log file {log_file}: {e}")
|
||||
|
||||
current_date += timedelta(days=1)
|
||||
|
||||
# Sort by timestamp descending
|
||||
events.sort(key=lambda e: e.timestamp, reverse=True)
|
||||
return events[:limit]
|
||||
|
||||
async def cleanup_old_events(self, days: int = 90) -> int:
|
||||
"""
|
||||
Clean up audit events older than specified days.
|
||||
|
||||
Args:
|
||||
days: Number of days to retain
|
||||
|
||||
Returns:
|
||||
Number of files deleted
|
||||
"""
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=days)
|
||||
deleted_count = 0
|
||||
|
||||
for log_file in self.log_directory.glob("audit_*.jsonl"):
|
||||
try:
|
||||
# Extract date from filename
|
||||
date_str = log_file.stem.replace("audit_", "")
|
||||
file_date = datetime.strptime(date_str, "%Y-%m-%d")
|
||||
|
||||
if file_date < cutoff_date:
|
||||
log_file.unlink()
|
||||
deleted_count += 1
|
||||
logger.info(f"Deleted old audit log: {log_file}")
|
||||
|
||||
except (ValueError, OSError) as e:
|
||||
logger.warning(f"Failed to process audit log file {log_file}: {e}")
|
||||
|
||||
return deleted_count
|
||||
|
||||
|
||||
class AuditService:
|
||||
"""Main audit service for logging security events."""
|
||||
|
||||
def __init__(self, storage: Optional[AuditLogStorage] = None):
|
||||
"""
|
||||
Initialize audit service.
|
||||
|
||||
Args:
|
||||
storage: Storage backend for audit logs
|
||||
"""
|
||||
self.storage = storage or FileAuditLogStorage()
|
||||
|
||||
async def log_event(
|
||||
self,
|
||||
event_type: AuditEventType,
|
||||
message: str,
|
||||
severity: AuditEventSeverity = AuditEventSeverity.INFO,
|
||||
user_id: Optional[str] = None,
|
||||
ip_address: Optional[str] = None,
|
||||
user_agent: Optional[str] = None,
|
||||
resource: Optional[str] = None,
|
||||
action: Optional[str] = None,
|
||||
status: str = "success",
|
||||
details: Optional[Dict[str, Any]] = None,
|
||||
session_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Log an audit event.
|
||||
|
||||
Args:
|
||||
event_type: Type of event
|
||||
message: Human-readable message
|
||||
severity: Event severity
|
||||
user_id: User identifier
|
||||
ip_address: Client IP address
|
||||
user_agent: Client user agent
|
||||
resource: Resource being accessed
|
||||
action: Action performed
|
||||
status: Operation status
|
||||
details: Additional details
|
||||
session_id: Session identifier
|
||||
"""
|
||||
event = AuditEvent(
|
||||
event_type=event_type,
|
||||
severity=severity,
|
||||
user_id=user_id,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
resource=resource,
|
||||
action=action,
|
||||
status=status,
|
||||
message=message,
|
||||
details=details,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
await self.storage.write_event(event)
|
||||
|
||||
# Also log to application logger for high severity events
|
||||
if severity in [AuditEventSeverity.ERROR, AuditEventSeverity.CRITICAL]:
|
||||
logger.error(f"Audit: {message}", extra={"audit_event": event.model_dump()})
|
||||
elif severity == AuditEventSeverity.WARNING:
|
||||
logger.warning(f"Audit: {message}", extra={"audit_event": event.model_dump()})
|
||||
|
||||
async def log_auth_setup(
|
||||
self, user_id: str, ip_address: Optional[str] = None
|
||||
) -> None:
|
||||
"""Log initial authentication setup."""
|
||||
await self.log_event(
|
||||
event_type=AuditEventType.AUTH_SETUP,
|
||||
message=f"Authentication configured by user {user_id}",
|
||||
user_id=user_id,
|
||||
ip_address=ip_address,
|
||||
action="setup",
|
||||
)
|
||||
|
||||
async def log_login_success(
|
||||
self,
|
||||
user_id: str,
|
||||
ip_address: Optional[str] = None,
|
||||
user_agent: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Log successful login."""
|
||||
await self.log_event(
|
||||
event_type=AuditEventType.AUTH_LOGIN_SUCCESS,
|
||||
message=f"User {user_id} logged in successfully",
|
||||
user_id=user_id,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
session_id=session_id,
|
||||
action="login",
|
||||
)
|
||||
|
||||
async def log_login_failure(
|
||||
self,
|
||||
user_id: Optional[str] = None,
|
||||
ip_address: Optional[str] = None,
|
||||
user_agent: Optional[str] = None,
|
||||
reason: str = "Invalid credentials",
|
||||
) -> None:
|
||||
"""Log failed login attempt."""
|
||||
await self.log_event(
|
||||
event_type=AuditEventType.AUTH_LOGIN_FAILURE,
|
||||
message=f"Login failed for user {user_id or 'unknown'}: {reason}",
|
||||
severity=AuditEventSeverity.WARNING,
|
||||
user_id=user_id,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
status="failure",
|
||||
action="login",
|
||||
details={"reason": reason},
|
||||
)
|
||||
|
||||
async def log_logout(
|
||||
self,
|
||||
user_id: str,
|
||||
ip_address: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Log user logout."""
|
||||
await self.log_event(
|
||||
event_type=AuditEventType.AUTH_LOGOUT,
|
||||
message=f"User {user_id} logged out",
|
||||
user_id=user_id,
|
||||
ip_address=ip_address,
|
||||
session_id=session_id,
|
||||
action="logout",
|
||||
)
|
||||
|
||||
async def log_config_update(
|
||||
self,
|
||||
user_id: str,
|
||||
changes: Dict[str, Any],
|
||||
ip_address: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Log configuration update."""
|
||||
await self.log_event(
|
||||
event_type=AuditEventType.CONFIG_UPDATE,
|
||||
message=f"Configuration updated by user {user_id}",
|
||||
user_id=user_id,
|
||||
ip_address=ip_address,
|
||||
resource="config",
|
||||
action="update",
|
||||
details={"changes": changes},
|
||||
)
|
||||
|
||||
async def log_config_backup(
|
||||
self, user_id: str, backup_file: str, ip_address: Optional[str] = None
|
||||
) -> None:
|
||||
"""Log configuration backup."""
|
||||
await self.log_event(
|
||||
event_type=AuditEventType.CONFIG_BACKUP,
|
||||
message=f"Configuration backed up by user {user_id}",
|
||||
user_id=user_id,
|
||||
ip_address=ip_address,
|
||||
resource="config",
|
||||
action="backup",
|
||||
details={"backup_file": backup_file},
|
||||
)
|
||||
|
||||
async def log_config_restore(
|
||||
self, user_id: str, backup_file: str, ip_address: Optional[str] = None
|
||||
) -> None:
|
||||
"""Log configuration restore."""
|
||||
await self.log_event(
|
||||
event_type=AuditEventType.CONFIG_RESTORE,
|
||||
message=f"Configuration restored by user {user_id}",
|
||||
user_id=user_id,
|
||||
ip_address=ip_address,
|
||||
resource="config",
|
||||
action="restore",
|
||||
details={"backup_file": backup_file},
|
||||
)
|
||||
|
||||
async def log_download_added(
|
||||
self,
|
||||
user_id: str,
|
||||
series_name: str,
|
||||
episodes: List[str],
|
||||
ip_address: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Log download added to queue."""
|
||||
await self.log_event(
|
||||
event_type=AuditEventType.DOWNLOAD_ADDED,
|
||||
message=f"Download added by user {user_id}: {series_name}",
|
||||
user_id=user_id,
|
||||
ip_address=ip_address,
|
||||
resource=series_name,
|
||||
action="add",
|
||||
details={"episodes": episodes},
|
||||
)
|
||||
|
||||
async def log_download_completed(
|
||||
self, series_name: str, episode: str, file_path: str
|
||||
) -> None:
|
||||
"""Log completed download."""
|
||||
await self.log_event(
|
||||
event_type=AuditEventType.DOWNLOAD_COMPLETED,
|
||||
message=f"Download completed: {series_name} - {episode}",
|
||||
resource=series_name,
|
||||
action="download",
|
||||
details={"episode": episode, "file_path": file_path},
|
||||
)
|
||||
|
||||
async def log_download_failed(
|
||||
self, series_name: str, episode: str, error: str
|
||||
) -> None:
|
||||
"""Log failed download."""
|
||||
await self.log_event(
|
||||
event_type=AuditEventType.DOWNLOAD_FAILED,
|
||||
message=f"Download failed: {series_name} - {episode}",
|
||||
severity=AuditEventSeverity.ERROR,
|
||||
resource=series_name,
|
||||
action="download",
|
||||
status="failure",
|
||||
details={"episode": episode, "error": error},
|
||||
)
|
||||
|
||||
async def log_queue_operation(
|
||||
self,
|
||||
user_id: str,
|
||||
operation: str,
|
||||
ip_address: Optional[str] = None,
|
||||
details: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Log queue operation."""
|
||||
event_type_map = {
|
||||
"start": AuditEventType.QUEUE_STARTED,
|
||||
"stop": AuditEventType.QUEUE_STOPPED,
|
||||
"pause": AuditEventType.QUEUE_PAUSED,
|
||||
"resume": AuditEventType.QUEUE_RESUMED,
|
||||
"clear": AuditEventType.QUEUE_CLEARED,
|
||||
}
|
||||
|
||||
event_type = event_type_map.get(operation, AuditEventType.SYSTEM_ERROR)
|
||||
await self.log_event(
|
||||
event_type=event_type,
|
||||
message=f"Queue {operation} by user {user_id}",
|
||||
user_id=user_id,
|
||||
ip_address=ip_address,
|
||||
resource="queue",
|
||||
action=operation,
|
||||
details=details,
|
||||
)
|
||||
|
||||
async def log_system_error(
|
||||
self, error: str, details: Optional[Dict[str, Any]] = None
|
||||
) -> None:
|
||||
"""Log system error."""
|
||||
await self.log_event(
|
||||
event_type=AuditEventType.SYSTEM_ERROR,
|
||||
message=f"System error: {error}",
|
||||
severity=AuditEventSeverity.ERROR,
|
||||
status="error",
|
||||
details=details,
|
||||
)
|
||||
|
||||
async def get_events(
|
||||
self,
|
||||
start_time: Optional[datetime] = None,
|
||||
end_time: Optional[datetime] = None,
|
||||
event_types: Optional[List[AuditEventType]] = None,
|
||||
user_id: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
) -> List[AuditEvent]:
|
||||
"""
|
||||
Get audit events with filters.
|
||||
|
||||
Args:
|
||||
start_time: Start of time range
|
||||
end_time: End of time range
|
||||
event_types: Filter by event types
|
||||
user_id: Filter by user ID
|
||||
limit: Maximum number of events to return
|
||||
|
||||
Returns:
|
||||
List of audit events
|
||||
"""
|
||||
return await self.storage.read_events(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
event_types=event_types,
|
||||
user_id=user_id,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
async def cleanup_old_events(self, days: int = 90) -> int:
|
||||
"""
|
||||
Clean up old audit events.
|
||||
|
||||
Args:
|
||||
days: Number of days to retain
|
||||
|
||||
Returns:
|
||||
Number of events deleted
|
||||
"""
|
||||
return await self.storage.cleanup_old_events(days)
|
||||
|
||||
|
||||
# Global audit service instance
|
||||
_audit_service: Optional[AuditService] = None
|
||||
|
||||
|
||||
def get_audit_service() -> AuditService:
|
||||
"""
|
||||
Get the global audit service instance.
|
||||
|
||||
Returns:
|
||||
AuditService instance
|
||||
"""
|
||||
global _audit_service
|
||||
if _audit_service is None:
|
||||
_audit_service = AuditService()
|
||||
return _audit_service
|
||||
|
||||
|
||||
def configure_audit_service(storage: Optional[AuditLogStorage] = None) -> AuditService:
|
||||
"""
|
||||
Configure the global audit service.
|
||||
|
||||
Args:
|
||||
storage: Custom storage backend
|
||||
|
||||
Returns:
|
||||
Configured AuditService instance
|
||||
"""
|
||||
global _audit_service
|
||||
_audit_service = AuditService(storage=storage)
|
||||
return _audit_service
|
||||
723
src/server/services/cache_service.py
Normal file
723
src/server/services/cache_service.py
Normal file
@@ -0,0 +1,723 @@
|
||||
"""
|
||||
Cache Service for AniWorld.
|
||||
|
||||
This module provides caching functionality with support for both
|
||||
in-memory and Redis backends to improve application performance.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import logging
|
||||
import pickle
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CacheBackend(ABC):
|
||||
"""Abstract base class for cache backends."""
|
||||
|
||||
@abstractmethod
|
||||
async def get(self, key: str) -> Optional[Any]:
|
||||
"""
|
||||
Get value from cache.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
|
||||
Returns:
|
||||
Cached value or None if not found
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set(
|
||||
self, key: str, value: Any, ttl: Optional[int] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Set value in cache.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
value: Value to cache
|
||||
ttl: Time to live in seconds
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""
|
||||
Delete value from cache.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
|
||||
Returns:
|
||||
True if key was deleted
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def exists(self, key: str) -> bool:
|
||||
"""
|
||||
Check if key exists in cache.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
|
||||
Returns:
|
||||
True if key exists
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def clear(self) -> bool:
|
||||
"""
|
||||
Clear all cached values.
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_many(self, keys: List[str]) -> Dict[str, Any]:
|
||||
"""
|
||||
Get multiple values from cache.
|
||||
|
||||
Args:
|
||||
keys: List of cache keys
|
||||
|
||||
Returns:
|
||||
Dictionary mapping keys to values
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_many(
|
||||
self, items: Dict[str, Any], ttl: Optional[int] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Set multiple values in cache.
|
||||
|
||||
Args:
|
||||
items: Dictionary of key-value pairs
|
||||
ttl: Time to live in seconds
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def delete_pattern(self, pattern: str) -> int:
|
||||
"""
|
||||
Delete all keys matching pattern.
|
||||
|
||||
Args:
|
||||
pattern: Pattern to match (supports wildcards)
|
||||
|
||||
Returns:
|
||||
Number of keys deleted
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class InMemoryCacheBackend(CacheBackend):
|
||||
"""In-memory cache backend using dictionary."""
|
||||
|
||||
def __init__(self, max_size: int = 1000):
|
||||
"""
|
||||
Initialize in-memory cache.
|
||||
|
||||
Args:
|
||||
max_size: Maximum number of items to cache
|
||||
"""
|
||||
self.cache: Dict[str, Dict[str, Any]] = {}
|
||||
self.max_size = max_size
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
def _is_expired(self, item: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Check if cache item is expired.
|
||||
|
||||
Args:
|
||||
item: Cache item with expiry
|
||||
|
||||
Returns:
|
||||
True if expired
|
||||
"""
|
||||
if item.get("expiry") is None:
|
||||
return False
|
||||
return datetime.utcnow() > item["expiry"]
|
||||
|
||||
def _evict_oldest(self) -> None:
|
||||
"""Evict oldest cache item when cache is full."""
|
||||
if len(self.cache) >= self.max_size:
|
||||
# Remove oldest item
|
||||
oldest_key = min(
|
||||
self.cache.keys(),
|
||||
key=lambda k: self.cache[k].get("created", datetime.utcnow()),
|
||||
)
|
||||
del self.cache[oldest_key]
|
||||
|
||||
async def get(self, key: str) -> Optional[Any]:
|
||||
"""Get value from cache."""
|
||||
async with self._lock:
|
||||
if key not in self.cache:
|
||||
return None
|
||||
|
||||
item = self.cache[key]
|
||||
|
||||
if self._is_expired(item):
|
||||
del self.cache[key]
|
||||
return None
|
||||
|
||||
return item["value"]
|
||||
|
||||
async def set(
|
||||
self, key: str, value: Any, ttl: Optional[int] = None
|
||||
) -> bool:
|
||||
"""Set value in cache."""
|
||||
async with self._lock:
|
||||
self._evict_oldest()
|
||||
|
||||
expiry = None
|
||||
if ttl:
|
||||
expiry = datetime.utcnow() + timedelta(seconds=ttl)
|
||||
|
||||
self.cache[key] = {
|
||||
"value": value,
|
||||
"expiry": expiry,
|
||||
"created": datetime.utcnow(),
|
||||
}
|
||||
return True
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""Delete value from cache."""
|
||||
async with self._lock:
|
||||
if key in self.cache:
|
||||
del self.cache[key]
|
||||
return True
|
||||
return False
|
||||
|
||||
async def exists(self, key: str) -> bool:
|
||||
"""Check if key exists in cache."""
|
||||
async with self._lock:
|
||||
if key not in self.cache:
|
||||
return False
|
||||
|
||||
item = self.cache[key]
|
||||
if self._is_expired(item):
|
||||
del self.cache[key]
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def clear(self) -> bool:
|
||||
"""Clear all cached values."""
|
||||
async with self._lock:
|
||||
self.cache.clear()
|
||||
return True
|
||||
|
||||
async def get_many(self, keys: List[str]) -> Dict[str, Any]:
|
||||
"""Get multiple values from cache."""
|
||||
result = {}
|
||||
for key in keys:
|
||||
value = await self.get(key)
|
||||
if value is not None:
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
async def set_many(
|
||||
self, items: Dict[str, Any], ttl: Optional[int] = None
|
||||
) -> bool:
|
||||
"""Set multiple values in cache."""
|
||||
for key, value in items.items():
|
||||
await self.set(key, value, ttl)
|
||||
return True
|
||||
|
||||
async def delete_pattern(self, pattern: str) -> int:
|
||||
"""Delete all keys matching pattern."""
|
||||
import fnmatch
|
||||
|
||||
async with self._lock:
|
||||
keys_to_delete = [
|
||||
key for key in self.cache.keys() if fnmatch.fnmatch(key, pattern)
|
||||
]
|
||||
for key in keys_to_delete:
|
||||
del self.cache[key]
|
||||
return len(keys_to_delete)
|
||||
|
||||
|
||||
class RedisCacheBackend(CacheBackend):
|
||||
"""Redis cache backend."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_url: str = "redis://localhost:6379",
|
||||
prefix: str = "aniworld:",
|
||||
):
|
||||
"""
|
||||
Initialize Redis cache.
|
||||
|
||||
Args:
|
||||
redis_url: Redis connection URL
|
||||
prefix: Key prefix for namespacing
|
||||
"""
|
||||
self.redis_url = redis_url
|
||||
self.prefix = prefix
|
||||
self._redis = None
|
||||
|
||||
async def _get_redis(self):
|
||||
"""Get Redis connection."""
|
||||
if self._redis is None:
|
||||
try:
|
||||
import aioredis
|
||||
|
||||
self._redis = await aioredis.create_redis_pool(self.redis_url)
|
||||
except ImportError:
|
||||
logger.error(
|
||||
"aioredis not installed. Install with: pip install aioredis"
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Redis: {e}")
|
||||
raise
|
||||
|
||||
return self._redis
|
||||
|
||||
def _make_key(self, key: str) -> str:
|
||||
"""Add prefix to key."""
|
||||
return f"{self.prefix}{key}"
|
||||
|
||||
async def get(self, key: str) -> Optional[Any]:
|
||||
"""Get value from cache."""
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
data = await redis.get(self._make_key(key))
|
||||
|
||||
if data is None:
|
||||
return None
|
||||
|
||||
return pickle.loads(data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Redis get error: {e}")
|
||||
return None
|
||||
|
||||
async def set(
|
||||
self, key: str, value: Any, ttl: Optional[int] = None
|
||||
) -> bool:
|
||||
"""Set value in cache."""
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
data = pickle.dumps(value)
|
||||
|
||||
if ttl:
|
||||
await redis.setex(self._make_key(key), ttl, data)
|
||||
else:
|
||||
await redis.set(self._make_key(key), data)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Redis set error: {e}")
|
||||
return False
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""Delete value from cache."""
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
result = await redis.delete(self._make_key(key))
|
||||
return result > 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Redis delete error: {e}")
|
||||
return False
|
||||
|
||||
async def exists(self, key: str) -> bool:
|
||||
"""Check if key exists in cache."""
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
return await redis.exists(self._make_key(key))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Redis exists error: {e}")
|
||||
return False
|
||||
|
||||
async def clear(self) -> bool:
|
||||
"""Clear all cached values with prefix."""
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
keys = await redis.keys(f"{self.prefix}*")
|
||||
if keys:
|
||||
await redis.delete(*keys)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Redis clear error: {e}")
|
||||
return False
|
||||
|
||||
async def get_many(self, keys: List[str]) -> Dict[str, Any]:
|
||||
"""Get multiple values from cache."""
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
prefixed_keys = [self._make_key(k) for k in keys]
|
||||
values = await redis.mget(*prefixed_keys)
|
||||
|
||||
result = {}
|
||||
for key, value in zip(keys, values):
|
||||
if value is not None:
|
||||
result[key] = pickle.loads(value)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Redis get_many error: {e}")
|
||||
return {}
|
||||
|
||||
async def set_many(
|
||||
self, items: Dict[str, Any], ttl: Optional[int] = None
|
||||
) -> bool:
|
||||
"""Set multiple values in cache."""
|
||||
try:
|
||||
for key, value in items.items():
|
||||
await self.set(key, value, ttl)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Redis set_many error: {e}")
|
||||
return False
|
||||
|
||||
async def delete_pattern(self, pattern: str) -> int:
|
||||
"""Delete all keys matching pattern."""
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
full_pattern = f"{self.prefix}{pattern}"
|
||||
keys = await redis.keys(full_pattern)
|
||||
|
||||
if keys:
|
||||
await redis.delete(*keys)
|
||||
return len(keys)
|
||||
|
||||
return 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Redis delete_pattern error: {e}")
|
||||
return 0
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close Redis connection."""
|
||||
if self._redis:
|
||||
self._redis.close()
|
||||
await self._redis.wait_closed()
|
||||
|
||||
|
||||
class CacheService:
|
||||
"""Main cache service with automatic key generation and TTL management."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backend: Optional[CacheBackend] = None,
|
||||
default_ttl: int = 3600,
|
||||
key_prefix: str = "",
|
||||
):
|
||||
"""
|
||||
Initialize cache service.
|
||||
|
||||
Args:
|
||||
backend: Cache backend to use
|
||||
default_ttl: Default time to live in seconds
|
||||
key_prefix: Prefix for all cache keys
|
||||
"""
|
||||
self.backend = backend or InMemoryCacheBackend()
|
||||
self.default_ttl = default_ttl
|
||||
self.key_prefix = key_prefix
|
||||
|
||||
def _make_key(self, *args: Any, **kwargs: Any) -> str:
|
||||
"""
|
||||
Generate cache key from arguments.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments
|
||||
**kwargs: Keyword arguments
|
||||
|
||||
Returns:
|
||||
Cache key string
|
||||
"""
|
||||
# Create a stable key from arguments
|
||||
key_parts = [str(arg) for arg in args]
|
||||
key_parts.extend(f"{k}={v}" for k, v in sorted(kwargs.items()))
|
||||
key_str = ":".join(key_parts)
|
||||
|
||||
# Hash long keys
|
||||
if len(key_str) > 200:
|
||||
key_hash = hashlib.md5(key_str.encode()).hexdigest()
|
||||
return f"{self.key_prefix}{key_hash}"
|
||||
|
||||
return f"{self.key_prefix}{key_str}"
|
||||
|
||||
async def get(
|
||||
self, key: str, default: Optional[Any] = None
|
||||
) -> Optional[Any]:
|
||||
"""
|
||||
Get value from cache.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
default: Default value if not found
|
||||
|
||||
Returns:
|
||||
Cached value or default
|
||||
"""
|
||||
value = await self.backend.get(key)
|
||||
return value if value is not None else default
|
||||
|
||||
async def set(
|
||||
self, key: str, value: Any, ttl: Optional[int] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Set value in cache.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
value: Value to cache
|
||||
ttl: Time to live in seconds (uses default if None)
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
if ttl is None:
|
||||
ttl = self.default_ttl
|
||||
return await self.backend.set(key, value, ttl)
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""
|
||||
Delete value from cache.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
|
||||
Returns:
|
||||
True if deleted
|
||||
"""
|
||||
return await self.backend.delete(key)
|
||||
|
||||
async def exists(self, key: str) -> bool:
|
||||
"""
|
||||
Check if key exists in cache.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
|
||||
Returns:
|
||||
True if exists
|
||||
"""
|
||||
return await self.backend.exists(key)
|
||||
|
||||
async def clear(self) -> bool:
|
||||
"""
|
||||
Clear all cached values.
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
return await self.backend.clear()
|
||||
|
||||
async def get_or_set(
|
||||
self,
|
||||
key: str,
|
||||
factory,
|
||||
ttl: Optional[int] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Get value from cache or compute and cache it.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
factory: Callable to compute value if not cached
|
||||
ttl: Time to live in seconds
|
||||
|
||||
Returns:
|
||||
Cached or computed value
|
||||
"""
|
||||
value = await self.get(key)
|
||||
|
||||
if value is None:
|
||||
# Compute value
|
||||
if asyncio.iscoroutinefunction(factory):
|
||||
value = await factory()
|
||||
else:
|
||||
value = factory()
|
||||
|
||||
# Cache it
|
||||
await self.set(key, value, ttl)
|
||||
|
||||
return value
|
||||
|
||||
async def invalidate_pattern(self, pattern: str) -> int:
|
||||
"""
|
||||
Invalidate all keys matching pattern.
|
||||
|
||||
Args:
|
||||
pattern: Pattern to match
|
||||
|
||||
Returns:
|
||||
Number of keys invalidated
|
||||
"""
|
||||
return await self.backend.delete_pattern(pattern)
|
||||
|
||||
async def cache_anime_list(
|
||||
self, anime_list: List[Dict[str, Any]], ttl: Optional[int] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Cache anime list.
|
||||
|
||||
Args:
|
||||
anime_list: List of anime data
|
||||
ttl: Time to live in seconds
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
key = self._make_key("anime", "list")
|
||||
return await self.set(key, anime_list, ttl)
|
||||
|
||||
async def get_anime_list(self) -> Optional[List[Dict[str, Any]]]:
|
||||
"""
|
||||
Get cached anime list.
|
||||
|
||||
Returns:
|
||||
Cached anime list or None
|
||||
"""
|
||||
key = self._make_key("anime", "list")
|
||||
return await self.get(key)
|
||||
|
||||
async def cache_anime_detail(
|
||||
self, anime_id: str, data: Dict[str, Any], ttl: Optional[int] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Cache anime detail.
|
||||
|
||||
Args:
|
||||
anime_id: Anime identifier
|
||||
data: Anime data
|
||||
ttl: Time to live in seconds
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
key = self._make_key("anime", "detail", anime_id)
|
||||
return await self.set(key, data, ttl)
|
||||
|
||||
async def get_anime_detail(self, anime_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get cached anime detail.
|
||||
|
||||
Args:
|
||||
anime_id: Anime identifier
|
||||
|
||||
Returns:
|
||||
Cached anime data or None
|
||||
"""
|
||||
key = self._make_key("anime", "detail", anime_id)
|
||||
return await self.get(key)
|
||||
|
||||
async def invalidate_anime_cache(self) -> int:
|
||||
"""
|
||||
Invalidate all anime-related cache.
|
||||
|
||||
Returns:
|
||||
Number of keys invalidated
|
||||
"""
|
||||
return await self.invalidate_pattern(f"{self.key_prefix}anime*")
|
||||
|
||||
async def cache_config(
|
||||
self, config: Dict[str, Any], ttl: Optional[int] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Cache configuration.
|
||||
|
||||
Args:
|
||||
config: Configuration data
|
||||
ttl: Time to live in seconds
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
key = self._make_key("config")
|
||||
return await self.set(key, config, ttl)
|
||||
|
||||
async def get_config(self) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get cached configuration.
|
||||
|
||||
Returns:
|
||||
Cached configuration or None
|
||||
"""
|
||||
key = self._make_key("config")
|
||||
return await self.get(key)
|
||||
|
||||
async def invalidate_config_cache(self) -> bool:
|
||||
"""
|
||||
Invalidate configuration cache.
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
key = self._make_key("config")
|
||||
return await self.delete(key)
|
||||
|
||||
|
||||
# Global cache service instance
|
||||
_cache_service: Optional[CacheService] = None
|
||||
|
||||
|
||||
def get_cache_service() -> CacheService:
|
||||
"""
|
||||
Get the global cache service instance.
|
||||
|
||||
Returns:
|
||||
CacheService instance
|
||||
"""
|
||||
global _cache_service
|
||||
if _cache_service is None:
|
||||
_cache_service = CacheService()
|
||||
return _cache_service
|
||||
|
||||
|
||||
def configure_cache_service(
|
||||
backend_type: str = "memory",
|
||||
redis_url: str = "redis://localhost:6379",
|
||||
default_ttl: int = 3600,
|
||||
max_size: int = 1000,
|
||||
) -> CacheService:
|
||||
"""
|
||||
Configure the global cache service.
|
||||
|
||||
Args:
|
||||
backend_type: Type of backend ("memory" or "redis")
|
||||
redis_url: Redis connection URL (for redis backend)
|
||||
default_ttl: Default time to live in seconds
|
||||
max_size: Maximum cache size (for memory backend)
|
||||
|
||||
Returns:
|
||||
Configured CacheService instance
|
||||
"""
|
||||
global _cache_service
|
||||
|
||||
if backend_type == "redis":
|
||||
backend = RedisCacheBackend(redis_url=redis_url)
|
||||
else:
|
||||
backend = InMemoryCacheBackend(max_size=max_size)
|
||||
|
||||
_cache_service = CacheService(
|
||||
backend=backend, default_ttl=default_ttl, key_prefix="aniworld:"
|
||||
)
|
||||
return _cache_service
|
||||
626
src/server/services/notification_service.py
Normal file
626
src/server/services/notification_service.py
Normal file
@@ -0,0 +1,626 @@
|
||||
"""
|
||||
Notification Service for AniWorld.
|
||||
|
||||
This module provides notification functionality including email, webhooks,
|
||||
and in-app notifications for download events and system alerts.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
import aiohttp
|
||||
from pydantic import BaseModel, EmailStr, Field, HttpUrl
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NotificationType(str, Enum):
|
||||
"""Types of notifications."""
|
||||
|
||||
DOWNLOAD_COMPLETE = "download_complete"
|
||||
DOWNLOAD_FAILED = "download_failed"
|
||||
QUEUE_COMPLETE = "queue_complete"
|
||||
SYSTEM_ERROR = "system_error"
|
||||
SYSTEM_WARNING = "system_warning"
|
||||
SYSTEM_INFO = "system_info"
|
||||
|
||||
|
||||
class NotificationPriority(str, Enum):
|
||||
"""Notification priority levels."""
|
||||
|
||||
LOW = "low"
|
||||
NORMAL = "normal"
|
||||
HIGH = "high"
|
||||
CRITICAL = "critical"
|
||||
|
||||
|
||||
class NotificationChannel(str, Enum):
|
||||
"""Available notification channels."""
|
||||
|
||||
EMAIL = "email"
|
||||
WEBHOOK = "webhook"
|
||||
IN_APP = "in_app"
|
||||
|
||||
|
||||
class NotificationPreferences(BaseModel):
|
||||
"""User notification preferences."""
|
||||
|
||||
enabled_channels: Set[NotificationChannel] = Field(
|
||||
default_factory=lambda: {NotificationChannel.IN_APP}
|
||||
)
|
||||
enabled_types: Set[NotificationType] = Field(
|
||||
default_factory=lambda: set(NotificationType)
|
||||
)
|
||||
email_address: Optional[EmailStr] = None
|
||||
webhook_urls: List[HttpUrl] = Field(default_factory=list)
|
||||
quiet_hours_start: Optional[int] = Field(None, ge=0, le=23)
|
||||
quiet_hours_end: Optional[int] = Field(None, ge=0, le=23)
|
||||
min_priority: NotificationPriority = NotificationPriority.NORMAL
|
||||
|
||||
|
||||
class Notification(BaseModel):
|
||||
"""Notification model."""
|
||||
|
||||
id: str
|
||||
type: NotificationType
|
||||
priority: NotificationPriority
|
||||
title: str
|
||||
message: str
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
read: bool = False
|
||||
channels: Set[NotificationChannel] = Field(
|
||||
default_factory=lambda: {NotificationChannel.IN_APP}
|
||||
)
|
||||
|
||||
|
||||
class EmailNotificationService:
|
||||
"""Service for sending email notifications."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
smtp_host: Optional[str] = None,
|
||||
smtp_port: int = 587,
|
||||
smtp_username: Optional[str] = None,
|
||||
smtp_password: Optional[str] = None,
|
||||
from_address: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize email notification service.
|
||||
|
||||
Args:
|
||||
smtp_host: SMTP server hostname
|
||||
smtp_port: SMTP server port
|
||||
smtp_username: SMTP authentication username
|
||||
smtp_password: SMTP authentication password
|
||||
from_address: Email sender address
|
||||
"""
|
||||
self.smtp_host = smtp_host
|
||||
self.smtp_port = smtp_port
|
||||
self.smtp_username = smtp_username
|
||||
self.smtp_password = smtp_password
|
||||
self.from_address = from_address
|
||||
self._enabled = all(
|
||||
[smtp_host, smtp_username, smtp_password, from_address]
|
||||
)
|
||||
|
||||
async def send_email(
|
||||
self, to_address: str, subject: str, body: str, html: bool = False
|
||||
) -> bool:
|
||||
"""
|
||||
Send an email notification.
|
||||
|
||||
Args:
|
||||
to_address: Recipient email address
|
||||
subject: Email subject
|
||||
body: Email body content
|
||||
html: Whether body is HTML format
|
||||
|
||||
Returns:
|
||||
True if email sent successfully
|
||||
"""
|
||||
if not self._enabled:
|
||||
logger.warning("Email notifications not configured")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Import here to make aiosmtplib optional
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
|
||||
import aiosmtplib
|
||||
|
||||
message = MIMEMultipart("alternative")
|
||||
message["Subject"] = subject
|
||||
message["From"] = self.from_address
|
||||
message["To"] = to_address
|
||||
|
||||
mime_type = "html" if html else "plain"
|
||||
message.attach(MIMEText(body, mime_type))
|
||||
|
||||
await aiosmtplib.send(
|
||||
message,
|
||||
hostname=self.smtp_host,
|
||||
port=self.smtp_port,
|
||||
username=self.smtp_username,
|
||||
password=self.smtp_password,
|
||||
start_tls=True,
|
||||
)
|
||||
|
||||
logger.info(f"Email notification sent to {to_address}")
|
||||
return True
|
||||
|
||||
except ImportError:
|
||||
logger.error(
|
||||
"aiosmtplib not installed. Install with: pip install aiosmtplib"
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send email notification: {e}")
|
||||
return False
|
||||
|
||||
|
||||
class WebhookNotificationService:
|
||||
"""Service for sending webhook notifications."""
|
||||
|
||||
def __init__(self, timeout: int = 10, max_retries: int = 3):
|
||||
"""
|
||||
Initialize webhook notification service.
|
||||
|
||||
Args:
|
||||
timeout: Request timeout in seconds
|
||||
max_retries: Maximum number of retry attempts
|
||||
"""
|
||||
self.timeout = timeout
|
||||
self.max_retries = max_retries
|
||||
|
||||
async def send_webhook(
|
||||
self, url: str, payload: Dict[str, Any], headers: Optional[Dict[str, str]] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Send a webhook notification.
|
||||
|
||||
Args:
|
||||
url: Webhook URL
|
||||
payload: JSON payload to send
|
||||
headers: Optional custom headers
|
||||
|
||||
Returns:
|
||||
True if webhook sent successfully
|
||||
"""
|
||||
if headers is None:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
url,
|
||||
json=payload,
|
||||
headers=headers,
|
||||
timeout=aiohttp.ClientTimeout(total=self.timeout),
|
||||
) as response:
|
||||
if response.status < 400:
|
||||
logger.info(f"Webhook notification sent to {url}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(
|
||||
f"Webhook returned status {response.status}: {url}"
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"Webhook timeout (attempt {attempt + 1}/{self.max_retries}): {url}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send webhook (attempt {attempt + 1}/{self.max_retries}): {e}")
|
||||
|
||||
if attempt < self.max_retries - 1:
|
||||
await asyncio.sleep(2 ** attempt) # Exponential backoff
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class InAppNotificationService:
|
||||
"""Service for managing in-app notifications."""
|
||||
|
||||
def __init__(self, max_notifications: int = 100):
|
||||
"""
|
||||
Initialize in-app notification service.
|
||||
|
||||
Args:
|
||||
max_notifications: Maximum number of notifications to keep
|
||||
"""
|
||||
self.notifications: List[Notification] = []
|
||||
self.max_notifications = max_notifications
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def add_notification(self, notification: Notification) -> None:
|
||||
"""
|
||||
Add a notification to the in-app list.
|
||||
|
||||
Args:
|
||||
notification: Notification to add
|
||||
"""
|
||||
async with self._lock:
|
||||
self.notifications.insert(0, notification)
|
||||
if len(self.notifications) > self.max_notifications:
|
||||
self.notifications = self.notifications[: self.max_notifications]
|
||||
|
||||
async def get_notifications(
|
||||
self, unread_only: bool = False, limit: Optional[int] = None
|
||||
) -> List[Notification]:
|
||||
"""
|
||||
Get in-app notifications.
|
||||
|
||||
Args:
|
||||
unread_only: Only return unread notifications
|
||||
limit: Maximum number of notifications to return
|
||||
|
||||
Returns:
|
||||
List of notifications
|
||||
"""
|
||||
async with self._lock:
|
||||
notifications = self.notifications
|
||||
if unread_only:
|
||||
notifications = [n for n in notifications if not n.read]
|
||||
if limit:
|
||||
notifications = notifications[:limit]
|
||||
return notifications.copy()
|
||||
|
||||
async def mark_as_read(self, notification_id: str) -> bool:
|
||||
"""
|
||||
Mark a notification as read.
|
||||
|
||||
Args:
|
||||
notification_id: ID of notification to mark
|
||||
|
||||
Returns:
|
||||
True if notification was found and marked
|
||||
"""
|
||||
async with self._lock:
|
||||
for notification in self.notifications:
|
||||
if notification.id == notification_id:
|
||||
notification.read = True
|
||||
return True
|
||||
return False
|
||||
|
||||
async def mark_all_as_read(self) -> int:
|
||||
"""
|
||||
Mark all notifications as read.
|
||||
|
||||
Returns:
|
||||
Number of notifications marked as read
|
||||
"""
|
||||
async with self._lock:
|
||||
count = 0
|
||||
for notification in self.notifications:
|
||||
if not notification.read:
|
||||
notification.read = True
|
||||
count += 1
|
||||
return count
|
||||
|
||||
async def clear_notifications(self, read_only: bool = True) -> int:
|
||||
"""
|
||||
Clear notifications.
|
||||
|
||||
Args:
|
||||
read_only: Only clear read notifications
|
||||
|
||||
Returns:
|
||||
Number of notifications cleared
|
||||
"""
|
||||
async with self._lock:
|
||||
if read_only:
|
||||
initial_count = len(self.notifications)
|
||||
self.notifications = [n for n in self.notifications if not n.read]
|
||||
return initial_count - len(self.notifications)
|
||||
else:
|
||||
count = len(self.notifications)
|
||||
self.notifications.clear()
|
||||
return count
|
||||
|
||||
|
||||
class NotificationService:
|
||||
"""Main notification service coordinating all notification channels."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
email_service: Optional[EmailNotificationService] = None,
|
||||
webhook_service: Optional[WebhookNotificationService] = None,
|
||||
in_app_service: Optional[InAppNotificationService] = None,
|
||||
):
|
||||
"""
|
||||
Initialize notification service.
|
||||
|
||||
Args:
|
||||
email_service: Email notification service instance
|
||||
webhook_service: Webhook notification service instance
|
||||
in_app_service: In-app notification service instance
|
||||
"""
|
||||
self.email_service = email_service or EmailNotificationService()
|
||||
self.webhook_service = webhook_service or WebhookNotificationService()
|
||||
self.in_app_service = in_app_service or InAppNotificationService()
|
||||
self.preferences = NotificationPreferences()
|
||||
|
||||
def set_preferences(self, preferences: NotificationPreferences) -> None:
|
||||
"""
|
||||
Update notification preferences.
|
||||
|
||||
Args:
|
||||
preferences: New notification preferences
|
||||
"""
|
||||
self.preferences = preferences
|
||||
|
||||
def _is_in_quiet_hours(self) -> bool:
|
||||
"""
|
||||
Check if current time is within quiet hours.
|
||||
|
||||
Returns:
|
||||
True if in quiet hours
|
||||
"""
|
||||
if (
|
||||
self.preferences.quiet_hours_start is None
|
||||
or self.preferences.quiet_hours_end is None
|
||||
):
|
||||
return False
|
||||
|
||||
current_hour = datetime.now().hour
|
||||
start = self.preferences.quiet_hours_start
|
||||
end = self.preferences.quiet_hours_end
|
||||
|
||||
if start <= end:
|
||||
return start <= current_hour < end
|
||||
else: # Quiet hours span midnight
|
||||
return current_hour >= start or current_hour < end
|
||||
|
||||
def _should_send_notification(
|
||||
self, notification_type: NotificationType, priority: NotificationPriority
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if a notification should be sent based on preferences.
|
||||
|
||||
Args:
|
||||
notification_type: Type of notification
|
||||
priority: Priority level
|
||||
|
||||
Returns:
|
||||
True if notification should be sent
|
||||
"""
|
||||
# Check if type is enabled
|
||||
if notification_type not in self.preferences.enabled_types:
|
||||
return False
|
||||
|
||||
# Check priority level
|
||||
priority_order = [
|
||||
NotificationPriority.LOW,
|
||||
NotificationPriority.NORMAL,
|
||||
NotificationPriority.HIGH,
|
||||
NotificationPriority.CRITICAL,
|
||||
]
|
||||
if (
|
||||
priority_order.index(priority)
|
||||
< priority_order.index(self.preferences.min_priority)
|
||||
):
|
||||
return False
|
||||
|
||||
# Check quiet hours (critical notifications bypass quiet hours)
|
||||
if priority != NotificationPriority.CRITICAL and self._is_in_quiet_hours():
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def send_notification(self, notification: Notification) -> Dict[str, bool]:
|
||||
"""
|
||||
Send a notification through enabled channels.
|
||||
|
||||
Args:
|
||||
notification: Notification to send
|
||||
|
||||
Returns:
|
||||
Dictionary mapping channel names to success status
|
||||
"""
|
||||
if not self._should_send_notification(notification.type, notification.priority):
|
||||
logger.debug(
|
||||
f"Notification not sent due to preferences: {notification.type}"
|
||||
)
|
||||
return {}
|
||||
|
||||
results = {}
|
||||
|
||||
# Send in-app notification
|
||||
if NotificationChannel.IN_APP in self.preferences.enabled_channels:
|
||||
try:
|
||||
await self.in_app_service.add_notification(notification)
|
||||
results["in_app"] = True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send in-app notification: {e}")
|
||||
results["in_app"] = False
|
||||
|
||||
# Send email notification
|
||||
if (
|
||||
NotificationChannel.EMAIL in self.preferences.enabled_channels
|
||||
and self.preferences.email_address
|
||||
):
|
||||
try:
|
||||
success = await self.email_service.send_email(
|
||||
to_address=self.preferences.email_address,
|
||||
subject=f"[{notification.priority.upper()}] {notification.title}",
|
||||
body=notification.message,
|
||||
)
|
||||
results["email"] = success
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send email notification: {e}")
|
||||
results["email"] = False
|
||||
|
||||
# Send webhook notifications
|
||||
if (
|
||||
NotificationChannel.WEBHOOK in self.preferences.enabled_channels
|
||||
and self.preferences.webhook_urls
|
||||
):
|
||||
payload = {
|
||||
"id": notification.id,
|
||||
"type": notification.type,
|
||||
"priority": notification.priority,
|
||||
"title": notification.title,
|
||||
"message": notification.message,
|
||||
"data": notification.data,
|
||||
"created_at": notification.created_at.isoformat(),
|
||||
}
|
||||
|
||||
webhook_results = []
|
||||
for url in self.preferences.webhook_urls:
|
||||
try:
|
||||
success = await self.webhook_service.send_webhook(str(url), payload)
|
||||
webhook_results.append(success)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send webhook notification to {url}: {e}")
|
||||
webhook_results.append(False)
|
||||
|
||||
results["webhook"] = all(webhook_results) if webhook_results else False
|
||||
|
||||
return results
|
||||
|
||||
async def notify_download_complete(
|
||||
self, series_name: str, episode: str, file_path: str
|
||||
) -> Dict[str, bool]:
|
||||
"""
|
||||
Send notification for completed download.
|
||||
|
||||
Args:
|
||||
series_name: Name of the series
|
||||
episode: Episode identifier
|
||||
file_path: Path to downloaded file
|
||||
|
||||
Returns:
|
||||
Dictionary of send results by channel
|
||||
"""
|
||||
notification = Notification(
|
||||
id=f"download_complete_{datetime.utcnow().timestamp()}",
|
||||
type=NotificationType.DOWNLOAD_COMPLETE,
|
||||
priority=NotificationPriority.NORMAL,
|
||||
title=f"Download Complete: {series_name}",
|
||||
message=f"Episode {episode} has been downloaded successfully.",
|
||||
data={
|
||||
"series_name": series_name,
|
||||
"episode": episode,
|
||||
"file_path": file_path,
|
||||
},
|
||||
)
|
||||
return await self.send_notification(notification)
|
||||
|
||||
async def notify_download_failed(
|
||||
self, series_name: str, episode: str, error: str
|
||||
) -> Dict[str, bool]:
|
||||
"""
|
||||
Send notification for failed download.
|
||||
|
||||
Args:
|
||||
series_name: Name of the series
|
||||
episode: Episode identifier
|
||||
error: Error message
|
||||
|
||||
Returns:
|
||||
Dictionary of send results by channel
|
||||
"""
|
||||
notification = Notification(
|
||||
id=f"download_failed_{datetime.utcnow().timestamp()}",
|
||||
type=NotificationType.DOWNLOAD_FAILED,
|
||||
priority=NotificationPriority.HIGH,
|
||||
title=f"Download Failed: {series_name}",
|
||||
message=f"Episode {episode} failed to download: {error}",
|
||||
data={"series_name": series_name, "episode": episode, "error": error},
|
||||
)
|
||||
return await self.send_notification(notification)
|
||||
|
||||
async def notify_queue_complete(self, total_downloads: int) -> Dict[str, bool]:
|
||||
"""
|
||||
Send notification for completed download queue.
|
||||
|
||||
Args:
|
||||
total_downloads: Number of downloads completed
|
||||
|
||||
Returns:
|
||||
Dictionary of send results by channel
|
||||
"""
|
||||
notification = Notification(
|
||||
id=f"queue_complete_{datetime.utcnow().timestamp()}",
|
||||
type=NotificationType.QUEUE_COMPLETE,
|
||||
priority=NotificationPriority.NORMAL,
|
||||
title="Download Queue Complete",
|
||||
message=f"All {total_downloads} downloads have been completed.",
|
||||
data={"total_downloads": total_downloads},
|
||||
)
|
||||
return await self.send_notification(notification)
|
||||
|
||||
async def notify_system_error(self, error: str, details: Optional[Dict[str, Any]] = None) -> Dict[str, bool]:
|
||||
"""
|
||||
Send notification for system error.
|
||||
|
||||
Args:
|
||||
error: Error message
|
||||
details: Optional error details
|
||||
|
||||
Returns:
|
||||
Dictionary of send results by channel
|
||||
"""
|
||||
notification = Notification(
|
||||
id=f"system_error_{datetime.utcnow().timestamp()}",
|
||||
type=NotificationType.SYSTEM_ERROR,
|
||||
priority=NotificationPriority.CRITICAL,
|
||||
title="System Error",
|
||||
message=error,
|
||||
data=details,
|
||||
)
|
||||
return await self.send_notification(notification)
|
||||
|
||||
|
||||
# Global notification service instance
|
||||
_notification_service: Optional[NotificationService] = None
|
||||
|
||||
|
||||
def get_notification_service() -> NotificationService:
|
||||
"""
|
||||
Get the global notification service instance.
|
||||
|
||||
Returns:
|
||||
NotificationService instance
|
||||
"""
|
||||
global _notification_service
|
||||
if _notification_service is None:
|
||||
_notification_service = NotificationService()
|
||||
return _notification_service
|
||||
|
||||
|
||||
def configure_notification_service(
|
||||
smtp_host: Optional[str] = None,
|
||||
smtp_port: int = 587,
|
||||
smtp_username: Optional[str] = None,
|
||||
smtp_password: Optional[str] = None,
|
||||
from_address: Optional[str] = None,
|
||||
) -> NotificationService:
|
||||
"""
|
||||
Configure the global notification service.
|
||||
|
||||
Args:
|
||||
smtp_host: SMTP server hostname
|
||||
smtp_port: SMTP server port
|
||||
smtp_username: SMTP authentication username
|
||||
smtp_password: SMTP authentication password
|
||||
from_address: Email sender address
|
||||
|
||||
Returns:
|
||||
Configured NotificationService instance
|
||||
"""
|
||||
global _notification_service
|
||||
email_service = EmailNotificationService(
|
||||
smtp_host=smtp_host,
|
||||
smtp_port=smtp_port,
|
||||
smtp_username=smtp_username,
|
||||
smtp_password=smtp_password,
|
||||
from_address=from_address,
|
||||
)
|
||||
_notification_service = NotificationService(email_service=email_service)
|
||||
return _notification_service
|
||||
Reference in New Issue
Block a user