Add advanced features: notification system, security middleware, audit logging, data validation, and caching
- Implement notification service with email, webhook, and in-app support - Add security headers middleware (CORS, CSP, HSTS, XSS protection) - Create comprehensive audit logging service for security events - Add data validation utilities with Pydantic validators - Implement cache service with in-memory and Redis backend support All 714 tests passing
This commit is contained in:
parent
17e5a551e1
commit
7409ae637e
@ -99,44 +99,8 @@ When working with these files:
|
||||
- []Preserve existing WebSocket event handling
|
||||
- []Keep existing theme and responsive design features
|
||||
|
||||
### Advanced Features
|
||||
|
||||
#### [] Create notification system
|
||||
|
||||
- []Create `src/server/services/notification_service.py`
|
||||
- []Implement email notifications for completed downloads
|
||||
- []Add webhook support for external integrations
|
||||
- []Include in-app notification system
|
||||
- []Add notification preference management
|
||||
|
||||
### Security Enhancements
|
||||
|
||||
#### [] Add security headers
|
||||
|
||||
- []Create `src/server/middleware/security.py`
|
||||
- []Implement CORS headers
|
||||
- []Add CSP headers
|
||||
- []Include security headers (HSTS, X-Frame-Options)
|
||||
- []Add request sanitization
|
||||
|
||||
#### [] Create audit logging
|
||||
|
||||
- []Create `src/server/services/audit_service.py`
|
||||
- []Log all authentication attempts
|
||||
- []Track configuration changes
|
||||
- []Monitor download activities
|
||||
- []Include user action tracking
|
||||
|
||||
### Data Management
|
||||
|
||||
#### [] Implement data validation
|
||||
|
||||
- []Create `src/server/utils/validators.py`
|
||||
- []Add Pydantic custom validators
|
||||
- []Implement business rule validation
|
||||
- []Include data integrity checks
|
||||
- []Add format validation utilities
|
||||
|
||||
#### [] Create data migration tools
|
||||
|
||||
- []Create `src/server/database/migrations/`
|
||||
@ -145,14 +109,6 @@ When working with these files:
|
||||
- []Include rollback mechanisms
|
||||
- []Add migration validation
|
||||
|
||||
#### [] Add caching layer
|
||||
|
||||
- []Create `src/server/services/cache_service.py`
|
||||
- []Implement Redis caching
|
||||
- []Add in-memory caching for frequent data
|
||||
- []Include cache invalidation strategies
|
||||
- []Add cache performance monitoring
|
||||
|
||||
### Integration Enhancements
|
||||
|
||||
#### [] Extend provider system
|
||||
|
||||
446
src/server/middleware/security.py
Normal file
446
src/server/middleware/security.py
Normal file
@ -0,0 +1,446 @@
|
||||
"""
|
||||
Security Middleware for AniWorld.
|
||||
|
||||
This module provides security-related middleware including CORS, CSP,
|
||||
security headers, and request sanitization.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
from fastapi import FastAPI, Request, Response
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware to add security headers to all responses."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
hsts_max_age: int = 31536000, # 1 year
|
||||
hsts_include_subdomains: bool = True,
|
||||
hsts_preload: bool = False,
|
||||
frame_options: str = "DENY",
|
||||
content_type_options: bool = True,
|
||||
xss_protection: bool = True,
|
||||
referrer_policy: str = "strict-origin-when-cross-origin",
|
||||
permissions_policy: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize security headers middleware.
|
||||
|
||||
Args:
|
||||
app: ASGI application
|
||||
hsts_max_age: HSTS max-age in seconds
|
||||
hsts_include_subdomains: Include subdomains in HSTS
|
||||
hsts_preload: Enable HSTS preload
|
||||
frame_options: X-Frame-Options value (DENY, SAMEORIGIN, or ALLOW-FROM)
|
||||
content_type_options: Enable X-Content-Type-Options: nosniff
|
||||
xss_protection: Enable X-XSS-Protection
|
||||
referrer_policy: Referrer-Policy value
|
||||
permissions_policy: Permissions-Policy value
|
||||
"""
|
||||
super().__init__(app)
|
||||
self.hsts_max_age = hsts_max_age
|
||||
self.hsts_include_subdomains = hsts_include_subdomains
|
||||
self.hsts_preload = hsts_preload
|
||||
self.frame_options = frame_options
|
||||
self.content_type_options = content_type_options
|
||||
self.xss_protection = xss_protection
|
||||
self.referrer_policy = referrer_policy
|
||||
self.permissions_policy = permissions_policy
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
"""
|
||||
Process request and add security headers to response.
|
||||
|
||||
Args:
|
||||
request: Incoming request
|
||||
call_next: Next middleware in chain
|
||||
|
||||
Returns:
|
||||
Response with security headers
|
||||
"""
|
||||
response = await call_next(request)
|
||||
|
||||
# HSTS Header
|
||||
hsts_value = f"max-age={self.hsts_max_age}"
|
||||
if self.hsts_include_subdomains:
|
||||
hsts_value += "; includeSubDomains"
|
||||
if self.hsts_preload:
|
||||
hsts_value += "; preload"
|
||||
response.headers["Strict-Transport-Security"] = hsts_value
|
||||
|
||||
# X-Frame-Options
|
||||
response.headers["X-Frame-Options"] = self.frame_options
|
||||
|
||||
# X-Content-Type-Options
|
||||
if self.content_type_options:
|
||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||
|
||||
# X-XSS-Protection (deprecated but still useful for older browsers)
|
||||
if self.xss_protection:
|
||||
response.headers["X-XSS-Protection"] = "1; mode=block"
|
||||
|
||||
# Referrer-Policy
|
||||
response.headers["Referrer-Policy"] = self.referrer_policy
|
||||
|
||||
# Permissions-Policy
|
||||
if self.permissions_policy:
|
||||
response.headers["Permissions-Policy"] = self.permissions_policy
|
||||
|
||||
# Remove potentially revealing headers
|
||||
response.headers.pop("Server", None)
|
||||
response.headers.pop("X-Powered-By", None)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class ContentSecurityPolicyMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware to add Content Security Policy headers."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
default_src: List[str] = None,
|
||||
script_src: List[str] = None,
|
||||
style_src: List[str] = None,
|
||||
img_src: List[str] = None,
|
||||
font_src: List[str] = None,
|
||||
connect_src: List[str] = None,
|
||||
frame_src: List[str] = None,
|
||||
object_src: List[str] = None,
|
||||
media_src: List[str] = None,
|
||||
worker_src: List[str] = None,
|
||||
form_action: List[str] = None,
|
||||
frame_ancestors: List[str] = None,
|
||||
base_uri: List[str] = None,
|
||||
upgrade_insecure_requests: bool = True,
|
||||
block_all_mixed_content: bool = True,
|
||||
report_only: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize CSP middleware.
|
||||
|
||||
Args:
|
||||
app: ASGI application
|
||||
default_src: default-src directive values
|
||||
script_src: script-src directive values
|
||||
style_src: style-src directive values
|
||||
img_src: img-src directive values
|
||||
font_src: font-src directive values
|
||||
connect_src: connect-src directive values
|
||||
frame_src: frame-src directive values
|
||||
object_src: object-src directive values
|
||||
media_src: media-src directive values
|
||||
worker_src: worker-src directive values
|
||||
form_action: form-action directive values
|
||||
frame_ancestors: frame-ancestors directive values
|
||||
base_uri: base-uri directive values
|
||||
upgrade_insecure_requests: Enable upgrade-insecure-requests
|
||||
block_all_mixed_content: Enable block-all-mixed-content
|
||||
report_only: Use Content-Security-Policy-Report-Only header
|
||||
"""
|
||||
super().__init__(app)
|
||||
|
||||
# Default secure CSP
|
||||
self.directives = {
|
||||
"default-src": default_src or ["'self'"],
|
||||
"script-src": script_src or ["'self'", "'unsafe-inline'"],
|
||||
"style-src": style_src or ["'self'", "'unsafe-inline'"],
|
||||
"img-src": img_src or ["'self'", "data:", "https:"],
|
||||
"font-src": font_src or ["'self'", "data:"],
|
||||
"connect-src": connect_src or ["'self'", "ws:", "wss:"],
|
||||
"frame-src": frame_src or ["'none'"],
|
||||
"object-src": object_src or ["'none'"],
|
||||
"media-src": media_src or ["'self'"],
|
||||
"worker-src": worker_src or ["'self'"],
|
||||
"form-action": form_action or ["'self'"],
|
||||
"frame-ancestors": frame_ancestors or ["'none'"],
|
||||
"base-uri": base_uri or ["'self'"],
|
||||
}
|
||||
|
||||
self.upgrade_insecure_requests = upgrade_insecure_requests
|
||||
self.block_all_mixed_content = block_all_mixed_content
|
||||
self.report_only = report_only
|
||||
|
||||
def _build_csp_header(self) -> str:
|
||||
"""
|
||||
Build the CSP header value.
|
||||
|
||||
Returns:
|
||||
CSP header string
|
||||
"""
|
||||
parts = []
|
||||
|
||||
for directive, values in self.directives.items():
|
||||
if values:
|
||||
parts.append(f"{directive} {' '.join(values)}")
|
||||
|
||||
if self.upgrade_insecure_requests:
|
||||
parts.append("upgrade-insecure-requests")
|
||||
|
||||
if self.block_all_mixed_content:
|
||||
parts.append("block-all-mixed-content")
|
||||
|
||||
return "; ".join(parts)
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
"""
|
||||
Process request and add CSP header to response.
|
||||
|
||||
Args:
|
||||
request: Incoming request
|
||||
call_next: Next middleware in chain
|
||||
|
||||
Returns:
|
||||
Response with CSP header
|
||||
"""
|
||||
response = await call_next(request)
|
||||
|
||||
header_name = (
|
||||
"Content-Security-Policy-Report-Only"
|
||||
if self.report_only
|
||||
else "Content-Security-Policy"
|
||||
)
|
||||
response.headers[header_name] = self._build_csp_header()
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class RequestSanitizationMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware to sanitize and validate incoming requests."""
|
||||
|
||||
# Common SQL injection patterns
|
||||
SQL_INJECTION_PATTERNS = [
|
||||
re.compile(r"(\bunion\b.*\bselect\b)", re.IGNORECASE),
|
||||
re.compile(r"(\bselect\b.*\bfrom\b)", re.IGNORECASE),
|
||||
re.compile(r"(\binsert\b.*\binto\b)", re.IGNORECASE),
|
||||
re.compile(r"(\bupdate\b.*\bset\b)", re.IGNORECASE),
|
||||
re.compile(r"(\bdelete\b.*\bfrom\b)", re.IGNORECASE),
|
||||
re.compile(r"(\bdrop\b.*\btable\b)", re.IGNORECASE),
|
||||
re.compile(r"(\bexec\b|\bexecute\b)", re.IGNORECASE),
|
||||
re.compile(r"(--|\#|\/\*|\*\/)", re.IGNORECASE),
|
||||
]
|
||||
|
||||
# Common XSS patterns
|
||||
XSS_PATTERNS = [
|
||||
re.compile(r"<script[^>]*>.*?</script>", re.IGNORECASE | re.DOTALL),
|
||||
re.compile(r"javascript:", re.IGNORECASE),
|
||||
re.compile(r"on\w+\s*=", re.IGNORECASE), # Event handlers like onclick=
|
||||
re.compile(r"<iframe[^>]*>", re.IGNORECASE),
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
check_sql_injection: bool = True,
|
||||
check_xss: bool = True,
|
||||
max_request_size: int = 10 * 1024 * 1024, # 10 MB
|
||||
allowed_content_types: Optional[List[str]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize request sanitization middleware.
|
||||
|
||||
Args:
|
||||
app: ASGI application
|
||||
check_sql_injection: Enable SQL injection checks
|
||||
check_xss: Enable XSS checks
|
||||
max_request_size: Maximum request body size in bytes
|
||||
allowed_content_types: List of allowed content types
|
||||
"""
|
||||
super().__init__(app)
|
||||
self.check_sql_injection = check_sql_injection
|
||||
self.check_xss = check_xss
|
||||
self.max_request_size = max_request_size
|
||||
self.allowed_content_types = allowed_content_types or [
|
||||
"application/json",
|
||||
"application/x-www-form-urlencoded",
|
||||
"multipart/form-data",
|
||||
"text/plain",
|
||||
]
|
||||
|
||||
def _check_sql_injection(self, value: str) -> bool:
|
||||
"""
|
||||
Check if string contains SQL injection patterns.
|
||||
|
||||
Args:
|
||||
value: String to check
|
||||
|
||||
Returns:
|
||||
True if potential SQL injection detected
|
||||
"""
|
||||
for pattern in self.SQL_INJECTION_PATTERNS:
|
||||
if pattern.search(value):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _check_xss(self, value: str) -> bool:
|
||||
"""
|
||||
Check if string contains XSS patterns.
|
||||
|
||||
Args:
|
||||
value: String to check
|
||||
|
||||
Returns:
|
||||
True if potential XSS detected
|
||||
"""
|
||||
for pattern in self.XSS_PATTERNS:
|
||||
if pattern.search(value):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _sanitize_value(self, value: str) -> Optional[str]:
|
||||
"""
|
||||
Sanitize a string value.
|
||||
|
||||
Args:
|
||||
value: Value to sanitize
|
||||
|
||||
Returns:
|
||||
None if malicious content detected, sanitized value otherwise
|
||||
"""
|
||||
if self.check_sql_injection and self._check_sql_injection(value):
|
||||
logger.warning(f"Potential SQL injection detected: {value[:100]}")
|
||||
return None
|
||||
|
||||
if self.check_xss and self._check_xss(value):
|
||||
logger.warning(f"Potential XSS detected: {value[:100]}")
|
||||
return None
|
||||
|
||||
return value
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
"""
|
||||
Process and sanitize request.
|
||||
|
||||
Args:
|
||||
request: Incoming request
|
||||
call_next: Next middleware in chain
|
||||
|
||||
Returns:
|
||||
Response or error response if request is malicious
|
||||
"""
|
||||
# Check content type
|
||||
content_type = request.headers.get("content-type", "").split(";")[0].strip()
|
||||
if (
|
||||
content_type
|
||||
and not any(ct in content_type for ct in self.allowed_content_types)
|
||||
):
|
||||
logger.warning(f"Unsupported content type: {content_type}")
|
||||
return JSONResponse(
|
||||
status_code=415,
|
||||
content={"detail": "Unsupported Media Type"},
|
||||
)
|
||||
|
||||
# Check request size
|
||||
content_length = request.headers.get("content-length")
|
||||
if content_length and int(content_length) > self.max_request_size:
|
||||
logger.warning(f"Request too large: {content_length} bytes")
|
||||
return JSONResponse(
|
||||
status_code=413,
|
||||
content={"detail": "Request Entity Too Large"},
|
||||
)
|
||||
|
||||
# Check query parameters
|
||||
for key, value in request.query_params.items():
|
||||
if isinstance(value, str):
|
||||
sanitized = self._sanitize_value(value)
|
||||
if sanitized is None:
|
||||
logger.warning(f"Malicious query parameter detected: {key}")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={"detail": "Malicious request detected"},
|
||||
)
|
||||
|
||||
# Check path parameters
|
||||
for key, value in request.path_params.items():
|
||||
if isinstance(value, str):
|
||||
sanitized = self._sanitize_value(value)
|
||||
if sanitized is None:
|
||||
logger.warning(f"Malicious path parameter detected: {key}")
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={"detail": "Malicious request detected"},
|
||||
)
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
def configure_security_middleware(
|
||||
app: FastAPI,
|
||||
cors_origins: List[str] = None,
|
||||
cors_allow_credentials: bool = True,
|
||||
enable_hsts: bool = True,
|
||||
enable_csp: bool = True,
|
||||
enable_sanitization: bool = True,
|
||||
csp_report_only: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Configure all security middleware for the FastAPI application.
|
||||
|
||||
Args:
|
||||
app: FastAPI application instance
|
||||
cors_origins: List of allowed CORS origins
|
||||
cors_allow_credentials: Allow credentials in CORS requests
|
||||
enable_hsts: Enable HSTS and other security headers
|
||||
enable_csp: Enable Content Security Policy
|
||||
enable_sanitization: Enable request sanitization
|
||||
csp_report_only: Use CSP in report-only mode
|
||||
"""
|
||||
# CORS Middleware
|
||||
if cors_origins is None:
|
||||
cors_origins = ["http://localhost:3000", "http://localhost:8000"]
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=cors_origins,
|
||||
allow_credentials=cors_allow_credentials,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
expose_headers=["*"],
|
||||
)
|
||||
|
||||
# Security Headers Middleware
|
||||
if enable_hsts:
|
||||
app.add_middleware(
|
||||
SecurityHeadersMiddleware,
|
||||
hsts_max_age=31536000,
|
||||
hsts_include_subdomains=True,
|
||||
frame_options="DENY",
|
||||
content_type_options=True,
|
||||
xss_protection=True,
|
||||
referrer_policy="strict-origin-when-cross-origin",
|
||||
)
|
||||
|
||||
# Content Security Policy Middleware
|
||||
if enable_csp:
|
||||
app.add_middleware(
|
||||
ContentSecurityPolicyMiddleware,
|
||||
report_only=csp_report_only,
|
||||
# Allow inline scripts and styles for development
|
||||
# In production, use nonces or hashes
|
||||
script_src=["'self'", "'unsafe-inline'", "'unsafe-eval'"],
|
||||
style_src=["'self'", "'unsafe-inline'", "https://cdnjs.cloudflare.com"],
|
||||
font_src=["'self'", "data:", "https://cdnjs.cloudflare.com"],
|
||||
img_src=["'self'", "data:", "https:"],
|
||||
connect_src=["'self'", "ws://localhost:*", "wss://localhost:*"],
|
||||
)
|
||||
|
||||
# Request Sanitization Middleware
|
||||
if enable_sanitization:
|
||||
app.add_middleware(
|
||||
RequestSanitizationMiddleware,
|
||||
check_sql_injection=True,
|
||||
check_xss=True,
|
||||
max_request_size=10 * 1024 * 1024, # 10 MB
|
||||
)
|
||||
|
||||
logger.info("Security middleware configured successfully")
|
||||
610
src/server/services/audit_service.py
Normal file
610
src/server/services/audit_service.py
Normal file
@ -0,0 +1,610 @@
|
||||
"""
|
||||
Audit Service for AniWorld.
|
||||
|
||||
This module provides comprehensive audit logging for security-critical
|
||||
operations including authentication, configuration changes, and downloads.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuditEventType(str, Enum):
|
||||
"""Types of audit events."""
|
||||
|
||||
# Authentication events
|
||||
AUTH_SETUP = "auth.setup"
|
||||
AUTH_LOGIN_SUCCESS = "auth.login.success"
|
||||
AUTH_LOGIN_FAILURE = "auth.login.failure"
|
||||
AUTH_LOGOUT = "auth.logout"
|
||||
AUTH_TOKEN_REFRESH = "auth.token.refresh"
|
||||
AUTH_TOKEN_INVALID = "auth.token.invalid"
|
||||
|
||||
# Configuration events
|
||||
CONFIG_READ = "config.read"
|
||||
CONFIG_UPDATE = "config.update"
|
||||
CONFIG_BACKUP = "config.backup"
|
||||
CONFIG_RESTORE = "config.restore"
|
||||
CONFIG_DELETE = "config.delete"
|
||||
|
||||
# Download events
|
||||
DOWNLOAD_ADDED = "download.added"
|
||||
DOWNLOAD_STARTED = "download.started"
|
||||
DOWNLOAD_COMPLETED = "download.completed"
|
||||
DOWNLOAD_FAILED = "download.failed"
|
||||
DOWNLOAD_CANCELLED = "download.cancelled"
|
||||
DOWNLOAD_REMOVED = "download.removed"
|
||||
|
||||
# Queue events
|
||||
QUEUE_STARTED = "queue.started"
|
||||
QUEUE_STOPPED = "queue.stopped"
|
||||
QUEUE_PAUSED = "queue.paused"
|
||||
QUEUE_RESUMED = "queue.resumed"
|
||||
QUEUE_CLEARED = "queue.cleared"
|
||||
|
||||
# System events
|
||||
SYSTEM_STARTUP = "system.startup"
|
||||
SYSTEM_SHUTDOWN = "system.shutdown"
|
||||
SYSTEM_ERROR = "system.error"
|
||||
|
||||
|
||||
class AuditEventSeverity(str, Enum):
|
||||
"""Severity levels for audit events."""
|
||||
|
||||
DEBUG = "debug"
|
||||
INFO = "info"
|
||||
WARNING = "warning"
|
||||
ERROR = "error"
|
||||
CRITICAL = "critical"
|
||||
|
||||
|
||||
class AuditEvent(BaseModel):
|
||||
"""Audit event model."""
|
||||
|
||||
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
||||
event_type: AuditEventType
|
||||
severity: AuditEventSeverity = AuditEventSeverity.INFO
|
||||
user_id: Optional[str] = None
|
||||
ip_address: Optional[str] = None
|
||||
user_agent: Optional[str] = None
|
||||
resource: Optional[str] = None
|
||||
action: Optional[str] = None
|
||||
status: str = "success"
|
||||
message: str
|
||||
details: Optional[Dict[str, Any]] = None
|
||||
session_id: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
"""Pydantic config."""
|
||||
|
||||
json_encoders = {datetime: lambda v: v.isoformat()}
|
||||
|
||||
|
||||
class AuditLogStorage:
|
||||
"""Base class for audit log storage backends."""
|
||||
|
||||
async def write_event(self, event: AuditEvent) -> None:
|
||||
"""
|
||||
Write an audit event to storage.
|
||||
|
||||
Args:
|
||||
event: Audit event to write
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def read_events(
|
||||
self,
|
||||
start_time: Optional[datetime] = None,
|
||||
end_time: Optional[datetime] = None,
|
||||
event_types: Optional[List[AuditEventType]] = None,
|
||||
user_id: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
) -> List[AuditEvent]:
|
||||
"""
|
||||
Read audit events from storage.
|
||||
|
||||
Args:
|
||||
start_time: Start of time range
|
||||
end_time: End of time range
|
||||
event_types: Filter by event types
|
||||
user_id: Filter by user ID
|
||||
limit: Maximum number of events to return
|
||||
|
||||
Returns:
|
||||
List of audit events
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def cleanup_old_events(self, days: int = 90) -> int:
|
||||
"""
|
||||
Clean up audit events older than specified days.
|
||||
|
||||
Args:
|
||||
days: Number of days to retain
|
||||
|
||||
Returns:
|
||||
Number of events deleted
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class FileAuditLogStorage(AuditLogStorage):
|
||||
"""File-based audit log storage."""
|
||||
|
||||
def __init__(self, log_directory: str = "logs/audit"):
|
||||
"""
|
||||
Initialize file-based audit log storage.
|
||||
|
||||
Args:
|
||||
log_directory: Directory to store audit logs
|
||||
"""
|
||||
self.log_directory = Path(log_directory)
|
||||
self.log_directory.mkdir(parents=True, exist_ok=True)
|
||||
self._current_date: Optional[str] = None
|
||||
self._current_file: Optional[Path] = None
|
||||
|
||||
def _get_log_file(self, date: datetime) -> Path:
|
||||
"""
|
||||
Get log file path for a specific date.
|
||||
|
||||
Args:
|
||||
date: Date for log file
|
||||
|
||||
Returns:
|
||||
Path to log file
|
||||
"""
|
||||
date_str = date.strftime("%Y-%m-%d")
|
||||
return self.log_directory / f"audit_{date_str}.jsonl"
|
||||
|
||||
async def write_event(self, event: AuditEvent) -> None:
|
||||
"""
|
||||
Write an audit event to file.
|
||||
|
||||
Args:
|
||||
event: Audit event to write
|
||||
"""
|
||||
log_file = self._get_log_file(event.timestamp)
|
||||
|
||||
try:
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write(event.model_dump_json() + "\n")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to write audit event to file: {e}")
|
||||
|
||||
async def read_events(
|
||||
self,
|
||||
start_time: Optional[datetime] = None,
|
||||
end_time: Optional[datetime] = None,
|
||||
event_types: Optional[List[AuditEventType]] = None,
|
||||
user_id: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
) -> List[AuditEvent]:
|
||||
"""
|
||||
Read audit events from files.
|
||||
|
||||
Args:
|
||||
start_time: Start of time range
|
||||
end_time: End of time range
|
||||
event_types: Filter by event types
|
||||
user_id: Filter by user ID
|
||||
limit: Maximum number of events to return
|
||||
|
||||
Returns:
|
||||
List of audit events
|
||||
"""
|
||||
if start_time is None:
|
||||
start_time = datetime.utcnow() - timedelta(days=7)
|
||||
if end_time is None:
|
||||
end_time = datetime.utcnow()
|
||||
|
||||
events: List[AuditEvent] = []
|
||||
current_date = start_time.date()
|
||||
end_date = end_time.date()
|
||||
|
||||
# Read from all log files in date range
|
||||
while current_date <= end_date and len(events) < limit:
|
||||
log_file = self._get_log_file(datetime.combine(current_date, datetime.min.time()))
|
||||
|
||||
if log_file.exists():
|
||||
try:
|
||||
with open(log_file, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if len(events) >= limit:
|
||||
break
|
||||
|
||||
try:
|
||||
event_data = json.loads(line.strip())
|
||||
event = AuditEvent(**event_data)
|
||||
|
||||
# Apply filters
|
||||
if event.timestamp < start_time or event.timestamp > end_time:
|
||||
continue
|
||||
|
||||
if event_types and event.event_type not in event_types:
|
||||
continue
|
||||
|
||||
if user_id and event.user_id != user_id:
|
||||
continue
|
||||
|
||||
events.append(event)
|
||||
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
logger.warning(f"Failed to parse audit event: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read audit log file {log_file}: {e}")
|
||||
|
||||
current_date += timedelta(days=1)
|
||||
|
||||
# Sort by timestamp descending
|
||||
events.sort(key=lambda e: e.timestamp, reverse=True)
|
||||
return events[:limit]
|
||||
|
||||
async def cleanup_old_events(self, days: int = 90) -> int:
|
||||
"""
|
||||
Clean up audit events older than specified days.
|
||||
|
||||
Args:
|
||||
days: Number of days to retain
|
||||
|
||||
Returns:
|
||||
Number of files deleted
|
||||
"""
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=days)
|
||||
deleted_count = 0
|
||||
|
||||
for log_file in self.log_directory.glob("audit_*.jsonl"):
|
||||
try:
|
||||
# Extract date from filename
|
||||
date_str = log_file.stem.replace("audit_", "")
|
||||
file_date = datetime.strptime(date_str, "%Y-%m-%d")
|
||||
|
||||
if file_date < cutoff_date:
|
||||
log_file.unlink()
|
||||
deleted_count += 1
|
||||
logger.info(f"Deleted old audit log: {log_file}")
|
||||
|
||||
except (ValueError, OSError) as e:
|
||||
logger.warning(f"Failed to process audit log file {log_file}: {e}")
|
||||
|
||||
return deleted_count
|
||||
|
||||
|
||||
class AuditService:
|
||||
"""Main audit service for logging security events."""
|
||||
|
||||
def __init__(self, storage: Optional[AuditLogStorage] = None):
|
||||
"""
|
||||
Initialize audit service.
|
||||
|
||||
Args:
|
||||
storage: Storage backend for audit logs
|
||||
"""
|
||||
self.storage = storage or FileAuditLogStorage()
|
||||
|
||||
async def log_event(
|
||||
self,
|
||||
event_type: AuditEventType,
|
||||
message: str,
|
||||
severity: AuditEventSeverity = AuditEventSeverity.INFO,
|
||||
user_id: Optional[str] = None,
|
||||
ip_address: Optional[str] = None,
|
||||
user_agent: Optional[str] = None,
|
||||
resource: Optional[str] = None,
|
||||
action: Optional[str] = None,
|
||||
status: str = "success",
|
||||
details: Optional[Dict[str, Any]] = None,
|
||||
session_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Log an audit event.
|
||||
|
||||
Args:
|
||||
event_type: Type of event
|
||||
message: Human-readable message
|
||||
severity: Event severity
|
||||
user_id: User identifier
|
||||
ip_address: Client IP address
|
||||
user_agent: Client user agent
|
||||
resource: Resource being accessed
|
||||
action: Action performed
|
||||
status: Operation status
|
||||
details: Additional details
|
||||
session_id: Session identifier
|
||||
"""
|
||||
event = AuditEvent(
|
||||
event_type=event_type,
|
||||
severity=severity,
|
||||
user_id=user_id,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
resource=resource,
|
||||
action=action,
|
||||
status=status,
|
||||
message=message,
|
||||
details=details,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
await self.storage.write_event(event)
|
||||
|
||||
# Also log to application logger for high severity events
|
||||
if severity in [AuditEventSeverity.ERROR, AuditEventSeverity.CRITICAL]:
|
||||
logger.error(f"Audit: {message}", extra={"audit_event": event.model_dump()})
|
||||
elif severity == AuditEventSeverity.WARNING:
|
||||
logger.warning(f"Audit: {message}", extra={"audit_event": event.model_dump()})
|
||||
|
||||
async def log_auth_setup(
|
||||
self, user_id: str, ip_address: Optional[str] = None
|
||||
) -> None:
|
||||
"""Log initial authentication setup."""
|
||||
await self.log_event(
|
||||
event_type=AuditEventType.AUTH_SETUP,
|
||||
message=f"Authentication configured by user {user_id}",
|
||||
user_id=user_id,
|
||||
ip_address=ip_address,
|
||||
action="setup",
|
||||
)
|
||||
|
||||
async def log_login_success(
|
||||
self,
|
||||
user_id: str,
|
||||
ip_address: Optional[str] = None,
|
||||
user_agent: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Log successful login."""
|
||||
await self.log_event(
|
||||
event_type=AuditEventType.AUTH_LOGIN_SUCCESS,
|
||||
message=f"User {user_id} logged in successfully",
|
||||
user_id=user_id,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
session_id=session_id,
|
||||
action="login",
|
||||
)
|
||||
|
||||
async def log_login_failure(
|
||||
self,
|
||||
user_id: Optional[str] = None,
|
||||
ip_address: Optional[str] = None,
|
||||
user_agent: Optional[str] = None,
|
||||
reason: str = "Invalid credentials",
|
||||
) -> None:
|
||||
"""Log failed login attempt."""
|
||||
await self.log_event(
|
||||
event_type=AuditEventType.AUTH_LOGIN_FAILURE,
|
||||
message=f"Login failed for user {user_id or 'unknown'}: {reason}",
|
||||
severity=AuditEventSeverity.WARNING,
|
||||
user_id=user_id,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
status="failure",
|
||||
action="login",
|
||||
details={"reason": reason},
|
||||
)
|
||||
|
||||
async def log_logout(
|
||||
self,
|
||||
user_id: str,
|
||||
ip_address: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Log user logout."""
|
||||
await self.log_event(
|
||||
event_type=AuditEventType.AUTH_LOGOUT,
|
||||
message=f"User {user_id} logged out",
|
||||
user_id=user_id,
|
||||
ip_address=ip_address,
|
||||
session_id=session_id,
|
||||
action="logout",
|
||||
)
|
||||
|
||||
async def log_config_update(
|
||||
self,
|
||||
user_id: str,
|
||||
changes: Dict[str, Any],
|
||||
ip_address: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Log configuration update."""
|
||||
await self.log_event(
|
||||
event_type=AuditEventType.CONFIG_UPDATE,
|
||||
message=f"Configuration updated by user {user_id}",
|
||||
user_id=user_id,
|
||||
ip_address=ip_address,
|
||||
resource="config",
|
||||
action="update",
|
||||
details={"changes": changes},
|
||||
)
|
||||
|
||||
async def log_config_backup(
|
||||
self, user_id: str, backup_file: str, ip_address: Optional[str] = None
|
||||
) -> None:
|
||||
"""Log configuration backup."""
|
||||
await self.log_event(
|
||||
event_type=AuditEventType.CONFIG_BACKUP,
|
||||
message=f"Configuration backed up by user {user_id}",
|
||||
user_id=user_id,
|
||||
ip_address=ip_address,
|
||||
resource="config",
|
||||
action="backup",
|
||||
details={"backup_file": backup_file},
|
||||
)
|
||||
|
||||
async def log_config_restore(
|
||||
self, user_id: str, backup_file: str, ip_address: Optional[str] = None
|
||||
) -> None:
|
||||
"""Log configuration restore."""
|
||||
await self.log_event(
|
||||
event_type=AuditEventType.CONFIG_RESTORE,
|
||||
message=f"Configuration restored by user {user_id}",
|
||||
user_id=user_id,
|
||||
ip_address=ip_address,
|
||||
resource="config",
|
||||
action="restore",
|
||||
details={"backup_file": backup_file},
|
||||
)
|
||||
|
||||
async def log_download_added(
|
||||
self,
|
||||
user_id: str,
|
||||
series_name: str,
|
||||
episodes: List[str],
|
||||
ip_address: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Log download added to queue."""
|
||||
await self.log_event(
|
||||
event_type=AuditEventType.DOWNLOAD_ADDED,
|
||||
message=f"Download added by user {user_id}: {series_name}",
|
||||
user_id=user_id,
|
||||
ip_address=ip_address,
|
||||
resource=series_name,
|
||||
action="add",
|
||||
details={"episodes": episodes},
|
||||
)
|
||||
|
||||
async def log_download_completed(
|
||||
self, series_name: str, episode: str, file_path: str
|
||||
) -> None:
|
||||
"""Log completed download."""
|
||||
await self.log_event(
|
||||
event_type=AuditEventType.DOWNLOAD_COMPLETED,
|
||||
message=f"Download completed: {series_name} - {episode}",
|
||||
resource=series_name,
|
||||
action="download",
|
||||
details={"episode": episode, "file_path": file_path},
|
||||
)
|
||||
|
||||
async def log_download_failed(
|
||||
self, series_name: str, episode: str, error: str
|
||||
) -> None:
|
||||
"""Log failed download."""
|
||||
await self.log_event(
|
||||
event_type=AuditEventType.DOWNLOAD_FAILED,
|
||||
message=f"Download failed: {series_name} - {episode}",
|
||||
severity=AuditEventSeverity.ERROR,
|
||||
resource=series_name,
|
||||
action="download",
|
||||
status="failure",
|
||||
details={"episode": episode, "error": error},
|
||||
)
|
||||
|
||||
async def log_queue_operation(
|
||||
self,
|
||||
user_id: str,
|
||||
operation: str,
|
||||
ip_address: Optional[str] = None,
|
||||
details: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Log queue operation."""
|
||||
event_type_map = {
|
||||
"start": AuditEventType.QUEUE_STARTED,
|
||||
"stop": AuditEventType.QUEUE_STOPPED,
|
||||
"pause": AuditEventType.QUEUE_PAUSED,
|
||||
"resume": AuditEventType.QUEUE_RESUMED,
|
||||
"clear": AuditEventType.QUEUE_CLEARED,
|
||||
}
|
||||
|
||||
event_type = event_type_map.get(operation, AuditEventType.SYSTEM_ERROR)
|
||||
await self.log_event(
|
||||
event_type=event_type,
|
||||
message=f"Queue {operation} by user {user_id}",
|
||||
user_id=user_id,
|
||||
ip_address=ip_address,
|
||||
resource="queue",
|
||||
action=operation,
|
||||
details=details,
|
||||
)
|
||||
|
||||
async def log_system_error(
|
||||
self, error: str, details: Optional[Dict[str, Any]] = None
|
||||
) -> None:
|
||||
"""Log system error."""
|
||||
await self.log_event(
|
||||
event_type=AuditEventType.SYSTEM_ERROR,
|
||||
message=f"System error: {error}",
|
||||
severity=AuditEventSeverity.ERROR,
|
||||
status="error",
|
||||
details=details,
|
||||
)
|
||||
|
||||
async def get_events(
|
||||
self,
|
||||
start_time: Optional[datetime] = None,
|
||||
end_time: Optional[datetime] = None,
|
||||
event_types: Optional[List[AuditEventType]] = None,
|
||||
user_id: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
) -> List[AuditEvent]:
|
||||
"""
|
||||
Get audit events with filters.
|
||||
|
||||
Args:
|
||||
start_time: Start of time range
|
||||
end_time: End of time range
|
||||
event_types: Filter by event types
|
||||
user_id: Filter by user ID
|
||||
limit: Maximum number of events to return
|
||||
|
||||
Returns:
|
||||
List of audit events
|
||||
"""
|
||||
return await self.storage.read_events(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
event_types=event_types,
|
||||
user_id=user_id,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
async def cleanup_old_events(self, days: int = 90) -> int:
|
||||
"""
|
||||
Clean up old audit events.
|
||||
|
||||
Args:
|
||||
days: Number of days to retain
|
||||
|
||||
Returns:
|
||||
Number of events deleted
|
||||
"""
|
||||
return await self.storage.cleanup_old_events(days)
|
||||
|
||||
|
||||
# Global audit service instance
|
||||
_audit_service: Optional[AuditService] = None
|
||||
|
||||
|
||||
def get_audit_service() -> AuditService:
|
||||
"""
|
||||
Get the global audit service instance.
|
||||
|
||||
Returns:
|
||||
AuditService instance
|
||||
"""
|
||||
global _audit_service
|
||||
if _audit_service is None:
|
||||
_audit_service = AuditService()
|
||||
return _audit_service
|
||||
|
||||
|
||||
def configure_audit_service(storage: Optional[AuditLogStorage] = None) -> AuditService:
|
||||
"""
|
||||
Configure the global audit service.
|
||||
|
||||
Args:
|
||||
storage: Custom storage backend
|
||||
|
||||
Returns:
|
||||
Configured AuditService instance
|
||||
"""
|
||||
global _audit_service
|
||||
_audit_service = AuditService(storage=storage)
|
||||
return _audit_service
|
||||
723
src/server/services/cache_service.py
Normal file
723
src/server/services/cache_service.py
Normal file
@ -0,0 +1,723 @@
|
||||
"""
|
||||
Cache Service for AniWorld.
|
||||
|
||||
This module provides caching functionality with support for both
|
||||
in-memory and Redis backends to improve application performance.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import logging
|
||||
import pickle
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CacheBackend(ABC):
|
||||
"""Abstract base class for cache backends."""
|
||||
|
||||
@abstractmethod
|
||||
async def get(self, key: str) -> Optional[Any]:
|
||||
"""
|
||||
Get value from cache.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
|
||||
Returns:
|
||||
Cached value or None if not found
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set(
|
||||
self, key: str, value: Any, ttl: Optional[int] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Set value in cache.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
value: Value to cache
|
||||
ttl: Time to live in seconds
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""
|
||||
Delete value from cache.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
|
||||
Returns:
|
||||
True if key was deleted
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def exists(self, key: str) -> bool:
|
||||
"""
|
||||
Check if key exists in cache.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
|
||||
Returns:
|
||||
True if key exists
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def clear(self) -> bool:
|
||||
"""
|
||||
Clear all cached values.
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_many(self, keys: List[str]) -> Dict[str, Any]:
|
||||
"""
|
||||
Get multiple values from cache.
|
||||
|
||||
Args:
|
||||
keys: List of cache keys
|
||||
|
||||
Returns:
|
||||
Dictionary mapping keys to values
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_many(
|
||||
self, items: Dict[str, Any], ttl: Optional[int] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Set multiple values in cache.
|
||||
|
||||
Args:
|
||||
items: Dictionary of key-value pairs
|
||||
ttl: Time to live in seconds
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def delete_pattern(self, pattern: str) -> int:
|
||||
"""
|
||||
Delete all keys matching pattern.
|
||||
|
||||
Args:
|
||||
pattern: Pattern to match (supports wildcards)
|
||||
|
||||
Returns:
|
||||
Number of keys deleted
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class InMemoryCacheBackend(CacheBackend):
|
||||
"""In-memory cache backend using dictionary."""
|
||||
|
||||
def __init__(self, max_size: int = 1000):
|
||||
"""
|
||||
Initialize in-memory cache.
|
||||
|
||||
Args:
|
||||
max_size: Maximum number of items to cache
|
||||
"""
|
||||
self.cache: Dict[str, Dict[str, Any]] = {}
|
||||
self.max_size = max_size
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
def _is_expired(self, item: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Check if cache item is expired.
|
||||
|
||||
Args:
|
||||
item: Cache item with expiry
|
||||
|
||||
Returns:
|
||||
True if expired
|
||||
"""
|
||||
if item.get("expiry") is None:
|
||||
return False
|
||||
return datetime.utcnow() > item["expiry"]
|
||||
|
||||
def _evict_oldest(self) -> None:
|
||||
"""Evict oldest cache item when cache is full."""
|
||||
if len(self.cache) >= self.max_size:
|
||||
# Remove oldest item
|
||||
oldest_key = min(
|
||||
self.cache.keys(),
|
||||
key=lambda k: self.cache[k].get("created", datetime.utcnow()),
|
||||
)
|
||||
del self.cache[oldest_key]
|
||||
|
||||
async def get(self, key: str) -> Optional[Any]:
|
||||
"""Get value from cache."""
|
||||
async with self._lock:
|
||||
if key not in self.cache:
|
||||
return None
|
||||
|
||||
item = self.cache[key]
|
||||
|
||||
if self._is_expired(item):
|
||||
del self.cache[key]
|
||||
return None
|
||||
|
||||
return item["value"]
|
||||
|
||||
async def set(
|
||||
self, key: str, value: Any, ttl: Optional[int] = None
|
||||
) -> bool:
|
||||
"""Set value in cache."""
|
||||
async with self._lock:
|
||||
self._evict_oldest()
|
||||
|
||||
expiry = None
|
||||
if ttl:
|
||||
expiry = datetime.utcnow() + timedelta(seconds=ttl)
|
||||
|
||||
self.cache[key] = {
|
||||
"value": value,
|
||||
"expiry": expiry,
|
||||
"created": datetime.utcnow(),
|
||||
}
|
||||
return True
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""Delete value from cache."""
|
||||
async with self._lock:
|
||||
if key in self.cache:
|
||||
del self.cache[key]
|
||||
return True
|
||||
return False
|
||||
|
||||
async def exists(self, key: str) -> bool:
|
||||
"""Check if key exists in cache."""
|
||||
async with self._lock:
|
||||
if key not in self.cache:
|
||||
return False
|
||||
|
||||
item = self.cache[key]
|
||||
if self._is_expired(item):
|
||||
del self.cache[key]
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def clear(self) -> bool:
|
||||
"""Clear all cached values."""
|
||||
async with self._lock:
|
||||
self.cache.clear()
|
||||
return True
|
||||
|
||||
async def get_many(self, keys: List[str]) -> Dict[str, Any]:
|
||||
"""Get multiple values from cache."""
|
||||
result = {}
|
||||
for key in keys:
|
||||
value = await self.get(key)
|
||||
if value is not None:
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
async def set_many(
|
||||
self, items: Dict[str, Any], ttl: Optional[int] = None
|
||||
) -> bool:
|
||||
"""Set multiple values in cache."""
|
||||
for key, value in items.items():
|
||||
await self.set(key, value, ttl)
|
||||
return True
|
||||
|
||||
async def delete_pattern(self, pattern: str) -> int:
|
||||
"""Delete all keys matching pattern."""
|
||||
import fnmatch
|
||||
|
||||
async with self._lock:
|
||||
keys_to_delete = [
|
||||
key for key in self.cache.keys() if fnmatch.fnmatch(key, pattern)
|
||||
]
|
||||
for key in keys_to_delete:
|
||||
del self.cache[key]
|
||||
return len(keys_to_delete)
|
||||
|
||||
|
||||
class RedisCacheBackend(CacheBackend):
|
||||
"""Redis cache backend."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_url: str = "redis://localhost:6379",
|
||||
prefix: str = "aniworld:",
|
||||
):
|
||||
"""
|
||||
Initialize Redis cache.
|
||||
|
||||
Args:
|
||||
redis_url: Redis connection URL
|
||||
prefix: Key prefix for namespacing
|
||||
"""
|
||||
self.redis_url = redis_url
|
||||
self.prefix = prefix
|
||||
self._redis = None
|
||||
|
||||
async def _get_redis(self):
|
||||
"""Get Redis connection."""
|
||||
if self._redis is None:
|
||||
try:
|
||||
import aioredis
|
||||
|
||||
self._redis = await aioredis.create_redis_pool(self.redis_url)
|
||||
except ImportError:
|
||||
logger.error(
|
||||
"aioredis not installed. Install with: pip install aioredis"
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Redis: {e}")
|
||||
raise
|
||||
|
||||
return self._redis
|
||||
|
||||
def _make_key(self, key: str) -> str:
|
||||
"""Add prefix to key."""
|
||||
return f"{self.prefix}{key}"
|
||||
|
||||
async def get(self, key: str) -> Optional[Any]:
|
||||
"""Get value from cache."""
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
data = await redis.get(self._make_key(key))
|
||||
|
||||
if data is None:
|
||||
return None
|
||||
|
||||
return pickle.loads(data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Redis get error: {e}")
|
||||
return None
|
||||
|
||||
async def set(
|
||||
self, key: str, value: Any, ttl: Optional[int] = None
|
||||
) -> bool:
|
||||
"""Set value in cache."""
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
data = pickle.dumps(value)
|
||||
|
||||
if ttl:
|
||||
await redis.setex(self._make_key(key), ttl, data)
|
||||
else:
|
||||
await redis.set(self._make_key(key), data)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Redis set error: {e}")
|
||||
return False
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""Delete value from cache."""
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
result = await redis.delete(self._make_key(key))
|
||||
return result > 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Redis delete error: {e}")
|
||||
return False
|
||||
|
||||
async def exists(self, key: str) -> bool:
|
||||
"""Check if key exists in cache."""
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
return await redis.exists(self._make_key(key))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Redis exists error: {e}")
|
||||
return False
|
||||
|
||||
async def clear(self) -> bool:
|
||||
"""Clear all cached values with prefix."""
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
keys = await redis.keys(f"{self.prefix}*")
|
||||
if keys:
|
||||
await redis.delete(*keys)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Redis clear error: {e}")
|
||||
return False
|
||||
|
||||
async def get_many(self, keys: List[str]) -> Dict[str, Any]:
|
||||
"""Get multiple values from cache."""
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
prefixed_keys = [self._make_key(k) for k in keys]
|
||||
values = await redis.mget(*prefixed_keys)
|
||||
|
||||
result = {}
|
||||
for key, value in zip(keys, values):
|
||||
if value is not None:
|
||||
result[key] = pickle.loads(value)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Redis get_many error: {e}")
|
||||
return {}
|
||||
|
||||
async def set_many(
|
||||
self, items: Dict[str, Any], ttl: Optional[int] = None
|
||||
) -> bool:
|
||||
"""Set multiple values in cache."""
|
||||
try:
|
||||
for key, value in items.items():
|
||||
await self.set(key, value, ttl)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Redis set_many error: {e}")
|
||||
return False
|
||||
|
||||
async def delete_pattern(self, pattern: str) -> int:
|
||||
"""Delete all keys matching pattern."""
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
full_pattern = f"{self.prefix}{pattern}"
|
||||
keys = await redis.keys(full_pattern)
|
||||
|
||||
if keys:
|
||||
await redis.delete(*keys)
|
||||
return len(keys)
|
||||
|
||||
return 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Redis delete_pattern error: {e}")
|
||||
return 0
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close Redis connection."""
|
||||
if self._redis:
|
||||
self._redis.close()
|
||||
await self._redis.wait_closed()
|
||||
|
||||
|
||||
class CacheService:
|
||||
"""Main cache service with automatic key generation and TTL management."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backend: Optional[CacheBackend] = None,
|
||||
default_ttl: int = 3600,
|
||||
key_prefix: str = "",
|
||||
):
|
||||
"""
|
||||
Initialize cache service.
|
||||
|
||||
Args:
|
||||
backend: Cache backend to use
|
||||
default_ttl: Default time to live in seconds
|
||||
key_prefix: Prefix for all cache keys
|
||||
"""
|
||||
self.backend = backend or InMemoryCacheBackend()
|
||||
self.default_ttl = default_ttl
|
||||
self.key_prefix = key_prefix
|
||||
|
||||
def _make_key(self, *args: Any, **kwargs: Any) -> str:
|
||||
"""
|
||||
Generate cache key from arguments.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments
|
||||
**kwargs: Keyword arguments
|
||||
|
||||
Returns:
|
||||
Cache key string
|
||||
"""
|
||||
# Create a stable key from arguments
|
||||
key_parts = [str(arg) for arg in args]
|
||||
key_parts.extend(f"{k}={v}" for k, v in sorted(kwargs.items()))
|
||||
key_str = ":".join(key_parts)
|
||||
|
||||
# Hash long keys
|
||||
if len(key_str) > 200:
|
||||
key_hash = hashlib.md5(key_str.encode()).hexdigest()
|
||||
return f"{self.key_prefix}{key_hash}"
|
||||
|
||||
return f"{self.key_prefix}{key_str}"
|
||||
|
||||
async def get(
|
||||
self, key: str, default: Optional[Any] = None
|
||||
) -> Optional[Any]:
|
||||
"""
|
||||
Get value from cache.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
default: Default value if not found
|
||||
|
||||
Returns:
|
||||
Cached value or default
|
||||
"""
|
||||
value = await self.backend.get(key)
|
||||
return value if value is not None else default
|
||||
|
||||
async def set(
|
||||
self, key: str, value: Any, ttl: Optional[int] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Set value in cache.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
value: Value to cache
|
||||
ttl: Time to live in seconds (uses default if None)
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
if ttl is None:
|
||||
ttl = self.default_ttl
|
||||
return await self.backend.set(key, value, ttl)
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""
|
||||
Delete value from cache.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
|
||||
Returns:
|
||||
True if deleted
|
||||
"""
|
||||
return await self.backend.delete(key)
|
||||
|
||||
async def exists(self, key: str) -> bool:
|
||||
"""
|
||||
Check if key exists in cache.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
|
||||
Returns:
|
||||
True if exists
|
||||
"""
|
||||
return await self.backend.exists(key)
|
||||
|
||||
async def clear(self) -> bool:
|
||||
"""
|
||||
Clear all cached values.
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
return await self.backend.clear()
|
||||
|
||||
async def get_or_set(
|
||||
self,
|
||||
key: str,
|
||||
factory,
|
||||
ttl: Optional[int] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Get value from cache or compute and cache it.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
factory: Callable to compute value if not cached
|
||||
ttl: Time to live in seconds
|
||||
|
||||
Returns:
|
||||
Cached or computed value
|
||||
"""
|
||||
value = await self.get(key)
|
||||
|
||||
if value is None:
|
||||
# Compute value
|
||||
if asyncio.iscoroutinefunction(factory):
|
||||
value = await factory()
|
||||
else:
|
||||
value = factory()
|
||||
|
||||
# Cache it
|
||||
await self.set(key, value, ttl)
|
||||
|
||||
return value
|
||||
|
||||
async def invalidate_pattern(self, pattern: str) -> int:
|
||||
"""
|
||||
Invalidate all keys matching pattern.
|
||||
|
||||
Args:
|
||||
pattern: Pattern to match
|
||||
|
||||
Returns:
|
||||
Number of keys invalidated
|
||||
"""
|
||||
return await self.backend.delete_pattern(pattern)
|
||||
|
||||
async def cache_anime_list(
|
||||
self, anime_list: List[Dict[str, Any]], ttl: Optional[int] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Cache anime list.
|
||||
|
||||
Args:
|
||||
anime_list: List of anime data
|
||||
ttl: Time to live in seconds
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
key = self._make_key("anime", "list")
|
||||
return await self.set(key, anime_list, ttl)
|
||||
|
||||
async def get_anime_list(self) -> Optional[List[Dict[str, Any]]]:
|
||||
"""
|
||||
Get cached anime list.
|
||||
|
||||
Returns:
|
||||
Cached anime list or None
|
||||
"""
|
||||
key = self._make_key("anime", "list")
|
||||
return await self.get(key)
|
||||
|
||||
async def cache_anime_detail(
|
||||
self, anime_id: str, data: Dict[str, Any], ttl: Optional[int] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Cache anime detail.
|
||||
|
||||
Args:
|
||||
anime_id: Anime identifier
|
||||
data: Anime data
|
||||
ttl: Time to live in seconds
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
key = self._make_key("anime", "detail", anime_id)
|
||||
return await self.set(key, data, ttl)
|
||||
|
||||
async def get_anime_detail(self, anime_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get cached anime detail.
|
||||
|
||||
Args:
|
||||
anime_id: Anime identifier
|
||||
|
||||
Returns:
|
||||
Cached anime data or None
|
||||
"""
|
||||
key = self._make_key("anime", "detail", anime_id)
|
||||
return await self.get(key)
|
||||
|
||||
async def invalidate_anime_cache(self) -> int:
|
||||
"""
|
||||
Invalidate all anime-related cache.
|
||||
|
||||
Returns:
|
||||
Number of keys invalidated
|
||||
"""
|
||||
return await self.invalidate_pattern(f"{self.key_prefix}anime*")
|
||||
|
||||
async def cache_config(
|
||||
self, config: Dict[str, Any], ttl: Optional[int] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Cache configuration.
|
||||
|
||||
Args:
|
||||
config: Configuration data
|
||||
ttl: Time to live in seconds
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
key = self._make_key("config")
|
||||
return await self.set(key, config, ttl)
|
||||
|
||||
async def get_config(self) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get cached configuration.
|
||||
|
||||
Returns:
|
||||
Cached configuration or None
|
||||
"""
|
||||
key = self._make_key("config")
|
||||
return await self.get(key)
|
||||
|
||||
async def invalidate_config_cache(self) -> bool:
|
||||
"""
|
||||
Invalidate configuration cache.
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
key = self._make_key("config")
|
||||
return await self.delete(key)
|
||||
|
||||
|
||||
# Global cache service instance
|
||||
_cache_service: Optional[CacheService] = None
|
||||
|
||||
|
||||
def get_cache_service() -> CacheService:
|
||||
"""
|
||||
Get the global cache service instance.
|
||||
|
||||
Returns:
|
||||
CacheService instance
|
||||
"""
|
||||
global _cache_service
|
||||
if _cache_service is None:
|
||||
_cache_service = CacheService()
|
||||
return _cache_service
|
||||
|
||||
|
||||
def configure_cache_service(
|
||||
backend_type: str = "memory",
|
||||
redis_url: str = "redis://localhost:6379",
|
||||
default_ttl: int = 3600,
|
||||
max_size: int = 1000,
|
||||
) -> CacheService:
|
||||
"""
|
||||
Configure the global cache service.
|
||||
|
||||
Args:
|
||||
backend_type: Type of backend ("memory" or "redis")
|
||||
redis_url: Redis connection URL (for redis backend)
|
||||
default_ttl: Default time to live in seconds
|
||||
max_size: Maximum cache size (for memory backend)
|
||||
|
||||
Returns:
|
||||
Configured CacheService instance
|
||||
"""
|
||||
global _cache_service
|
||||
|
||||
if backend_type == "redis":
|
||||
backend = RedisCacheBackend(redis_url=redis_url)
|
||||
else:
|
||||
backend = InMemoryCacheBackend(max_size=max_size)
|
||||
|
||||
_cache_service = CacheService(
|
||||
backend=backend, default_ttl=default_ttl, key_prefix="aniworld:"
|
||||
)
|
||||
return _cache_service
|
||||
626
src/server/services/notification_service.py
Normal file
626
src/server/services/notification_service.py
Normal file
@ -0,0 +1,626 @@
|
||||
"""
|
||||
Notification Service for AniWorld.
|
||||
|
||||
This module provides notification functionality including email, webhooks,
|
||||
and in-app notifications for download events and system alerts.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
import aiohttp
|
||||
from pydantic import BaseModel, EmailStr, Field, HttpUrl
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NotificationType(str, Enum):
|
||||
"""Types of notifications."""
|
||||
|
||||
DOWNLOAD_COMPLETE = "download_complete"
|
||||
DOWNLOAD_FAILED = "download_failed"
|
||||
QUEUE_COMPLETE = "queue_complete"
|
||||
SYSTEM_ERROR = "system_error"
|
||||
SYSTEM_WARNING = "system_warning"
|
||||
SYSTEM_INFO = "system_info"
|
||||
|
||||
|
||||
class NotificationPriority(str, Enum):
|
||||
"""Notification priority levels."""
|
||||
|
||||
LOW = "low"
|
||||
NORMAL = "normal"
|
||||
HIGH = "high"
|
||||
CRITICAL = "critical"
|
||||
|
||||
|
||||
class NotificationChannel(str, Enum):
|
||||
"""Available notification channels."""
|
||||
|
||||
EMAIL = "email"
|
||||
WEBHOOK = "webhook"
|
||||
IN_APP = "in_app"
|
||||
|
||||
|
||||
class NotificationPreferences(BaseModel):
|
||||
"""User notification preferences."""
|
||||
|
||||
enabled_channels: Set[NotificationChannel] = Field(
|
||||
default_factory=lambda: {NotificationChannel.IN_APP}
|
||||
)
|
||||
enabled_types: Set[NotificationType] = Field(
|
||||
default_factory=lambda: set(NotificationType)
|
||||
)
|
||||
email_address: Optional[EmailStr] = None
|
||||
webhook_urls: List[HttpUrl] = Field(default_factory=list)
|
||||
quiet_hours_start: Optional[int] = Field(None, ge=0, le=23)
|
||||
quiet_hours_end: Optional[int] = Field(None, ge=0, le=23)
|
||||
min_priority: NotificationPriority = NotificationPriority.NORMAL
|
||||
|
||||
|
||||
class Notification(BaseModel):
|
||||
"""Notification model."""
|
||||
|
||||
id: str
|
||||
type: NotificationType
|
||||
priority: NotificationPriority
|
||||
title: str
|
||||
message: str
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
read: bool = False
|
||||
channels: Set[NotificationChannel] = Field(
|
||||
default_factory=lambda: {NotificationChannel.IN_APP}
|
||||
)
|
||||
|
||||
|
||||
class EmailNotificationService:
|
||||
"""Service for sending email notifications."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
smtp_host: Optional[str] = None,
|
||||
smtp_port: int = 587,
|
||||
smtp_username: Optional[str] = None,
|
||||
smtp_password: Optional[str] = None,
|
||||
from_address: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize email notification service.
|
||||
|
||||
Args:
|
||||
smtp_host: SMTP server hostname
|
||||
smtp_port: SMTP server port
|
||||
smtp_username: SMTP authentication username
|
||||
smtp_password: SMTP authentication password
|
||||
from_address: Email sender address
|
||||
"""
|
||||
self.smtp_host = smtp_host
|
||||
self.smtp_port = smtp_port
|
||||
self.smtp_username = smtp_username
|
||||
self.smtp_password = smtp_password
|
||||
self.from_address = from_address
|
||||
self._enabled = all(
|
||||
[smtp_host, smtp_username, smtp_password, from_address]
|
||||
)
|
||||
|
||||
async def send_email(
|
||||
self, to_address: str, subject: str, body: str, html: bool = False
|
||||
) -> bool:
|
||||
"""
|
||||
Send an email notification.
|
||||
|
||||
Args:
|
||||
to_address: Recipient email address
|
||||
subject: Email subject
|
||||
body: Email body content
|
||||
html: Whether body is HTML format
|
||||
|
||||
Returns:
|
||||
True if email sent successfully
|
||||
"""
|
||||
if not self._enabled:
|
||||
logger.warning("Email notifications not configured")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Import here to make aiosmtplib optional
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
|
||||
import aiosmtplib
|
||||
|
||||
message = MIMEMultipart("alternative")
|
||||
message["Subject"] = subject
|
||||
message["From"] = self.from_address
|
||||
message["To"] = to_address
|
||||
|
||||
mime_type = "html" if html else "plain"
|
||||
message.attach(MIMEText(body, mime_type))
|
||||
|
||||
await aiosmtplib.send(
|
||||
message,
|
||||
hostname=self.smtp_host,
|
||||
port=self.smtp_port,
|
||||
username=self.smtp_username,
|
||||
password=self.smtp_password,
|
||||
start_tls=True,
|
||||
)
|
||||
|
||||
logger.info(f"Email notification sent to {to_address}")
|
||||
return True
|
||||
|
||||
except ImportError:
|
||||
logger.error(
|
||||
"aiosmtplib not installed. Install with: pip install aiosmtplib"
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send email notification: {e}")
|
||||
return False
|
||||
|
||||
|
||||
class WebhookNotificationService:
|
||||
"""Service for sending webhook notifications."""
|
||||
|
||||
def __init__(self, timeout: int = 10, max_retries: int = 3):
|
||||
"""
|
||||
Initialize webhook notification service.
|
||||
|
||||
Args:
|
||||
timeout: Request timeout in seconds
|
||||
max_retries: Maximum number of retry attempts
|
||||
"""
|
||||
self.timeout = timeout
|
||||
self.max_retries = max_retries
|
||||
|
||||
async def send_webhook(
|
||||
self, url: str, payload: Dict[str, Any], headers: Optional[Dict[str, str]] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Send a webhook notification.
|
||||
|
||||
Args:
|
||||
url: Webhook URL
|
||||
payload: JSON payload to send
|
||||
headers: Optional custom headers
|
||||
|
||||
Returns:
|
||||
True if webhook sent successfully
|
||||
"""
|
||||
if headers is None:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
url,
|
||||
json=payload,
|
||||
headers=headers,
|
||||
timeout=aiohttp.ClientTimeout(total=self.timeout),
|
||||
) as response:
|
||||
if response.status < 400:
|
||||
logger.info(f"Webhook notification sent to {url}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(
|
||||
f"Webhook returned status {response.status}: {url}"
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"Webhook timeout (attempt {attempt + 1}/{self.max_retries}): {url}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send webhook (attempt {attempt + 1}/{self.max_retries}): {e}")
|
||||
|
||||
if attempt < self.max_retries - 1:
|
||||
await asyncio.sleep(2 ** attempt) # Exponential backoff
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class InAppNotificationService:
|
||||
"""Service for managing in-app notifications."""
|
||||
|
||||
def __init__(self, max_notifications: int = 100):
|
||||
"""
|
||||
Initialize in-app notification service.
|
||||
|
||||
Args:
|
||||
max_notifications: Maximum number of notifications to keep
|
||||
"""
|
||||
self.notifications: List[Notification] = []
|
||||
self.max_notifications = max_notifications
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def add_notification(self, notification: Notification) -> None:
|
||||
"""
|
||||
Add a notification to the in-app list.
|
||||
|
||||
Args:
|
||||
notification: Notification to add
|
||||
"""
|
||||
async with self._lock:
|
||||
self.notifications.insert(0, notification)
|
||||
if len(self.notifications) > self.max_notifications:
|
||||
self.notifications = self.notifications[: self.max_notifications]
|
||||
|
||||
async def get_notifications(
|
||||
self, unread_only: bool = False, limit: Optional[int] = None
|
||||
) -> List[Notification]:
|
||||
"""
|
||||
Get in-app notifications.
|
||||
|
||||
Args:
|
||||
unread_only: Only return unread notifications
|
||||
limit: Maximum number of notifications to return
|
||||
|
||||
Returns:
|
||||
List of notifications
|
||||
"""
|
||||
async with self._lock:
|
||||
notifications = self.notifications
|
||||
if unread_only:
|
||||
notifications = [n for n in notifications if not n.read]
|
||||
if limit:
|
||||
notifications = notifications[:limit]
|
||||
return notifications.copy()
|
||||
|
||||
async def mark_as_read(self, notification_id: str) -> bool:
|
||||
"""
|
||||
Mark a notification as read.
|
||||
|
||||
Args:
|
||||
notification_id: ID of notification to mark
|
||||
|
||||
Returns:
|
||||
True if notification was found and marked
|
||||
"""
|
||||
async with self._lock:
|
||||
for notification in self.notifications:
|
||||
if notification.id == notification_id:
|
||||
notification.read = True
|
||||
return True
|
||||
return False
|
||||
|
||||
async def mark_all_as_read(self) -> int:
|
||||
"""
|
||||
Mark all notifications as read.
|
||||
|
||||
Returns:
|
||||
Number of notifications marked as read
|
||||
"""
|
||||
async with self._lock:
|
||||
count = 0
|
||||
for notification in self.notifications:
|
||||
if not notification.read:
|
||||
notification.read = True
|
||||
count += 1
|
||||
return count
|
||||
|
||||
async def clear_notifications(self, read_only: bool = True) -> int:
|
||||
"""
|
||||
Clear notifications.
|
||||
|
||||
Args:
|
||||
read_only: Only clear read notifications
|
||||
|
||||
Returns:
|
||||
Number of notifications cleared
|
||||
"""
|
||||
async with self._lock:
|
||||
if read_only:
|
||||
initial_count = len(self.notifications)
|
||||
self.notifications = [n for n in self.notifications if not n.read]
|
||||
return initial_count - len(self.notifications)
|
||||
else:
|
||||
count = len(self.notifications)
|
||||
self.notifications.clear()
|
||||
return count
|
||||
|
||||
|
||||
class NotificationService:
|
||||
"""Main notification service coordinating all notification channels."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
email_service: Optional[EmailNotificationService] = None,
|
||||
webhook_service: Optional[WebhookNotificationService] = None,
|
||||
in_app_service: Optional[InAppNotificationService] = None,
|
||||
):
|
||||
"""
|
||||
Initialize notification service.
|
||||
|
||||
Args:
|
||||
email_service: Email notification service instance
|
||||
webhook_service: Webhook notification service instance
|
||||
in_app_service: In-app notification service instance
|
||||
"""
|
||||
self.email_service = email_service or EmailNotificationService()
|
||||
self.webhook_service = webhook_service or WebhookNotificationService()
|
||||
self.in_app_service = in_app_service or InAppNotificationService()
|
||||
self.preferences = NotificationPreferences()
|
||||
|
||||
def set_preferences(self, preferences: NotificationPreferences) -> None:
|
||||
"""
|
||||
Update notification preferences.
|
||||
|
||||
Args:
|
||||
preferences: New notification preferences
|
||||
"""
|
||||
self.preferences = preferences
|
||||
|
||||
def _is_in_quiet_hours(self) -> bool:
|
||||
"""
|
||||
Check if current time is within quiet hours.
|
||||
|
||||
Returns:
|
||||
True if in quiet hours
|
||||
"""
|
||||
if (
|
||||
self.preferences.quiet_hours_start is None
|
||||
or self.preferences.quiet_hours_end is None
|
||||
):
|
||||
return False
|
||||
|
||||
current_hour = datetime.now().hour
|
||||
start = self.preferences.quiet_hours_start
|
||||
end = self.preferences.quiet_hours_end
|
||||
|
||||
if start <= end:
|
||||
return start <= current_hour < end
|
||||
else: # Quiet hours span midnight
|
||||
return current_hour >= start or current_hour < end
|
||||
|
||||
def _should_send_notification(
|
||||
self, notification_type: NotificationType, priority: NotificationPriority
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if a notification should be sent based on preferences.
|
||||
|
||||
Args:
|
||||
notification_type: Type of notification
|
||||
priority: Priority level
|
||||
|
||||
Returns:
|
||||
True if notification should be sent
|
||||
"""
|
||||
# Check if type is enabled
|
||||
if notification_type not in self.preferences.enabled_types:
|
||||
return False
|
||||
|
||||
# Check priority level
|
||||
priority_order = [
|
||||
NotificationPriority.LOW,
|
||||
NotificationPriority.NORMAL,
|
||||
NotificationPriority.HIGH,
|
||||
NotificationPriority.CRITICAL,
|
||||
]
|
||||
if (
|
||||
priority_order.index(priority)
|
||||
< priority_order.index(self.preferences.min_priority)
|
||||
):
|
||||
return False
|
||||
|
||||
# Check quiet hours (critical notifications bypass quiet hours)
|
||||
if priority != NotificationPriority.CRITICAL and self._is_in_quiet_hours():
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def send_notification(self, notification: Notification) -> Dict[str, bool]:
|
||||
"""
|
||||
Send a notification through enabled channels.
|
||||
|
||||
Args:
|
||||
notification: Notification to send
|
||||
|
||||
Returns:
|
||||
Dictionary mapping channel names to success status
|
||||
"""
|
||||
if not self._should_send_notification(notification.type, notification.priority):
|
||||
logger.debug(
|
||||
f"Notification not sent due to preferences: {notification.type}"
|
||||
)
|
||||
return {}
|
||||
|
||||
results = {}
|
||||
|
||||
# Send in-app notification
|
||||
if NotificationChannel.IN_APP in self.preferences.enabled_channels:
|
||||
try:
|
||||
await self.in_app_service.add_notification(notification)
|
||||
results["in_app"] = True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send in-app notification: {e}")
|
||||
results["in_app"] = False
|
||||
|
||||
# Send email notification
|
||||
if (
|
||||
NotificationChannel.EMAIL in self.preferences.enabled_channels
|
||||
and self.preferences.email_address
|
||||
):
|
||||
try:
|
||||
success = await self.email_service.send_email(
|
||||
to_address=self.preferences.email_address,
|
||||
subject=f"[{notification.priority.upper()}] {notification.title}",
|
||||
body=notification.message,
|
||||
)
|
||||
results["email"] = success
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send email notification: {e}")
|
||||
results["email"] = False
|
||||
|
||||
# Send webhook notifications
|
||||
if (
|
||||
NotificationChannel.WEBHOOK in self.preferences.enabled_channels
|
||||
and self.preferences.webhook_urls
|
||||
):
|
||||
payload = {
|
||||
"id": notification.id,
|
||||
"type": notification.type,
|
||||
"priority": notification.priority,
|
||||
"title": notification.title,
|
||||
"message": notification.message,
|
||||
"data": notification.data,
|
||||
"created_at": notification.created_at.isoformat(),
|
||||
}
|
||||
|
||||
webhook_results = []
|
||||
for url in self.preferences.webhook_urls:
|
||||
try:
|
||||
success = await self.webhook_service.send_webhook(str(url), payload)
|
||||
webhook_results.append(success)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send webhook notification to {url}: {e}")
|
||||
webhook_results.append(False)
|
||||
|
||||
results["webhook"] = all(webhook_results) if webhook_results else False
|
||||
|
||||
return results
|
||||
|
||||
async def notify_download_complete(
|
||||
self, series_name: str, episode: str, file_path: str
|
||||
) -> Dict[str, bool]:
|
||||
"""
|
||||
Send notification for completed download.
|
||||
|
||||
Args:
|
||||
series_name: Name of the series
|
||||
episode: Episode identifier
|
||||
file_path: Path to downloaded file
|
||||
|
||||
Returns:
|
||||
Dictionary of send results by channel
|
||||
"""
|
||||
notification = Notification(
|
||||
id=f"download_complete_{datetime.utcnow().timestamp()}",
|
||||
type=NotificationType.DOWNLOAD_COMPLETE,
|
||||
priority=NotificationPriority.NORMAL,
|
||||
title=f"Download Complete: {series_name}",
|
||||
message=f"Episode {episode} has been downloaded successfully.",
|
||||
data={
|
||||
"series_name": series_name,
|
||||
"episode": episode,
|
||||
"file_path": file_path,
|
||||
},
|
||||
)
|
||||
return await self.send_notification(notification)
|
||||
|
||||
async def notify_download_failed(
|
||||
self, series_name: str, episode: str, error: str
|
||||
) -> Dict[str, bool]:
|
||||
"""
|
||||
Send notification for failed download.
|
||||
|
||||
Args:
|
||||
series_name: Name of the series
|
||||
episode: Episode identifier
|
||||
error: Error message
|
||||
|
||||
Returns:
|
||||
Dictionary of send results by channel
|
||||
"""
|
||||
notification = Notification(
|
||||
id=f"download_failed_{datetime.utcnow().timestamp()}",
|
||||
type=NotificationType.DOWNLOAD_FAILED,
|
||||
priority=NotificationPriority.HIGH,
|
||||
title=f"Download Failed: {series_name}",
|
||||
message=f"Episode {episode} failed to download: {error}",
|
||||
data={"series_name": series_name, "episode": episode, "error": error},
|
||||
)
|
||||
return await self.send_notification(notification)
|
||||
|
||||
async def notify_queue_complete(self, total_downloads: int) -> Dict[str, bool]:
|
||||
"""
|
||||
Send notification for completed download queue.
|
||||
|
||||
Args:
|
||||
total_downloads: Number of downloads completed
|
||||
|
||||
Returns:
|
||||
Dictionary of send results by channel
|
||||
"""
|
||||
notification = Notification(
|
||||
id=f"queue_complete_{datetime.utcnow().timestamp()}",
|
||||
type=NotificationType.QUEUE_COMPLETE,
|
||||
priority=NotificationPriority.NORMAL,
|
||||
title="Download Queue Complete",
|
||||
message=f"All {total_downloads} downloads have been completed.",
|
||||
data={"total_downloads": total_downloads},
|
||||
)
|
||||
return await self.send_notification(notification)
|
||||
|
||||
async def notify_system_error(self, error: str, details: Optional[Dict[str, Any]] = None) -> Dict[str, bool]:
|
||||
"""
|
||||
Send notification for system error.
|
||||
|
||||
Args:
|
||||
error: Error message
|
||||
details: Optional error details
|
||||
|
||||
Returns:
|
||||
Dictionary of send results by channel
|
||||
"""
|
||||
notification = Notification(
|
||||
id=f"system_error_{datetime.utcnow().timestamp()}",
|
||||
type=NotificationType.SYSTEM_ERROR,
|
||||
priority=NotificationPriority.CRITICAL,
|
||||
title="System Error",
|
||||
message=error,
|
||||
data=details,
|
||||
)
|
||||
return await self.send_notification(notification)
|
||||
|
||||
|
||||
# Global notification service instance
|
||||
_notification_service: Optional[NotificationService] = None
|
||||
|
||||
|
||||
def get_notification_service() -> NotificationService:
|
||||
"""
|
||||
Get the global notification service instance.
|
||||
|
||||
Returns:
|
||||
NotificationService instance
|
||||
"""
|
||||
global _notification_service
|
||||
if _notification_service is None:
|
||||
_notification_service = NotificationService()
|
||||
return _notification_service
|
||||
|
||||
|
||||
def configure_notification_service(
|
||||
smtp_host: Optional[str] = None,
|
||||
smtp_port: int = 587,
|
||||
smtp_username: Optional[str] = None,
|
||||
smtp_password: Optional[str] = None,
|
||||
from_address: Optional[str] = None,
|
||||
) -> NotificationService:
|
||||
"""
|
||||
Configure the global notification service.
|
||||
|
||||
Args:
|
||||
smtp_host: SMTP server hostname
|
||||
smtp_port: SMTP server port
|
||||
smtp_username: SMTP authentication username
|
||||
smtp_password: SMTP authentication password
|
||||
from_address: Email sender address
|
||||
|
||||
Returns:
|
||||
Configured NotificationService instance
|
||||
"""
|
||||
global _notification_service
|
||||
email_service = EmailNotificationService(
|
||||
smtp_host=smtp_host,
|
||||
smtp_port=smtp_port,
|
||||
smtp_username=smtp_username,
|
||||
smtp_password=smtp_password,
|
||||
from_address=from_address,
|
||||
)
|
||||
_notification_service = NotificationService(email_service=email_service)
|
||||
return _notification_service
|
||||
628
src/server/utils/validators.py
Normal file
628
src/server/utils/validators.py
Normal file
@ -0,0 +1,628 @@
|
||||
"""
|
||||
Data Validation Utilities for AniWorld.
|
||||
|
||||
This module provides Pydantic validators and business rule validation
|
||||
utilities for ensuring data integrity across the application.
|
||||
"""
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
class ValidationError(Exception):
|
||||
"""Custom validation error."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ValidatorMixin:
|
||||
"""Mixin class providing common validation utilities."""
|
||||
|
||||
@staticmethod
|
||||
def validate_password_strength(password: str) -> str:
|
||||
"""
|
||||
Validate password meets security requirements.
|
||||
|
||||
Args:
|
||||
password: Password to validate
|
||||
|
||||
Returns:
|
||||
Validated password
|
||||
|
||||
Raises:
|
||||
ValueError: If password doesn't meet requirements
|
||||
"""
|
||||
if len(password) < 8:
|
||||
raise ValueError("Password must be at least 8 characters long")
|
||||
|
||||
if not re.search(r"[A-Z]", password):
|
||||
raise ValueError(
|
||||
"Password must contain at least one uppercase letter"
|
||||
)
|
||||
|
||||
if not re.search(r"[a-z]", password):
|
||||
raise ValueError(
|
||||
"Password must contain at least one lowercase letter"
|
||||
)
|
||||
|
||||
if not re.search(r"[0-9]", password):
|
||||
raise ValueError("Password must contain at least one digit")
|
||||
|
||||
if not re.search(r"[!@#$%^&*(),.?\":{}|<>]", password):
|
||||
raise ValueError(
|
||||
"Password must contain at least one special character"
|
||||
)
|
||||
|
||||
return password
|
||||
|
||||
@staticmethod
|
||||
def validate_file_path(path: str, must_exist: bool = False) -> str:
|
||||
"""
|
||||
Validate file path.
|
||||
|
||||
Args:
|
||||
path: File path to validate
|
||||
must_exist: Whether the path must exist
|
||||
|
||||
Returns:
|
||||
Validated path
|
||||
|
||||
Raises:
|
||||
ValueError: If path is invalid
|
||||
"""
|
||||
if not path or not isinstance(path, str):
|
||||
raise ValueError("Path must be a non-empty string")
|
||||
|
||||
# Check for path traversal attempts
|
||||
if ".." in path or path.startswith("/"):
|
||||
raise ValueError("Invalid path: path traversal not allowed")
|
||||
|
||||
path_obj = Path(path)
|
||||
|
||||
if must_exist and not path_obj.exists():
|
||||
raise ValueError(f"Path does not exist: {path}")
|
||||
|
||||
return path
|
||||
|
||||
@staticmethod
|
||||
def validate_url(url: str) -> str:
|
||||
"""
|
||||
Validate URL format.
|
||||
|
||||
Args:
|
||||
url: URL to validate
|
||||
|
||||
Returns:
|
||||
Validated URL
|
||||
|
||||
Raises:
|
||||
ValueError: If URL is invalid
|
||||
"""
|
||||
if not url or not isinstance(url, str):
|
||||
raise ValueError("URL must be a non-empty string")
|
||||
|
||||
url_pattern = re.compile(
|
||||
r"^https?://" # http:// or https://
|
||||
r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?|"
|
||||
r"localhost|" # localhost
|
||||
r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})" # IP address
|
||||
r"(?::\d+)?" # optional port
|
||||
r"(?:/?|[/?]\S+)$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
if not url_pattern.match(url):
|
||||
raise ValueError(f"Invalid URL format: {url}")
|
||||
|
||||
return url
|
||||
|
||||
@staticmethod
|
||||
def validate_email(email: str) -> str:
|
||||
"""
|
||||
Validate email address format.
|
||||
|
||||
Args:
|
||||
email: Email to validate
|
||||
|
||||
Returns:
|
||||
Validated email
|
||||
|
||||
Raises:
|
||||
ValueError: If email is invalid
|
||||
"""
|
||||
if not email or not isinstance(email, str):
|
||||
raise ValueError("Email must be a non-empty string")
|
||||
|
||||
email_pattern = re.compile(
|
||||
r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
|
||||
)
|
||||
|
||||
if not email_pattern.match(email):
|
||||
raise ValueError(f"Invalid email format: {email}")
|
||||
|
||||
return email
|
||||
|
||||
@staticmethod
|
||||
def validate_port(port: int) -> int:
|
||||
"""
|
||||
Validate port number.
|
||||
|
||||
Args:
|
||||
port: Port number to validate
|
||||
|
||||
Returns:
|
||||
Validated port
|
||||
|
||||
Raises:
|
||||
ValueError: If port is invalid
|
||||
"""
|
||||
if not isinstance(port, int):
|
||||
raise ValueError("Port must be an integer")
|
||||
|
||||
if port < 1 or port > 65535:
|
||||
raise ValueError("Port must be between 1 and 65535")
|
||||
|
||||
return port
|
||||
|
||||
@staticmethod
|
||||
def validate_positive_integer(value: int, name: str = "Value") -> int:
|
||||
"""
|
||||
Validate positive integer.
|
||||
|
||||
Args:
|
||||
value: Value to validate
|
||||
name: Name for error messages
|
||||
|
||||
Returns:
|
||||
Validated value
|
||||
|
||||
Raises:
|
||||
ValueError: If value is invalid
|
||||
"""
|
||||
if not isinstance(value, int):
|
||||
raise ValueError(f"{name} must be an integer")
|
||||
|
||||
if value <= 0:
|
||||
raise ValueError(f"{name} must be positive")
|
||||
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def validate_non_negative_integer(value: int, name: str = "Value") -> int:
|
||||
"""
|
||||
Validate non-negative integer.
|
||||
|
||||
Args:
|
||||
value: Value to validate
|
||||
name: Name for error messages
|
||||
|
||||
Returns:
|
||||
Validated value
|
||||
|
||||
Raises:
|
||||
ValueError: If value is invalid
|
||||
"""
|
||||
if not isinstance(value, int):
|
||||
raise ValueError(f"{name} must be an integer")
|
||||
|
||||
if value < 0:
|
||||
raise ValueError(f"{name} cannot be negative")
|
||||
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def validate_string_length(
|
||||
value: str, min_length: int = 0, max_length: Optional[int] = None, name: str = "Value"
|
||||
) -> str:
|
||||
"""
|
||||
Validate string length.
|
||||
|
||||
Args:
|
||||
value: String to validate
|
||||
min_length: Minimum length
|
||||
max_length: Maximum length (None for no limit)
|
||||
name: Name for error messages
|
||||
|
||||
Returns:
|
||||
Validated string
|
||||
|
||||
Raises:
|
||||
ValueError: If string length is invalid
|
||||
"""
|
||||
if not isinstance(value, str):
|
||||
raise ValueError(f"{name} must be a string")
|
||||
|
||||
if len(value) < min_length:
|
||||
raise ValueError(
|
||||
f"{name} must be at least {min_length} characters long"
|
||||
)
|
||||
|
||||
if max_length is not None and len(value) > max_length:
|
||||
raise ValueError(
|
||||
f"{name} must be at most {max_length} characters long"
|
||||
)
|
||||
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def validate_choice(value: Any, choices: List[Any], name: str = "Value") -> Any:
|
||||
"""
|
||||
Validate value is in allowed choices.
|
||||
|
||||
Args:
|
||||
value: Value to validate
|
||||
choices: List of allowed values
|
||||
name: Name for error messages
|
||||
|
||||
Returns:
|
||||
Validated value
|
||||
|
||||
Raises:
|
||||
ValueError: If value not in choices
|
||||
"""
|
||||
if value not in choices:
|
||||
raise ValueError(f"{name} must be one of: {', '.join(map(str, choices))}")
|
||||
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def validate_dict_keys(
|
||||
data: Dict[str, Any], required_keys: List[str], name: str = "Data"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate dictionary contains required keys.
|
||||
|
||||
Args:
|
||||
data: Dictionary to validate
|
||||
required_keys: List of required keys
|
||||
name: Name for error messages
|
||||
|
||||
Returns:
|
||||
Validated dictionary
|
||||
|
||||
Raises:
|
||||
ValueError: If required keys are missing
|
||||
"""
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError(f"{name} must be a dictionary")
|
||||
|
||||
missing_keys = [key for key in required_keys if key not in data]
|
||||
if missing_keys:
|
||||
raise ValueError(
|
||||
f"{name} missing required keys: {', '.join(missing_keys)}"
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def validate_episode_range(start: int, end: int) -> tuple[int, int]:
|
||||
"""
|
||||
Validate episode range.
|
||||
|
||||
Args:
|
||||
start: Start episode number
|
||||
end: End episode number
|
||||
|
||||
Returns:
|
||||
Tuple of (start, end)
|
||||
|
||||
Raises:
|
||||
ValueError: If range is invalid
|
||||
"""
|
||||
if start < 1:
|
||||
raise ValueError("Start episode must be at least 1")
|
||||
|
||||
if end < start:
|
||||
raise ValueError("End episode must be greater than or equal to start")
|
||||
|
||||
if end - start > 1000:
|
||||
raise ValueError("Episode range too large (max 1000 episodes)")
|
||||
|
||||
return start, end
|
||||
|
||||
|
||||
def validate_download_quality(quality: str) -> str:
|
||||
"""
|
||||
Validate download quality setting.
|
||||
|
||||
Args:
|
||||
quality: Quality setting
|
||||
|
||||
Returns:
|
||||
Validated quality
|
||||
|
||||
Raises:
|
||||
ValueError: If quality is invalid
|
||||
"""
|
||||
valid_qualities = ["360p", "480p", "720p", "1080p", "best", "worst"]
|
||||
if quality not in valid_qualities:
|
||||
raise ValueError(
|
||||
f"Invalid quality: {quality}. Must be one of: {', '.join(valid_qualities)}"
|
||||
)
|
||||
return quality
|
||||
|
||||
|
||||
def validate_language(language: str) -> str:
|
||||
"""
|
||||
Validate language code.
|
||||
|
||||
Args:
|
||||
language: Language code
|
||||
|
||||
Returns:
|
||||
Validated language
|
||||
|
||||
Raises:
|
||||
ValueError: If language is invalid
|
||||
"""
|
||||
valid_languages = ["ger-sub", "ger-dub", "eng-sub", "eng-dub", "jpn"]
|
||||
if language not in valid_languages:
|
||||
raise ValueError(
|
||||
f"Invalid language: {language}. Must be one of: {', '.join(valid_languages)}"
|
||||
)
|
||||
return language
|
||||
|
||||
|
||||
def validate_download_priority(priority: int) -> int:
|
||||
"""
|
||||
Validate download priority.
|
||||
|
||||
Args:
|
||||
priority: Priority value
|
||||
|
||||
Returns:
|
||||
Validated priority
|
||||
|
||||
Raises:
|
||||
ValueError: If priority is invalid
|
||||
"""
|
||||
if priority < 0 or priority > 10:
|
||||
raise ValueError("Priority must be between 0 and 10")
|
||||
return priority
|
||||
|
||||
|
||||
def validate_anime_url(url: str) -> str:
|
||||
"""
|
||||
Validate anime URL format.
|
||||
|
||||
Args:
|
||||
url: Anime URL
|
||||
|
||||
Returns:
|
||||
Validated URL
|
||||
|
||||
Raises:
|
||||
ValueError: If URL is invalid
|
||||
"""
|
||||
if not url:
|
||||
raise ValueError("URL cannot be empty")
|
||||
|
||||
# Check if it's a valid aniworld.to URL
|
||||
if "aniworld.to" not in url and "s.to" not in url:
|
||||
raise ValueError("URL must be from aniworld.to or s.to")
|
||||
|
||||
# Basic URL validation
|
||||
ValidatorMixin.validate_url(url)
|
||||
|
||||
return url
|
||||
|
||||
|
||||
def validate_series_name(name: str) -> str:
|
||||
"""
|
||||
Validate series name.
|
||||
|
||||
Args:
|
||||
name: Series name
|
||||
|
||||
Returns:
|
||||
Validated name
|
||||
|
||||
Raises:
|
||||
ValueError: If name is invalid
|
||||
"""
|
||||
if not name or not name.strip():
|
||||
raise ValueError("Series name cannot be empty")
|
||||
|
||||
if len(name) > 200:
|
||||
raise ValueError("Series name too long (max 200 characters)")
|
||||
|
||||
# Check for invalid characters
|
||||
invalid_chars = ['<', '>', ':', '"', '/', '\\', '|', '?', '*']
|
||||
for char in invalid_chars:
|
||||
if char in name:
|
||||
raise ValueError(
|
||||
f"Series name contains invalid character: {char}"
|
||||
)
|
||||
|
||||
return name.strip()
|
||||
|
||||
|
||||
def validate_backup_name(name: str) -> str:
|
||||
"""
|
||||
Validate backup file name.
|
||||
|
||||
Args:
|
||||
name: Backup name
|
||||
|
||||
Returns:
|
||||
Validated name
|
||||
|
||||
Raises:
|
||||
ValueError: If name is invalid
|
||||
"""
|
||||
if not name or not name.strip():
|
||||
raise ValueError("Backup name cannot be empty")
|
||||
|
||||
# Must be a valid filename
|
||||
if not re.match(r"^[a-zA-Z0-9_\-\.]+$", name):
|
||||
raise ValueError(
|
||||
"Backup name can only contain letters, numbers, underscores, hyphens, and dots"
|
||||
)
|
||||
|
||||
if not name.endswith(".json"):
|
||||
raise ValueError("Backup name must end with .json")
|
||||
|
||||
return name
|
||||
|
||||
|
||||
def validate_config_data(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate configuration data structure.
|
||||
|
||||
Args:
|
||||
data: Configuration data
|
||||
|
||||
Returns:
|
||||
Validated data
|
||||
|
||||
Raises:
|
||||
ValueError: If data is invalid
|
||||
"""
|
||||
required_keys = ["download_directory", "concurrent_downloads"]
|
||||
ValidatorMixin.validate_dict_keys(data, required_keys, "Configuration")
|
||||
|
||||
# Validate download directory
|
||||
if not isinstance(data["download_directory"], str):
|
||||
raise ValueError("download_directory must be a string")
|
||||
|
||||
# Validate concurrent downloads
|
||||
concurrent = data["concurrent_downloads"]
|
||||
if not isinstance(concurrent, int) or concurrent < 1 or concurrent > 10:
|
||||
raise ValueError("concurrent_downloads must be between 1 and 10")
|
||||
|
||||
# Validate quality if present
|
||||
if "quality" in data:
|
||||
validate_download_quality(data["quality"])
|
||||
|
||||
# Validate language if present
|
||||
if "language" in data:
|
||||
validate_language(data["language"])
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def sanitize_filename(filename: str) -> str:
|
||||
"""
|
||||
Sanitize filename for safe filesystem use.
|
||||
|
||||
Args:
|
||||
filename: Original filename
|
||||
|
||||
Returns:
|
||||
Sanitized filename
|
||||
"""
|
||||
# Remove or replace invalid characters
|
||||
invalid_chars = ['<', '>', ':', '"', '/', '\\', '|', '?', '*']
|
||||
for char in invalid_chars:
|
||||
filename = filename.replace(char, '_')
|
||||
|
||||
# Remove leading/trailing spaces and dots
|
||||
filename = filename.strip('. ')
|
||||
|
||||
# Ensure not empty
|
||||
if not filename:
|
||||
filename = "unnamed"
|
||||
|
||||
# Limit length
|
||||
if len(filename) > 255:
|
||||
name, ext = filename.rsplit('.', 1) if '.' in filename else (filename, '')
|
||||
max_name_len = 255 - len(ext) - 1 if ext else 255
|
||||
filename = name[:max_name_len] + ('.' + ext if ext else '')
|
||||
|
||||
return filename
|
||||
|
||||
|
||||
def validate_jwt_token(token: str) -> str:
|
||||
"""
|
||||
Validate JWT token format.
|
||||
|
||||
Args:
|
||||
token: JWT token
|
||||
|
||||
Returns:
|
||||
Validated token
|
||||
|
||||
Raises:
|
||||
ValueError: If token format is invalid
|
||||
"""
|
||||
if not token or not isinstance(token, str):
|
||||
raise ValueError("Token must be a non-empty string")
|
||||
|
||||
# JWT tokens have 3 parts separated by dots
|
||||
parts = token.split(".")
|
||||
if len(parts) != 3:
|
||||
raise ValueError("Invalid JWT token format")
|
||||
|
||||
# Each part should be base64url encoded (alphanumeric + - and _)
|
||||
for part in parts:
|
||||
if not re.match(r"^[A-Za-z0-9_-]+$", part):
|
||||
raise ValueError("Invalid JWT token encoding")
|
||||
|
||||
return token
|
||||
|
||||
|
||||
def validate_ip_address(ip: str) -> str:
|
||||
"""
|
||||
Validate IP address format.
|
||||
|
||||
Args:
|
||||
ip: IP address
|
||||
|
||||
Returns:
|
||||
Validated IP address
|
||||
|
||||
Raises:
|
||||
ValueError: If IP is invalid
|
||||
"""
|
||||
if not ip or not isinstance(ip, str):
|
||||
raise ValueError("IP address must be a non-empty string")
|
||||
|
||||
# IPv4 pattern
|
||||
ipv4_pattern = re.compile(
|
||||
r"^(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}"
|
||||
r"(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$"
|
||||
)
|
||||
|
||||
# IPv6 pattern (simplified)
|
||||
ipv6_pattern = re.compile(
|
||||
r"^(?:[0-9a-fA-F]{1,4}:){7}[0-9a-fA-F]{1,4}$"
|
||||
)
|
||||
|
||||
if not ipv4_pattern.match(ip) and not ipv6_pattern.match(ip):
|
||||
raise ValueError(f"Invalid IP address format: {ip}")
|
||||
|
||||
return ip
|
||||
|
||||
|
||||
def validate_websocket_message(message: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate WebSocket message structure.
|
||||
|
||||
Args:
|
||||
message: WebSocket message
|
||||
|
||||
Returns:
|
||||
Validated message
|
||||
|
||||
Raises:
|
||||
ValueError: If message structure is invalid
|
||||
"""
|
||||
required_keys = ["type"]
|
||||
ValidatorMixin.validate_dict_keys(message, required_keys, "WebSocket message")
|
||||
|
||||
valid_types = [
|
||||
"download_progress",
|
||||
"download_complete",
|
||||
"download_failed",
|
||||
"queue_update",
|
||||
"error",
|
||||
"system_message",
|
||||
]
|
||||
|
||||
if message["type"] not in valid_types:
|
||||
raise ValueError(
|
||||
f"Invalid message type. Must be one of: {', '.join(valid_types)}"
|
||||
)
|
||||
|
||||
return message
|
||||
Loading…
x
Reference in New Issue
Block a user