Add advanced features: notification system, security middleware, audit logging, data validation, and caching
- Implement notification service with email, webhook, and in-app support - Add security headers middleware (CORS, CSP, HSTS, XSS protection) - Create comprehensive audit logging service for security events - Add data validation utilities with Pydantic validators - Implement cache service with in-memory and Redis backend support All 714 tests passing
This commit is contained in:
parent
17e5a551e1
commit
7409ae637e
@ -99,44 +99,8 @@ When working with these files:
|
|||||||
- []Preserve existing WebSocket event handling
|
- []Preserve existing WebSocket event handling
|
||||||
- []Keep existing theme and responsive design features
|
- []Keep existing theme and responsive design features
|
||||||
|
|
||||||
### Advanced Features
|
|
||||||
|
|
||||||
#### [] Create notification system
|
|
||||||
|
|
||||||
- []Create `src/server/services/notification_service.py`
|
|
||||||
- []Implement email notifications for completed downloads
|
|
||||||
- []Add webhook support for external integrations
|
|
||||||
- []Include in-app notification system
|
|
||||||
- []Add notification preference management
|
|
||||||
|
|
||||||
### Security Enhancements
|
|
||||||
|
|
||||||
#### [] Add security headers
|
|
||||||
|
|
||||||
- []Create `src/server/middleware/security.py`
|
|
||||||
- []Implement CORS headers
|
|
||||||
- []Add CSP headers
|
|
||||||
- []Include security headers (HSTS, X-Frame-Options)
|
|
||||||
- []Add request sanitization
|
|
||||||
|
|
||||||
#### [] Create audit logging
|
|
||||||
|
|
||||||
- []Create `src/server/services/audit_service.py`
|
|
||||||
- []Log all authentication attempts
|
|
||||||
- []Track configuration changes
|
|
||||||
- []Monitor download activities
|
|
||||||
- []Include user action tracking
|
|
||||||
|
|
||||||
### Data Management
|
### Data Management
|
||||||
|
|
||||||
#### [] Implement data validation
|
|
||||||
|
|
||||||
- []Create `src/server/utils/validators.py`
|
|
||||||
- []Add Pydantic custom validators
|
|
||||||
- []Implement business rule validation
|
|
||||||
- []Include data integrity checks
|
|
||||||
- []Add format validation utilities
|
|
||||||
|
|
||||||
#### [] Create data migration tools
|
#### [] Create data migration tools
|
||||||
|
|
||||||
- []Create `src/server/database/migrations/`
|
- []Create `src/server/database/migrations/`
|
||||||
@ -145,14 +109,6 @@ When working with these files:
|
|||||||
- []Include rollback mechanisms
|
- []Include rollback mechanisms
|
||||||
- []Add migration validation
|
- []Add migration validation
|
||||||
|
|
||||||
#### [] Add caching layer
|
|
||||||
|
|
||||||
- []Create `src/server/services/cache_service.py`
|
|
||||||
- []Implement Redis caching
|
|
||||||
- []Add in-memory caching for frequent data
|
|
||||||
- []Include cache invalidation strategies
|
|
||||||
- []Add cache performance monitoring
|
|
||||||
|
|
||||||
### Integration Enhancements
|
### Integration Enhancements
|
||||||
|
|
||||||
#### [] Extend provider system
|
#### [] Extend provider system
|
||||||
|
|||||||
446
src/server/middleware/security.py
Normal file
446
src/server/middleware/security.py
Normal file
@ -0,0 +1,446 @@
|
|||||||
|
"""
|
||||||
|
Security Middleware for AniWorld.
|
||||||
|
|
||||||
|
This module provides security-related middleware including CORS, CSP,
|
||||||
|
security headers, and request sanitization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
|
from fastapi import FastAPI, Request, Response
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
from starlette.types import ASGIApp
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""Middleware to add security headers to all responses."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
app: ASGIApp,
|
||||||
|
hsts_max_age: int = 31536000, # 1 year
|
||||||
|
hsts_include_subdomains: bool = True,
|
||||||
|
hsts_preload: bool = False,
|
||||||
|
frame_options: str = "DENY",
|
||||||
|
content_type_options: bool = True,
|
||||||
|
xss_protection: bool = True,
|
||||||
|
referrer_policy: str = "strict-origin-when-cross-origin",
|
||||||
|
permissions_policy: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize security headers middleware.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: ASGI application
|
||||||
|
hsts_max_age: HSTS max-age in seconds
|
||||||
|
hsts_include_subdomains: Include subdomains in HSTS
|
||||||
|
hsts_preload: Enable HSTS preload
|
||||||
|
frame_options: X-Frame-Options value (DENY, SAMEORIGIN, or ALLOW-FROM)
|
||||||
|
content_type_options: Enable X-Content-Type-Options: nosniff
|
||||||
|
xss_protection: Enable X-XSS-Protection
|
||||||
|
referrer_policy: Referrer-Policy value
|
||||||
|
permissions_policy: Permissions-Policy value
|
||||||
|
"""
|
||||||
|
super().__init__(app)
|
||||||
|
self.hsts_max_age = hsts_max_age
|
||||||
|
self.hsts_include_subdomains = hsts_include_subdomains
|
||||||
|
self.hsts_preload = hsts_preload
|
||||||
|
self.frame_options = frame_options
|
||||||
|
self.content_type_options = content_type_options
|
||||||
|
self.xss_protection = xss_protection
|
||||||
|
self.referrer_policy = referrer_policy
|
||||||
|
self.permissions_policy = permissions_policy
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||||
|
"""
|
||||||
|
Process request and add security headers to response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: Incoming request
|
||||||
|
call_next: Next middleware in chain
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response with security headers
|
||||||
|
"""
|
||||||
|
response = await call_next(request)
|
||||||
|
|
||||||
|
# HSTS Header
|
||||||
|
hsts_value = f"max-age={self.hsts_max_age}"
|
||||||
|
if self.hsts_include_subdomains:
|
||||||
|
hsts_value += "; includeSubDomains"
|
||||||
|
if self.hsts_preload:
|
||||||
|
hsts_value += "; preload"
|
||||||
|
response.headers["Strict-Transport-Security"] = hsts_value
|
||||||
|
|
||||||
|
# X-Frame-Options
|
||||||
|
response.headers["X-Frame-Options"] = self.frame_options
|
||||||
|
|
||||||
|
# X-Content-Type-Options
|
||||||
|
if self.content_type_options:
|
||||||
|
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||||
|
|
||||||
|
# X-XSS-Protection (deprecated but still useful for older browsers)
|
||||||
|
if self.xss_protection:
|
||||||
|
response.headers["X-XSS-Protection"] = "1; mode=block"
|
||||||
|
|
||||||
|
# Referrer-Policy
|
||||||
|
response.headers["Referrer-Policy"] = self.referrer_policy
|
||||||
|
|
||||||
|
# Permissions-Policy
|
||||||
|
if self.permissions_policy:
|
||||||
|
response.headers["Permissions-Policy"] = self.permissions_policy
|
||||||
|
|
||||||
|
# Remove potentially revealing headers
|
||||||
|
response.headers.pop("Server", None)
|
||||||
|
response.headers.pop("X-Powered-By", None)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
class ContentSecurityPolicyMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""Middleware to add Content Security Policy headers."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
app: ASGIApp,
|
||||||
|
default_src: List[str] = None,
|
||||||
|
script_src: List[str] = None,
|
||||||
|
style_src: List[str] = None,
|
||||||
|
img_src: List[str] = None,
|
||||||
|
font_src: List[str] = None,
|
||||||
|
connect_src: List[str] = None,
|
||||||
|
frame_src: List[str] = None,
|
||||||
|
object_src: List[str] = None,
|
||||||
|
media_src: List[str] = None,
|
||||||
|
worker_src: List[str] = None,
|
||||||
|
form_action: List[str] = None,
|
||||||
|
frame_ancestors: List[str] = None,
|
||||||
|
base_uri: List[str] = None,
|
||||||
|
upgrade_insecure_requests: bool = True,
|
||||||
|
block_all_mixed_content: bool = True,
|
||||||
|
report_only: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize CSP middleware.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: ASGI application
|
||||||
|
default_src: default-src directive values
|
||||||
|
script_src: script-src directive values
|
||||||
|
style_src: style-src directive values
|
||||||
|
img_src: img-src directive values
|
||||||
|
font_src: font-src directive values
|
||||||
|
connect_src: connect-src directive values
|
||||||
|
frame_src: frame-src directive values
|
||||||
|
object_src: object-src directive values
|
||||||
|
media_src: media-src directive values
|
||||||
|
worker_src: worker-src directive values
|
||||||
|
form_action: form-action directive values
|
||||||
|
frame_ancestors: frame-ancestors directive values
|
||||||
|
base_uri: base-uri directive values
|
||||||
|
upgrade_insecure_requests: Enable upgrade-insecure-requests
|
||||||
|
block_all_mixed_content: Enable block-all-mixed-content
|
||||||
|
report_only: Use Content-Security-Policy-Report-Only header
|
||||||
|
"""
|
||||||
|
super().__init__(app)
|
||||||
|
|
||||||
|
# Default secure CSP
|
||||||
|
self.directives = {
|
||||||
|
"default-src": default_src or ["'self'"],
|
||||||
|
"script-src": script_src or ["'self'", "'unsafe-inline'"],
|
||||||
|
"style-src": style_src or ["'self'", "'unsafe-inline'"],
|
||||||
|
"img-src": img_src or ["'self'", "data:", "https:"],
|
||||||
|
"font-src": font_src or ["'self'", "data:"],
|
||||||
|
"connect-src": connect_src or ["'self'", "ws:", "wss:"],
|
||||||
|
"frame-src": frame_src or ["'none'"],
|
||||||
|
"object-src": object_src or ["'none'"],
|
||||||
|
"media-src": media_src or ["'self'"],
|
||||||
|
"worker-src": worker_src or ["'self'"],
|
||||||
|
"form-action": form_action or ["'self'"],
|
||||||
|
"frame-ancestors": frame_ancestors or ["'none'"],
|
||||||
|
"base-uri": base_uri or ["'self'"],
|
||||||
|
}
|
||||||
|
|
||||||
|
self.upgrade_insecure_requests = upgrade_insecure_requests
|
||||||
|
self.block_all_mixed_content = block_all_mixed_content
|
||||||
|
self.report_only = report_only
|
||||||
|
|
||||||
|
def _build_csp_header(self) -> str:
|
||||||
|
"""
|
||||||
|
Build the CSP header value.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CSP header string
|
||||||
|
"""
|
||||||
|
parts = []
|
||||||
|
|
||||||
|
for directive, values in self.directives.items():
|
||||||
|
if values:
|
||||||
|
parts.append(f"{directive} {' '.join(values)}")
|
||||||
|
|
||||||
|
if self.upgrade_insecure_requests:
|
||||||
|
parts.append("upgrade-insecure-requests")
|
||||||
|
|
||||||
|
if self.block_all_mixed_content:
|
||||||
|
parts.append("block-all-mixed-content")
|
||||||
|
|
||||||
|
return "; ".join(parts)
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||||
|
"""
|
||||||
|
Process request and add CSP header to response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: Incoming request
|
||||||
|
call_next: Next middleware in chain
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response with CSP header
|
||||||
|
"""
|
||||||
|
response = await call_next(request)
|
||||||
|
|
||||||
|
header_name = (
|
||||||
|
"Content-Security-Policy-Report-Only"
|
||||||
|
if self.report_only
|
||||||
|
else "Content-Security-Policy"
|
||||||
|
)
|
||||||
|
response.headers[header_name] = self._build_csp_header()
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
class RequestSanitizationMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""Middleware to sanitize and validate incoming requests."""
|
||||||
|
|
||||||
|
# Common SQL injection patterns
|
||||||
|
SQL_INJECTION_PATTERNS = [
|
||||||
|
re.compile(r"(\bunion\b.*\bselect\b)", re.IGNORECASE),
|
||||||
|
re.compile(r"(\bselect\b.*\bfrom\b)", re.IGNORECASE),
|
||||||
|
re.compile(r"(\binsert\b.*\binto\b)", re.IGNORECASE),
|
||||||
|
re.compile(r"(\bupdate\b.*\bset\b)", re.IGNORECASE),
|
||||||
|
re.compile(r"(\bdelete\b.*\bfrom\b)", re.IGNORECASE),
|
||||||
|
re.compile(r"(\bdrop\b.*\btable\b)", re.IGNORECASE),
|
||||||
|
re.compile(r"(\bexec\b|\bexecute\b)", re.IGNORECASE),
|
||||||
|
re.compile(r"(--|\#|\/\*|\*\/)", re.IGNORECASE),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Common XSS patterns
|
||||||
|
XSS_PATTERNS = [
|
||||||
|
re.compile(r"<script[^>]*>.*?</script>", re.IGNORECASE | re.DOTALL),
|
||||||
|
re.compile(r"javascript:", re.IGNORECASE),
|
||||||
|
re.compile(r"on\w+\s*=", re.IGNORECASE), # Event handlers like onclick=
|
||||||
|
re.compile(r"<iframe[^>]*>", re.IGNORECASE),
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
app: ASGIApp,
|
||||||
|
check_sql_injection: bool = True,
|
||||||
|
check_xss: bool = True,
|
||||||
|
max_request_size: int = 10 * 1024 * 1024, # 10 MB
|
||||||
|
allowed_content_types: Optional[List[str]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize request sanitization middleware.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: ASGI application
|
||||||
|
check_sql_injection: Enable SQL injection checks
|
||||||
|
check_xss: Enable XSS checks
|
||||||
|
max_request_size: Maximum request body size in bytes
|
||||||
|
allowed_content_types: List of allowed content types
|
||||||
|
"""
|
||||||
|
super().__init__(app)
|
||||||
|
self.check_sql_injection = check_sql_injection
|
||||||
|
self.check_xss = check_xss
|
||||||
|
self.max_request_size = max_request_size
|
||||||
|
self.allowed_content_types = allowed_content_types or [
|
||||||
|
"application/json",
|
||||||
|
"application/x-www-form-urlencoded",
|
||||||
|
"multipart/form-data",
|
||||||
|
"text/plain",
|
||||||
|
]
|
||||||
|
|
||||||
|
def _check_sql_injection(self, value: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if string contains SQL injection patterns.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: String to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if potential SQL injection detected
|
||||||
|
"""
|
||||||
|
for pattern in self.SQL_INJECTION_PATTERNS:
|
||||||
|
if pattern.search(value):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _check_xss(self, value: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if string contains XSS patterns.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: String to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if potential XSS detected
|
||||||
|
"""
|
||||||
|
for pattern in self.XSS_PATTERNS:
|
||||||
|
if pattern.search(value):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _sanitize_value(self, value: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Sanitize a string value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: Value to sanitize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None if malicious content detected, sanitized value otherwise
|
||||||
|
"""
|
||||||
|
if self.check_sql_injection and self._check_sql_injection(value):
|
||||||
|
logger.warning(f"Potential SQL injection detected: {value[:100]}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if self.check_xss and self._check_xss(value):
|
||||||
|
logger.warning(f"Potential XSS detected: {value[:100]}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||||
|
"""
|
||||||
|
Process and sanitize request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: Incoming request
|
||||||
|
call_next: Next middleware in chain
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response or error response if request is malicious
|
||||||
|
"""
|
||||||
|
# Check content type
|
||||||
|
content_type = request.headers.get("content-type", "").split(";")[0].strip()
|
||||||
|
if (
|
||||||
|
content_type
|
||||||
|
and not any(ct in content_type for ct in self.allowed_content_types)
|
||||||
|
):
|
||||||
|
logger.warning(f"Unsupported content type: {content_type}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=415,
|
||||||
|
content={"detail": "Unsupported Media Type"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check request size
|
||||||
|
content_length = request.headers.get("content-length")
|
||||||
|
if content_length and int(content_length) > self.max_request_size:
|
||||||
|
logger.warning(f"Request too large: {content_length} bytes")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=413,
|
||||||
|
content={"detail": "Request Entity Too Large"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check query parameters
|
||||||
|
for key, value in request.query_params.items():
|
||||||
|
if isinstance(value, str):
|
||||||
|
sanitized = self._sanitize_value(value)
|
||||||
|
if sanitized is None:
|
||||||
|
logger.warning(f"Malicious query parameter detected: {key}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=400,
|
||||||
|
content={"detail": "Malicious request detected"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check path parameters
|
||||||
|
for key, value in request.path_params.items():
|
||||||
|
if isinstance(value, str):
|
||||||
|
sanitized = self._sanitize_value(value)
|
||||||
|
if sanitized is None:
|
||||||
|
logger.warning(f"Malicious path parameter detected: {key}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=400,
|
||||||
|
content={"detail": "Malicious request detected"},
|
||||||
|
)
|
||||||
|
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
|
||||||
|
def configure_security_middleware(
|
||||||
|
app: FastAPI,
|
||||||
|
cors_origins: List[str] = None,
|
||||||
|
cors_allow_credentials: bool = True,
|
||||||
|
enable_hsts: bool = True,
|
||||||
|
enable_csp: bool = True,
|
||||||
|
enable_sanitization: bool = True,
|
||||||
|
csp_report_only: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Configure all security middleware for the FastAPI application.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: FastAPI application instance
|
||||||
|
cors_origins: List of allowed CORS origins
|
||||||
|
cors_allow_credentials: Allow credentials in CORS requests
|
||||||
|
enable_hsts: Enable HSTS and other security headers
|
||||||
|
enable_csp: Enable Content Security Policy
|
||||||
|
enable_sanitization: Enable request sanitization
|
||||||
|
csp_report_only: Use CSP in report-only mode
|
||||||
|
"""
|
||||||
|
# CORS Middleware
|
||||||
|
if cors_origins is None:
|
||||||
|
cors_origins = ["http://localhost:3000", "http://localhost:8000"]
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=cors_origins,
|
||||||
|
allow_credentials=cors_allow_credentials,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
expose_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Security Headers Middleware
|
||||||
|
if enable_hsts:
|
||||||
|
app.add_middleware(
|
||||||
|
SecurityHeadersMiddleware,
|
||||||
|
hsts_max_age=31536000,
|
||||||
|
hsts_include_subdomains=True,
|
||||||
|
frame_options="DENY",
|
||||||
|
content_type_options=True,
|
||||||
|
xss_protection=True,
|
||||||
|
referrer_policy="strict-origin-when-cross-origin",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Content Security Policy Middleware
|
||||||
|
if enable_csp:
|
||||||
|
app.add_middleware(
|
||||||
|
ContentSecurityPolicyMiddleware,
|
||||||
|
report_only=csp_report_only,
|
||||||
|
# Allow inline scripts and styles for development
|
||||||
|
# In production, use nonces or hashes
|
||||||
|
script_src=["'self'", "'unsafe-inline'", "'unsafe-eval'"],
|
||||||
|
style_src=["'self'", "'unsafe-inline'", "https://cdnjs.cloudflare.com"],
|
||||||
|
font_src=["'self'", "data:", "https://cdnjs.cloudflare.com"],
|
||||||
|
img_src=["'self'", "data:", "https:"],
|
||||||
|
connect_src=["'self'", "ws://localhost:*", "wss://localhost:*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Request Sanitization Middleware
|
||||||
|
if enable_sanitization:
|
||||||
|
app.add_middleware(
|
||||||
|
RequestSanitizationMiddleware,
|
||||||
|
check_sql_injection=True,
|
||||||
|
check_xss=True,
|
||||||
|
max_request_size=10 * 1024 * 1024, # 10 MB
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Security middleware configured successfully")
|
||||||
610
src/server/services/audit_service.py
Normal file
610
src/server/services/audit_service.py
Normal file
@ -0,0 +1,610 @@
|
|||||||
|
"""
|
||||||
|
Audit Service for AniWorld.
|
||||||
|
|
||||||
|
This module provides comprehensive audit logging for security-critical
|
||||||
|
operations including authentication, configuration changes, and downloads.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from enum import Enum
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AuditEventType(str, Enum):
|
||||||
|
"""Types of audit events."""
|
||||||
|
|
||||||
|
# Authentication events
|
||||||
|
AUTH_SETUP = "auth.setup"
|
||||||
|
AUTH_LOGIN_SUCCESS = "auth.login.success"
|
||||||
|
AUTH_LOGIN_FAILURE = "auth.login.failure"
|
||||||
|
AUTH_LOGOUT = "auth.logout"
|
||||||
|
AUTH_TOKEN_REFRESH = "auth.token.refresh"
|
||||||
|
AUTH_TOKEN_INVALID = "auth.token.invalid"
|
||||||
|
|
||||||
|
# Configuration events
|
||||||
|
CONFIG_READ = "config.read"
|
||||||
|
CONFIG_UPDATE = "config.update"
|
||||||
|
CONFIG_BACKUP = "config.backup"
|
||||||
|
CONFIG_RESTORE = "config.restore"
|
||||||
|
CONFIG_DELETE = "config.delete"
|
||||||
|
|
||||||
|
# Download events
|
||||||
|
DOWNLOAD_ADDED = "download.added"
|
||||||
|
DOWNLOAD_STARTED = "download.started"
|
||||||
|
DOWNLOAD_COMPLETED = "download.completed"
|
||||||
|
DOWNLOAD_FAILED = "download.failed"
|
||||||
|
DOWNLOAD_CANCELLED = "download.cancelled"
|
||||||
|
DOWNLOAD_REMOVED = "download.removed"
|
||||||
|
|
||||||
|
# Queue events
|
||||||
|
QUEUE_STARTED = "queue.started"
|
||||||
|
QUEUE_STOPPED = "queue.stopped"
|
||||||
|
QUEUE_PAUSED = "queue.paused"
|
||||||
|
QUEUE_RESUMED = "queue.resumed"
|
||||||
|
QUEUE_CLEARED = "queue.cleared"
|
||||||
|
|
||||||
|
# System events
|
||||||
|
SYSTEM_STARTUP = "system.startup"
|
||||||
|
SYSTEM_SHUTDOWN = "system.shutdown"
|
||||||
|
SYSTEM_ERROR = "system.error"
|
||||||
|
|
||||||
|
|
||||||
|
class AuditEventSeverity(str, Enum):
|
||||||
|
"""Severity levels for audit events."""
|
||||||
|
|
||||||
|
DEBUG = "debug"
|
||||||
|
INFO = "info"
|
||||||
|
WARNING = "warning"
|
||||||
|
ERROR = "error"
|
||||||
|
CRITICAL = "critical"
|
||||||
|
|
||||||
|
|
||||||
|
class AuditEvent(BaseModel):
|
||||||
|
"""Audit event model."""
|
||||||
|
|
||||||
|
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
event_type: AuditEventType
|
||||||
|
severity: AuditEventSeverity = AuditEventSeverity.INFO
|
||||||
|
user_id: Optional[str] = None
|
||||||
|
ip_address: Optional[str] = None
|
||||||
|
user_agent: Optional[str] = None
|
||||||
|
resource: Optional[str] = None
|
||||||
|
action: Optional[str] = None
|
||||||
|
status: str = "success"
|
||||||
|
message: str
|
||||||
|
details: Optional[Dict[str, Any]] = None
|
||||||
|
session_id: Optional[str] = None
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Pydantic config."""
|
||||||
|
|
||||||
|
json_encoders = {datetime: lambda v: v.isoformat()}
|
||||||
|
|
||||||
|
|
||||||
|
class AuditLogStorage:
|
||||||
|
"""Base class for audit log storage backends."""
|
||||||
|
|
||||||
|
async def write_event(self, event: AuditEvent) -> None:
|
||||||
|
"""
|
||||||
|
Write an audit event to storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event: Audit event to write
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def read_events(
|
||||||
|
self,
|
||||||
|
start_time: Optional[datetime] = None,
|
||||||
|
end_time: Optional[datetime] = None,
|
||||||
|
event_types: Optional[List[AuditEventType]] = None,
|
||||||
|
user_id: Optional[str] = None,
|
||||||
|
limit: int = 100,
|
||||||
|
) -> List[AuditEvent]:
|
||||||
|
"""
|
||||||
|
Read audit events from storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start_time: Start of time range
|
||||||
|
end_time: End of time range
|
||||||
|
event_types: Filter by event types
|
||||||
|
user_id: Filter by user ID
|
||||||
|
limit: Maximum number of events to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of audit events
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def cleanup_old_events(self, days: int = 90) -> int:
|
||||||
|
"""
|
||||||
|
Clean up audit events older than specified days.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
days: Number of days to retain
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of events deleted
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class FileAuditLogStorage(AuditLogStorage):
|
||||||
|
"""File-based audit log storage."""
|
||||||
|
|
||||||
|
def __init__(self, log_directory: str = "logs/audit"):
|
||||||
|
"""
|
||||||
|
Initialize file-based audit log storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
log_directory: Directory to store audit logs
|
||||||
|
"""
|
||||||
|
self.log_directory = Path(log_directory)
|
||||||
|
self.log_directory.mkdir(parents=True, exist_ok=True)
|
||||||
|
self._current_date: Optional[str] = None
|
||||||
|
self._current_file: Optional[Path] = None
|
||||||
|
|
||||||
|
def _get_log_file(self, date: datetime) -> Path:
|
||||||
|
"""
|
||||||
|
Get log file path for a specific date.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
date: Date for log file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to log file
|
||||||
|
"""
|
||||||
|
date_str = date.strftime("%Y-%m-%d")
|
||||||
|
return self.log_directory / f"audit_{date_str}.jsonl"
|
||||||
|
|
||||||
|
async def write_event(self, event: AuditEvent) -> None:
|
||||||
|
"""
|
||||||
|
Write an audit event to file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event: Audit event to write
|
||||||
|
"""
|
||||||
|
log_file = self._get_log_file(event.timestamp)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(log_file, "a", encoding="utf-8") as f:
|
||||||
|
f.write(event.model_dump_json() + "\n")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to write audit event to file: {e}")
|
||||||
|
|
||||||
|
async def read_events(
|
||||||
|
self,
|
||||||
|
start_time: Optional[datetime] = None,
|
||||||
|
end_time: Optional[datetime] = None,
|
||||||
|
event_types: Optional[List[AuditEventType]] = None,
|
||||||
|
user_id: Optional[str] = None,
|
||||||
|
limit: int = 100,
|
||||||
|
) -> List[AuditEvent]:
|
||||||
|
"""
|
||||||
|
Read audit events from files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start_time: Start of time range
|
||||||
|
end_time: End of time range
|
||||||
|
event_types: Filter by event types
|
||||||
|
user_id: Filter by user ID
|
||||||
|
limit: Maximum number of events to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of audit events
|
||||||
|
"""
|
||||||
|
if start_time is None:
|
||||||
|
start_time = datetime.utcnow() - timedelta(days=7)
|
||||||
|
if end_time is None:
|
||||||
|
end_time = datetime.utcnow()
|
||||||
|
|
||||||
|
events: List[AuditEvent] = []
|
||||||
|
current_date = start_time.date()
|
||||||
|
end_date = end_time.date()
|
||||||
|
|
||||||
|
# Read from all log files in date range
|
||||||
|
while current_date <= end_date and len(events) < limit:
|
||||||
|
log_file = self._get_log_file(datetime.combine(current_date, datetime.min.time()))
|
||||||
|
|
||||||
|
if log_file.exists():
|
||||||
|
try:
|
||||||
|
with open(log_file, "r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
if len(events) >= limit:
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
event_data = json.loads(line.strip())
|
||||||
|
event = AuditEvent(**event_data)
|
||||||
|
|
||||||
|
# Apply filters
|
||||||
|
if event.timestamp < start_time or event.timestamp > end_time:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if event_types and event.event_type not in event_types:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if user_id and event.user_id != user_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
events.append(event)
|
||||||
|
|
||||||
|
except (json.JSONDecodeError, ValueError) as e:
|
||||||
|
logger.warning(f"Failed to parse audit event: {e}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to read audit log file {log_file}: {e}")
|
||||||
|
|
||||||
|
current_date += timedelta(days=1)
|
||||||
|
|
||||||
|
# Sort by timestamp descending
|
||||||
|
events.sort(key=lambda e: e.timestamp, reverse=True)
|
||||||
|
return events[:limit]
|
||||||
|
|
||||||
|
async def cleanup_old_events(self, days: int = 90) -> int:
|
||||||
|
"""
|
||||||
|
Clean up audit events older than specified days.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
days: Number of days to retain
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of files deleted
|
||||||
|
"""
|
||||||
|
cutoff_date = datetime.utcnow() - timedelta(days=days)
|
||||||
|
deleted_count = 0
|
||||||
|
|
||||||
|
for log_file in self.log_directory.glob("audit_*.jsonl"):
|
||||||
|
try:
|
||||||
|
# Extract date from filename
|
||||||
|
date_str = log_file.stem.replace("audit_", "")
|
||||||
|
file_date = datetime.strptime(date_str, "%Y-%m-%d")
|
||||||
|
|
||||||
|
if file_date < cutoff_date:
|
||||||
|
log_file.unlink()
|
||||||
|
deleted_count += 1
|
||||||
|
logger.info(f"Deleted old audit log: {log_file}")
|
||||||
|
|
||||||
|
except (ValueError, OSError) as e:
|
||||||
|
logger.warning(f"Failed to process audit log file {log_file}: {e}")
|
||||||
|
|
||||||
|
return deleted_count
|
||||||
|
|
||||||
|
|
||||||
|
class AuditService:
|
||||||
|
"""Main audit service for logging security events."""
|
||||||
|
|
||||||
|
def __init__(self, storage: Optional[AuditLogStorage] = None):
|
||||||
|
"""
|
||||||
|
Initialize audit service.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storage: Storage backend for audit logs
|
||||||
|
"""
|
||||||
|
self.storage = storage or FileAuditLogStorage()
|
||||||
|
|
||||||
|
async def log_event(
|
||||||
|
self,
|
||||||
|
event_type: AuditEventType,
|
||||||
|
message: str,
|
||||||
|
severity: AuditEventSeverity = AuditEventSeverity.INFO,
|
||||||
|
user_id: Optional[str] = None,
|
||||||
|
ip_address: Optional[str] = None,
|
||||||
|
user_agent: Optional[str] = None,
|
||||||
|
resource: Optional[str] = None,
|
||||||
|
action: Optional[str] = None,
|
||||||
|
status: str = "success",
|
||||||
|
details: Optional[Dict[str, Any]] = None,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Log an audit event.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_type: Type of event
|
||||||
|
message: Human-readable message
|
||||||
|
severity: Event severity
|
||||||
|
user_id: User identifier
|
||||||
|
ip_address: Client IP address
|
||||||
|
user_agent: Client user agent
|
||||||
|
resource: Resource being accessed
|
||||||
|
action: Action performed
|
||||||
|
status: Operation status
|
||||||
|
details: Additional details
|
||||||
|
session_id: Session identifier
|
||||||
|
"""
|
||||||
|
event = AuditEvent(
|
||||||
|
event_type=event_type,
|
||||||
|
severity=severity,
|
||||||
|
user_id=user_id,
|
||||||
|
ip_address=ip_address,
|
||||||
|
user_agent=user_agent,
|
||||||
|
resource=resource,
|
||||||
|
action=action,
|
||||||
|
status=status,
|
||||||
|
message=message,
|
||||||
|
details=details,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.storage.write_event(event)
|
||||||
|
|
||||||
|
# Also log to application logger for high severity events
|
||||||
|
if severity in [AuditEventSeverity.ERROR, AuditEventSeverity.CRITICAL]:
|
||||||
|
logger.error(f"Audit: {message}", extra={"audit_event": event.model_dump()})
|
||||||
|
elif severity == AuditEventSeverity.WARNING:
|
||||||
|
logger.warning(f"Audit: {message}", extra={"audit_event": event.model_dump()})
|
||||||
|
|
||||||
|
async def log_auth_setup(
|
||||||
|
self, user_id: str, ip_address: Optional[str] = None
|
||||||
|
) -> None:
|
||||||
|
"""Log initial authentication setup."""
|
||||||
|
await self.log_event(
|
||||||
|
event_type=AuditEventType.AUTH_SETUP,
|
||||||
|
message=f"Authentication configured by user {user_id}",
|
||||||
|
user_id=user_id,
|
||||||
|
ip_address=ip_address,
|
||||||
|
action="setup",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def log_login_success(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
ip_address: Optional[str] = None,
|
||||||
|
user_agent: Optional[str] = None,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Log successful login."""
|
||||||
|
await self.log_event(
|
||||||
|
event_type=AuditEventType.AUTH_LOGIN_SUCCESS,
|
||||||
|
message=f"User {user_id} logged in successfully",
|
||||||
|
user_id=user_id,
|
||||||
|
ip_address=ip_address,
|
||||||
|
user_agent=user_agent,
|
||||||
|
session_id=session_id,
|
||||||
|
action="login",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def log_login_failure(
|
||||||
|
self,
|
||||||
|
user_id: Optional[str] = None,
|
||||||
|
ip_address: Optional[str] = None,
|
||||||
|
user_agent: Optional[str] = None,
|
||||||
|
reason: str = "Invalid credentials",
|
||||||
|
) -> None:
|
||||||
|
"""Log failed login attempt."""
|
||||||
|
await self.log_event(
|
||||||
|
event_type=AuditEventType.AUTH_LOGIN_FAILURE,
|
||||||
|
message=f"Login failed for user {user_id or 'unknown'}: {reason}",
|
||||||
|
severity=AuditEventSeverity.WARNING,
|
||||||
|
user_id=user_id,
|
||||||
|
ip_address=ip_address,
|
||||||
|
user_agent=user_agent,
|
||||||
|
status="failure",
|
||||||
|
action="login",
|
||||||
|
details={"reason": reason},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def log_logout(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
ip_address: Optional[str] = None,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Log user logout."""
|
||||||
|
await self.log_event(
|
||||||
|
event_type=AuditEventType.AUTH_LOGOUT,
|
||||||
|
message=f"User {user_id} logged out",
|
||||||
|
user_id=user_id,
|
||||||
|
ip_address=ip_address,
|
||||||
|
session_id=session_id,
|
||||||
|
action="logout",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def log_config_update(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
changes: Dict[str, Any],
|
||||||
|
ip_address: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Log configuration update."""
|
||||||
|
await self.log_event(
|
||||||
|
event_type=AuditEventType.CONFIG_UPDATE,
|
||||||
|
message=f"Configuration updated by user {user_id}",
|
||||||
|
user_id=user_id,
|
||||||
|
ip_address=ip_address,
|
||||||
|
resource="config",
|
||||||
|
action="update",
|
||||||
|
details={"changes": changes},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def log_config_backup(
|
||||||
|
self, user_id: str, backup_file: str, ip_address: Optional[str] = None
|
||||||
|
) -> None:
|
||||||
|
"""Log configuration backup."""
|
||||||
|
await self.log_event(
|
||||||
|
event_type=AuditEventType.CONFIG_BACKUP,
|
||||||
|
message=f"Configuration backed up by user {user_id}",
|
||||||
|
user_id=user_id,
|
||||||
|
ip_address=ip_address,
|
||||||
|
resource="config",
|
||||||
|
action="backup",
|
||||||
|
details={"backup_file": backup_file},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def log_config_restore(
|
||||||
|
self, user_id: str, backup_file: str, ip_address: Optional[str] = None
|
||||||
|
) -> None:
|
||||||
|
"""Log configuration restore."""
|
||||||
|
await self.log_event(
|
||||||
|
event_type=AuditEventType.CONFIG_RESTORE,
|
||||||
|
message=f"Configuration restored by user {user_id}",
|
||||||
|
user_id=user_id,
|
||||||
|
ip_address=ip_address,
|
||||||
|
resource="config",
|
||||||
|
action="restore",
|
||||||
|
details={"backup_file": backup_file},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def log_download_added(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
series_name: str,
|
||||||
|
episodes: List[str],
|
||||||
|
ip_address: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Log download added to queue."""
|
||||||
|
await self.log_event(
|
||||||
|
event_type=AuditEventType.DOWNLOAD_ADDED,
|
||||||
|
message=f"Download added by user {user_id}: {series_name}",
|
||||||
|
user_id=user_id,
|
||||||
|
ip_address=ip_address,
|
||||||
|
resource=series_name,
|
||||||
|
action="add",
|
||||||
|
details={"episodes": episodes},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def log_download_completed(
|
||||||
|
self, series_name: str, episode: str, file_path: str
|
||||||
|
) -> None:
|
||||||
|
"""Log completed download."""
|
||||||
|
await self.log_event(
|
||||||
|
event_type=AuditEventType.DOWNLOAD_COMPLETED,
|
||||||
|
message=f"Download completed: {series_name} - {episode}",
|
||||||
|
resource=series_name,
|
||||||
|
action="download",
|
||||||
|
details={"episode": episode, "file_path": file_path},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def log_download_failed(
|
||||||
|
self, series_name: str, episode: str, error: str
|
||||||
|
) -> None:
|
||||||
|
"""Log failed download."""
|
||||||
|
await self.log_event(
|
||||||
|
event_type=AuditEventType.DOWNLOAD_FAILED,
|
||||||
|
message=f"Download failed: {series_name} - {episode}",
|
||||||
|
severity=AuditEventSeverity.ERROR,
|
||||||
|
resource=series_name,
|
||||||
|
action="download",
|
||||||
|
status="failure",
|
||||||
|
details={"episode": episode, "error": error},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def log_queue_operation(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
operation: str,
|
||||||
|
ip_address: Optional[str] = None,
|
||||||
|
details: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Log queue operation."""
|
||||||
|
event_type_map = {
|
||||||
|
"start": AuditEventType.QUEUE_STARTED,
|
||||||
|
"stop": AuditEventType.QUEUE_STOPPED,
|
||||||
|
"pause": AuditEventType.QUEUE_PAUSED,
|
||||||
|
"resume": AuditEventType.QUEUE_RESUMED,
|
||||||
|
"clear": AuditEventType.QUEUE_CLEARED,
|
||||||
|
}
|
||||||
|
|
||||||
|
event_type = event_type_map.get(operation, AuditEventType.SYSTEM_ERROR)
|
||||||
|
await self.log_event(
|
||||||
|
event_type=event_type,
|
||||||
|
message=f"Queue {operation} by user {user_id}",
|
||||||
|
user_id=user_id,
|
||||||
|
ip_address=ip_address,
|
||||||
|
resource="queue",
|
||||||
|
action=operation,
|
||||||
|
details=details,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def log_system_error(
|
||||||
|
self, error: str, details: Optional[Dict[str, Any]] = None
|
||||||
|
) -> None:
|
||||||
|
"""Log system error."""
|
||||||
|
await self.log_event(
|
||||||
|
event_type=AuditEventType.SYSTEM_ERROR,
|
||||||
|
message=f"System error: {error}",
|
||||||
|
severity=AuditEventSeverity.ERROR,
|
||||||
|
status="error",
|
||||||
|
details=details,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_events(
|
||||||
|
self,
|
||||||
|
start_time: Optional[datetime] = None,
|
||||||
|
end_time: Optional[datetime] = None,
|
||||||
|
event_types: Optional[List[AuditEventType]] = None,
|
||||||
|
user_id: Optional[str] = None,
|
||||||
|
limit: int = 100,
|
||||||
|
) -> List[AuditEvent]:
|
||||||
|
"""
|
||||||
|
Get audit events with filters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start_time: Start of time range
|
||||||
|
end_time: End of time range
|
||||||
|
event_types: Filter by event types
|
||||||
|
user_id: Filter by user ID
|
||||||
|
limit: Maximum number of events to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of audit events
|
||||||
|
"""
|
||||||
|
return await self.storage.read_events(
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
event_types=event_types,
|
||||||
|
user_id=user_id,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def cleanup_old_events(self, days: int = 90) -> int:
|
||||||
|
"""
|
||||||
|
Clean up old audit events.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
days: Number of days to retain
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of events deleted
|
||||||
|
"""
|
||||||
|
return await self.storage.cleanup_old_events(days)
|
||||||
|
|
||||||
|
|
||||||
|
# Global audit service instance
|
||||||
|
_audit_service: Optional[AuditService] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_audit_service() -> AuditService:
|
||||||
|
"""
|
||||||
|
Get the global audit service instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AuditService instance
|
||||||
|
"""
|
||||||
|
global _audit_service
|
||||||
|
if _audit_service is None:
|
||||||
|
_audit_service = AuditService()
|
||||||
|
return _audit_service
|
||||||
|
|
||||||
|
|
||||||
|
def configure_audit_service(storage: Optional[AuditLogStorage] = None) -> AuditService:
|
||||||
|
"""
|
||||||
|
Configure the global audit service.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storage: Custom storage backend
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured AuditService instance
|
||||||
|
"""
|
||||||
|
global _audit_service
|
||||||
|
_audit_service = AuditService(storage=storage)
|
||||||
|
return _audit_service
|
||||||
723
src/server/services/cache_service.py
Normal file
723
src/server/services/cache_service.py
Normal file
@ -0,0 +1,723 @@
|
|||||||
|
"""
|
||||||
|
Cache Service for AniWorld.
|
||||||
|
|
||||||
|
This module provides caching functionality with support for both
|
||||||
|
in-memory and Redis backends to improve application performance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
import pickle
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CacheBackend(ABC):
|
||||||
|
"""Abstract base class for cache backends."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get(self, key: str) -> Optional[Any]:
|
||||||
|
"""
|
||||||
|
Get value from cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Cache key
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cached value or None if not found
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def set(
|
||||||
|
self, key: str, value: Any, ttl: Optional[int] = None
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Set value in cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Cache key
|
||||||
|
value: Value to cache
|
||||||
|
ttl: Time to live in seconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if successful
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def delete(self, key: str) -> bool:
|
||||||
|
"""
|
||||||
|
Delete value from cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Cache key
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if key was deleted
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def exists(self, key: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if key exists in cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Cache key
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if key exists
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def clear(self) -> bool:
|
||||||
|
"""
|
||||||
|
Clear all cached values.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if successful
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_many(self, keys: List[str]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get multiple values from cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
keys: List of cache keys
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping keys to values
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def set_many(
|
||||||
|
self, items: Dict[str, Any], ttl: Optional[int] = None
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Set multiple values in cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
items: Dictionary of key-value pairs
|
||||||
|
ttl: Time to live in seconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if successful
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def delete_pattern(self, pattern: str) -> int:
|
||||||
|
"""
|
||||||
|
Delete all keys matching pattern.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pattern: Pattern to match (supports wildcards)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of keys deleted
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InMemoryCacheBackend(CacheBackend):
|
||||||
|
"""In-memory cache backend using dictionary."""
|
||||||
|
|
||||||
|
def __init__(self, max_size: int = 1000):
|
||||||
|
"""
|
||||||
|
Initialize in-memory cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_size: Maximum number of items to cache
|
||||||
|
"""
|
||||||
|
self.cache: Dict[str, Dict[str, Any]] = {}
|
||||||
|
self.max_size = max_size
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
def _is_expired(self, item: Dict[str, Any]) -> bool:
|
||||||
|
"""
|
||||||
|
Check if cache item is expired.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
item: Cache item with expiry
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if expired
|
||||||
|
"""
|
||||||
|
if item.get("expiry") is None:
|
||||||
|
return False
|
||||||
|
return datetime.utcnow() > item["expiry"]
|
||||||
|
|
||||||
|
def _evict_oldest(self) -> None:
|
||||||
|
"""Evict oldest cache item when cache is full."""
|
||||||
|
if len(self.cache) >= self.max_size:
|
||||||
|
# Remove oldest item
|
||||||
|
oldest_key = min(
|
||||||
|
self.cache.keys(),
|
||||||
|
key=lambda k: self.cache[k].get("created", datetime.utcnow()),
|
||||||
|
)
|
||||||
|
del self.cache[oldest_key]
|
||||||
|
|
||||||
|
async def get(self, key: str) -> Optional[Any]:
|
||||||
|
"""Get value from cache."""
|
||||||
|
async with self._lock:
|
||||||
|
if key not in self.cache:
|
||||||
|
return None
|
||||||
|
|
||||||
|
item = self.cache[key]
|
||||||
|
|
||||||
|
if self._is_expired(item):
|
||||||
|
del self.cache[key]
|
||||||
|
return None
|
||||||
|
|
||||||
|
return item["value"]
|
||||||
|
|
||||||
|
async def set(
|
||||||
|
self, key: str, value: Any, ttl: Optional[int] = None
|
||||||
|
) -> bool:
|
||||||
|
"""Set value in cache."""
|
||||||
|
async with self._lock:
|
||||||
|
self._evict_oldest()
|
||||||
|
|
||||||
|
expiry = None
|
||||||
|
if ttl:
|
||||||
|
expiry = datetime.utcnow() + timedelta(seconds=ttl)
|
||||||
|
|
||||||
|
self.cache[key] = {
|
||||||
|
"value": value,
|
||||||
|
"expiry": expiry,
|
||||||
|
"created": datetime.utcnow(),
|
||||||
|
}
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def delete(self, key: str) -> bool:
|
||||||
|
"""Delete value from cache."""
|
||||||
|
async with self._lock:
|
||||||
|
if key in self.cache:
|
||||||
|
del self.cache[key]
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def exists(self, key: str) -> bool:
|
||||||
|
"""Check if key exists in cache."""
|
||||||
|
async with self._lock:
|
||||||
|
if key not in self.cache:
|
||||||
|
return False
|
||||||
|
|
||||||
|
item = self.cache[key]
|
||||||
|
if self._is_expired(item):
|
||||||
|
del self.cache[key]
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def clear(self) -> bool:
|
||||||
|
"""Clear all cached values."""
|
||||||
|
async with self._lock:
|
||||||
|
self.cache.clear()
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def get_many(self, keys: List[str]) -> Dict[str, Any]:
|
||||||
|
"""Get multiple values from cache."""
|
||||||
|
result = {}
|
||||||
|
for key in keys:
|
||||||
|
value = await self.get(key)
|
||||||
|
if value is not None:
|
||||||
|
result[key] = value
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def set_many(
|
||||||
|
self, items: Dict[str, Any], ttl: Optional[int] = None
|
||||||
|
) -> bool:
|
||||||
|
"""Set multiple values in cache."""
|
||||||
|
for key, value in items.items():
|
||||||
|
await self.set(key, value, ttl)
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def delete_pattern(self, pattern: str) -> int:
|
||||||
|
"""Delete all keys matching pattern."""
|
||||||
|
import fnmatch
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
keys_to_delete = [
|
||||||
|
key for key in self.cache.keys() if fnmatch.fnmatch(key, pattern)
|
||||||
|
]
|
||||||
|
for key in keys_to_delete:
|
||||||
|
del self.cache[key]
|
||||||
|
return len(keys_to_delete)
|
||||||
|
|
||||||
|
|
||||||
|
class RedisCacheBackend(CacheBackend):
|
||||||
|
"""Redis cache backend."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
redis_url: str = "redis://localhost:6379",
|
||||||
|
prefix: str = "aniworld:",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize Redis cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
redis_url: Redis connection URL
|
||||||
|
prefix: Key prefix for namespacing
|
||||||
|
"""
|
||||||
|
self.redis_url = redis_url
|
||||||
|
self.prefix = prefix
|
||||||
|
self._redis = None
|
||||||
|
|
||||||
|
async def _get_redis(self):
|
||||||
|
"""Get Redis connection."""
|
||||||
|
if self._redis is None:
|
||||||
|
try:
|
||||||
|
import aioredis
|
||||||
|
|
||||||
|
self._redis = await aioredis.create_redis_pool(self.redis_url)
|
||||||
|
except ImportError:
|
||||||
|
logger.error(
|
||||||
|
"aioredis not installed. Install with: pip install aioredis"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to connect to Redis: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
return self._redis
|
||||||
|
|
||||||
|
def _make_key(self, key: str) -> str:
|
||||||
|
"""Add prefix to key."""
|
||||||
|
return f"{self.prefix}{key}"
|
||||||
|
|
||||||
|
async def get(self, key: str) -> Optional[Any]:
|
||||||
|
"""Get value from cache."""
|
||||||
|
try:
|
||||||
|
redis = await self._get_redis()
|
||||||
|
data = await redis.get(self._make_key(key))
|
||||||
|
|
||||||
|
if data is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return pickle.loads(data)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Redis get error: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def set(
|
||||||
|
self, key: str, value: Any, ttl: Optional[int] = None
|
||||||
|
) -> bool:
|
||||||
|
"""Set value in cache."""
|
||||||
|
try:
|
||||||
|
redis = await self._get_redis()
|
||||||
|
data = pickle.dumps(value)
|
||||||
|
|
||||||
|
if ttl:
|
||||||
|
await redis.setex(self._make_key(key), ttl, data)
|
||||||
|
else:
|
||||||
|
await redis.set(self._make_key(key), data)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Redis set error: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def delete(self, key: str) -> bool:
|
||||||
|
"""Delete value from cache."""
|
||||||
|
try:
|
||||||
|
redis = await self._get_redis()
|
||||||
|
result = await redis.delete(self._make_key(key))
|
||||||
|
return result > 0
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Redis delete error: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def exists(self, key: str) -> bool:
|
||||||
|
"""Check if key exists in cache."""
|
||||||
|
try:
|
||||||
|
redis = await self._get_redis()
|
||||||
|
return await redis.exists(self._make_key(key))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Redis exists error: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def clear(self) -> bool:
|
||||||
|
"""Clear all cached values with prefix."""
|
||||||
|
try:
|
||||||
|
redis = await self._get_redis()
|
||||||
|
keys = await redis.keys(f"{self.prefix}*")
|
||||||
|
if keys:
|
||||||
|
await redis.delete(*keys)
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Redis clear error: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def get_many(self, keys: List[str]) -> Dict[str, Any]:
|
||||||
|
"""Get multiple values from cache."""
|
||||||
|
try:
|
||||||
|
redis = await self._get_redis()
|
||||||
|
prefixed_keys = [self._make_key(k) for k in keys]
|
||||||
|
values = await redis.mget(*prefixed_keys)
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
for key, value in zip(keys, values):
|
||||||
|
if value is not None:
|
||||||
|
result[key] = pickle.loads(value)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Redis get_many error: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def set_many(
|
||||||
|
self, items: Dict[str, Any], ttl: Optional[int] = None
|
||||||
|
) -> bool:
|
||||||
|
"""Set multiple values in cache."""
|
||||||
|
try:
|
||||||
|
for key, value in items.items():
|
||||||
|
await self.set(key, value, ttl)
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Redis set_many error: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def delete_pattern(self, pattern: str) -> int:
|
||||||
|
"""Delete all keys matching pattern."""
|
||||||
|
try:
|
||||||
|
redis = await self._get_redis()
|
||||||
|
full_pattern = f"{self.prefix}{pattern}"
|
||||||
|
keys = await redis.keys(full_pattern)
|
||||||
|
|
||||||
|
if keys:
|
||||||
|
await redis.delete(*keys)
|
||||||
|
return len(keys)
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Redis delete_pattern error: {e}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close Redis connection."""
|
||||||
|
if self._redis:
|
||||||
|
self._redis.close()
|
||||||
|
await self._redis.wait_closed()
|
||||||
|
|
||||||
|
|
||||||
|
class CacheService:
|
||||||
|
"""Main cache service with automatic key generation and TTL management."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
backend: Optional[CacheBackend] = None,
|
||||||
|
default_ttl: int = 3600,
|
||||||
|
key_prefix: str = "",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize cache service.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backend: Cache backend to use
|
||||||
|
default_ttl: Default time to live in seconds
|
||||||
|
key_prefix: Prefix for all cache keys
|
||||||
|
"""
|
||||||
|
self.backend = backend or InMemoryCacheBackend()
|
||||||
|
self.default_ttl = default_ttl
|
||||||
|
self.key_prefix = key_prefix
|
||||||
|
|
||||||
|
def _make_key(self, *args: Any, **kwargs: Any) -> str:
|
||||||
|
"""
|
||||||
|
Generate cache key from arguments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*args: Positional arguments
|
||||||
|
**kwargs: Keyword arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cache key string
|
||||||
|
"""
|
||||||
|
# Create a stable key from arguments
|
||||||
|
key_parts = [str(arg) for arg in args]
|
||||||
|
key_parts.extend(f"{k}={v}" for k, v in sorted(kwargs.items()))
|
||||||
|
key_str = ":".join(key_parts)
|
||||||
|
|
||||||
|
# Hash long keys
|
||||||
|
if len(key_str) > 200:
|
||||||
|
key_hash = hashlib.md5(key_str.encode()).hexdigest()
|
||||||
|
return f"{self.key_prefix}{key_hash}"
|
||||||
|
|
||||||
|
return f"{self.key_prefix}{key_str}"
|
||||||
|
|
||||||
|
async def get(
|
||||||
|
self, key: str, default: Optional[Any] = None
|
||||||
|
) -> Optional[Any]:
|
||||||
|
"""
|
||||||
|
Get value from cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Cache key
|
||||||
|
default: Default value if not found
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cached value or default
|
||||||
|
"""
|
||||||
|
value = await self.backend.get(key)
|
||||||
|
return value if value is not None else default
|
||||||
|
|
||||||
|
async def set(
|
||||||
|
self, key: str, value: Any, ttl: Optional[int] = None
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Set value in cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Cache key
|
||||||
|
value: Value to cache
|
||||||
|
ttl: Time to live in seconds (uses default if None)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if successful
|
||||||
|
"""
|
||||||
|
if ttl is None:
|
||||||
|
ttl = self.default_ttl
|
||||||
|
return await self.backend.set(key, value, ttl)
|
||||||
|
|
||||||
|
async def delete(self, key: str) -> bool:
|
||||||
|
"""
|
||||||
|
Delete value from cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Cache key
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if deleted
|
||||||
|
"""
|
||||||
|
return await self.backend.delete(key)
|
||||||
|
|
||||||
|
async def exists(self, key: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if key exists in cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Cache key
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if exists
|
||||||
|
"""
|
||||||
|
return await self.backend.exists(key)
|
||||||
|
|
||||||
|
async def clear(self) -> bool:
|
||||||
|
"""
|
||||||
|
Clear all cached values.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if successful
|
||||||
|
"""
|
||||||
|
return await self.backend.clear()
|
||||||
|
|
||||||
|
async def get_or_set(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
factory,
|
||||||
|
ttl: Optional[int] = None,
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Get value from cache or compute and cache it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Cache key
|
||||||
|
factory: Callable to compute value if not cached
|
||||||
|
ttl: Time to live in seconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cached or computed value
|
||||||
|
"""
|
||||||
|
value = await self.get(key)
|
||||||
|
|
||||||
|
if value is None:
|
||||||
|
# Compute value
|
||||||
|
if asyncio.iscoroutinefunction(factory):
|
||||||
|
value = await factory()
|
||||||
|
else:
|
||||||
|
value = factory()
|
||||||
|
|
||||||
|
# Cache it
|
||||||
|
await self.set(key, value, ttl)
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
async def invalidate_pattern(self, pattern: str) -> int:
|
||||||
|
"""
|
||||||
|
Invalidate all keys matching pattern.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pattern: Pattern to match
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of keys invalidated
|
||||||
|
"""
|
||||||
|
return await self.backend.delete_pattern(pattern)
|
||||||
|
|
||||||
|
async def cache_anime_list(
|
||||||
|
self, anime_list: List[Dict[str, Any]], ttl: Optional[int] = None
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Cache anime list.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
anime_list: List of anime data
|
||||||
|
ttl: Time to live in seconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if successful
|
||||||
|
"""
|
||||||
|
key = self._make_key("anime", "list")
|
||||||
|
return await self.set(key, anime_list, ttl)
|
||||||
|
|
||||||
|
async def get_anime_list(self) -> Optional[List[Dict[str, Any]]]:
|
||||||
|
"""
|
||||||
|
Get cached anime list.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cached anime list or None
|
||||||
|
"""
|
||||||
|
key = self._make_key("anime", "list")
|
||||||
|
return await self.get(key)
|
||||||
|
|
||||||
|
async def cache_anime_detail(
|
||||||
|
self, anime_id: str, data: Dict[str, Any], ttl: Optional[int] = None
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Cache anime detail.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
anime_id: Anime identifier
|
||||||
|
data: Anime data
|
||||||
|
ttl: Time to live in seconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if successful
|
||||||
|
"""
|
||||||
|
key = self._make_key("anime", "detail", anime_id)
|
||||||
|
return await self.set(key, data, ttl)
|
||||||
|
|
||||||
|
async def get_anime_detail(self, anime_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get cached anime detail.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
anime_id: Anime identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cached anime data or None
|
||||||
|
"""
|
||||||
|
key = self._make_key("anime", "detail", anime_id)
|
||||||
|
return await self.get(key)
|
||||||
|
|
||||||
|
async def invalidate_anime_cache(self) -> int:
|
||||||
|
"""
|
||||||
|
Invalidate all anime-related cache.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of keys invalidated
|
||||||
|
"""
|
||||||
|
return await self.invalidate_pattern(f"{self.key_prefix}anime*")
|
||||||
|
|
||||||
|
async def cache_config(
|
||||||
|
self, config: Dict[str, Any], ttl: Optional[int] = None
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Cache configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Configuration data
|
||||||
|
ttl: Time to live in seconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if successful
|
||||||
|
"""
|
||||||
|
key = self._make_key("config")
|
||||||
|
return await self.set(key, config, ttl)
|
||||||
|
|
||||||
|
async def get_config(self) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get cached configuration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cached configuration or None
|
||||||
|
"""
|
||||||
|
key = self._make_key("config")
|
||||||
|
return await self.get(key)
|
||||||
|
|
||||||
|
async def invalidate_config_cache(self) -> bool:
|
||||||
|
"""
|
||||||
|
Invalidate configuration cache.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if successful
|
||||||
|
"""
|
||||||
|
key = self._make_key("config")
|
||||||
|
return await self.delete(key)
|
||||||
|
|
||||||
|
|
||||||
|
# Global cache service instance
|
||||||
|
_cache_service: Optional[CacheService] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_cache_service() -> CacheService:
|
||||||
|
"""
|
||||||
|
Get the global cache service instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CacheService instance
|
||||||
|
"""
|
||||||
|
global _cache_service
|
||||||
|
if _cache_service is None:
|
||||||
|
_cache_service = CacheService()
|
||||||
|
return _cache_service
|
||||||
|
|
||||||
|
|
||||||
|
def configure_cache_service(
|
||||||
|
backend_type: str = "memory",
|
||||||
|
redis_url: str = "redis://localhost:6379",
|
||||||
|
default_ttl: int = 3600,
|
||||||
|
max_size: int = 1000,
|
||||||
|
) -> CacheService:
|
||||||
|
"""
|
||||||
|
Configure the global cache service.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backend_type: Type of backend ("memory" or "redis")
|
||||||
|
redis_url: Redis connection URL (for redis backend)
|
||||||
|
default_ttl: Default time to live in seconds
|
||||||
|
max_size: Maximum cache size (for memory backend)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured CacheService instance
|
||||||
|
"""
|
||||||
|
global _cache_service
|
||||||
|
|
||||||
|
if backend_type == "redis":
|
||||||
|
backend = RedisCacheBackend(redis_url=redis_url)
|
||||||
|
else:
|
||||||
|
backend = InMemoryCacheBackend(max_size=max_size)
|
||||||
|
|
||||||
|
_cache_service = CacheService(
|
||||||
|
backend=backend, default_ttl=default_ttl, key_prefix="aniworld:"
|
||||||
|
)
|
||||||
|
return _cache_service
|
||||||
626
src/server/services/notification_service.py
Normal file
626
src/server/services/notification_service.py
Normal file
@ -0,0 +1,626 @@
|
|||||||
|
"""
|
||||||
|
Notification Service for AniWorld.
|
||||||
|
|
||||||
|
This module provides notification functionality including email, webhooks,
|
||||||
|
and in-app notifications for download events and system alerts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, List, Optional, Set
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from pydantic import BaseModel, EmailStr, Field, HttpUrl
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class NotificationType(str, Enum):
|
||||||
|
"""Types of notifications."""
|
||||||
|
|
||||||
|
DOWNLOAD_COMPLETE = "download_complete"
|
||||||
|
DOWNLOAD_FAILED = "download_failed"
|
||||||
|
QUEUE_COMPLETE = "queue_complete"
|
||||||
|
SYSTEM_ERROR = "system_error"
|
||||||
|
SYSTEM_WARNING = "system_warning"
|
||||||
|
SYSTEM_INFO = "system_info"
|
||||||
|
|
||||||
|
|
||||||
|
class NotificationPriority(str, Enum):
|
||||||
|
"""Notification priority levels."""
|
||||||
|
|
||||||
|
LOW = "low"
|
||||||
|
NORMAL = "normal"
|
||||||
|
HIGH = "high"
|
||||||
|
CRITICAL = "critical"
|
||||||
|
|
||||||
|
|
||||||
|
class NotificationChannel(str, Enum):
|
||||||
|
"""Available notification channels."""
|
||||||
|
|
||||||
|
EMAIL = "email"
|
||||||
|
WEBHOOK = "webhook"
|
||||||
|
IN_APP = "in_app"
|
||||||
|
|
||||||
|
|
||||||
|
class NotificationPreferences(BaseModel):
|
||||||
|
"""User notification preferences."""
|
||||||
|
|
||||||
|
enabled_channels: Set[NotificationChannel] = Field(
|
||||||
|
default_factory=lambda: {NotificationChannel.IN_APP}
|
||||||
|
)
|
||||||
|
enabled_types: Set[NotificationType] = Field(
|
||||||
|
default_factory=lambda: set(NotificationType)
|
||||||
|
)
|
||||||
|
email_address: Optional[EmailStr] = None
|
||||||
|
webhook_urls: List[HttpUrl] = Field(default_factory=list)
|
||||||
|
quiet_hours_start: Optional[int] = Field(None, ge=0, le=23)
|
||||||
|
quiet_hours_end: Optional[int] = Field(None, ge=0, le=23)
|
||||||
|
min_priority: NotificationPriority = NotificationPriority.NORMAL
|
||||||
|
|
||||||
|
|
||||||
|
class Notification(BaseModel):
|
||||||
|
"""Notification model."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
type: NotificationType
|
||||||
|
priority: NotificationPriority
|
||||||
|
title: str
|
||||||
|
message: str
|
||||||
|
data: Optional[Dict[str, Any]] = None
|
||||||
|
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
read: bool = False
|
||||||
|
channels: Set[NotificationChannel] = Field(
|
||||||
|
default_factory=lambda: {NotificationChannel.IN_APP}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EmailNotificationService:
|
||||||
|
"""Service for sending email notifications."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
smtp_host: Optional[str] = None,
|
||||||
|
smtp_port: int = 587,
|
||||||
|
smtp_username: Optional[str] = None,
|
||||||
|
smtp_password: Optional[str] = None,
|
||||||
|
from_address: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize email notification service.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
smtp_host: SMTP server hostname
|
||||||
|
smtp_port: SMTP server port
|
||||||
|
smtp_username: SMTP authentication username
|
||||||
|
smtp_password: SMTP authentication password
|
||||||
|
from_address: Email sender address
|
||||||
|
"""
|
||||||
|
self.smtp_host = smtp_host
|
||||||
|
self.smtp_port = smtp_port
|
||||||
|
self.smtp_username = smtp_username
|
||||||
|
self.smtp_password = smtp_password
|
||||||
|
self.from_address = from_address
|
||||||
|
self._enabled = all(
|
||||||
|
[smtp_host, smtp_username, smtp_password, from_address]
|
||||||
|
)
|
||||||
|
|
||||||
|
async def send_email(
|
||||||
|
self, to_address: str, subject: str, body: str, html: bool = False
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Send an email notification.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
to_address: Recipient email address
|
||||||
|
subject: Email subject
|
||||||
|
body: Email body content
|
||||||
|
html: Whether body is HTML format
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if email sent successfully
|
||||||
|
"""
|
||||||
|
if not self._enabled:
|
||||||
|
logger.warning("Email notifications not configured")
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Import here to make aiosmtplib optional
|
||||||
|
from email.mime.multipart import MIMEMultipart
|
||||||
|
from email.mime.text import MIMEText
|
||||||
|
|
||||||
|
import aiosmtplib
|
||||||
|
|
||||||
|
message = MIMEMultipart("alternative")
|
||||||
|
message["Subject"] = subject
|
||||||
|
message["From"] = self.from_address
|
||||||
|
message["To"] = to_address
|
||||||
|
|
||||||
|
mime_type = "html" if html else "plain"
|
||||||
|
message.attach(MIMEText(body, mime_type))
|
||||||
|
|
||||||
|
await aiosmtplib.send(
|
||||||
|
message,
|
||||||
|
hostname=self.smtp_host,
|
||||||
|
port=self.smtp_port,
|
||||||
|
username=self.smtp_username,
|
||||||
|
password=self.smtp_password,
|
||||||
|
start_tls=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Email notification sent to {to_address}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
logger.error(
|
||||||
|
"aiosmtplib not installed. Install with: pip install aiosmtplib"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to send email notification: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class WebhookNotificationService:
|
||||||
|
"""Service for sending webhook notifications."""
|
||||||
|
|
||||||
|
def __init__(self, timeout: int = 10, max_retries: int = 3):
|
||||||
|
"""
|
||||||
|
Initialize webhook notification service.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeout: Request timeout in seconds
|
||||||
|
max_retries: Maximum number of retry attempts
|
||||||
|
"""
|
||||||
|
self.timeout = timeout
|
||||||
|
self.max_retries = max_retries
|
||||||
|
|
||||||
|
async def send_webhook(
|
||||||
|
self, url: str, payload: Dict[str, Any], headers: Optional[Dict[str, str]] = None
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Send a webhook notification.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: Webhook URL
|
||||||
|
payload: JSON payload to send
|
||||||
|
headers: Optional custom headers
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if webhook sent successfully
|
||||||
|
"""
|
||||||
|
if headers is None:
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
|
||||||
|
for attempt in range(self.max_retries):
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
url,
|
||||||
|
json=payload,
|
||||||
|
headers=headers,
|
||||||
|
timeout=aiohttp.ClientTimeout(total=self.timeout),
|
||||||
|
) as response:
|
||||||
|
if response.status < 400:
|
||||||
|
logger.info(f"Webhook notification sent to {url}")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"Webhook returned status {response.status}: {url}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning(f"Webhook timeout (attempt {attempt + 1}/{self.max_retries}): {url}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to send webhook (attempt {attempt + 1}/{self.max_retries}): {e}")
|
||||||
|
|
||||||
|
if attempt < self.max_retries - 1:
|
||||||
|
await asyncio.sleep(2 ** attempt) # Exponential backoff
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class InAppNotificationService:
|
||||||
|
"""Service for managing in-app notifications."""
|
||||||
|
|
||||||
|
def __init__(self, max_notifications: int = 100):
|
||||||
|
"""
|
||||||
|
Initialize in-app notification service.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_notifications: Maximum number of notifications to keep
|
||||||
|
"""
|
||||||
|
self.notifications: List[Notification] = []
|
||||||
|
self.max_notifications = max_notifications
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def add_notification(self, notification: Notification) -> None:
|
||||||
|
"""
|
||||||
|
Add a notification to the in-app list.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
notification: Notification to add
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
self.notifications.insert(0, notification)
|
||||||
|
if len(self.notifications) > self.max_notifications:
|
||||||
|
self.notifications = self.notifications[: self.max_notifications]
|
||||||
|
|
||||||
|
async def get_notifications(
|
||||||
|
self, unread_only: bool = False, limit: Optional[int] = None
|
||||||
|
) -> List[Notification]:
|
||||||
|
"""
|
||||||
|
Get in-app notifications.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unread_only: Only return unread notifications
|
||||||
|
limit: Maximum number of notifications to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of notifications
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
notifications = self.notifications
|
||||||
|
if unread_only:
|
||||||
|
notifications = [n for n in notifications if not n.read]
|
||||||
|
if limit:
|
||||||
|
notifications = notifications[:limit]
|
||||||
|
return notifications.copy()
|
||||||
|
|
||||||
|
async def mark_as_read(self, notification_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Mark a notification as read.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
notification_id: ID of notification to mark
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if notification was found and marked
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
for notification in self.notifications:
|
||||||
|
if notification.id == notification_id:
|
||||||
|
notification.read = True
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def mark_all_as_read(self) -> int:
|
||||||
|
"""
|
||||||
|
Mark all notifications as read.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of notifications marked as read
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
count = 0
|
||||||
|
for notification in self.notifications:
|
||||||
|
if not notification.read:
|
||||||
|
notification.read = True
|
||||||
|
count += 1
|
||||||
|
return count
|
||||||
|
|
||||||
|
async def clear_notifications(self, read_only: bool = True) -> int:
|
||||||
|
"""
|
||||||
|
Clear notifications.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
read_only: Only clear read notifications
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of notifications cleared
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
if read_only:
|
||||||
|
initial_count = len(self.notifications)
|
||||||
|
self.notifications = [n for n in self.notifications if not n.read]
|
||||||
|
return initial_count - len(self.notifications)
|
||||||
|
else:
|
||||||
|
count = len(self.notifications)
|
||||||
|
self.notifications.clear()
|
||||||
|
return count
|
||||||
|
|
||||||
|
|
||||||
|
class NotificationService:
|
||||||
|
"""Main notification service coordinating all notification channels."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
email_service: Optional[EmailNotificationService] = None,
|
||||||
|
webhook_service: Optional[WebhookNotificationService] = None,
|
||||||
|
in_app_service: Optional[InAppNotificationService] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize notification service.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
email_service: Email notification service instance
|
||||||
|
webhook_service: Webhook notification service instance
|
||||||
|
in_app_service: In-app notification service instance
|
||||||
|
"""
|
||||||
|
self.email_service = email_service or EmailNotificationService()
|
||||||
|
self.webhook_service = webhook_service or WebhookNotificationService()
|
||||||
|
self.in_app_service = in_app_service or InAppNotificationService()
|
||||||
|
self.preferences = NotificationPreferences()
|
||||||
|
|
||||||
|
def set_preferences(self, preferences: NotificationPreferences) -> None:
|
||||||
|
"""
|
||||||
|
Update notification preferences.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
preferences: New notification preferences
|
||||||
|
"""
|
||||||
|
self.preferences = preferences
|
||||||
|
|
||||||
|
def _is_in_quiet_hours(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if current time is within quiet hours.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if in quiet hours
|
||||||
|
"""
|
||||||
|
if (
|
||||||
|
self.preferences.quiet_hours_start is None
|
||||||
|
or self.preferences.quiet_hours_end is None
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
|
||||||
|
current_hour = datetime.now().hour
|
||||||
|
start = self.preferences.quiet_hours_start
|
||||||
|
end = self.preferences.quiet_hours_end
|
||||||
|
|
||||||
|
if start <= end:
|
||||||
|
return start <= current_hour < end
|
||||||
|
else: # Quiet hours span midnight
|
||||||
|
return current_hour >= start or current_hour < end
|
||||||
|
|
||||||
|
def _should_send_notification(
|
||||||
|
self, notification_type: NotificationType, priority: NotificationPriority
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Determine if a notification should be sent based on preferences.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
notification_type: Type of notification
|
||||||
|
priority: Priority level
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if notification should be sent
|
||||||
|
"""
|
||||||
|
# Check if type is enabled
|
||||||
|
if notification_type not in self.preferences.enabled_types:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check priority level
|
||||||
|
priority_order = [
|
||||||
|
NotificationPriority.LOW,
|
||||||
|
NotificationPriority.NORMAL,
|
||||||
|
NotificationPriority.HIGH,
|
||||||
|
NotificationPriority.CRITICAL,
|
||||||
|
]
|
||||||
|
if (
|
||||||
|
priority_order.index(priority)
|
||||||
|
< priority_order.index(self.preferences.min_priority)
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check quiet hours (critical notifications bypass quiet hours)
|
||||||
|
if priority != NotificationPriority.CRITICAL and self._is_in_quiet_hours():
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def send_notification(self, notification: Notification) -> Dict[str, bool]:
|
||||||
|
"""
|
||||||
|
Send a notification through enabled channels.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
notification: Notification to send
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping channel names to success status
|
||||||
|
"""
|
||||||
|
if not self._should_send_notification(notification.type, notification.priority):
|
||||||
|
logger.debug(
|
||||||
|
f"Notification not sent due to preferences: {notification.type}"
|
||||||
|
)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
# Send in-app notification
|
||||||
|
if NotificationChannel.IN_APP in self.preferences.enabled_channels:
|
||||||
|
try:
|
||||||
|
await self.in_app_service.add_notification(notification)
|
||||||
|
results["in_app"] = True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to send in-app notification: {e}")
|
||||||
|
results["in_app"] = False
|
||||||
|
|
||||||
|
# Send email notification
|
||||||
|
if (
|
||||||
|
NotificationChannel.EMAIL in self.preferences.enabled_channels
|
||||||
|
and self.preferences.email_address
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
success = await self.email_service.send_email(
|
||||||
|
to_address=self.preferences.email_address,
|
||||||
|
subject=f"[{notification.priority.upper()}] {notification.title}",
|
||||||
|
body=notification.message,
|
||||||
|
)
|
||||||
|
results["email"] = success
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to send email notification: {e}")
|
||||||
|
results["email"] = False
|
||||||
|
|
||||||
|
# Send webhook notifications
|
||||||
|
if (
|
||||||
|
NotificationChannel.WEBHOOK in self.preferences.enabled_channels
|
||||||
|
and self.preferences.webhook_urls
|
||||||
|
):
|
||||||
|
payload = {
|
||||||
|
"id": notification.id,
|
||||||
|
"type": notification.type,
|
||||||
|
"priority": notification.priority,
|
||||||
|
"title": notification.title,
|
||||||
|
"message": notification.message,
|
||||||
|
"data": notification.data,
|
||||||
|
"created_at": notification.created_at.isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
webhook_results = []
|
||||||
|
for url in self.preferences.webhook_urls:
|
||||||
|
try:
|
||||||
|
success = await self.webhook_service.send_webhook(str(url), payload)
|
||||||
|
webhook_results.append(success)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to send webhook notification to {url}: {e}")
|
||||||
|
webhook_results.append(False)
|
||||||
|
|
||||||
|
results["webhook"] = all(webhook_results) if webhook_results else False
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def notify_download_complete(
|
||||||
|
self, series_name: str, episode: str, file_path: str
|
||||||
|
) -> Dict[str, bool]:
|
||||||
|
"""
|
||||||
|
Send notification for completed download.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
series_name: Name of the series
|
||||||
|
episode: Episode identifier
|
||||||
|
file_path: Path to downloaded file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of send results by channel
|
||||||
|
"""
|
||||||
|
notification = Notification(
|
||||||
|
id=f"download_complete_{datetime.utcnow().timestamp()}",
|
||||||
|
type=NotificationType.DOWNLOAD_COMPLETE,
|
||||||
|
priority=NotificationPriority.NORMAL,
|
||||||
|
title=f"Download Complete: {series_name}",
|
||||||
|
message=f"Episode {episode} has been downloaded successfully.",
|
||||||
|
data={
|
||||||
|
"series_name": series_name,
|
||||||
|
"episode": episode,
|
||||||
|
"file_path": file_path,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return await self.send_notification(notification)
|
||||||
|
|
||||||
|
async def notify_download_failed(
|
||||||
|
self, series_name: str, episode: str, error: str
|
||||||
|
) -> Dict[str, bool]:
|
||||||
|
"""
|
||||||
|
Send notification for failed download.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
series_name: Name of the series
|
||||||
|
episode: Episode identifier
|
||||||
|
error: Error message
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of send results by channel
|
||||||
|
"""
|
||||||
|
notification = Notification(
|
||||||
|
id=f"download_failed_{datetime.utcnow().timestamp()}",
|
||||||
|
type=NotificationType.DOWNLOAD_FAILED,
|
||||||
|
priority=NotificationPriority.HIGH,
|
||||||
|
title=f"Download Failed: {series_name}",
|
||||||
|
message=f"Episode {episode} failed to download: {error}",
|
||||||
|
data={"series_name": series_name, "episode": episode, "error": error},
|
||||||
|
)
|
||||||
|
return await self.send_notification(notification)
|
||||||
|
|
||||||
|
async def notify_queue_complete(self, total_downloads: int) -> Dict[str, bool]:
|
||||||
|
"""
|
||||||
|
Send notification for completed download queue.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
total_downloads: Number of downloads completed
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of send results by channel
|
||||||
|
"""
|
||||||
|
notification = Notification(
|
||||||
|
id=f"queue_complete_{datetime.utcnow().timestamp()}",
|
||||||
|
type=NotificationType.QUEUE_COMPLETE,
|
||||||
|
priority=NotificationPriority.NORMAL,
|
||||||
|
title="Download Queue Complete",
|
||||||
|
message=f"All {total_downloads} downloads have been completed.",
|
||||||
|
data={"total_downloads": total_downloads},
|
||||||
|
)
|
||||||
|
return await self.send_notification(notification)
|
||||||
|
|
||||||
|
async def notify_system_error(self, error: str, details: Optional[Dict[str, Any]] = None) -> Dict[str, bool]:
|
||||||
|
"""
|
||||||
|
Send notification for system error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error: Error message
|
||||||
|
details: Optional error details
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of send results by channel
|
||||||
|
"""
|
||||||
|
notification = Notification(
|
||||||
|
id=f"system_error_{datetime.utcnow().timestamp()}",
|
||||||
|
type=NotificationType.SYSTEM_ERROR,
|
||||||
|
priority=NotificationPriority.CRITICAL,
|
||||||
|
title="System Error",
|
||||||
|
message=error,
|
||||||
|
data=details,
|
||||||
|
)
|
||||||
|
return await self.send_notification(notification)
|
||||||
|
|
||||||
|
|
||||||
|
# Global notification service instance
|
||||||
|
_notification_service: Optional[NotificationService] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_notification_service() -> NotificationService:
|
||||||
|
"""
|
||||||
|
Get the global notification service instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
NotificationService instance
|
||||||
|
"""
|
||||||
|
global _notification_service
|
||||||
|
if _notification_service is None:
|
||||||
|
_notification_service = NotificationService()
|
||||||
|
return _notification_service
|
||||||
|
|
||||||
|
|
||||||
|
def configure_notification_service(
|
||||||
|
smtp_host: Optional[str] = None,
|
||||||
|
smtp_port: int = 587,
|
||||||
|
smtp_username: Optional[str] = None,
|
||||||
|
smtp_password: Optional[str] = None,
|
||||||
|
from_address: Optional[str] = None,
|
||||||
|
) -> NotificationService:
|
||||||
|
"""
|
||||||
|
Configure the global notification service.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
smtp_host: SMTP server hostname
|
||||||
|
smtp_port: SMTP server port
|
||||||
|
smtp_username: SMTP authentication username
|
||||||
|
smtp_password: SMTP authentication password
|
||||||
|
from_address: Email sender address
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured NotificationService instance
|
||||||
|
"""
|
||||||
|
global _notification_service
|
||||||
|
email_service = EmailNotificationService(
|
||||||
|
smtp_host=smtp_host,
|
||||||
|
smtp_port=smtp_port,
|
||||||
|
smtp_username=smtp_username,
|
||||||
|
smtp_password=smtp_password,
|
||||||
|
from_address=from_address,
|
||||||
|
)
|
||||||
|
_notification_service = NotificationService(email_service=email_service)
|
||||||
|
return _notification_service
|
||||||
628
src/server/utils/validators.py
Normal file
628
src/server/utils/validators.py
Normal file
@ -0,0 +1,628 @@
|
|||||||
|
"""
|
||||||
|
Data Validation Utilities for AniWorld.
|
||||||
|
|
||||||
|
This module provides Pydantic validators and business rule validation
|
||||||
|
utilities for ensuring data integrity across the application.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class ValidationError(Exception):
|
||||||
|
"""Custom validation error."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ValidatorMixin:
|
||||||
|
"""Mixin class providing common validation utilities."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def validate_password_strength(password: str) -> str:
|
||||||
|
"""
|
||||||
|
Validate password meets security requirements.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
password: Password to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Validated password
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If password doesn't meet requirements
|
||||||
|
"""
|
||||||
|
if len(password) < 8:
|
||||||
|
raise ValueError("Password must be at least 8 characters long")
|
||||||
|
|
||||||
|
if not re.search(r"[A-Z]", password):
|
||||||
|
raise ValueError(
|
||||||
|
"Password must contain at least one uppercase letter"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not re.search(r"[a-z]", password):
|
||||||
|
raise ValueError(
|
||||||
|
"Password must contain at least one lowercase letter"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not re.search(r"[0-9]", password):
|
||||||
|
raise ValueError("Password must contain at least one digit")
|
||||||
|
|
||||||
|
if not re.search(r"[!@#$%^&*(),.?\":{}|<>]", password):
|
||||||
|
raise ValueError(
|
||||||
|
"Password must contain at least one special character"
|
||||||
|
)
|
||||||
|
|
||||||
|
return password
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def validate_file_path(path: str, must_exist: bool = False) -> str:
|
||||||
|
"""
|
||||||
|
Validate file path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: File path to validate
|
||||||
|
must_exist: Whether the path must exist
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Validated path
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If path is invalid
|
||||||
|
"""
|
||||||
|
if not path or not isinstance(path, str):
|
||||||
|
raise ValueError("Path must be a non-empty string")
|
||||||
|
|
||||||
|
# Check for path traversal attempts
|
||||||
|
if ".." in path or path.startswith("/"):
|
||||||
|
raise ValueError("Invalid path: path traversal not allowed")
|
||||||
|
|
||||||
|
path_obj = Path(path)
|
||||||
|
|
||||||
|
if must_exist and not path_obj.exists():
|
||||||
|
raise ValueError(f"Path does not exist: {path}")
|
||||||
|
|
||||||
|
return path
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def validate_url(url: str) -> str:
|
||||||
|
"""
|
||||||
|
Validate URL format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: URL to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Validated URL
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If URL is invalid
|
||||||
|
"""
|
||||||
|
if not url or not isinstance(url, str):
|
||||||
|
raise ValueError("URL must be a non-empty string")
|
||||||
|
|
||||||
|
url_pattern = re.compile(
|
||||||
|
r"^https?://" # http:// or https://
|
||||||
|
r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?|"
|
||||||
|
r"localhost|" # localhost
|
||||||
|
r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})" # IP address
|
||||||
|
r"(?::\d+)?" # optional port
|
||||||
|
r"(?:/?|[/?]\S+)$",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not url_pattern.match(url):
|
||||||
|
raise ValueError(f"Invalid URL format: {url}")
|
||||||
|
|
||||||
|
return url
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def validate_email(email: str) -> str:
|
||||||
|
"""
|
||||||
|
Validate email address format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
email: Email to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Validated email
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If email is invalid
|
||||||
|
"""
|
||||||
|
if not email or not isinstance(email, str):
|
||||||
|
raise ValueError("Email must be a non-empty string")
|
||||||
|
|
||||||
|
email_pattern = re.compile(
|
||||||
|
r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not email_pattern.match(email):
|
||||||
|
raise ValueError(f"Invalid email format: {email}")
|
||||||
|
|
||||||
|
return email
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def validate_port(port: int) -> int:
|
||||||
|
"""
|
||||||
|
Validate port number.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
port: Port number to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Validated port
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If port is invalid
|
||||||
|
"""
|
||||||
|
if not isinstance(port, int):
|
||||||
|
raise ValueError("Port must be an integer")
|
||||||
|
|
||||||
|
if port < 1 or port > 65535:
|
||||||
|
raise ValueError("Port must be between 1 and 65535")
|
||||||
|
|
||||||
|
return port
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def validate_positive_integer(value: int, name: str = "Value") -> int:
|
||||||
|
"""
|
||||||
|
Validate positive integer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: Value to validate
|
||||||
|
name: Name for error messages
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Validated value
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If value is invalid
|
||||||
|
"""
|
||||||
|
if not isinstance(value, int):
|
||||||
|
raise ValueError(f"{name} must be an integer")
|
||||||
|
|
||||||
|
if value <= 0:
|
||||||
|
raise ValueError(f"{name} must be positive")
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def validate_non_negative_integer(value: int, name: str = "Value") -> int:
|
||||||
|
"""
|
||||||
|
Validate non-negative integer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: Value to validate
|
||||||
|
name: Name for error messages
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Validated value
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If value is invalid
|
||||||
|
"""
|
||||||
|
if not isinstance(value, int):
|
||||||
|
raise ValueError(f"{name} must be an integer")
|
||||||
|
|
||||||
|
if value < 0:
|
||||||
|
raise ValueError(f"{name} cannot be negative")
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def validate_string_length(
|
||||||
|
value: str, min_length: int = 0, max_length: Optional[int] = None, name: str = "Value"
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Validate string length.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: String to validate
|
||||||
|
min_length: Minimum length
|
||||||
|
max_length: Maximum length (None for no limit)
|
||||||
|
name: Name for error messages
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Validated string
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If string length is invalid
|
||||||
|
"""
|
||||||
|
if not isinstance(value, str):
|
||||||
|
raise ValueError(f"{name} must be a string")
|
||||||
|
|
||||||
|
if len(value) < min_length:
|
||||||
|
raise ValueError(
|
||||||
|
f"{name} must be at least {min_length} characters long"
|
||||||
|
)
|
||||||
|
|
||||||
|
if max_length is not None and len(value) > max_length:
|
||||||
|
raise ValueError(
|
||||||
|
f"{name} must be at most {max_length} characters long"
|
||||||
|
)
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def validate_choice(value: Any, choices: List[Any], name: str = "Value") -> Any:
|
||||||
|
"""
|
||||||
|
Validate value is in allowed choices.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: Value to validate
|
||||||
|
choices: List of allowed values
|
||||||
|
name: Name for error messages
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Validated value
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If value not in choices
|
||||||
|
"""
|
||||||
|
if value not in choices:
|
||||||
|
raise ValueError(f"{name} must be one of: {', '.join(map(str, choices))}")
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def validate_dict_keys(
|
||||||
|
data: Dict[str, Any], required_keys: List[str], name: str = "Data"
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Validate dictionary contains required keys.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Dictionary to validate
|
||||||
|
required_keys: List of required keys
|
||||||
|
name: Name for error messages
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Validated dictionary
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If required keys are missing
|
||||||
|
"""
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
raise ValueError(f"{name} must be a dictionary")
|
||||||
|
|
||||||
|
missing_keys = [key for key in required_keys if key not in data]
|
||||||
|
if missing_keys:
|
||||||
|
raise ValueError(
|
||||||
|
f"{name} missing required keys: {', '.join(missing_keys)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def validate_episode_range(start: int, end: int) -> tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Validate episode range.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start: Start episode number
|
||||||
|
end: End episode number
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (start, end)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If range is invalid
|
||||||
|
"""
|
||||||
|
if start < 1:
|
||||||
|
raise ValueError("Start episode must be at least 1")
|
||||||
|
|
||||||
|
if end < start:
|
||||||
|
raise ValueError("End episode must be greater than or equal to start")
|
||||||
|
|
||||||
|
if end - start > 1000:
|
||||||
|
raise ValueError("Episode range too large (max 1000 episodes)")
|
||||||
|
|
||||||
|
return start, end
|
||||||
|
|
||||||
|
|
||||||
|
def validate_download_quality(quality: str) -> str:
|
||||||
|
"""
|
||||||
|
Validate download quality setting.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
quality: Quality setting
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Validated quality
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If quality is invalid
|
||||||
|
"""
|
||||||
|
valid_qualities = ["360p", "480p", "720p", "1080p", "best", "worst"]
|
||||||
|
if quality not in valid_qualities:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid quality: {quality}. Must be one of: {', '.join(valid_qualities)}"
|
||||||
|
)
|
||||||
|
return quality
|
||||||
|
|
||||||
|
|
||||||
|
def validate_language(language: str) -> str:
|
||||||
|
"""
|
||||||
|
Validate language code.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
language: Language code
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Validated language
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If language is invalid
|
||||||
|
"""
|
||||||
|
valid_languages = ["ger-sub", "ger-dub", "eng-sub", "eng-dub", "jpn"]
|
||||||
|
if language not in valid_languages:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid language: {language}. Must be one of: {', '.join(valid_languages)}"
|
||||||
|
)
|
||||||
|
return language
|
||||||
|
|
||||||
|
|
||||||
|
def validate_download_priority(priority: int) -> int:
|
||||||
|
"""
|
||||||
|
Validate download priority.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
priority: Priority value
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Validated priority
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If priority is invalid
|
||||||
|
"""
|
||||||
|
if priority < 0 or priority > 10:
|
||||||
|
raise ValueError("Priority must be between 0 and 10")
|
||||||
|
return priority
|
||||||
|
|
||||||
|
|
||||||
|
def validate_anime_url(url: str) -> str:
|
||||||
|
"""
|
||||||
|
Validate anime URL format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: Anime URL
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Validated URL
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If URL is invalid
|
||||||
|
"""
|
||||||
|
if not url:
|
||||||
|
raise ValueError("URL cannot be empty")
|
||||||
|
|
||||||
|
# Check if it's a valid aniworld.to URL
|
||||||
|
if "aniworld.to" not in url and "s.to" not in url:
|
||||||
|
raise ValueError("URL must be from aniworld.to or s.to")
|
||||||
|
|
||||||
|
# Basic URL validation
|
||||||
|
ValidatorMixin.validate_url(url)
|
||||||
|
|
||||||
|
return url
|
||||||
|
|
||||||
|
|
||||||
|
def validate_series_name(name: str) -> str:
|
||||||
|
"""
|
||||||
|
Validate series name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Series name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Validated name
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If name is invalid
|
||||||
|
"""
|
||||||
|
if not name or not name.strip():
|
||||||
|
raise ValueError("Series name cannot be empty")
|
||||||
|
|
||||||
|
if len(name) > 200:
|
||||||
|
raise ValueError("Series name too long (max 200 characters)")
|
||||||
|
|
||||||
|
# Check for invalid characters
|
||||||
|
invalid_chars = ['<', '>', ':', '"', '/', '\\', '|', '?', '*']
|
||||||
|
for char in invalid_chars:
|
||||||
|
if char in name:
|
||||||
|
raise ValueError(
|
||||||
|
f"Series name contains invalid character: {char}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return name.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def validate_backup_name(name: str) -> str:
|
||||||
|
"""
|
||||||
|
Validate backup file name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Backup name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Validated name
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If name is invalid
|
||||||
|
"""
|
||||||
|
if not name or not name.strip():
|
||||||
|
raise ValueError("Backup name cannot be empty")
|
||||||
|
|
||||||
|
# Must be a valid filename
|
||||||
|
if not re.match(r"^[a-zA-Z0-9_\-\.]+$", name):
|
||||||
|
raise ValueError(
|
||||||
|
"Backup name can only contain letters, numbers, underscores, hyphens, and dots"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not name.endswith(".json"):
|
||||||
|
raise ValueError("Backup name must end with .json")
|
||||||
|
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
def validate_config_data(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Validate configuration data structure.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Configuration data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Validated data
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If data is invalid
|
||||||
|
"""
|
||||||
|
required_keys = ["download_directory", "concurrent_downloads"]
|
||||||
|
ValidatorMixin.validate_dict_keys(data, required_keys, "Configuration")
|
||||||
|
|
||||||
|
# Validate download directory
|
||||||
|
if not isinstance(data["download_directory"], str):
|
||||||
|
raise ValueError("download_directory must be a string")
|
||||||
|
|
||||||
|
# Validate concurrent downloads
|
||||||
|
concurrent = data["concurrent_downloads"]
|
||||||
|
if not isinstance(concurrent, int) or concurrent < 1 or concurrent > 10:
|
||||||
|
raise ValueError("concurrent_downloads must be between 1 and 10")
|
||||||
|
|
||||||
|
# Validate quality if present
|
||||||
|
if "quality" in data:
|
||||||
|
validate_download_quality(data["quality"])
|
||||||
|
|
||||||
|
# Validate language if present
|
||||||
|
if "language" in data:
|
||||||
|
validate_language(data["language"])
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_filename(filename: str) -> str:
|
||||||
|
"""
|
||||||
|
Sanitize filename for safe filesystem use.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: Original filename
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sanitized filename
|
||||||
|
"""
|
||||||
|
# Remove or replace invalid characters
|
||||||
|
invalid_chars = ['<', '>', ':', '"', '/', '\\', '|', '?', '*']
|
||||||
|
for char in invalid_chars:
|
||||||
|
filename = filename.replace(char, '_')
|
||||||
|
|
||||||
|
# Remove leading/trailing spaces and dots
|
||||||
|
filename = filename.strip('. ')
|
||||||
|
|
||||||
|
# Ensure not empty
|
||||||
|
if not filename:
|
||||||
|
filename = "unnamed"
|
||||||
|
|
||||||
|
# Limit length
|
||||||
|
if len(filename) > 255:
|
||||||
|
name, ext = filename.rsplit('.', 1) if '.' in filename else (filename, '')
|
||||||
|
max_name_len = 255 - len(ext) - 1 if ext else 255
|
||||||
|
filename = name[:max_name_len] + ('.' + ext if ext else '')
|
||||||
|
|
||||||
|
return filename
|
||||||
|
|
||||||
|
|
||||||
|
def validate_jwt_token(token: str) -> str:
|
||||||
|
"""
|
||||||
|
Validate JWT token format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: JWT token
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Validated token
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If token format is invalid
|
||||||
|
"""
|
||||||
|
if not token or not isinstance(token, str):
|
||||||
|
raise ValueError("Token must be a non-empty string")
|
||||||
|
|
||||||
|
# JWT tokens have 3 parts separated by dots
|
||||||
|
parts = token.split(".")
|
||||||
|
if len(parts) != 3:
|
||||||
|
raise ValueError("Invalid JWT token format")
|
||||||
|
|
||||||
|
# Each part should be base64url encoded (alphanumeric + - and _)
|
||||||
|
for part in parts:
|
||||||
|
if not re.match(r"^[A-Za-z0-9_-]+$", part):
|
||||||
|
raise ValueError("Invalid JWT token encoding")
|
||||||
|
|
||||||
|
return token
|
||||||
|
|
||||||
|
|
||||||
|
def validate_ip_address(ip: str) -> str:
|
||||||
|
"""
|
||||||
|
Validate IP address format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ip: IP address
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Validated IP address
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If IP is invalid
|
||||||
|
"""
|
||||||
|
if not ip or not isinstance(ip, str):
|
||||||
|
raise ValueError("IP address must be a non-empty string")
|
||||||
|
|
||||||
|
# IPv4 pattern
|
||||||
|
ipv4_pattern = re.compile(
|
||||||
|
r"^(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}"
|
||||||
|
r"(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$"
|
||||||
|
)
|
||||||
|
|
||||||
|
# IPv6 pattern (simplified)
|
||||||
|
ipv6_pattern = re.compile(
|
||||||
|
r"^(?:[0-9a-fA-F]{1,4}:){7}[0-9a-fA-F]{1,4}$"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not ipv4_pattern.match(ip) and not ipv6_pattern.match(ip):
|
||||||
|
raise ValueError(f"Invalid IP address format: {ip}")
|
||||||
|
|
||||||
|
return ip
|
||||||
|
|
||||||
|
|
||||||
|
def validate_websocket_message(message: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Validate WebSocket message structure.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: WebSocket message
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Validated message
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If message structure is invalid
|
||||||
|
"""
|
||||||
|
required_keys = ["type"]
|
||||||
|
ValidatorMixin.validate_dict_keys(message, required_keys, "WebSocket message")
|
||||||
|
|
||||||
|
valid_types = [
|
||||||
|
"download_progress",
|
||||||
|
"download_complete",
|
||||||
|
"download_failed",
|
||||||
|
"queue_update",
|
||||||
|
"error",
|
||||||
|
"system_message",
|
||||||
|
]
|
||||||
|
|
||||||
|
if message["type"] not in valid_types:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid message type. Must be one of: {', '.join(valid_types)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return message
|
||||||
Loading…
x
Reference in New Issue
Block a user