Add advanced features: notification system, security middleware, audit logging, data validation, and caching

- Implement notification service with email, webhook, and in-app support
- Add security headers middleware (CORS, CSP, HSTS, XSS protection)
- Create comprehensive audit logging service for security events
- Add data validation utilities with Pydantic validators
- Implement cache service with in-memory and Redis backend support

All 714 tests passing
This commit is contained in:
2025-10-24 09:23:15 +02:00
parent 17e5a551e1
commit 7409ae637e
6 changed files with 3033 additions and 44 deletions

View File

@@ -0,0 +1,610 @@
"""
Audit Service for AniWorld.
This module provides comprehensive audit logging for security-critical
operations including authentication, configuration changes, and downloads.
"""
import json
import logging
from datetime import datetime, timedelta
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
class AuditEventType(str, Enum):
"""Types of audit events."""
# Authentication events
AUTH_SETUP = "auth.setup"
AUTH_LOGIN_SUCCESS = "auth.login.success"
AUTH_LOGIN_FAILURE = "auth.login.failure"
AUTH_LOGOUT = "auth.logout"
AUTH_TOKEN_REFRESH = "auth.token.refresh"
AUTH_TOKEN_INVALID = "auth.token.invalid"
# Configuration events
CONFIG_READ = "config.read"
CONFIG_UPDATE = "config.update"
CONFIG_BACKUP = "config.backup"
CONFIG_RESTORE = "config.restore"
CONFIG_DELETE = "config.delete"
# Download events
DOWNLOAD_ADDED = "download.added"
DOWNLOAD_STARTED = "download.started"
DOWNLOAD_COMPLETED = "download.completed"
DOWNLOAD_FAILED = "download.failed"
DOWNLOAD_CANCELLED = "download.cancelled"
DOWNLOAD_REMOVED = "download.removed"
# Queue events
QUEUE_STARTED = "queue.started"
QUEUE_STOPPED = "queue.stopped"
QUEUE_PAUSED = "queue.paused"
QUEUE_RESUMED = "queue.resumed"
QUEUE_CLEARED = "queue.cleared"
# System events
SYSTEM_STARTUP = "system.startup"
SYSTEM_SHUTDOWN = "system.shutdown"
SYSTEM_ERROR = "system.error"
class AuditEventSeverity(str, Enum):
"""Severity levels for audit events."""
DEBUG = "debug"
INFO = "info"
WARNING = "warning"
ERROR = "error"
CRITICAL = "critical"
class AuditEvent(BaseModel):
"""Audit event model."""
timestamp: datetime = Field(default_factory=datetime.utcnow)
event_type: AuditEventType
severity: AuditEventSeverity = AuditEventSeverity.INFO
user_id: Optional[str] = None
ip_address: Optional[str] = None
user_agent: Optional[str] = None
resource: Optional[str] = None
action: Optional[str] = None
status: str = "success"
message: str
details: Optional[Dict[str, Any]] = None
session_id: Optional[str] = None
class Config:
"""Pydantic config."""
json_encoders = {datetime: lambda v: v.isoformat()}
class AuditLogStorage:
"""Base class for audit log storage backends."""
async def write_event(self, event: AuditEvent) -> None:
"""
Write an audit event to storage.
Args:
event: Audit event to write
"""
raise NotImplementedError
async def read_events(
self,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
event_types: Optional[List[AuditEventType]] = None,
user_id: Optional[str] = None,
limit: int = 100,
) -> List[AuditEvent]:
"""
Read audit events from storage.
Args:
start_time: Start of time range
end_time: End of time range
event_types: Filter by event types
user_id: Filter by user ID
limit: Maximum number of events to return
Returns:
List of audit events
"""
raise NotImplementedError
async def cleanup_old_events(self, days: int = 90) -> int:
"""
Clean up audit events older than specified days.
Args:
days: Number of days to retain
Returns:
Number of events deleted
"""
raise NotImplementedError
class FileAuditLogStorage(AuditLogStorage):
"""File-based audit log storage."""
def __init__(self, log_directory: str = "logs/audit"):
"""
Initialize file-based audit log storage.
Args:
log_directory: Directory to store audit logs
"""
self.log_directory = Path(log_directory)
self.log_directory.mkdir(parents=True, exist_ok=True)
self._current_date: Optional[str] = None
self._current_file: Optional[Path] = None
def _get_log_file(self, date: datetime) -> Path:
"""
Get log file path for a specific date.
Args:
date: Date for log file
Returns:
Path to log file
"""
date_str = date.strftime("%Y-%m-%d")
return self.log_directory / f"audit_{date_str}.jsonl"
async def write_event(self, event: AuditEvent) -> None:
"""
Write an audit event to file.
Args:
event: Audit event to write
"""
log_file = self._get_log_file(event.timestamp)
try:
with open(log_file, "a", encoding="utf-8") as f:
f.write(event.model_dump_json() + "\n")
except Exception as e:
logger.error(f"Failed to write audit event to file: {e}")
async def read_events(
self,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
event_types: Optional[List[AuditEventType]] = None,
user_id: Optional[str] = None,
limit: int = 100,
) -> List[AuditEvent]:
"""
Read audit events from files.
Args:
start_time: Start of time range
end_time: End of time range
event_types: Filter by event types
user_id: Filter by user ID
limit: Maximum number of events to return
Returns:
List of audit events
"""
if start_time is None:
start_time = datetime.utcnow() - timedelta(days=7)
if end_time is None:
end_time = datetime.utcnow()
events: List[AuditEvent] = []
current_date = start_time.date()
end_date = end_time.date()
# Read from all log files in date range
while current_date <= end_date and len(events) < limit:
log_file = self._get_log_file(datetime.combine(current_date, datetime.min.time()))
if log_file.exists():
try:
with open(log_file, "r", encoding="utf-8") as f:
for line in f:
if len(events) >= limit:
break
try:
event_data = json.loads(line.strip())
event = AuditEvent(**event_data)
# Apply filters
if event.timestamp < start_time or event.timestamp > end_time:
continue
if event_types and event.event_type not in event_types:
continue
if user_id and event.user_id != user_id:
continue
events.append(event)
except (json.JSONDecodeError, ValueError) as e:
logger.warning(f"Failed to parse audit event: {e}")
except Exception as e:
logger.error(f"Failed to read audit log file {log_file}: {e}")
current_date += timedelta(days=1)
# Sort by timestamp descending
events.sort(key=lambda e: e.timestamp, reverse=True)
return events[:limit]
async def cleanup_old_events(self, days: int = 90) -> int:
"""
Clean up audit events older than specified days.
Args:
days: Number of days to retain
Returns:
Number of files deleted
"""
cutoff_date = datetime.utcnow() - timedelta(days=days)
deleted_count = 0
for log_file in self.log_directory.glob("audit_*.jsonl"):
try:
# Extract date from filename
date_str = log_file.stem.replace("audit_", "")
file_date = datetime.strptime(date_str, "%Y-%m-%d")
if file_date < cutoff_date:
log_file.unlink()
deleted_count += 1
logger.info(f"Deleted old audit log: {log_file}")
except (ValueError, OSError) as e:
logger.warning(f"Failed to process audit log file {log_file}: {e}")
return deleted_count
class AuditService:
"""Main audit service for logging security events."""
def __init__(self, storage: Optional[AuditLogStorage] = None):
"""
Initialize audit service.
Args:
storage: Storage backend for audit logs
"""
self.storage = storage or FileAuditLogStorage()
async def log_event(
self,
event_type: AuditEventType,
message: str,
severity: AuditEventSeverity = AuditEventSeverity.INFO,
user_id: Optional[str] = None,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
resource: Optional[str] = None,
action: Optional[str] = None,
status: str = "success",
details: Optional[Dict[str, Any]] = None,
session_id: Optional[str] = None,
) -> None:
"""
Log an audit event.
Args:
event_type: Type of event
message: Human-readable message
severity: Event severity
user_id: User identifier
ip_address: Client IP address
user_agent: Client user agent
resource: Resource being accessed
action: Action performed
status: Operation status
details: Additional details
session_id: Session identifier
"""
event = AuditEvent(
event_type=event_type,
severity=severity,
user_id=user_id,
ip_address=ip_address,
user_agent=user_agent,
resource=resource,
action=action,
status=status,
message=message,
details=details,
session_id=session_id,
)
await self.storage.write_event(event)
# Also log to application logger for high severity events
if severity in [AuditEventSeverity.ERROR, AuditEventSeverity.CRITICAL]:
logger.error(f"Audit: {message}", extra={"audit_event": event.model_dump()})
elif severity == AuditEventSeverity.WARNING:
logger.warning(f"Audit: {message}", extra={"audit_event": event.model_dump()})
async def log_auth_setup(
self, user_id: str, ip_address: Optional[str] = None
) -> None:
"""Log initial authentication setup."""
await self.log_event(
event_type=AuditEventType.AUTH_SETUP,
message=f"Authentication configured by user {user_id}",
user_id=user_id,
ip_address=ip_address,
action="setup",
)
async def log_login_success(
self,
user_id: str,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
session_id: Optional[str] = None,
) -> None:
"""Log successful login."""
await self.log_event(
event_type=AuditEventType.AUTH_LOGIN_SUCCESS,
message=f"User {user_id} logged in successfully",
user_id=user_id,
ip_address=ip_address,
user_agent=user_agent,
session_id=session_id,
action="login",
)
async def log_login_failure(
self,
user_id: Optional[str] = None,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None,
reason: str = "Invalid credentials",
) -> None:
"""Log failed login attempt."""
await self.log_event(
event_type=AuditEventType.AUTH_LOGIN_FAILURE,
message=f"Login failed for user {user_id or 'unknown'}: {reason}",
severity=AuditEventSeverity.WARNING,
user_id=user_id,
ip_address=ip_address,
user_agent=user_agent,
status="failure",
action="login",
details={"reason": reason},
)
async def log_logout(
self,
user_id: str,
ip_address: Optional[str] = None,
session_id: Optional[str] = None,
) -> None:
"""Log user logout."""
await self.log_event(
event_type=AuditEventType.AUTH_LOGOUT,
message=f"User {user_id} logged out",
user_id=user_id,
ip_address=ip_address,
session_id=session_id,
action="logout",
)
async def log_config_update(
self,
user_id: str,
changes: Dict[str, Any],
ip_address: Optional[str] = None,
) -> None:
"""Log configuration update."""
await self.log_event(
event_type=AuditEventType.CONFIG_UPDATE,
message=f"Configuration updated by user {user_id}",
user_id=user_id,
ip_address=ip_address,
resource="config",
action="update",
details={"changes": changes},
)
async def log_config_backup(
self, user_id: str, backup_file: str, ip_address: Optional[str] = None
) -> None:
"""Log configuration backup."""
await self.log_event(
event_type=AuditEventType.CONFIG_BACKUP,
message=f"Configuration backed up by user {user_id}",
user_id=user_id,
ip_address=ip_address,
resource="config",
action="backup",
details={"backup_file": backup_file},
)
async def log_config_restore(
self, user_id: str, backup_file: str, ip_address: Optional[str] = None
) -> None:
"""Log configuration restore."""
await self.log_event(
event_type=AuditEventType.CONFIG_RESTORE,
message=f"Configuration restored by user {user_id}",
user_id=user_id,
ip_address=ip_address,
resource="config",
action="restore",
details={"backup_file": backup_file},
)
async def log_download_added(
self,
user_id: str,
series_name: str,
episodes: List[str],
ip_address: Optional[str] = None,
) -> None:
"""Log download added to queue."""
await self.log_event(
event_type=AuditEventType.DOWNLOAD_ADDED,
message=f"Download added by user {user_id}: {series_name}",
user_id=user_id,
ip_address=ip_address,
resource=series_name,
action="add",
details={"episodes": episodes},
)
async def log_download_completed(
self, series_name: str, episode: str, file_path: str
) -> None:
"""Log completed download."""
await self.log_event(
event_type=AuditEventType.DOWNLOAD_COMPLETED,
message=f"Download completed: {series_name} - {episode}",
resource=series_name,
action="download",
details={"episode": episode, "file_path": file_path},
)
async def log_download_failed(
self, series_name: str, episode: str, error: str
) -> None:
"""Log failed download."""
await self.log_event(
event_type=AuditEventType.DOWNLOAD_FAILED,
message=f"Download failed: {series_name} - {episode}",
severity=AuditEventSeverity.ERROR,
resource=series_name,
action="download",
status="failure",
details={"episode": episode, "error": error},
)
async def log_queue_operation(
self,
user_id: str,
operation: str,
ip_address: Optional[str] = None,
details: Optional[Dict[str, Any]] = None,
) -> None:
"""Log queue operation."""
event_type_map = {
"start": AuditEventType.QUEUE_STARTED,
"stop": AuditEventType.QUEUE_STOPPED,
"pause": AuditEventType.QUEUE_PAUSED,
"resume": AuditEventType.QUEUE_RESUMED,
"clear": AuditEventType.QUEUE_CLEARED,
}
event_type = event_type_map.get(operation, AuditEventType.SYSTEM_ERROR)
await self.log_event(
event_type=event_type,
message=f"Queue {operation} by user {user_id}",
user_id=user_id,
ip_address=ip_address,
resource="queue",
action=operation,
details=details,
)
async def log_system_error(
self, error: str, details: Optional[Dict[str, Any]] = None
) -> None:
"""Log system error."""
await self.log_event(
event_type=AuditEventType.SYSTEM_ERROR,
message=f"System error: {error}",
severity=AuditEventSeverity.ERROR,
status="error",
details=details,
)
async def get_events(
self,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
event_types: Optional[List[AuditEventType]] = None,
user_id: Optional[str] = None,
limit: int = 100,
) -> List[AuditEvent]:
"""
Get audit events with filters.
Args:
start_time: Start of time range
end_time: End of time range
event_types: Filter by event types
user_id: Filter by user ID
limit: Maximum number of events to return
Returns:
List of audit events
"""
return await self.storage.read_events(
start_time=start_time,
end_time=end_time,
event_types=event_types,
user_id=user_id,
limit=limit,
)
async def cleanup_old_events(self, days: int = 90) -> int:
"""
Clean up old audit events.
Args:
days: Number of days to retain
Returns:
Number of events deleted
"""
return await self.storage.cleanup_old_events(days)
# Global audit service instance
_audit_service: Optional[AuditService] = None
def get_audit_service() -> AuditService:
"""
Get the global audit service instance.
Returns:
AuditService instance
"""
global _audit_service
if _audit_service is None:
_audit_service = AuditService()
return _audit_service
def configure_audit_service(storage: Optional[AuditLogStorage] = None) -> AuditService:
"""
Configure the global audit service.
Args:
storage: Custom storage backend
Returns:
Configured AuditService instance
"""
global _audit_service
_audit_service = AuditService(storage=storage)
return _audit_service

