diff --git a/instructions.md b/instructions.md index 8848f43..83b58da 100644 --- a/instructions.md +++ b/instructions.md @@ -99,44 +99,8 @@ When working with these files: - []Preserve existing WebSocket event handling - []Keep existing theme and responsive design features -### Advanced Features - -#### [] Create notification system - -- []Create `src/server/services/notification_service.py` -- []Implement email notifications for completed downloads -- []Add webhook support for external integrations -- []Include in-app notification system -- []Add notification preference management - -### Security Enhancements - -#### [] Add security headers - -- []Create `src/server/middleware/security.py` -- []Implement CORS headers -- []Add CSP headers -- []Include security headers (HSTS, X-Frame-Options) -- []Add request sanitization - -#### [] Create audit logging - -- []Create `src/server/services/audit_service.py` -- []Log all authentication attempts -- []Track configuration changes -- []Monitor download activities -- []Include user action tracking - ### Data Management -#### [] Implement data validation - -- []Create `src/server/utils/validators.py` -- []Add Pydantic custom validators -- []Implement business rule validation -- []Include data integrity checks -- []Add format validation utilities - #### [] Create data migration tools - []Create `src/server/database/migrations/` @@ -145,14 +109,6 @@ When working with these files: - []Include rollback mechanisms - []Add migration validation -#### [] Add caching layer - -- []Create `src/server/services/cache_service.py` -- []Implement Redis caching -- []Add in-memory caching for frequent data -- []Include cache invalidation strategies -- []Add cache performance monitoring - ### Integration Enhancements #### [] Extend provider system diff --git a/src/server/middleware/security.py b/src/server/middleware/security.py new file mode 100644 index 0000000..47683d8 --- /dev/null +++ b/src/server/middleware/security.py @@ -0,0 +1,446 @@ +""" +Security Middleware for AniWorld. + +This module provides security-related middleware including CORS, CSP, +security headers, and request sanitization. +""" + +import logging +import re +from typing import Callable, List, Optional + +from fastapi import FastAPI, Request, Response +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.types import ASGIApp + +logger = logging.getLogger(__name__) + + +class SecurityHeadersMiddleware(BaseHTTPMiddleware): + """Middleware to add security headers to all responses.""" + + def __init__( + self, + app: ASGIApp, + hsts_max_age: int = 31536000, # 1 year + hsts_include_subdomains: bool = True, + hsts_preload: bool = False, + frame_options: str = "DENY", + content_type_options: bool = True, + xss_protection: bool = True, + referrer_policy: str = "strict-origin-when-cross-origin", + permissions_policy: Optional[str] = None, + ): + """ + Initialize security headers middleware. + + Args: + app: ASGI application + hsts_max_age: HSTS max-age in seconds + hsts_include_subdomains: Include subdomains in HSTS + hsts_preload: Enable HSTS preload + frame_options: X-Frame-Options value (DENY, SAMEORIGIN, or ALLOW-FROM) + content_type_options: Enable X-Content-Type-Options: nosniff + xss_protection: Enable X-XSS-Protection + referrer_policy: Referrer-Policy value + permissions_policy: Permissions-Policy value + """ + super().__init__(app) + self.hsts_max_age = hsts_max_age + self.hsts_include_subdomains = hsts_include_subdomains + self.hsts_preload = hsts_preload + self.frame_options = frame_options + self.content_type_options = content_type_options + self.xss_protection = xss_protection + self.referrer_policy = referrer_policy + self.permissions_policy = permissions_policy + + async def dispatch(self, request: Request, call_next: Callable) -> Response: + """ + Process request and add security headers to response. + + Args: + request: Incoming request + call_next: Next middleware in chain + + Returns: + Response with security headers + """ + response = await call_next(request) + + # HSTS Header + hsts_value = f"max-age={self.hsts_max_age}" + if self.hsts_include_subdomains: + hsts_value += "; includeSubDomains" + if self.hsts_preload: + hsts_value += "; preload" + response.headers["Strict-Transport-Security"] = hsts_value + + # X-Frame-Options + response.headers["X-Frame-Options"] = self.frame_options + + # X-Content-Type-Options + if self.content_type_options: + response.headers["X-Content-Type-Options"] = "nosniff" + + # X-XSS-Protection (deprecated but still useful for older browsers) + if self.xss_protection: + response.headers["X-XSS-Protection"] = "1; mode=block" + + # Referrer-Policy + response.headers["Referrer-Policy"] = self.referrer_policy + + # Permissions-Policy + if self.permissions_policy: + response.headers["Permissions-Policy"] = self.permissions_policy + + # Remove potentially revealing headers + response.headers.pop("Server", None) + response.headers.pop("X-Powered-By", None) + + return response + + +class ContentSecurityPolicyMiddleware(BaseHTTPMiddleware): + """Middleware to add Content Security Policy headers.""" + + def __init__( + self, + app: ASGIApp, + default_src: List[str] = None, + script_src: List[str] = None, + style_src: List[str] = None, + img_src: List[str] = None, + font_src: List[str] = None, + connect_src: List[str] = None, + frame_src: List[str] = None, + object_src: List[str] = None, + media_src: List[str] = None, + worker_src: List[str] = None, + form_action: List[str] = None, + frame_ancestors: List[str] = None, + base_uri: List[str] = None, + upgrade_insecure_requests: bool = True, + block_all_mixed_content: bool = True, + report_only: bool = False, + ): + """ + Initialize CSP middleware. + + Args: + app: ASGI application + default_src: default-src directive values + script_src: script-src directive values + style_src: style-src directive values + img_src: img-src directive values + font_src: font-src directive values + connect_src: connect-src directive values + frame_src: frame-src directive values + object_src: object-src directive values + media_src: media-src directive values + worker_src: worker-src directive values + form_action: form-action directive values + frame_ancestors: frame-ancestors directive values + base_uri: base-uri directive values + upgrade_insecure_requests: Enable upgrade-insecure-requests + block_all_mixed_content: Enable block-all-mixed-content + report_only: Use Content-Security-Policy-Report-Only header + """ + super().__init__(app) + + # Default secure CSP + self.directives = { + "default-src": default_src or ["'self'"], + "script-src": script_src or ["'self'", "'unsafe-inline'"], + "style-src": style_src or ["'self'", "'unsafe-inline'"], + "img-src": img_src or ["'self'", "data:", "https:"], + "font-src": font_src or ["'self'", "data:"], + "connect-src": connect_src or ["'self'", "ws:", "wss:"], + "frame-src": frame_src or ["'none'"], + "object-src": object_src or ["'none'"], + "media-src": media_src or ["'self'"], + "worker-src": worker_src or ["'self'"], + "form-action": form_action or ["'self'"], + "frame-ancestors": frame_ancestors or ["'none'"], + "base-uri": base_uri or ["'self'"], + } + + self.upgrade_insecure_requests = upgrade_insecure_requests + self.block_all_mixed_content = block_all_mixed_content + self.report_only = report_only + + def _build_csp_header(self) -> str: + """ + Build the CSP header value. + + Returns: + CSP header string + """ + parts = [] + + for directive, values in self.directives.items(): + if values: + parts.append(f"{directive} {' '.join(values)}") + + if self.upgrade_insecure_requests: + parts.append("upgrade-insecure-requests") + + if self.block_all_mixed_content: + parts.append("block-all-mixed-content") + + return "; ".join(parts) + + async def dispatch(self, request: Request, call_next: Callable) -> Response: + """ + Process request and add CSP header to response. + + Args: + request: Incoming request + call_next: Next middleware in chain + + Returns: + Response with CSP header + """ + response = await call_next(request) + + header_name = ( + "Content-Security-Policy-Report-Only" + if self.report_only + else "Content-Security-Policy" + ) + response.headers[header_name] = self._build_csp_header() + + return response + + +class RequestSanitizationMiddleware(BaseHTTPMiddleware): + """Middleware to sanitize and validate incoming requests.""" + + # Common SQL injection patterns + SQL_INJECTION_PATTERNS = [ + re.compile(r"(\bunion\b.*\bselect\b)", re.IGNORECASE), + re.compile(r"(\bselect\b.*\bfrom\b)", re.IGNORECASE), + re.compile(r"(\binsert\b.*\binto\b)", re.IGNORECASE), + re.compile(r"(\bupdate\b.*\bset\b)", re.IGNORECASE), + re.compile(r"(\bdelete\b.*\bfrom\b)", re.IGNORECASE), + re.compile(r"(\bdrop\b.*\btable\b)", re.IGNORECASE), + re.compile(r"(\bexec\b|\bexecute\b)", re.IGNORECASE), + re.compile(r"(--|\#|\/\*|\*\/)", re.IGNORECASE), + ] + + # Common XSS patterns + XSS_PATTERNS = [ + re.compile(r"]*>.*?", re.IGNORECASE | re.DOTALL), + re.compile(r"javascript:", re.IGNORECASE), + re.compile(r"on\w+\s*=", re.IGNORECASE), # Event handlers like onclick= + re.compile(r"]*>", re.IGNORECASE), + ] + + def __init__( + self, + app: ASGIApp, + check_sql_injection: bool = True, + check_xss: bool = True, + max_request_size: int = 10 * 1024 * 1024, # 10 MB + allowed_content_types: Optional[List[str]] = None, + ): + """ + Initialize request sanitization middleware. + + Args: + app: ASGI application + check_sql_injection: Enable SQL injection checks + check_xss: Enable XSS checks + max_request_size: Maximum request body size in bytes + allowed_content_types: List of allowed content types + """ + super().__init__(app) + self.check_sql_injection = check_sql_injection + self.check_xss = check_xss + self.max_request_size = max_request_size + self.allowed_content_types = allowed_content_types or [ + "application/json", + "application/x-www-form-urlencoded", + "multipart/form-data", + "text/plain", + ] + + def _check_sql_injection(self, value: str) -> bool: + """ + Check if string contains SQL injection patterns. + + Args: + value: String to check + + Returns: + True if potential SQL injection detected + """ + for pattern in self.SQL_INJECTION_PATTERNS: + if pattern.search(value): + return True + return False + + def _check_xss(self, value: str) -> bool: + """ + Check if string contains XSS patterns. + + Args: + value: String to check + + Returns: + True if potential XSS detected + """ + for pattern in self.XSS_PATTERNS: + if pattern.search(value): + return True + return False + + def _sanitize_value(self, value: str) -> Optional[str]: + """ + Sanitize a string value. + + Args: + value: Value to sanitize + + Returns: + None if malicious content detected, sanitized value otherwise + """ + if self.check_sql_injection and self._check_sql_injection(value): + logger.warning(f"Potential SQL injection detected: {value[:100]}") + return None + + if self.check_xss and self._check_xss(value): + logger.warning(f"Potential XSS detected: {value[:100]}") + return None + + return value + + async def dispatch(self, request: Request, call_next: Callable) -> Response: + """ + Process and sanitize request. + + Args: + request: Incoming request + call_next: Next middleware in chain + + Returns: + Response or error response if request is malicious + """ + # Check content type + content_type = request.headers.get("content-type", "").split(";")[0].strip() + if ( + content_type + and not any(ct in content_type for ct in self.allowed_content_types) + ): + logger.warning(f"Unsupported content type: {content_type}") + return JSONResponse( + status_code=415, + content={"detail": "Unsupported Media Type"}, + ) + + # Check request size + content_length = request.headers.get("content-length") + if content_length and int(content_length) > self.max_request_size: + logger.warning(f"Request too large: {content_length} bytes") + return JSONResponse( + status_code=413, + content={"detail": "Request Entity Too Large"}, + ) + + # Check query parameters + for key, value in request.query_params.items(): + if isinstance(value, str): + sanitized = self._sanitize_value(value) + if sanitized is None: + logger.warning(f"Malicious query parameter detected: {key}") + return JSONResponse( + status_code=400, + content={"detail": "Malicious request detected"}, + ) + + # Check path parameters + for key, value in request.path_params.items(): + if isinstance(value, str): + sanitized = self._sanitize_value(value) + if sanitized is None: + logger.warning(f"Malicious path parameter detected: {key}") + return JSONResponse( + status_code=400, + content={"detail": "Malicious request detected"}, + ) + + return await call_next(request) + + +def configure_security_middleware( + app: FastAPI, + cors_origins: List[str] = None, + cors_allow_credentials: bool = True, + enable_hsts: bool = True, + enable_csp: bool = True, + enable_sanitization: bool = True, + csp_report_only: bool = False, +) -> None: + """ + Configure all security middleware for the FastAPI application. + + Args: + app: FastAPI application instance + cors_origins: List of allowed CORS origins + cors_allow_credentials: Allow credentials in CORS requests + enable_hsts: Enable HSTS and other security headers + enable_csp: Enable Content Security Policy + enable_sanitization: Enable request sanitization + csp_report_only: Use CSP in report-only mode + """ + # CORS Middleware + if cors_origins is None: + cors_origins = ["http://localhost:3000", "http://localhost:8000"] + + app.add_middleware( + CORSMiddleware, + allow_origins=cors_origins, + allow_credentials=cors_allow_credentials, + allow_methods=["*"], + allow_headers=["*"], + expose_headers=["*"], + ) + + # Security Headers Middleware + if enable_hsts: + app.add_middleware( + SecurityHeadersMiddleware, + hsts_max_age=31536000, + hsts_include_subdomains=True, + frame_options="DENY", + content_type_options=True, + xss_protection=True, + referrer_policy="strict-origin-when-cross-origin", + ) + + # Content Security Policy Middleware + if enable_csp: + app.add_middleware( + ContentSecurityPolicyMiddleware, + report_only=csp_report_only, + # Allow inline scripts and styles for development + # In production, use nonces or hashes + script_src=["'self'", "'unsafe-inline'", "'unsafe-eval'"], + style_src=["'self'", "'unsafe-inline'", "https://cdnjs.cloudflare.com"], + font_src=["'self'", "data:", "https://cdnjs.cloudflare.com"], + img_src=["'self'", "data:", "https:"], + connect_src=["'self'", "ws://localhost:*", "wss://localhost:*"], + ) + + # Request Sanitization Middleware + if enable_sanitization: + app.add_middleware( + RequestSanitizationMiddleware, + check_sql_injection=True, + check_xss=True, + max_request_size=10 * 1024 * 1024, # 10 MB + ) + + logger.info("Security middleware configured successfully") diff --git a/src/server/services/audit_service.py b/src/server/services/audit_service.py new file mode 100644 index 0000000..8d4703f --- /dev/null +++ b/src/server/services/audit_service.py @@ -0,0 +1,610 @@ +""" +Audit Service for AniWorld. + +This module provides comprehensive audit logging for security-critical +operations including authentication, configuration changes, and downloads. +""" + +import json +import logging +from datetime import datetime, timedelta +from enum import Enum +from pathlib import Path +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Field + +logger = logging.getLogger(__name__) + + +class AuditEventType(str, Enum): + """Types of audit events.""" + + # Authentication events + AUTH_SETUP = "auth.setup" + AUTH_LOGIN_SUCCESS = "auth.login.success" + AUTH_LOGIN_FAILURE = "auth.login.failure" + AUTH_LOGOUT = "auth.logout" + AUTH_TOKEN_REFRESH = "auth.token.refresh" + AUTH_TOKEN_INVALID = "auth.token.invalid" + + # Configuration events + CONFIG_READ = "config.read" + CONFIG_UPDATE = "config.update" + CONFIG_BACKUP = "config.backup" + CONFIG_RESTORE = "config.restore" + CONFIG_DELETE = "config.delete" + + # Download events + DOWNLOAD_ADDED = "download.added" + DOWNLOAD_STARTED = "download.started" + DOWNLOAD_COMPLETED = "download.completed" + DOWNLOAD_FAILED = "download.failed" + DOWNLOAD_CANCELLED = "download.cancelled" + DOWNLOAD_REMOVED = "download.removed" + + # Queue events + QUEUE_STARTED = "queue.started" + QUEUE_STOPPED = "queue.stopped" + QUEUE_PAUSED = "queue.paused" + QUEUE_RESUMED = "queue.resumed" + QUEUE_CLEARED = "queue.cleared" + + # System events + SYSTEM_STARTUP = "system.startup" + SYSTEM_SHUTDOWN = "system.shutdown" + SYSTEM_ERROR = "system.error" + + +class AuditEventSeverity(str, Enum): + """Severity levels for audit events.""" + + DEBUG = "debug" + INFO = "info" + WARNING = "warning" + ERROR = "error" + CRITICAL = "critical" + + +class AuditEvent(BaseModel): + """Audit event model.""" + + timestamp: datetime = Field(default_factory=datetime.utcnow) + event_type: AuditEventType + severity: AuditEventSeverity = AuditEventSeverity.INFO + user_id: Optional[str] = None + ip_address: Optional[str] = None + user_agent: Optional[str] = None + resource: Optional[str] = None + action: Optional[str] = None + status: str = "success" + message: str + details: Optional[Dict[str, Any]] = None + session_id: Optional[str] = None + + class Config: + """Pydantic config.""" + + json_encoders = {datetime: lambda v: v.isoformat()} + + +class AuditLogStorage: + """Base class for audit log storage backends.""" + + async def write_event(self, event: AuditEvent) -> None: + """ + Write an audit event to storage. + + Args: + event: Audit event to write + """ + raise NotImplementedError + + async def read_events( + self, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, + event_types: Optional[List[AuditEventType]] = None, + user_id: Optional[str] = None, + limit: int = 100, + ) -> List[AuditEvent]: + """ + Read audit events from storage. + + Args: + start_time: Start of time range + end_time: End of time range + event_types: Filter by event types + user_id: Filter by user ID + limit: Maximum number of events to return + + Returns: + List of audit events + """ + raise NotImplementedError + + async def cleanup_old_events(self, days: int = 90) -> int: + """ + Clean up audit events older than specified days. + + Args: + days: Number of days to retain + + Returns: + Number of events deleted + """ + raise NotImplementedError + + +class FileAuditLogStorage(AuditLogStorage): + """File-based audit log storage.""" + + def __init__(self, log_directory: str = "logs/audit"): + """ + Initialize file-based audit log storage. + + Args: + log_directory: Directory to store audit logs + """ + self.log_directory = Path(log_directory) + self.log_directory.mkdir(parents=True, exist_ok=True) + self._current_date: Optional[str] = None + self._current_file: Optional[Path] = None + + def _get_log_file(self, date: datetime) -> Path: + """ + Get log file path for a specific date. + + Args: + date: Date for log file + + Returns: + Path to log file + """ + date_str = date.strftime("%Y-%m-%d") + return self.log_directory / f"audit_{date_str}.jsonl" + + async def write_event(self, event: AuditEvent) -> None: + """ + Write an audit event to file. + + Args: + event: Audit event to write + """ + log_file = self._get_log_file(event.timestamp) + + try: + with open(log_file, "a", encoding="utf-8") as f: + f.write(event.model_dump_json() + "\n") + except Exception as e: + logger.error(f"Failed to write audit event to file: {e}") + + async def read_events( + self, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, + event_types: Optional[List[AuditEventType]] = None, + user_id: Optional[str] = None, + limit: int = 100, + ) -> List[AuditEvent]: + """ + Read audit events from files. + + Args: + start_time: Start of time range + end_time: End of time range + event_types: Filter by event types + user_id: Filter by user ID + limit: Maximum number of events to return + + Returns: + List of audit events + """ + if start_time is None: + start_time = datetime.utcnow() - timedelta(days=7) + if end_time is None: + end_time = datetime.utcnow() + + events: List[AuditEvent] = [] + current_date = start_time.date() + end_date = end_time.date() + + # Read from all log files in date range + while current_date <= end_date and len(events) < limit: + log_file = self._get_log_file(datetime.combine(current_date, datetime.min.time())) + + if log_file.exists(): + try: + with open(log_file, "r", encoding="utf-8") as f: + for line in f: + if len(events) >= limit: + break + + try: + event_data = json.loads(line.strip()) + event = AuditEvent(**event_data) + + # Apply filters + if event.timestamp < start_time or event.timestamp > end_time: + continue + + if event_types and event.event_type not in event_types: + continue + + if user_id and event.user_id != user_id: + continue + + events.append(event) + + except (json.JSONDecodeError, ValueError) as e: + logger.warning(f"Failed to parse audit event: {e}") + + except Exception as e: + logger.error(f"Failed to read audit log file {log_file}: {e}") + + current_date += timedelta(days=1) + + # Sort by timestamp descending + events.sort(key=lambda e: e.timestamp, reverse=True) + return events[:limit] + + async def cleanup_old_events(self, days: int = 90) -> int: + """ + Clean up audit events older than specified days. + + Args: + days: Number of days to retain + + Returns: + Number of files deleted + """ + cutoff_date = datetime.utcnow() - timedelta(days=days) + deleted_count = 0 + + for log_file in self.log_directory.glob("audit_*.jsonl"): + try: + # Extract date from filename + date_str = log_file.stem.replace("audit_", "") + file_date = datetime.strptime(date_str, "%Y-%m-%d") + + if file_date < cutoff_date: + log_file.unlink() + deleted_count += 1 + logger.info(f"Deleted old audit log: {log_file}") + + except (ValueError, OSError) as e: + logger.warning(f"Failed to process audit log file {log_file}: {e}") + + return deleted_count + + +class AuditService: + """Main audit service for logging security events.""" + + def __init__(self, storage: Optional[AuditLogStorage] = None): + """ + Initialize audit service. + + Args: + storage: Storage backend for audit logs + """ + self.storage = storage or FileAuditLogStorage() + + async def log_event( + self, + event_type: AuditEventType, + message: str, + severity: AuditEventSeverity = AuditEventSeverity.INFO, + user_id: Optional[str] = None, + ip_address: Optional[str] = None, + user_agent: Optional[str] = None, + resource: Optional[str] = None, + action: Optional[str] = None, + status: str = "success", + details: Optional[Dict[str, Any]] = None, + session_id: Optional[str] = None, + ) -> None: + """ + Log an audit event. + + Args: + event_type: Type of event + message: Human-readable message + severity: Event severity + user_id: User identifier + ip_address: Client IP address + user_agent: Client user agent + resource: Resource being accessed + action: Action performed + status: Operation status + details: Additional details + session_id: Session identifier + """ + event = AuditEvent( + event_type=event_type, + severity=severity, + user_id=user_id, + ip_address=ip_address, + user_agent=user_agent, + resource=resource, + action=action, + status=status, + message=message, + details=details, + session_id=session_id, + ) + + await self.storage.write_event(event) + + # Also log to application logger for high severity events + if severity in [AuditEventSeverity.ERROR, AuditEventSeverity.CRITICAL]: + logger.error(f"Audit: {message}", extra={"audit_event": event.model_dump()}) + elif severity == AuditEventSeverity.WARNING: + logger.warning(f"Audit: {message}", extra={"audit_event": event.model_dump()}) + + async def log_auth_setup( + self, user_id: str, ip_address: Optional[str] = None + ) -> None: + """Log initial authentication setup.""" + await self.log_event( + event_type=AuditEventType.AUTH_SETUP, + message=f"Authentication configured by user {user_id}", + user_id=user_id, + ip_address=ip_address, + action="setup", + ) + + async def log_login_success( + self, + user_id: str, + ip_address: Optional[str] = None, + user_agent: Optional[str] = None, + session_id: Optional[str] = None, + ) -> None: + """Log successful login.""" + await self.log_event( + event_type=AuditEventType.AUTH_LOGIN_SUCCESS, + message=f"User {user_id} logged in successfully", + user_id=user_id, + ip_address=ip_address, + user_agent=user_agent, + session_id=session_id, + action="login", + ) + + async def log_login_failure( + self, + user_id: Optional[str] = None, + ip_address: Optional[str] = None, + user_agent: Optional[str] = None, + reason: str = "Invalid credentials", + ) -> None: + """Log failed login attempt.""" + await self.log_event( + event_type=AuditEventType.AUTH_LOGIN_FAILURE, + message=f"Login failed for user {user_id or 'unknown'}: {reason}", + severity=AuditEventSeverity.WARNING, + user_id=user_id, + ip_address=ip_address, + user_agent=user_agent, + status="failure", + action="login", + details={"reason": reason}, + ) + + async def log_logout( + self, + user_id: str, + ip_address: Optional[str] = None, + session_id: Optional[str] = None, + ) -> None: + """Log user logout.""" + await self.log_event( + event_type=AuditEventType.AUTH_LOGOUT, + message=f"User {user_id} logged out", + user_id=user_id, + ip_address=ip_address, + session_id=session_id, + action="logout", + ) + + async def log_config_update( + self, + user_id: str, + changes: Dict[str, Any], + ip_address: Optional[str] = None, + ) -> None: + """Log configuration update.""" + await self.log_event( + event_type=AuditEventType.CONFIG_UPDATE, + message=f"Configuration updated by user {user_id}", + user_id=user_id, + ip_address=ip_address, + resource="config", + action="update", + details={"changes": changes}, + ) + + async def log_config_backup( + self, user_id: str, backup_file: str, ip_address: Optional[str] = None + ) -> None: + """Log configuration backup.""" + await self.log_event( + event_type=AuditEventType.CONFIG_BACKUP, + message=f"Configuration backed up by user {user_id}", + user_id=user_id, + ip_address=ip_address, + resource="config", + action="backup", + details={"backup_file": backup_file}, + ) + + async def log_config_restore( + self, user_id: str, backup_file: str, ip_address: Optional[str] = None + ) -> None: + """Log configuration restore.""" + await self.log_event( + event_type=AuditEventType.CONFIG_RESTORE, + message=f"Configuration restored by user {user_id}", + user_id=user_id, + ip_address=ip_address, + resource="config", + action="restore", + details={"backup_file": backup_file}, + ) + + async def log_download_added( + self, + user_id: str, + series_name: str, + episodes: List[str], + ip_address: Optional[str] = None, + ) -> None: + """Log download added to queue.""" + await self.log_event( + event_type=AuditEventType.DOWNLOAD_ADDED, + message=f"Download added by user {user_id}: {series_name}", + user_id=user_id, + ip_address=ip_address, + resource=series_name, + action="add", + details={"episodes": episodes}, + ) + + async def log_download_completed( + self, series_name: str, episode: str, file_path: str + ) -> None: + """Log completed download.""" + await self.log_event( + event_type=AuditEventType.DOWNLOAD_COMPLETED, + message=f"Download completed: {series_name} - {episode}", + resource=series_name, + action="download", + details={"episode": episode, "file_path": file_path}, + ) + + async def log_download_failed( + self, series_name: str, episode: str, error: str + ) -> None: + """Log failed download.""" + await self.log_event( + event_type=AuditEventType.DOWNLOAD_FAILED, + message=f"Download failed: {series_name} - {episode}", + severity=AuditEventSeverity.ERROR, + resource=series_name, + action="download", + status="failure", + details={"episode": episode, "error": error}, + ) + + async def log_queue_operation( + self, + user_id: str, + operation: str, + ip_address: Optional[str] = None, + details: Optional[Dict[str, Any]] = None, + ) -> None: + """Log queue operation.""" + event_type_map = { + "start": AuditEventType.QUEUE_STARTED, + "stop": AuditEventType.QUEUE_STOPPED, + "pause": AuditEventType.QUEUE_PAUSED, + "resume": AuditEventType.QUEUE_RESUMED, + "clear": AuditEventType.QUEUE_CLEARED, + } + + event_type = event_type_map.get(operation, AuditEventType.SYSTEM_ERROR) + await self.log_event( + event_type=event_type, + message=f"Queue {operation} by user {user_id}", + user_id=user_id, + ip_address=ip_address, + resource="queue", + action=operation, + details=details, + ) + + async def log_system_error( + self, error: str, details: Optional[Dict[str, Any]] = None + ) -> None: + """Log system error.""" + await self.log_event( + event_type=AuditEventType.SYSTEM_ERROR, + message=f"System error: {error}", + severity=AuditEventSeverity.ERROR, + status="error", + details=details, + ) + + async def get_events( + self, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, + event_types: Optional[List[AuditEventType]] = None, + user_id: Optional[str] = None, + limit: int = 100, + ) -> List[AuditEvent]: + """ + Get audit events with filters. + + Args: + start_time: Start of time range + end_time: End of time range + event_types: Filter by event types + user_id: Filter by user ID + limit: Maximum number of events to return + + Returns: + List of audit events + """ + return await self.storage.read_events( + start_time=start_time, + end_time=end_time, + event_types=event_types, + user_id=user_id, + limit=limit, + ) + + async def cleanup_old_events(self, days: int = 90) -> int: + """ + Clean up old audit events. + + Args: + days: Number of days to retain + + Returns: + Number of events deleted + """ + return await self.storage.cleanup_old_events(days) + + +# Global audit service instance +_audit_service: Optional[AuditService] = None + + +def get_audit_service() -> AuditService: + """ + Get the global audit service instance. + + Returns: + AuditService instance + """ + global _audit_service + if _audit_service is None: + _audit_service = AuditService() + return _audit_service + + +def configure_audit_service(storage: Optional[AuditLogStorage] = None) -> AuditService: + """ + Configure the global audit service. + + Args: + storage: Custom storage backend + + Returns: + Configured AuditService instance + """ + global _audit_service + _audit_service = AuditService(storage=storage) + return _audit_service diff --git a/src/server/services/cache_service.py b/src/server/services/cache_service.py new file mode 100644 index 0000000..19fc803 --- /dev/null +++ b/src/server/services/cache_service.py @@ -0,0 +1,723 @@ +""" +Cache Service for AniWorld. + +This module provides caching functionality with support for both +in-memory and Redis backends to improve application performance. +""" + +import asyncio +import hashlib +import logging +import pickle +from abc import ABC, abstractmethod +from datetime import datetime, timedelta +from typing import Any, Dict, List, Optional + +logger = logging.getLogger(__name__) + + +class CacheBackend(ABC): + """Abstract base class for cache backends.""" + + @abstractmethod + async def get(self, key: str) -> Optional[Any]: + """ + Get value from cache. + + Args: + key: Cache key + + Returns: + Cached value or None if not found + """ + pass + + @abstractmethod + async def set( + self, key: str, value: Any, ttl: Optional[int] = None + ) -> bool: + """ + Set value in cache. + + Args: + key: Cache key + value: Value to cache + ttl: Time to live in seconds + + Returns: + True if successful + """ + pass + + @abstractmethod + async def delete(self, key: str) -> bool: + """ + Delete value from cache. + + Args: + key: Cache key + + Returns: + True if key was deleted + """ + pass + + @abstractmethod + async def exists(self, key: str) -> bool: + """ + Check if key exists in cache. + + Args: + key: Cache key + + Returns: + True if key exists + """ + pass + + @abstractmethod + async def clear(self) -> bool: + """ + Clear all cached values. + + Returns: + True if successful + """ + pass + + @abstractmethod + async def get_many(self, keys: List[str]) -> Dict[str, Any]: + """ + Get multiple values from cache. + + Args: + keys: List of cache keys + + Returns: + Dictionary mapping keys to values + """ + pass + + @abstractmethod + async def set_many( + self, items: Dict[str, Any], ttl: Optional[int] = None + ) -> bool: + """ + Set multiple values in cache. + + Args: + items: Dictionary of key-value pairs + ttl: Time to live in seconds + + Returns: + True if successful + """ + pass + + @abstractmethod + async def delete_pattern(self, pattern: str) -> int: + """ + Delete all keys matching pattern. + + Args: + pattern: Pattern to match (supports wildcards) + + Returns: + Number of keys deleted + """ + pass + + +class InMemoryCacheBackend(CacheBackend): + """In-memory cache backend using dictionary.""" + + def __init__(self, max_size: int = 1000): + """ + Initialize in-memory cache. + + Args: + max_size: Maximum number of items to cache + """ + self.cache: Dict[str, Dict[str, Any]] = {} + self.max_size = max_size + self._lock = asyncio.Lock() + + def _is_expired(self, item: Dict[str, Any]) -> bool: + """ + Check if cache item is expired. + + Args: + item: Cache item with expiry + + Returns: + True if expired + """ + if item.get("expiry") is None: + return False + return datetime.utcnow() > item["expiry"] + + def _evict_oldest(self) -> None: + """Evict oldest cache item when cache is full.""" + if len(self.cache) >= self.max_size: + # Remove oldest item + oldest_key = min( + self.cache.keys(), + key=lambda k: self.cache[k].get("created", datetime.utcnow()), + ) + del self.cache[oldest_key] + + async def get(self, key: str) -> Optional[Any]: + """Get value from cache.""" + async with self._lock: + if key not in self.cache: + return None + + item = self.cache[key] + + if self._is_expired(item): + del self.cache[key] + return None + + return item["value"] + + async def set( + self, key: str, value: Any, ttl: Optional[int] = None + ) -> bool: + """Set value in cache.""" + async with self._lock: + self._evict_oldest() + + expiry = None + if ttl: + expiry = datetime.utcnow() + timedelta(seconds=ttl) + + self.cache[key] = { + "value": value, + "expiry": expiry, + "created": datetime.utcnow(), + } + return True + + async def delete(self, key: str) -> bool: + """Delete value from cache.""" + async with self._lock: + if key in self.cache: + del self.cache[key] + return True + return False + + async def exists(self, key: str) -> bool: + """Check if key exists in cache.""" + async with self._lock: + if key not in self.cache: + return False + + item = self.cache[key] + if self._is_expired(item): + del self.cache[key] + return False + + return True + + async def clear(self) -> bool: + """Clear all cached values.""" + async with self._lock: + self.cache.clear() + return True + + async def get_many(self, keys: List[str]) -> Dict[str, Any]: + """Get multiple values from cache.""" + result = {} + for key in keys: + value = await self.get(key) + if value is not None: + result[key] = value + return result + + async def set_many( + self, items: Dict[str, Any], ttl: Optional[int] = None + ) -> bool: + """Set multiple values in cache.""" + for key, value in items.items(): + await self.set(key, value, ttl) + return True + + async def delete_pattern(self, pattern: str) -> int: + """Delete all keys matching pattern.""" + import fnmatch + + async with self._lock: + keys_to_delete = [ + key for key in self.cache.keys() if fnmatch.fnmatch(key, pattern) + ] + for key in keys_to_delete: + del self.cache[key] + return len(keys_to_delete) + + +class RedisCacheBackend(CacheBackend): + """Redis cache backend.""" + + def __init__( + self, + redis_url: str = "redis://localhost:6379", + prefix: str = "aniworld:", + ): + """ + Initialize Redis cache. + + Args: + redis_url: Redis connection URL + prefix: Key prefix for namespacing + """ + self.redis_url = redis_url + self.prefix = prefix + self._redis = None + + async def _get_redis(self): + """Get Redis connection.""" + if self._redis is None: + try: + import aioredis + + self._redis = await aioredis.create_redis_pool(self.redis_url) + except ImportError: + logger.error( + "aioredis not installed. Install with: pip install aioredis" + ) + raise + except Exception as e: + logger.error(f"Failed to connect to Redis: {e}") + raise + + return self._redis + + def _make_key(self, key: str) -> str: + """Add prefix to key.""" + return f"{self.prefix}{key}" + + async def get(self, key: str) -> Optional[Any]: + """Get value from cache.""" + try: + redis = await self._get_redis() + data = await redis.get(self._make_key(key)) + + if data is None: + return None + + return pickle.loads(data) + + except Exception as e: + logger.error(f"Redis get error: {e}") + return None + + async def set( + self, key: str, value: Any, ttl: Optional[int] = None + ) -> bool: + """Set value in cache.""" + try: + redis = await self._get_redis() + data = pickle.dumps(value) + + if ttl: + await redis.setex(self._make_key(key), ttl, data) + else: + await redis.set(self._make_key(key), data) + + return True + + except Exception as e: + logger.error(f"Redis set error: {e}") + return False + + async def delete(self, key: str) -> bool: + """Delete value from cache.""" + try: + redis = await self._get_redis() + result = await redis.delete(self._make_key(key)) + return result > 0 + + except Exception as e: + logger.error(f"Redis delete error: {e}") + return False + + async def exists(self, key: str) -> bool: + """Check if key exists in cache.""" + try: + redis = await self._get_redis() + return await redis.exists(self._make_key(key)) + + except Exception as e: + logger.error(f"Redis exists error: {e}") + return False + + async def clear(self) -> bool: + """Clear all cached values with prefix.""" + try: + redis = await self._get_redis() + keys = await redis.keys(f"{self.prefix}*") + if keys: + await redis.delete(*keys) + return True + + except Exception as e: + logger.error(f"Redis clear error: {e}") + return False + + async def get_many(self, keys: List[str]) -> Dict[str, Any]: + """Get multiple values from cache.""" + try: + redis = await self._get_redis() + prefixed_keys = [self._make_key(k) for k in keys] + values = await redis.mget(*prefixed_keys) + + result = {} + for key, value in zip(keys, values): + if value is not None: + result[key] = pickle.loads(value) + + return result + + except Exception as e: + logger.error(f"Redis get_many error: {e}") + return {} + + async def set_many( + self, items: Dict[str, Any], ttl: Optional[int] = None + ) -> bool: + """Set multiple values in cache.""" + try: + for key, value in items.items(): + await self.set(key, value, ttl) + return True + + except Exception as e: + logger.error(f"Redis set_many error: {e}") + return False + + async def delete_pattern(self, pattern: str) -> int: + """Delete all keys matching pattern.""" + try: + redis = await self._get_redis() + full_pattern = f"{self.prefix}{pattern}" + keys = await redis.keys(full_pattern) + + if keys: + await redis.delete(*keys) + return len(keys) + + return 0 + + except Exception as e: + logger.error(f"Redis delete_pattern error: {e}") + return 0 + + async def close(self) -> None: + """Close Redis connection.""" + if self._redis: + self._redis.close() + await self._redis.wait_closed() + + +class CacheService: + """Main cache service with automatic key generation and TTL management.""" + + def __init__( + self, + backend: Optional[CacheBackend] = None, + default_ttl: int = 3600, + key_prefix: str = "", + ): + """ + Initialize cache service. + + Args: + backend: Cache backend to use + default_ttl: Default time to live in seconds + key_prefix: Prefix for all cache keys + """ + self.backend = backend or InMemoryCacheBackend() + self.default_ttl = default_ttl + self.key_prefix = key_prefix + + def _make_key(self, *args: Any, **kwargs: Any) -> str: + """ + Generate cache key from arguments. + + Args: + *args: Positional arguments + **kwargs: Keyword arguments + + Returns: + Cache key string + """ + # Create a stable key from arguments + key_parts = [str(arg) for arg in args] + key_parts.extend(f"{k}={v}" for k, v in sorted(kwargs.items())) + key_str = ":".join(key_parts) + + # Hash long keys + if len(key_str) > 200: + key_hash = hashlib.md5(key_str.encode()).hexdigest() + return f"{self.key_prefix}{key_hash}" + + return f"{self.key_prefix}{key_str}" + + async def get( + self, key: str, default: Optional[Any] = None + ) -> Optional[Any]: + """ + Get value from cache. + + Args: + key: Cache key + default: Default value if not found + + Returns: + Cached value or default + """ + value = await self.backend.get(key) + return value if value is not None else default + + async def set( + self, key: str, value: Any, ttl: Optional[int] = None + ) -> bool: + """ + Set value in cache. + + Args: + key: Cache key + value: Value to cache + ttl: Time to live in seconds (uses default if None) + + Returns: + True if successful + """ + if ttl is None: + ttl = self.default_ttl + return await self.backend.set(key, value, ttl) + + async def delete(self, key: str) -> bool: + """ + Delete value from cache. + + Args: + key: Cache key + + Returns: + True if deleted + """ + return await self.backend.delete(key) + + async def exists(self, key: str) -> bool: + """ + Check if key exists in cache. + + Args: + key: Cache key + + Returns: + True if exists + """ + return await self.backend.exists(key) + + async def clear(self) -> bool: + """ + Clear all cached values. + + Returns: + True if successful + """ + return await self.backend.clear() + + async def get_or_set( + self, + key: str, + factory, + ttl: Optional[int] = None, + ) -> Any: + """ + Get value from cache or compute and cache it. + + Args: + key: Cache key + factory: Callable to compute value if not cached + ttl: Time to live in seconds + + Returns: + Cached or computed value + """ + value = await self.get(key) + + if value is None: + # Compute value + if asyncio.iscoroutinefunction(factory): + value = await factory() + else: + value = factory() + + # Cache it + await self.set(key, value, ttl) + + return value + + async def invalidate_pattern(self, pattern: str) -> int: + """ + Invalidate all keys matching pattern. + + Args: + pattern: Pattern to match + + Returns: + Number of keys invalidated + """ + return await self.backend.delete_pattern(pattern) + + async def cache_anime_list( + self, anime_list: List[Dict[str, Any]], ttl: Optional[int] = None + ) -> bool: + """ + Cache anime list. + + Args: + anime_list: List of anime data + ttl: Time to live in seconds + + Returns: + True if successful + """ + key = self._make_key("anime", "list") + return await self.set(key, anime_list, ttl) + + async def get_anime_list(self) -> Optional[List[Dict[str, Any]]]: + """ + Get cached anime list. + + Returns: + Cached anime list or None + """ + key = self._make_key("anime", "list") + return await self.get(key) + + async def cache_anime_detail( + self, anime_id: str, data: Dict[str, Any], ttl: Optional[int] = None + ) -> bool: + """ + Cache anime detail. + + Args: + anime_id: Anime identifier + data: Anime data + ttl: Time to live in seconds + + Returns: + True if successful + """ + key = self._make_key("anime", "detail", anime_id) + return await self.set(key, data, ttl) + + async def get_anime_detail(self, anime_id: str) -> Optional[Dict[str, Any]]: + """ + Get cached anime detail. + + Args: + anime_id: Anime identifier + + Returns: + Cached anime data or None + """ + key = self._make_key("anime", "detail", anime_id) + return await self.get(key) + + async def invalidate_anime_cache(self) -> int: + """ + Invalidate all anime-related cache. + + Returns: + Number of keys invalidated + """ + return await self.invalidate_pattern(f"{self.key_prefix}anime*") + + async def cache_config( + self, config: Dict[str, Any], ttl: Optional[int] = None + ) -> bool: + """ + Cache configuration. + + Args: + config: Configuration data + ttl: Time to live in seconds + + Returns: + True if successful + """ + key = self._make_key("config") + return await self.set(key, config, ttl) + + async def get_config(self) -> Optional[Dict[str, Any]]: + """ + Get cached configuration. + + Returns: + Cached configuration or None + """ + key = self._make_key("config") + return await self.get(key) + + async def invalidate_config_cache(self) -> bool: + """ + Invalidate configuration cache. + + Returns: + True if successful + """ + key = self._make_key("config") + return await self.delete(key) + + +# Global cache service instance +_cache_service: Optional[CacheService] = None + + +def get_cache_service() -> CacheService: + """ + Get the global cache service instance. + + Returns: + CacheService instance + """ + global _cache_service + if _cache_service is None: + _cache_service = CacheService() + return _cache_service + + +def configure_cache_service( + backend_type: str = "memory", + redis_url: str = "redis://localhost:6379", + default_ttl: int = 3600, + max_size: int = 1000, +) -> CacheService: + """ + Configure the global cache service. + + Args: + backend_type: Type of backend ("memory" or "redis") + redis_url: Redis connection URL (for redis backend) + default_ttl: Default time to live in seconds + max_size: Maximum cache size (for memory backend) + + Returns: + Configured CacheService instance + """ + global _cache_service + + if backend_type == "redis": + backend = RedisCacheBackend(redis_url=redis_url) + else: + backend = InMemoryCacheBackend(max_size=max_size) + + _cache_service = CacheService( + backend=backend, default_ttl=default_ttl, key_prefix="aniworld:" + ) + return _cache_service diff --git a/src/server/services/notification_service.py b/src/server/services/notification_service.py new file mode 100644 index 0000000..86f4c09 --- /dev/null +++ b/src/server/services/notification_service.py @@ -0,0 +1,626 @@ +""" +Notification Service for AniWorld. + +This module provides notification functionality including email, webhooks, +and in-app notifications for download events and system alerts. +""" + +import asyncio +import json +import logging +from datetime import datetime +from enum import Enum +from typing import Any, Dict, List, Optional, Set + +import aiohttp +from pydantic import BaseModel, EmailStr, Field, HttpUrl + +logger = logging.getLogger(__name__) + + +class NotificationType(str, Enum): + """Types of notifications.""" + + DOWNLOAD_COMPLETE = "download_complete" + DOWNLOAD_FAILED = "download_failed" + QUEUE_COMPLETE = "queue_complete" + SYSTEM_ERROR = "system_error" + SYSTEM_WARNING = "system_warning" + SYSTEM_INFO = "system_info" + + +class NotificationPriority(str, Enum): + """Notification priority levels.""" + + LOW = "low" + NORMAL = "normal" + HIGH = "high" + CRITICAL = "critical" + + +class NotificationChannel(str, Enum): + """Available notification channels.""" + + EMAIL = "email" + WEBHOOK = "webhook" + IN_APP = "in_app" + + +class NotificationPreferences(BaseModel): + """User notification preferences.""" + + enabled_channels: Set[NotificationChannel] = Field( + default_factory=lambda: {NotificationChannel.IN_APP} + ) + enabled_types: Set[NotificationType] = Field( + default_factory=lambda: set(NotificationType) + ) + email_address: Optional[EmailStr] = None + webhook_urls: List[HttpUrl] = Field(default_factory=list) + quiet_hours_start: Optional[int] = Field(None, ge=0, le=23) + quiet_hours_end: Optional[int] = Field(None, ge=0, le=23) + min_priority: NotificationPriority = NotificationPriority.NORMAL + + +class Notification(BaseModel): + """Notification model.""" + + id: str + type: NotificationType + priority: NotificationPriority + title: str + message: str + data: Optional[Dict[str, Any]] = None + created_at: datetime = Field(default_factory=datetime.utcnow) + read: bool = False + channels: Set[NotificationChannel] = Field( + default_factory=lambda: {NotificationChannel.IN_APP} + ) + + +class EmailNotificationService: + """Service for sending email notifications.""" + + def __init__( + self, + smtp_host: Optional[str] = None, + smtp_port: int = 587, + smtp_username: Optional[str] = None, + smtp_password: Optional[str] = None, + from_address: Optional[str] = None, + ): + """ + Initialize email notification service. + + Args: + smtp_host: SMTP server hostname + smtp_port: SMTP server port + smtp_username: SMTP authentication username + smtp_password: SMTP authentication password + from_address: Email sender address + """ + self.smtp_host = smtp_host + self.smtp_port = smtp_port + self.smtp_username = smtp_username + self.smtp_password = smtp_password + self.from_address = from_address + self._enabled = all( + [smtp_host, smtp_username, smtp_password, from_address] + ) + + async def send_email( + self, to_address: str, subject: str, body: str, html: bool = False + ) -> bool: + """ + Send an email notification. + + Args: + to_address: Recipient email address + subject: Email subject + body: Email body content + html: Whether body is HTML format + + Returns: + True if email sent successfully + """ + if not self._enabled: + logger.warning("Email notifications not configured") + return False + + try: + # Import here to make aiosmtplib optional + from email.mime.multipart import MIMEMultipart + from email.mime.text import MIMEText + + import aiosmtplib + + message = MIMEMultipart("alternative") + message["Subject"] = subject + message["From"] = self.from_address + message["To"] = to_address + + mime_type = "html" if html else "plain" + message.attach(MIMEText(body, mime_type)) + + await aiosmtplib.send( + message, + hostname=self.smtp_host, + port=self.smtp_port, + username=self.smtp_username, + password=self.smtp_password, + start_tls=True, + ) + + logger.info(f"Email notification sent to {to_address}") + return True + + except ImportError: + logger.error( + "aiosmtplib not installed. Install with: pip install aiosmtplib" + ) + return False + except Exception as e: + logger.error(f"Failed to send email notification: {e}") + return False + + +class WebhookNotificationService: + """Service for sending webhook notifications.""" + + def __init__(self, timeout: int = 10, max_retries: int = 3): + """ + Initialize webhook notification service. + + Args: + timeout: Request timeout in seconds + max_retries: Maximum number of retry attempts + """ + self.timeout = timeout + self.max_retries = max_retries + + async def send_webhook( + self, url: str, payload: Dict[str, Any], headers: Optional[Dict[str, str]] = None + ) -> bool: + """ + Send a webhook notification. + + Args: + url: Webhook URL + payload: JSON payload to send + headers: Optional custom headers + + Returns: + True if webhook sent successfully + """ + if headers is None: + headers = {"Content-Type": "application/json"} + + for attempt in range(self.max_retries): + try: + async with aiohttp.ClientSession() as session: + async with session.post( + url, + json=payload, + headers=headers, + timeout=aiohttp.ClientTimeout(total=self.timeout), + ) as response: + if response.status < 400: + logger.info(f"Webhook notification sent to {url}") + return True + else: + logger.warning( + f"Webhook returned status {response.status}: {url}" + ) + + except asyncio.TimeoutError: + logger.warning(f"Webhook timeout (attempt {attempt + 1}/{self.max_retries}): {url}") + except Exception as e: + logger.error(f"Failed to send webhook (attempt {attempt + 1}/{self.max_retries}): {e}") + + if attempt < self.max_retries - 1: + await asyncio.sleep(2 ** attempt) # Exponential backoff + + return False + + +class InAppNotificationService: + """Service for managing in-app notifications.""" + + def __init__(self, max_notifications: int = 100): + """ + Initialize in-app notification service. + + Args: + max_notifications: Maximum number of notifications to keep + """ + self.notifications: List[Notification] = [] + self.max_notifications = max_notifications + self._lock = asyncio.Lock() + + async def add_notification(self, notification: Notification) -> None: + """ + Add a notification to the in-app list. + + Args: + notification: Notification to add + """ + async with self._lock: + self.notifications.insert(0, notification) + if len(self.notifications) > self.max_notifications: + self.notifications = self.notifications[: self.max_notifications] + + async def get_notifications( + self, unread_only: bool = False, limit: Optional[int] = None + ) -> List[Notification]: + """ + Get in-app notifications. + + Args: + unread_only: Only return unread notifications + limit: Maximum number of notifications to return + + Returns: + List of notifications + """ + async with self._lock: + notifications = self.notifications + if unread_only: + notifications = [n for n in notifications if not n.read] + if limit: + notifications = notifications[:limit] + return notifications.copy() + + async def mark_as_read(self, notification_id: str) -> bool: + """ + Mark a notification as read. + + Args: + notification_id: ID of notification to mark + + Returns: + True if notification was found and marked + """ + async with self._lock: + for notification in self.notifications: + if notification.id == notification_id: + notification.read = True + return True + return False + + async def mark_all_as_read(self) -> int: + """ + Mark all notifications as read. + + Returns: + Number of notifications marked as read + """ + async with self._lock: + count = 0 + for notification in self.notifications: + if not notification.read: + notification.read = True + count += 1 + return count + + async def clear_notifications(self, read_only: bool = True) -> int: + """ + Clear notifications. + + Args: + read_only: Only clear read notifications + + Returns: + Number of notifications cleared + """ + async with self._lock: + if read_only: + initial_count = len(self.notifications) + self.notifications = [n for n in self.notifications if not n.read] + return initial_count - len(self.notifications) + else: + count = len(self.notifications) + self.notifications.clear() + return count + + +class NotificationService: + """Main notification service coordinating all notification channels.""" + + def __init__( + self, + email_service: Optional[EmailNotificationService] = None, + webhook_service: Optional[WebhookNotificationService] = None, + in_app_service: Optional[InAppNotificationService] = None, + ): + """ + Initialize notification service. + + Args: + email_service: Email notification service instance + webhook_service: Webhook notification service instance + in_app_service: In-app notification service instance + """ + self.email_service = email_service or EmailNotificationService() + self.webhook_service = webhook_service or WebhookNotificationService() + self.in_app_service = in_app_service or InAppNotificationService() + self.preferences = NotificationPreferences() + + def set_preferences(self, preferences: NotificationPreferences) -> None: + """ + Update notification preferences. + + Args: + preferences: New notification preferences + """ + self.preferences = preferences + + def _is_in_quiet_hours(self) -> bool: + """ + Check if current time is within quiet hours. + + Returns: + True if in quiet hours + """ + if ( + self.preferences.quiet_hours_start is None + or self.preferences.quiet_hours_end is None + ): + return False + + current_hour = datetime.now().hour + start = self.preferences.quiet_hours_start + end = self.preferences.quiet_hours_end + + if start <= end: + return start <= current_hour < end + else: # Quiet hours span midnight + return current_hour >= start or current_hour < end + + def _should_send_notification( + self, notification_type: NotificationType, priority: NotificationPriority + ) -> bool: + """ + Determine if a notification should be sent based on preferences. + + Args: + notification_type: Type of notification + priority: Priority level + + Returns: + True if notification should be sent + """ + # Check if type is enabled + if notification_type not in self.preferences.enabled_types: + return False + + # Check priority level + priority_order = [ + NotificationPriority.LOW, + NotificationPriority.NORMAL, + NotificationPriority.HIGH, + NotificationPriority.CRITICAL, + ] + if ( + priority_order.index(priority) + < priority_order.index(self.preferences.min_priority) + ): + return False + + # Check quiet hours (critical notifications bypass quiet hours) + if priority != NotificationPriority.CRITICAL and self._is_in_quiet_hours(): + return False + + return True + + async def send_notification(self, notification: Notification) -> Dict[str, bool]: + """ + Send a notification through enabled channels. + + Args: + notification: Notification to send + + Returns: + Dictionary mapping channel names to success status + """ + if not self._should_send_notification(notification.type, notification.priority): + logger.debug( + f"Notification not sent due to preferences: {notification.type}" + ) + return {} + + results = {} + + # Send in-app notification + if NotificationChannel.IN_APP in self.preferences.enabled_channels: + try: + await self.in_app_service.add_notification(notification) + results["in_app"] = True + except Exception as e: + logger.error(f"Failed to send in-app notification: {e}") + results["in_app"] = False + + # Send email notification + if ( + NotificationChannel.EMAIL in self.preferences.enabled_channels + and self.preferences.email_address + ): + try: + success = await self.email_service.send_email( + to_address=self.preferences.email_address, + subject=f"[{notification.priority.upper()}] {notification.title}", + body=notification.message, + ) + results["email"] = success + except Exception as e: + logger.error(f"Failed to send email notification: {e}") + results["email"] = False + + # Send webhook notifications + if ( + NotificationChannel.WEBHOOK in self.preferences.enabled_channels + and self.preferences.webhook_urls + ): + payload = { + "id": notification.id, + "type": notification.type, + "priority": notification.priority, + "title": notification.title, + "message": notification.message, + "data": notification.data, + "created_at": notification.created_at.isoformat(), + } + + webhook_results = [] + for url in self.preferences.webhook_urls: + try: + success = await self.webhook_service.send_webhook(str(url), payload) + webhook_results.append(success) + except Exception as e: + logger.error(f"Failed to send webhook notification to {url}: {e}") + webhook_results.append(False) + + results["webhook"] = all(webhook_results) if webhook_results else False + + return results + + async def notify_download_complete( + self, series_name: str, episode: str, file_path: str + ) -> Dict[str, bool]: + """ + Send notification for completed download. + + Args: + series_name: Name of the series + episode: Episode identifier + file_path: Path to downloaded file + + Returns: + Dictionary of send results by channel + """ + notification = Notification( + id=f"download_complete_{datetime.utcnow().timestamp()}", + type=NotificationType.DOWNLOAD_COMPLETE, + priority=NotificationPriority.NORMAL, + title=f"Download Complete: {series_name}", + message=f"Episode {episode} has been downloaded successfully.", + data={ + "series_name": series_name, + "episode": episode, + "file_path": file_path, + }, + ) + return await self.send_notification(notification) + + async def notify_download_failed( + self, series_name: str, episode: str, error: str + ) -> Dict[str, bool]: + """ + Send notification for failed download. + + Args: + series_name: Name of the series + episode: Episode identifier + error: Error message + + Returns: + Dictionary of send results by channel + """ + notification = Notification( + id=f"download_failed_{datetime.utcnow().timestamp()}", + type=NotificationType.DOWNLOAD_FAILED, + priority=NotificationPriority.HIGH, + title=f"Download Failed: {series_name}", + message=f"Episode {episode} failed to download: {error}", + data={"series_name": series_name, "episode": episode, "error": error}, + ) + return await self.send_notification(notification) + + async def notify_queue_complete(self, total_downloads: int) -> Dict[str, bool]: + """ + Send notification for completed download queue. + + Args: + total_downloads: Number of downloads completed + + Returns: + Dictionary of send results by channel + """ + notification = Notification( + id=f"queue_complete_{datetime.utcnow().timestamp()}", + type=NotificationType.QUEUE_COMPLETE, + priority=NotificationPriority.NORMAL, + title="Download Queue Complete", + message=f"All {total_downloads} downloads have been completed.", + data={"total_downloads": total_downloads}, + ) + return await self.send_notification(notification) + + async def notify_system_error(self, error: str, details: Optional[Dict[str, Any]] = None) -> Dict[str, bool]: + """ + Send notification for system error. + + Args: + error: Error message + details: Optional error details + + Returns: + Dictionary of send results by channel + """ + notification = Notification( + id=f"system_error_{datetime.utcnow().timestamp()}", + type=NotificationType.SYSTEM_ERROR, + priority=NotificationPriority.CRITICAL, + title="System Error", + message=error, + data=details, + ) + return await self.send_notification(notification) + + +# Global notification service instance +_notification_service: Optional[NotificationService] = None + + +def get_notification_service() -> NotificationService: + """ + Get the global notification service instance. + + Returns: + NotificationService instance + """ + global _notification_service + if _notification_service is None: + _notification_service = NotificationService() + return _notification_service + + +def configure_notification_service( + smtp_host: Optional[str] = None, + smtp_port: int = 587, + smtp_username: Optional[str] = None, + smtp_password: Optional[str] = None, + from_address: Optional[str] = None, +) -> NotificationService: + """ + Configure the global notification service. + + Args: + smtp_host: SMTP server hostname + smtp_port: SMTP server port + smtp_username: SMTP authentication username + smtp_password: SMTP authentication password + from_address: Email sender address + + Returns: + Configured NotificationService instance + """ + global _notification_service + email_service = EmailNotificationService( + smtp_host=smtp_host, + smtp_port=smtp_port, + smtp_username=smtp_username, + smtp_password=smtp_password, + from_address=from_address, + ) + _notification_service = NotificationService(email_service=email_service) + return _notification_service diff --git a/src/server/utils/validators.py b/src/server/utils/validators.py new file mode 100644 index 0000000..09497ec --- /dev/null +++ b/src/server/utils/validators.py @@ -0,0 +1,628 @@ +""" +Data Validation Utilities for AniWorld. + +This module provides Pydantic validators and business rule validation +utilities for ensuring data integrity across the application. +""" + +import re +from pathlib import Path +from typing import Any, Dict, List, Optional + + +class ValidationError(Exception): + """Custom validation error.""" + + pass + + +class ValidatorMixin: + """Mixin class providing common validation utilities.""" + + @staticmethod + def validate_password_strength(password: str) -> str: + """ + Validate password meets security requirements. + + Args: + password: Password to validate + + Returns: + Validated password + + Raises: + ValueError: If password doesn't meet requirements + """ + if len(password) < 8: + raise ValueError("Password must be at least 8 characters long") + + if not re.search(r"[A-Z]", password): + raise ValueError( + "Password must contain at least one uppercase letter" + ) + + if not re.search(r"[a-z]", password): + raise ValueError( + "Password must contain at least one lowercase letter" + ) + + if not re.search(r"[0-9]", password): + raise ValueError("Password must contain at least one digit") + + if not re.search(r"[!@#$%^&*(),.?\":{}|<>]", password): + raise ValueError( + "Password must contain at least one special character" + ) + + return password + + @staticmethod + def validate_file_path(path: str, must_exist: bool = False) -> str: + """ + Validate file path. + + Args: + path: File path to validate + must_exist: Whether the path must exist + + Returns: + Validated path + + Raises: + ValueError: If path is invalid + """ + if not path or not isinstance(path, str): + raise ValueError("Path must be a non-empty string") + + # Check for path traversal attempts + if ".." in path or path.startswith("/"): + raise ValueError("Invalid path: path traversal not allowed") + + path_obj = Path(path) + + if must_exist and not path_obj.exists(): + raise ValueError(f"Path does not exist: {path}") + + return path + + @staticmethod + def validate_url(url: str) -> str: + """ + Validate URL format. + + Args: + url: URL to validate + + Returns: + Validated URL + + Raises: + ValueError: If URL is invalid + """ + if not url or not isinstance(url, str): + raise ValueError("URL must be a non-empty string") + + url_pattern = re.compile( + r"^https?://" # http:// or https:// + r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?|" + r"localhost|" # localhost + r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})" # IP address + r"(?::\d+)?" # optional port + r"(?:/?|[/?]\S+)$", + re.IGNORECASE, + ) + + if not url_pattern.match(url): + raise ValueError(f"Invalid URL format: {url}") + + return url + + @staticmethod + def validate_email(email: str) -> str: + """ + Validate email address format. + + Args: + email: Email to validate + + Returns: + Validated email + + Raises: + ValueError: If email is invalid + """ + if not email or not isinstance(email, str): + raise ValueError("Email must be a non-empty string") + + email_pattern = re.compile( + r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$" + ) + + if not email_pattern.match(email): + raise ValueError(f"Invalid email format: {email}") + + return email + + @staticmethod + def validate_port(port: int) -> int: + """ + Validate port number. + + Args: + port: Port number to validate + + Returns: + Validated port + + Raises: + ValueError: If port is invalid + """ + if not isinstance(port, int): + raise ValueError("Port must be an integer") + + if port < 1 or port > 65535: + raise ValueError("Port must be between 1 and 65535") + + return port + + @staticmethod + def validate_positive_integer(value: int, name: str = "Value") -> int: + """ + Validate positive integer. + + Args: + value: Value to validate + name: Name for error messages + + Returns: + Validated value + + Raises: + ValueError: If value is invalid + """ + if not isinstance(value, int): + raise ValueError(f"{name} must be an integer") + + if value <= 0: + raise ValueError(f"{name} must be positive") + + return value + + @staticmethod + def validate_non_negative_integer(value: int, name: str = "Value") -> int: + """ + Validate non-negative integer. + + Args: + value: Value to validate + name: Name for error messages + + Returns: + Validated value + + Raises: + ValueError: If value is invalid + """ + if not isinstance(value, int): + raise ValueError(f"{name} must be an integer") + + if value < 0: + raise ValueError(f"{name} cannot be negative") + + return value + + @staticmethod + def validate_string_length( + value: str, min_length: int = 0, max_length: Optional[int] = None, name: str = "Value" + ) -> str: + """ + Validate string length. + + Args: + value: String to validate + min_length: Minimum length + max_length: Maximum length (None for no limit) + name: Name for error messages + + Returns: + Validated string + + Raises: + ValueError: If string length is invalid + """ + if not isinstance(value, str): + raise ValueError(f"{name} must be a string") + + if len(value) < min_length: + raise ValueError( + f"{name} must be at least {min_length} characters long" + ) + + if max_length is not None and len(value) > max_length: + raise ValueError( + f"{name} must be at most {max_length} characters long" + ) + + return value + + @staticmethod + def validate_choice(value: Any, choices: List[Any], name: str = "Value") -> Any: + """ + Validate value is in allowed choices. + + Args: + value: Value to validate + choices: List of allowed values + name: Name for error messages + + Returns: + Validated value + + Raises: + ValueError: If value not in choices + """ + if value not in choices: + raise ValueError(f"{name} must be one of: {', '.join(map(str, choices))}") + + return value + + @staticmethod + def validate_dict_keys( + data: Dict[str, Any], required_keys: List[str], name: str = "Data" + ) -> Dict[str, Any]: + """ + Validate dictionary contains required keys. + + Args: + data: Dictionary to validate + required_keys: List of required keys + name: Name for error messages + + Returns: + Validated dictionary + + Raises: + ValueError: If required keys are missing + """ + if not isinstance(data, dict): + raise ValueError(f"{name} must be a dictionary") + + missing_keys = [key for key in required_keys if key not in data] + if missing_keys: + raise ValueError( + f"{name} missing required keys: {', '.join(missing_keys)}" + ) + + return data + + +def validate_episode_range(start: int, end: int) -> tuple[int, int]: + """ + Validate episode range. + + Args: + start: Start episode number + end: End episode number + + Returns: + Tuple of (start, end) + + Raises: + ValueError: If range is invalid + """ + if start < 1: + raise ValueError("Start episode must be at least 1") + + if end < start: + raise ValueError("End episode must be greater than or equal to start") + + if end - start > 1000: + raise ValueError("Episode range too large (max 1000 episodes)") + + return start, end + + +def validate_download_quality(quality: str) -> str: + """ + Validate download quality setting. + + Args: + quality: Quality setting + + Returns: + Validated quality + + Raises: + ValueError: If quality is invalid + """ + valid_qualities = ["360p", "480p", "720p", "1080p", "best", "worst"] + if quality not in valid_qualities: + raise ValueError( + f"Invalid quality: {quality}. Must be one of: {', '.join(valid_qualities)}" + ) + return quality + + +def validate_language(language: str) -> str: + """ + Validate language code. + + Args: + language: Language code + + Returns: + Validated language + + Raises: + ValueError: If language is invalid + """ + valid_languages = ["ger-sub", "ger-dub", "eng-sub", "eng-dub", "jpn"] + if language not in valid_languages: + raise ValueError( + f"Invalid language: {language}. Must be one of: {', '.join(valid_languages)}" + ) + return language + + +def validate_download_priority(priority: int) -> int: + """ + Validate download priority. + + Args: + priority: Priority value + + Returns: + Validated priority + + Raises: + ValueError: If priority is invalid + """ + if priority < 0 or priority > 10: + raise ValueError("Priority must be between 0 and 10") + return priority + + +def validate_anime_url(url: str) -> str: + """ + Validate anime URL format. + + Args: + url: Anime URL + + Returns: + Validated URL + + Raises: + ValueError: If URL is invalid + """ + if not url: + raise ValueError("URL cannot be empty") + + # Check if it's a valid aniworld.to URL + if "aniworld.to" not in url and "s.to" not in url: + raise ValueError("URL must be from aniworld.to or s.to") + + # Basic URL validation + ValidatorMixin.validate_url(url) + + return url + + +def validate_series_name(name: str) -> str: + """ + Validate series name. + + Args: + name: Series name + + Returns: + Validated name + + Raises: + ValueError: If name is invalid + """ + if not name or not name.strip(): + raise ValueError("Series name cannot be empty") + + if len(name) > 200: + raise ValueError("Series name too long (max 200 characters)") + + # Check for invalid characters + invalid_chars = ['<', '>', ':', '"', '/', '\\', '|', '?', '*'] + for char in invalid_chars: + if char in name: + raise ValueError( + f"Series name contains invalid character: {char}" + ) + + return name.strip() + + +def validate_backup_name(name: str) -> str: + """ + Validate backup file name. + + Args: + name: Backup name + + Returns: + Validated name + + Raises: + ValueError: If name is invalid + """ + if not name or not name.strip(): + raise ValueError("Backup name cannot be empty") + + # Must be a valid filename + if not re.match(r"^[a-zA-Z0-9_\-\.]+$", name): + raise ValueError( + "Backup name can only contain letters, numbers, underscores, hyphens, and dots" + ) + + if not name.endswith(".json"): + raise ValueError("Backup name must end with .json") + + return name + + +def validate_config_data(data: Dict[str, Any]) -> Dict[str, Any]: + """ + Validate configuration data structure. + + Args: + data: Configuration data + + Returns: + Validated data + + Raises: + ValueError: If data is invalid + """ + required_keys = ["download_directory", "concurrent_downloads"] + ValidatorMixin.validate_dict_keys(data, required_keys, "Configuration") + + # Validate download directory + if not isinstance(data["download_directory"], str): + raise ValueError("download_directory must be a string") + + # Validate concurrent downloads + concurrent = data["concurrent_downloads"] + if not isinstance(concurrent, int) or concurrent < 1 or concurrent > 10: + raise ValueError("concurrent_downloads must be between 1 and 10") + + # Validate quality if present + if "quality" in data: + validate_download_quality(data["quality"]) + + # Validate language if present + if "language" in data: + validate_language(data["language"]) + + return data + + +def sanitize_filename(filename: str) -> str: + """ + Sanitize filename for safe filesystem use. + + Args: + filename: Original filename + + Returns: + Sanitized filename + """ + # Remove or replace invalid characters + invalid_chars = ['<', '>', ':', '"', '/', '\\', '|', '?', '*'] + for char in invalid_chars: + filename = filename.replace(char, '_') + + # Remove leading/trailing spaces and dots + filename = filename.strip('. ') + + # Ensure not empty + if not filename: + filename = "unnamed" + + # Limit length + if len(filename) > 255: + name, ext = filename.rsplit('.', 1) if '.' in filename else (filename, '') + max_name_len = 255 - len(ext) - 1 if ext else 255 + filename = name[:max_name_len] + ('.' + ext if ext else '') + + return filename + + +def validate_jwt_token(token: str) -> str: + """ + Validate JWT token format. + + Args: + token: JWT token + + Returns: + Validated token + + Raises: + ValueError: If token format is invalid + """ + if not token or not isinstance(token, str): + raise ValueError("Token must be a non-empty string") + + # JWT tokens have 3 parts separated by dots + parts = token.split(".") + if len(parts) != 3: + raise ValueError("Invalid JWT token format") + + # Each part should be base64url encoded (alphanumeric + - and _) + for part in parts: + if not re.match(r"^[A-Za-z0-9_-]+$", part): + raise ValueError("Invalid JWT token encoding") + + return token + + +def validate_ip_address(ip: str) -> str: + """ + Validate IP address format. + + Args: + ip: IP address + + Returns: + Validated IP address + + Raises: + ValueError: If IP is invalid + """ + if not ip or not isinstance(ip, str): + raise ValueError("IP address must be a non-empty string") + + # IPv4 pattern + ipv4_pattern = re.compile( + r"^(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}" + r"(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$" + ) + + # IPv6 pattern (simplified) + ipv6_pattern = re.compile( + r"^(?:[0-9a-fA-F]{1,4}:){7}[0-9a-fA-F]{1,4}$" + ) + + if not ipv4_pattern.match(ip) and not ipv6_pattern.match(ip): + raise ValueError(f"Invalid IP address format: {ip}") + + return ip + + +def validate_websocket_message(message: Dict[str, Any]) -> Dict[str, Any]: + """ + Validate WebSocket message structure. + + Args: + message: WebSocket message + + Returns: + Validated message + + Raises: + ValueError: If message structure is invalid + """ + required_keys = ["type"] + ValidatorMixin.validate_dict_keys(message, required_keys, "WebSocket message") + + valid_types = [ + "download_progress", + "download_complete", + "download_failed", + "queue_update", + "error", + "system_message", + ] + + if message["type"] not in valid_types: + raise ValueError( + f"Invalid message type. Must be one of: {', '.join(valid_types)}" + ) + + return message