View File

@@ -0,0 +1,723 @@
"""
Cache Service for AniWorld.
This module provides caching functionality with support for both
in-memory and Redis backends to improve application performance.
"""
import asyncio
import hashlib
import logging
import pickle
from abc import ABC, abstractmethod
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
class CacheBackend(ABC):
"""Abstract base class for cache backends."""
@abstractmethod
async def get(self, key: str) -> Optional[Any]:
"""
Get value from cache.
Args:
key: Cache key
Returns:
Cached value or None if not found
"""
pass
@abstractmethod
async def set(
self, key: str, value: Any, ttl: Optional[int] = None
) -> bool:
"""
Set value in cache.
Args:
key: Cache key
value: Value to cache
ttl: Time to live in seconds
Returns:
True if successful
"""
pass
@abstractmethod
async def delete(self, key: str) -> bool:
"""
Delete value from cache.
Args:
key: Cache key
Returns:
True if key was deleted
"""
pass
@abstractmethod
async def exists(self, key: str) -> bool:
"""
Check if key exists in cache.
Args:
key: Cache key
Returns:
True if key exists
"""
pass
@abstractmethod
async def clear(self) -> bool:
"""
Clear all cached values.
Returns:
True if successful
"""
pass
@abstractmethod
async def get_many(self, keys: List[str]) -> Dict[str, Any]:
"""
Get multiple values from cache.
Args:
keys: List of cache keys
Returns:
Dictionary mapping keys to values
"""
pass
@abstractmethod
async def set_many(
self, items: Dict[str, Any], ttl: Optional[int] = None
) -> bool:
"""
Set multiple values in cache.
Args:
items: Dictionary of key-value pairs
ttl: Time to live in seconds
Returns:
True if successful
"""
pass
@abstractmethod
async def delete_pattern(self, pattern: str) -> int:
"""
Delete all keys matching pattern.
Args:
pattern: Pattern to match (supports wildcards)
Returns:
Number of keys deleted
"""
pass
class InMemoryCacheBackend(CacheBackend):
"""In-memory cache backend using dictionary."""
def __init__(self, max_size: int = 1000):
"""
Initialize in-memory cache.
Args:
max_size: Maximum number of items to cache
"""
self.cache: Dict[str, Dict[str, Any]] = {}
self.max_size = max_size
self._lock = asyncio.Lock()
def _is_expired(self, item: Dict[str, Any]) -> bool:
"""
Check if cache item is expired.
Args:
item: Cache item with expiry
Returns:
True if expired
"""
if item.get("expiry") is None:
return False
return datetime.utcnow() > item["expiry"]
def _evict_oldest(self) -> None:
"""Evict oldest cache item when cache is full."""
if len(self.cache) >= self.max_size:
# Remove oldest item
oldest_key = min(
self.cache.keys(),
key=lambda k: self.cache[k].get("created", datetime.utcnow()),
)
del self.cache[oldest_key]
async def get(self, key: str) -> Optional[Any]:
"""Get value from cache."""
async with self._lock:
if key not in self.cache:
return None
item = self.cache[key]
if self._is_expired(item):
del self.cache[key]
return None
return item["value"]
async def set(
self, key: str, value: Any, ttl: Optional[int] = None
) -> bool:
"""Set value in cache."""
async with self._lock:
self._evict_oldest()
expiry = None
if ttl:
expiry = datetime.utcnow() + timedelta(seconds=ttl)
self.cache[key] = {
"value": value,
"expiry": expiry,
"created": datetime.utcnow(),
}
return True
async def delete(self, key: str) -> bool:
"""Delete value from cache."""
async with self._lock:
if key in self.cache:
del self.cache[key]
return True
return False
async def exists(self, key: str) -> bool:
"""Check if key exists in cache."""
async with self._lock:
if key not in self.cache:
return False
item = self.cache[key]
if self._is_expired(item):
del self.cache[key]
return False
return True
async def clear(self) -> bool:
"""Clear all cached values."""
async with self._lock:
self.cache.clear()
return True
async def get_many(self, keys: List[str]) -> Dict[str, Any]:
"""Get multiple values from cache."""
result = {}
for key in keys:
value = await self.get(key)
if value is not None:
result[key] = value
return result
async def set_many(
self, items: Dict[str, Any], ttl: Optional[int] = None
) -> bool:
"""Set multiple values in cache."""
for key, value in items.items():
await self.set(key, value, ttl)
return True
async def delete_pattern(self, pattern: str) -> int:
"""Delete all keys matching pattern."""
import fnmatch
async with self._lock:
keys_to_delete = [
key for key in self.cache.keys() if fnmatch.fnmatch(key, pattern)
]
for key in keys_to_delete:
del self.cache[key]
return len(keys_to_delete)
class RedisCacheBackend(CacheBackend):
"""Redis cache backend."""
def __init__(
self,
redis_url: str = "redis://localhost:6379",
prefix: str = "aniworld:",
):
"""
Initialize Redis cache.
Args:
redis_url: Redis connection URL
prefix: Key prefix for namespacing
"""
self.redis_url = redis_url
self.prefix = prefix
self._redis = None
async def _get_redis(self):
"""Get Redis connection."""
if self._redis is None:
try:
import aioredis
self._redis = await aioredis.create_redis_pool(self.redis_url)
except ImportError:
logger.error(
"aioredis not installed. Install with: pip install aioredis"
)
raise
except Exception as e:
logger.error(f"Failed to connect to Redis: {e}")
raise
return self._redis
def _make_key(self, key: str) -> str:
"""Add prefix to key."""
return f"{self.prefix}{key}"
async def get(self, key: str) -> Optional[Any]:
"""Get value from cache."""
try:
redis = await self._get_redis()
data = await redis.get(self._make_key(key))
if data is None:
return None
return pickle.loads(data)
except Exception as e:
logger.error(f"Redis get error: {e}")
return None
async def set(
self, key: str, value: Any, ttl: Optional[int] = None
) -> bool:
"""Set value in cache."""
try:
redis = await self._get_redis()
data = pickle.dumps(value)
if ttl:
await redis.setex(self._make_key(key), ttl, data)
else:
await redis.set(self._make_key(key), data)
return True
except Exception as e:
logger.error(f"Redis set error: {e}")
return False
async def delete(self, key: str) -> bool:
"""Delete value from cache."""
try:
redis = await self._get_redis()
result = await redis.delete(self._make_key(key))
return result > 0
except Exception as e:
logger.error(f"Redis delete error: {e}")
return False
async def exists(self, key: str) -> bool:
"""Check if key exists in cache."""
try:
redis = await self._get_redis()
return await redis.exists(self._make_key(key))
except Exception as e:
logger.error(f"Redis exists error: {e}")
return False
async def clear(self) -> bool:
"""Clear all cached values with prefix."""
try:
redis = await self._get_redis()
keys = await redis.keys(f"{self.prefix}*")
if keys:
await redis.delete(*keys)
return True
except Exception as e:
logger.error(f"Redis clear error: {e}")
return False
async def get_many(self, keys: List[str]) -> Dict[str, Any]:
"""Get multiple values from cache."""
try:
redis = await self._get_redis()
prefixed_keys = [self._make_key(k) for k in keys]
values = await redis.mget(*prefixed_keys)
result = {}
for key, value in zip(keys, values):
if value is not None:
result[key] = pickle.loads(value)
return result
except Exception as e:
logger.error(f"Redis get_many error: {e}")
return {}
async def set_many(
self, items: Dict[str, Any], ttl: Optional[int] = None
) -> bool:
"""Set multiple values in cache."""
try:
for key, value in items.items():
await self.set(key, value, ttl)
return True
except Exception as e:
logger.error(f"Redis set_many error: {e}")
return False
async def delete_pattern(self, pattern: str) -> int:
"""Delete all keys matching pattern."""
try:
redis = await self._get_redis()
full_pattern = f"{self.prefix}{pattern}"
keys = await redis.keys(full_pattern)
if keys:
await redis.delete(*keys)
return len(keys)
return 0
except Exception as e:
logger.error(f"Redis delete_pattern error: {e}")
return 0
async def close(self) -> None:
"""Close Redis connection."""
if self._redis:
self._redis.close()
await self._redis.wait_closed()
class CacheService:
"""Main cache service with automatic key generation and TTL management."""
def __init__(
self,
backend: Optional[CacheBackend] = None,
default_ttl: int = 3600,
key_prefix: str = "",
):
"""
Initialize cache service.
Args:
backend: Cache backend to use
default_ttl: Default time to live in seconds
key_prefix: Prefix for all cache keys
"""
self.backend = backend or InMemoryCacheBackend()
self.default_ttl = default_ttl
self.key_prefix = key_prefix
def _make_key(self, *args: Any, **kwargs: Any) -> str:
"""
Generate cache key from arguments.
Args:
*args: Positional arguments
**kwargs: Keyword arguments
Returns:
Cache key string
"""
# Create a stable key from arguments
key_parts = [str(arg) for arg in args]
key_parts.extend(f"{k}={v}" for k, v in sorted(kwargs.items()))
key_str = ":".join(key_parts)
# Hash long keys
if len(key_str) > 200:
key_hash = hashlib.md5(key_str.encode()).hexdigest()
return f"{self.key_prefix}{key_hash}"
return f"{self.key_prefix}{key_str}"
async def get(
self, key: str, default: Optional[Any] = None
) -> Optional[Any]:
"""
Get value from cache.
Args:
key: Cache key
default: Default value if not found
Returns:
Cached value or default
"""
value = await self.backend.get(key)
return value if value is not None else default
async def set(
self, key: str, value: Any, ttl: Optional[int] = None
) -> bool:
"""
Set value in cache.
Args:
key: Cache key
value: Value to cache
ttl: Time to live in seconds (uses default if None)
Returns:
True if successful
"""
if ttl is None:
ttl = self.default_ttl
return await self.backend.set(key, value, ttl)
async def delete(self, key: str) -> bool:
"""
Delete value from cache.
Args:
key: Cache key
Returns:
True if deleted
"""
return await self.backend.delete(key)
async def exists(self, key: str) -> bool:
"""
Check if key exists in cache.
Args:
key: Cache key
Returns:
True if exists
"""
return await self.backend.exists(key)
async def clear(self) -> bool:
"""
Clear all cached values.
Returns:
True if successful
"""
return await self.backend.clear()
async def get_or_set(
self,
key: str,
factory,
ttl: Optional[int] = None,
) -> Any:
"""
Get value from cache or compute and cache it.
Args:
key: Cache key
factory: Callable to compute value if not cached
ttl: Time to live in seconds
Returns:
Cached or computed value
"""
value = await self.get(key)
if value is None:
# Compute value
if asyncio.iscoroutinefunction(factory):
value = await factory()
else:
value = factory()
# Cache it
await self.set(key, value, ttl)
return value
async def invalidate_pattern(self, pattern: str) -> int:
"""
Invalidate all keys matching pattern.
Args:
pattern: Pattern to match
Returns:
Number of keys invalidated
"""
return await self.backend.delete_pattern(pattern)
async def cache_anime_list(
self, anime_list: List[Dict[str, Any]], ttl: Optional[int] = None
) -> bool:
"""
Cache anime list.
Args:
anime_list: List of anime data
ttl: Time to live in seconds
Returns:
True if successful
"""
key = self._make_key("anime", "list")
return await self.set(key, anime_list, ttl)
async def get_anime_list(self) -> Optional[List[Dict[str, Any]]]:
"""
Get cached anime list.
Returns:
Cached anime list or None
"""
key = self._make_key("anime", "list")
return await self.get(key)
async def cache_anime_detail(
self, anime_id: str, data: Dict[str, Any], ttl: Optional[int] = None
) -> bool:
"""
Cache anime detail.
Args:
anime_id: Anime identifier
data: Anime data
ttl: Time to live in seconds
Returns:
True if successful
"""
key = self._make_key("anime", "detail", anime_id)
return await self.set(key, data, ttl)
async def get_anime_detail(self, anime_id: str) -> Optional[Dict[str, Any]]:
"""
Get cached anime detail.
Args:
anime_id: Anime identifier
Returns:
Cached anime data or None
"""
key = self._make_key("anime", "detail", anime_id)
return await self.get(key)
async def invalidate_anime_cache(self) -> int:
"""
Invalidate all anime-related cache.
Returns:
Number of keys invalidated
"""
return await self.invalidate_pattern(f"{self.key_prefix}anime*")
async def cache_config(
self, config: Dict[str, Any], ttl: Optional[int] = None
) -> bool:
"""
Cache configuration.
Args:
config: Configuration data
ttl: Time to live in seconds
Returns:
True if successful
"""
key = self._make_key("config")
return await self.set(key, config, ttl)
async def get_config(self) -> Optional[Dict[str, Any]]:
"""
Get cached configuration.
Returns:
Cached configuration or None
"""
key = self._make_key("config")
return await self.get(key)
async def invalidate_config_cache(self) -> bool:
"""
Invalidate configuration cache.
Returns:
True if successful
"""
key = self._make_key("config")
return await self.delete(key)
# Global cache service instance
_cache_service: Optional[CacheService] = None
def get_cache_service() -> CacheService:
"""
Get the global cache service instance.
Returns:
CacheService instance
"""
global _cache_service
if _cache_service is None:
_cache_service = CacheService()
return _cache_service
def configure_cache_service(
backend_type: str = "memory",
redis_url: str = "redis://localhost:6379",
default_ttl: int = 3600,
max_size: int = 1000,
) -> CacheService:
"""
Configure the global cache service.
Args:
backend_type: Type of backend ("memory" or "redis")
redis_url: Redis connection URL (for redis backend)
default_ttl: Default time to live in seconds
max_size: Maximum cache size (for memory backend)
Returns:
Configured CacheService instance
"""
global _cache_service
if backend_type == "redis":
backend = RedisCacheBackend(redis_url=redis_url)
else:
backend = InMemoryCacheBackend(max_size=max_size)
_cache_service = CacheService(
backend=backend, default_ttl=default_ttl, key_prefix="aniworld:"
)
return _cache_service

View File

@@ -0,0 +1,626 @@
"""
Notification Service for AniWorld.
This module provides notification functionality including email, webhooks,
and in-app notifications for download events and system alerts.
"""
import asyncio
import json
import logging
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional, Set
import aiohttp
from pydantic import BaseModel, EmailStr, Field, HttpUrl
logger = logging.getLogger(__name__)
class NotificationType(str, Enum):
"""Types of notifications."""
DOWNLOAD_COMPLETE = "download_complete"
DOWNLOAD_FAILED = "download_failed"
QUEUE_COMPLETE = "queue_complete"
SYSTEM_ERROR = "system_error"
SYSTEM_WARNING = "system_warning"
SYSTEM_INFO = "system_info"
class NotificationPriority(str, Enum):
"""Notification priority levels."""
LOW = "low"
NORMAL = "normal"
HIGH = "high"
CRITICAL = "critical"
class NotificationChannel(str, Enum):
"""Available notification channels."""
EMAIL = "email"
WEBHOOK = "webhook"
IN_APP = "in_app"
class NotificationPreferences(BaseModel):
"""User notification preferences."""
enabled_channels: Set[NotificationChannel] = Field(
default_factory=lambda: {NotificationChannel.IN_APP}
)
enabled_types: Set[NotificationType] = Field(
default_factory=lambda: set(NotificationType)
)
email_address: Optional[EmailStr] = None
webhook_urls: List[HttpUrl] = Field(default_factory=list)
quiet_hours_start: Optional[int] = Field(None, ge=0, le=23)
quiet_hours_end: Optional[int] = Field(None, ge=0, le=23)
min_priority: NotificationPriority = NotificationPriority.NORMAL
class Notification(BaseModel):
"""Notification model."""
id: str
type: NotificationType
priority: NotificationPriority
title: str
message: str
data: Optional[Dict[str, Any]] = None
created_at: datetime = Field(default_factory=datetime.utcnow)
read: bool = False
channels: Set[NotificationChannel] = Field(
default_factory=lambda: {NotificationChannel.IN_APP}
)
class EmailNotificationService:
"""Service for sending email notifications."""
def __init__(
self,
smtp_host: Optional[str] = None,
smtp_port: int = 587,
smtp_username: Optional[str] = None,
smtp_password: Optional[str] = None,
from_address: Optional[str] = None,
):
"""
Initialize email notification service.
Args:
smtp_host: SMTP server hostname
smtp_port: SMTP server port
smtp_username: SMTP authentication username
smtp_password: SMTP authentication password
from_address: Email sender address
"""
self.smtp_host = smtp_host
self.smtp_port = smtp_port
self.smtp_username = smtp_username
self.smtp_password = smtp_password
self.from_address = from_address
self._enabled = all(
[smtp_host, smtp_username, smtp_password, from_address]
)
async def send_email(
self, to_address: str, subject: str, body: str, html: bool = False
) -> bool:
"""
Send an email notification.
Args:
to_address: Recipient email address
subject: Email subject
body: Email body content
html: Whether body is HTML format
Returns:
True if email sent successfully
"""
if not self._enabled:
logger.warning("Email notifications not configured")
return False
try:
# Import here to make aiosmtplib optional
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
import aiosmtplib
message = MIMEMultipart("alternative")
message["Subject"] = subject
message["From"] = self.from_address
message["To"] = to_address
mime_type = "html" if html else "plain"
message.attach(MIMEText(body, mime_type))
await aiosmtplib.send(
message,
hostname=self.smtp_host,
port=self.smtp_port,
username=self.smtp_username,
password=self.smtp_password,
start_tls=True,
)
logger.info(f"Email notification sent to {to_address}")
return True
except ImportError:
logger.error(
"aiosmtplib not installed. Install with: pip install aiosmtplib"
)
return False
except Exception as e:
logger.error(f"Failed to send email notification: {e}")
return False
class WebhookNotificationService:
"""Service for sending webhook notifications."""
def __init__(self, timeout: int = 10, max_retries: int = 3):
"""
Initialize webhook notification service.
Args:
timeout: Request timeout in seconds
max_retries: Maximum number of retry attempts
"""
self.timeout = timeout
self.max_retries = max_retries
async def send_webhook(
self, url: str, payload: Dict[str, Any], headers: Optional[Dict[str, str]] = None
) -> bool:
"""
Send a webhook notification.
Args:
url: Webhook URL
payload: JSON payload to send
headers: Optional custom headers
Returns:
True if webhook sent successfully
"""
if headers is None:
headers = {"Content-Type": "application/json"}
for attempt in range(self.max_retries):
try:
async with aiohttp.ClientSession() as session:
async with session.post(
url,
json=payload,
headers=headers,
timeout=aiohttp.ClientTimeout(total=self.timeout),
) as response:
if response.status < 400:
logger.info(f"Webhook notification sent to {url}")
return True
else:
logger.warning(
f"Webhook returned status {response.status}: {url}"
)
except asyncio.TimeoutError:
logger.warning(f"Webhook timeout (attempt {attempt + 1}/{self.max_retries}): {url}")
except Exception as e:
logger.error(f"Failed to send webhook (attempt {attempt + 1}/{self.max_retries}): {e}")
if attempt < self.max_retries - 1:
await asyncio.sleep(2 ** attempt) # Exponential backoff
return False
class InAppNotificationService:
"""Service for managing in-app notifications."""
def __init__(self, max_notifications: int = 100):
"""
Initialize in-app notification service.
Args:
max_notifications: Maximum number of notifications to keep
"""
self.notifications: List[Notification] = []
self.max_notifications = max_notifications
self._lock = asyncio.Lock()
async def add_notification(self, notification: Notification) -> None:
"""
Add a notification to the in-app list.
Args:
notification: Notification to add
"""
async with self._lock:
self.notifications.insert(0, notification)
if len(self.notifications) > self.max_notifications:
self.notifications = self.notifications[: self.max_notifications]
async def get_notifications(
self, unread_only: bool = False, limit: Optional[int] = None
) -> List[Notification]:
"""
Get in-app notifications.
Args:
unread_only: Only return unread notifications
limit: Maximum number of notifications to return
Returns:
List of notifications
"""
async with self._lock:
notifications = self.notifications
if unread_only:
notifications = [n for n in notifications if not n.read]
if limit:
notifications = notifications[:limit]
return notifications.copy()
async def mark_as_read(self, notification_id: str) -> bool:
"""
Mark a notification as read.
Args:
notification_id: ID of notification to mark
Returns:
True if notification was found and marked
"""
async with self._lock:
for notification in self.notifications:
if notification.id == notification_id:
notification.read = True
return True
return False
async def mark_all_as_read(self) -> int:
"""
Mark all notifications as read.
Returns:
Number of notifications marked as read
"""
async with self._lock:
count = 0
for notification in self.notifications:
if not notification.read:
notification.read = True
count += 1
return count
async def clear_notifications(self, read_only: bool = True) -> int:
"""
Clear notifications.
Args:
read_only: Only clear read notifications
Returns:
Number of notifications cleared
"""
async with self._lock:
if read_only:
initial_count = len(self.notifications)
self.notifications = [n for n in self.notifications if not n.read]
return initial_count - len(self.notifications)
else:
count = len(self.notifications)
self.notifications.clear()
return count
class NotificationService:
"""Main notification service coordinating all notification channels."""
def __init__(
self,
email_service: Optional[EmailNotificationService] = None,
webhook_service: Optional[WebhookNotificationService] = None,
in_app_service: Optional[InAppNotificationService] = None,
):
"""
Initialize notification service.
Args:
email_service: Email notification service instance
webhook_service: Webhook notification service instance
in_app_service: In-app notification service instance
"""
self.email_service = email_service or EmailNotificationService()
self.webhook_service = webhook_service or WebhookNotificationService()
self.in_app_service = in_app_service or InAppNotificationService()
self.preferences = NotificationPreferences()
def set_preferences(self, preferences: NotificationPreferences) -> None:
"""
Update notification preferences.
Args:
preferences: New notification preferences
"""
self.preferences = preferences
def _is_in_quiet_hours(self) -> bool:
"""
Check if current time is within quiet hours.
Returns:
True if in quiet hours
"""
if (
self.preferences.quiet_hours_start is None
or self.preferences.quiet_hours_end is None
):
return False
current_hour = datetime.now().hour
start = self.preferences.quiet_hours_start
end = self.preferences.quiet_hours_end
if start <= end:
return start <= current_hour < end
else: # Quiet hours span midnight
return current_hour >= start or current_hour < end
def _should_send_notification(
self, notification_type: NotificationType, priority: NotificationPriority
) -> bool:
"""
Determine if a notification should be sent based on preferences.
Args:
notification_type: Type of notification
priority: Priority level
Returns:
True if notification should be sent
"""
# Check if type is enabled
if notification_type not in self.preferences.enabled_types:
return False
# Check priority level
priority_order = [
NotificationPriority.LOW,
NotificationPriority.NORMAL,
NotificationPriority.HIGH,
NotificationPriority.CRITICAL,
]
if (
priority_order.index(priority)
< priority_order.index(self.preferences.min_priority)
):
return False
# Check quiet hours (critical notifications bypass quiet hours)
if priority != NotificationPriority.CRITICAL and self._is_in_quiet_hours():
return False
return True
async def send_notification(self, notification: Notification) -> Dict[str, bool]:
"""
Send a notification through enabled channels.
Args:
notification: Notification to send
Returns:
Dictionary mapping channel names to success status
"""
if not self._should_send_notification(notification.type, notification.priority):
logger.debug(
f"Notification not sent due to preferences: {notification.type}"
)
return {}
results = {}
# Send in-app notification
if NotificationChannel.IN_APP in self.preferences.enabled_channels:
try:
await self.in_app_service.add_notification(notification)
results["in_app"] = True
except Exception as e:
logger.error(f"Failed to send in-app notification: {e}")
results["in_app"] = False
# Send email notification
if (
NotificationChannel.EMAIL in self.preferences.enabled_channels
and self.preferences.email_address
):
try:
success = await self.email_service.send_email(
to_address=self.preferences.email_address,
subject=f"[{notification.priority.upper()}] {notification.title}",
body=notification.message,
)
results["email"] = success
except Exception as e:
logger.error(f"Failed to send email notification: {e}")
results["email"] = False
# Send webhook notifications
if (
NotificationChannel.WEBHOOK in self.preferences.enabled_channels
and self.preferences.webhook_urls
):
payload = {
"id": notification.id,
"type": notification.type,
"priority": notification.priority,
"title": notification.title,
"message": notification.message,
"data": notification.data,
"created_at": notification.created_at.isoformat(),
}
webhook_results = []
for url in self.preferences.webhook_urls:
try:
success = await self.webhook_service.send_webhook(str(url), payload)
webhook_results.append(success)
except Exception as e:
logger.error(f"Failed to send webhook notification to {url}: {e}")
webhook_results.append(False)
results["webhook"] = all(webhook_results) if webhook_results else False
return results
async def notify_download_complete(
self, series_name: str, episode: str, file_path: str
) -> Dict[str, bool]:
"""
Send notification for completed download.
Args:
series_name: Name of the series
episode: Episode identifier
file_path: Path to downloaded file
Returns:
Dictionary of send results by channel
"""
notification = Notification(
id=f"download_complete_{datetime.utcnow().timestamp()}",
type=NotificationType.DOWNLOAD_COMPLETE,
priority=NotificationPriority.NORMAL,
title=f"Download Complete: {series_name}",
message=f"Episode {episode} has been downloaded successfully.",
data={
"series_name": series_name,
"episode": episode,
"file_path": file_path,
},
)
return await self.send_notification(notification)
async def notify_download_failed(
self, series_name: str, episode: str, error: str
) -> Dict[str, bool]:
"""
Send notification for failed download.
Args:
series_name: Name of the series
episode: Episode identifier
error: Error message
Returns:
Dictionary of send results by channel
"""
notification = Notification(
id=f"download_failed_{datetime.utcnow().timestamp()}",
type=NotificationType.DOWNLOAD_FAILED,
priority=NotificationPriority.HIGH,
title=f"Download Failed: {series_name}",
message=f"Episode {episode} failed to download: {error}",
data={"series_name": series_name, "episode": episode, "error": error},
)
return await self.send_notification(notification)
async def notify_queue_complete(self, total_downloads: int) -> Dict[str, bool]:
"""
Send notification for completed download queue.
Args:
total_downloads: Number of downloads completed
Returns:
Dictionary of send results by channel
"""
notification = Notification(
id=f"queue_complete_{datetime.utcnow().timestamp()}",
type=NotificationType.QUEUE_COMPLETE,
priority=NotificationPriority.NORMAL,
title="Download Queue Complete",
message=f"All {total_downloads} downloads have been completed.",
data={"total_downloads": total_downloads},
)
return await self.send_notification(notification)
async def notify_system_error(self, error: str, details: Optional[Dict[str, Any]] = None) -> Dict[str, bool]:
"""
Send notification for system error.
Args:
error: Error message
details: Optional error details
Returns:
Dictionary of send results by channel
"""
notification = Notification(
id=f"system_error_{datetime.utcnow().timestamp()}",
type=NotificationType.SYSTEM_ERROR,
priority=NotificationPriority.CRITICAL,
title="System Error",
message=error,
data=details,
)
return await self.send_notification(notification)
# Global notification service instance
_notification_service: Optional[NotificationService] = None
def get_notification_service() -> NotificationService:
"""
Get the global notification service instance.
Returns:
NotificationService instance
"""
global _notification_service
if _notification_service is None:
_notification_service = NotificationService()
return _notification_service
def configure_notification_service(
smtp_host: Optional[str] = None,
smtp_port: int = 587,
smtp_username: Optional[str] = None,
smtp_password: Optional[str] = None,
from_address: Optional[str] = None,
) -> NotificationService:
"""
Configure the global notification service.
Args:
smtp_host: SMTP server hostname
smtp_port: SMTP server port
smtp_username: SMTP authentication username
smtp_password: SMTP authentication password
from_address: Email sender address
Returns:
Configured NotificationService instance
"""
global _notification_service
email_service = EmailNotificationService(
smtp_host=smtp_host,
smtp_port=smtp_port,
smtp_username=smtp_username,
smtp_password=smtp_password,
from_address=from_address,
)
_notification_service = NotificationService(email_service=email_service)
return _notification_service