This commit is contained in:
2025-10-23 19:00:49 +02:00
parent 3d5c19939c
commit c81a493fb1
9 changed files with 1225 additions and 131 deletions

View File

@@ -0,0 +1,20 @@
"""Security utilities for the Aniworld application.
This module provides security-related utilities including:
- File integrity verification with checksums
- Database integrity checks
- Configuration encryption
"""
from .config_encryption import ConfigEncryption, get_config_encryption
from .database_integrity import DatabaseIntegrityChecker, check_database_integrity
from .file_integrity import FileIntegrityManager, get_integrity_manager
__all__ = [
"FileIntegrityManager",
"get_integrity_manager",
"DatabaseIntegrityChecker",
"check_database_integrity",
"ConfigEncryption",
"get_config_encryption",
]

View File

@@ -0,0 +1,274 @@
"""Configuration encryption utilities.
This module provides encryption/decryption for sensitive configuration
values such as passwords, API keys, and tokens.
"""
import base64
import logging
import os
from pathlib import Path
from typing import Any, Dict, Optional
from cryptography.fernet import Fernet
logger = logging.getLogger(__name__)
class ConfigEncryption:
"""Handles encryption/decryption of sensitive configuration values."""
def __init__(self, key_file: Optional[Path] = None):
"""Initialize the configuration encryption.
Args:
key_file: Path to store encryption key.
Defaults to data/encryption.key
"""
if key_file is None:
project_root = Path(__file__).parent.parent.parent.parent
key_file = project_root / "data" / "encryption.key"
self.key_file = Path(key_file)
self._cipher: Optional[Fernet] = None
self._ensure_key_exists()
def _ensure_key_exists(self) -> None:
"""Ensure encryption key exists or create one."""
if not self.key_file.exists():
logger.info(f"Creating new encryption key at {self.key_file}")
self._generate_new_key()
else:
logger.info(f"Using existing encryption key from {self.key_file}")
def _generate_new_key(self) -> None:
"""Generate and store a new encryption key."""
try:
self.key_file.parent.mkdir(parents=True, exist_ok=True)
# Generate a secure random key
key = Fernet.generate_key()
# Write key with restrictive permissions (owner read/write only)
self.key_file.write_bytes(key)
os.chmod(self.key_file, 0o600)
logger.info("Generated new encryption key")
except IOError as e:
logger.error(f"Failed to generate encryption key: {e}")
raise
def _load_key(self) -> bytes:
"""Load encryption key from file.
Returns:
Encryption key bytes
Raises:
FileNotFoundError: If key file doesn't exist
"""
if not self.key_file.exists():
raise FileNotFoundError(
f"Encryption key not found: {self.key_file}"
)
try:
key = self.key_file.read_bytes()
return key
except IOError as e:
logger.error(f"Failed to load encryption key: {e}")
raise
def _get_cipher(self) -> Fernet:
"""Get or create Fernet cipher instance.
Returns:
Fernet cipher instance
"""
if self._cipher is None:
key = self._load_key()
self._cipher = Fernet(key)
return self._cipher
def encrypt_value(self, value: str) -> str:
"""Encrypt a configuration value.
Args:
value: Plain text value to encrypt
Returns:
Base64-encoded encrypted value
Raises:
ValueError: If value is empty
"""
if not value:
raise ValueError("Cannot encrypt empty value")
try:
cipher = self._get_cipher()
encrypted_bytes = cipher.encrypt(value.encode('utf-8'))
# Return as base64 string for easy storage
encrypted_str = base64.b64encode(encrypted_bytes).decode('utf-8')
logger.debug("Encrypted configuration value")
return encrypted_str
except Exception as e:
logger.error(f"Failed to encrypt value: {e}")
raise
def decrypt_value(self, encrypted_value: str) -> str:
"""Decrypt a configuration value.
Args:
encrypted_value: Base64-encoded encrypted value
Returns:
Decrypted plain text value
Raises:
ValueError: If encrypted value is invalid
"""
if not encrypted_value:
raise ValueError("Cannot decrypt empty value")
try:
cipher = self._get_cipher()
# Decode from base64
encrypted_bytes = base64.b64decode(encrypted_value.encode('utf-8'))
# Decrypt
decrypted_bytes = cipher.decrypt(encrypted_bytes)
decrypted_str = decrypted_bytes.decode('utf-8')
logger.debug("Decrypted configuration value")
return decrypted_str
except Exception as e:
logger.error(f"Failed to decrypt value: {e}")
raise
def encrypt_config(self, config: Dict[str, Any]) -> Dict[str, Any]:
"""Encrypt sensitive fields in configuration dictionary.
Args:
config: Configuration dictionary
Returns:
Dictionary with encrypted sensitive fields
"""
# List of sensitive field names to encrypt
sensitive_fields = {
'password',
'passwd',
'secret',
'key',
'token',
'api_key',
'apikey',
'auth_token',
'jwt_secret',
'master_password',
}
encrypted_config = {}
for key, value in config.items():
key_lower = key.lower()
# Check if field name suggests sensitive data
is_sensitive = any(
field in key_lower for field in sensitive_fields
)
if is_sensitive and isinstance(value, str) and value:
try:
encrypted_config[key] = {
'encrypted': True,
'value': self.encrypt_value(value)
}
logger.debug(f"Encrypted config field: {key}")
except Exception as e:
logger.warning(f"Failed to encrypt {key}: {e}")
encrypted_config[key] = value
else:
encrypted_config[key] = value
return encrypted_config
def decrypt_config(self, config: Dict[str, Any]) -> Dict[str, Any]:
"""Decrypt sensitive fields in configuration dictionary.
Args:
config: Configuration dictionary with encrypted fields
Returns:
Dictionary with decrypted values
"""
decrypted_config = {}
for key, value in config.items():
# Check if this is an encrypted field
if (
isinstance(value, dict) and
value.get('encrypted') is True and
'value' in value
):
try:
decrypted_config[key] = self.decrypt_value(
value['value']
)
logger.debug(f"Decrypted config field: {key}")
except Exception as e:
logger.error(f"Failed to decrypt {key}: {e}")
decrypted_config[key] = None
else:
decrypted_config[key] = value
return decrypted_config
def rotate_key(self, new_key_file: Optional[Path] = None) -> None:
"""Rotate encryption key.
**Warning**: This will invalidate all previously encrypted data.
Args:
new_key_file: Path for new key file (optional)
"""
logger.warning(
"Rotating encryption key - all encrypted data will "
"need re-encryption"
)
# Backup old key if it exists
if self.key_file.exists():
backup_path = self.key_file.with_suffix('.key.bak')
self.key_file.rename(backup_path)
logger.info(f"Backed up old key to {backup_path}")
# Generate new key
if new_key_file:
self.key_file = new_key_file
self._generate_new_key()
self._cipher = None # Reset cipher to use new key
# Global instance
_config_encryption: Optional[ConfigEncryption] = None
def get_config_encryption() -> ConfigEncryption:
"""Get the global configuration encryption instance.
Returns:
ConfigEncryption instance
"""
global _config_encryption
if _config_encryption is None:
_config_encryption = ConfigEncryption()
return _config_encryption

View File

@@ -0,0 +1,330 @@
"""Database integrity verification utilities.
This module provides database integrity checks including:
- Foreign key constraint validation
- Orphaned record detection
- Data consistency checks
"""
import logging
from typing import Any, Dict, List, Optional
from sqlalchemy import select, text
from sqlalchemy.orm import Session
from src.server.database.models import AnimeSeries, DownloadQueueItem, Episode
logger = logging.getLogger(__name__)
class DatabaseIntegrityChecker:
"""Checks database integrity and consistency."""
def __init__(self, session: Optional[Session] = None):
"""Initialize the database integrity checker.
Args:
session: SQLAlchemy session for database access
"""
self.session = session
self.issues: List[str] = []
def check_all(self) -> Dict[str, Any]:
"""Run all integrity checks.
Returns:
Dictionary with check results and issues found
"""
if self.session is None:
raise ValueError("Session required for integrity checks")
self.issues = []
results = {
"orphaned_episodes": self._check_orphaned_episodes(),
"orphaned_queue_items": self._check_orphaned_queue_items(),
"invalid_references": self._check_invalid_references(),
"duplicate_keys": self._check_duplicate_keys(),
"data_consistency": self._check_data_consistency(),
"total_issues": len(self.issues),
"issues": self.issues,
}
return results
def _check_orphaned_episodes(self) -> int:
"""Check for episodes without parent series.
Returns:
Number of orphaned episodes found
"""
try:
# Find episodes with non-existent series_id
stmt = select(Episode).outerjoin(
AnimeSeries, Episode.series_id == AnimeSeries.id
).where(AnimeSeries.id.is_(None))
orphaned = self.session.execute(stmt).scalars().all()
if orphaned:
count = len(orphaned)
msg = f"Found {count} orphaned episodes without parent series"
self.issues.append(msg)
logger.warning(msg)
return count
logger.info("No orphaned episodes found")
return 0
except Exception as e:
msg = f"Error checking orphaned episodes: {e}"
self.issues.append(msg)
logger.error(msg)
return -1
def _check_orphaned_queue_items(self) -> int:
"""Check for queue items without parent series.
Returns:
Number of orphaned queue items found
"""
try:
# Find queue items with non-existent series_id
stmt = select(DownloadQueueItem).outerjoin(
AnimeSeries,
DownloadQueueItem.series_id == AnimeSeries.id
).where(AnimeSeries.id.is_(None))
orphaned = self.session.execute(stmt).scalars().all()
if orphaned:
count = len(orphaned)
msg = (
f"Found {count} orphaned queue items "
f"without parent series"
)
self.issues.append(msg)
logger.warning(msg)
return count
logger.info("No orphaned queue items found")
return 0
except Exception as e:
msg = f"Error checking orphaned queue items: {e}"
self.issues.append(msg)
logger.error(msg)
return -1
def _check_invalid_references(self) -> int:
"""Check for invalid foreign key references.
Returns:
Number of invalid references found
"""
issues_found = 0
try:
# Check Episode.series_id references
stmt = text("""
SELECT COUNT(*) as count
FROM episode e
LEFT JOIN anime_series s ON e.series_id = s.id
WHERE e.series_id IS NOT NULL AND s.id IS NULL
""")
result = self.session.execute(stmt).fetchone()
if result and result[0] > 0:
msg = f"Found {result[0]} episodes with invalid series_id"
self.issues.append(msg)
logger.warning(msg)
issues_found += result[0]
# Check DownloadQueueItem.series_id references
stmt = text("""
SELECT COUNT(*) as count
FROM download_queue_item d
LEFT JOIN anime_series s ON d.series_id = s.id
WHERE d.series_id IS NOT NULL AND s.id IS NULL
""")
result = self.session.execute(stmt).fetchone()
if result and result[0] > 0:
msg = (
f"Found {result[0]} queue items with invalid series_id"
)
self.issues.append(msg)
logger.warning(msg)
issues_found += result[0]
if issues_found == 0:
logger.info("No invalid foreign key references found")
return issues_found
except Exception as e:
msg = f"Error checking invalid references: {e}"
self.issues.append(msg)
logger.error(msg)
return -1
def _check_duplicate_keys(self) -> int:
"""Check for duplicate primary keys.
Returns:
Number of duplicate key issues found
"""
issues_found = 0
try:
# Check for duplicate anime series keys
stmt = text("""
SELECT anime_key, COUNT(*) as count
FROM anime_series
GROUP BY anime_key
HAVING COUNT(*) > 1
""")
duplicates = self.session.execute(stmt).fetchall()
if duplicates:
for row in duplicates:
msg = (
f"Duplicate anime_key found: {row[0]} "
f"({row[1]} times)"
)
self.issues.append(msg)
logger.warning(msg)
issues_found += 1
if issues_found == 0:
logger.info("No duplicate keys found")
return issues_found
except Exception as e:
msg = f"Error checking duplicate keys: {e}"
self.issues.append(msg)
logger.error(msg)
return -1
def _check_data_consistency(self) -> int:
"""Check for data consistency issues.
Returns:
Number of consistency issues found
"""
issues_found = 0
try:
# Check for invalid season/episode numbers
stmt = select(Episode).where(
(Episode.season < 0) | (Episode.episode_number < 0)
)
invalid_episodes = self.session.execute(stmt).scalars().all()
if invalid_episodes:
count = len(invalid_episodes)
msg = (
f"Found {count} episodes with invalid "
f"season/episode numbers"
)
self.issues.append(msg)
logger.warning(msg)
issues_found += count
# Check for invalid progress percentages
stmt = select(DownloadQueueItem).where(
(DownloadQueueItem.progress < 0) |
(DownloadQueueItem.progress > 100)
)
invalid_progress = self.session.execute(stmt).scalars().all()
if invalid_progress:
count = len(invalid_progress)
msg = (
f"Found {count} queue items with invalid progress "
f"percentages"
)
self.issues.append(msg)
logger.warning(msg)
issues_found += count
# Check for queue items with invalid status
valid_statuses = {'pending', 'downloading', 'completed', 'failed'}
stmt = select(DownloadQueueItem).where(
~DownloadQueueItem.status.in_(valid_statuses)
)
invalid_status = self.session.execute(stmt).scalars().all()
if invalid_status:
count = len(invalid_status)
msg = f"Found {count} queue items with invalid status"
self.issues.append(msg)
logger.warning(msg)
issues_found += count
if issues_found == 0:
logger.info("No data consistency issues found")
return issues_found
except Exception as e:
msg = f"Error checking data consistency: {e}"
self.issues.append(msg)
logger.error(msg)
return -1
def repair_orphaned_records(self) -> int:
"""Remove orphaned records from database.
Returns:
Number of records removed
"""
if self.session is None:
raise ValueError("Session required for repair operations")
removed = 0
try:
# Remove orphaned episodes
stmt = select(Episode).outerjoin(
AnimeSeries, Episode.series_id == AnimeSeries.id
).where(AnimeSeries.id.is_(None))
orphaned_episodes = self.session.execute(stmt).scalars().all()
for episode in orphaned_episodes:
self.session.delete(episode)
removed += 1
# Remove orphaned queue items
stmt = select(DownloadQueueItem).outerjoin(
AnimeSeries,
DownloadQueueItem.series_id == AnimeSeries.id
).where(AnimeSeries.id.is_(None))
orphaned_queue = self.session.execute(stmt).scalars().all()
for item in orphaned_queue:
self.session.delete(item)
removed += 1
self.session.commit()
logger.info(f"Removed {removed} orphaned records")
return removed
except Exception as e:
self.session.rollback()
logger.error(f"Error removing orphaned records: {e}")
raise
def check_database_integrity(session: Session) -> Dict[str, Any]:
"""Convenience function to check database integrity.
Args:
session: SQLAlchemy session
Returns:
Dictionary with check results
"""
checker = DatabaseIntegrityChecker(session)
return checker.check_all()

View File

@@ -0,0 +1,232 @@
"""File integrity verification utilities.
This module provides checksum calculation and verification for
downloaded files. Supports SHA256 hashing for file integrity validation.
"""
import hashlib
import json
import logging
from pathlib import Path
from typing import Dict, Optional
logger = logging.getLogger(__name__)
class FileIntegrityManager:
"""Manages file integrity checksums and verification."""
def __init__(self, checksum_file: Optional[Path] = None):
"""Initialize the file integrity manager.
Args:
checksum_file: Path to store checksums.
Defaults to data/checksums.json
"""
if checksum_file is None:
project_root = Path(__file__).parent.parent.parent.parent
checksum_file = project_root / "data" / "checksums.json"
self.checksum_file = Path(checksum_file)
self.checksums: Dict[str, str] = {}
self._load_checksums()
def _load_checksums(self) -> None:
"""Load checksums from file."""
if self.checksum_file.exists():
try:
with open(self.checksum_file, 'r', encoding='utf-8') as f:
self.checksums = json.load(f)
count = len(self.checksums)
logger.info(
f"Loaded {count} checksums from {self.checksum_file}"
)
except (json.JSONDecodeError, IOError) as e:
logger.error(f"Failed to load checksums: {e}")
self.checksums = {}
else:
logger.info(f"Checksum file does not exist: {self.checksum_file}")
self.checksums = {}
def _save_checksums(self) -> None:
"""Save checksums to file."""
try:
self.checksum_file.parent.mkdir(parents=True, exist_ok=True)
with open(self.checksum_file, 'w', encoding='utf-8') as f:
json.dump(self.checksums, f, indent=2)
count = len(self.checksums)
logger.debug(
f"Saved {count} checksums to {self.checksum_file}"
)
except IOError as e:
logger.error(f"Failed to save checksums: {e}")
def calculate_checksum(
self, file_path: Path, algorithm: str = "sha256"
) -> str:
"""Calculate checksum for a file.
Args:
file_path: Path to the file
algorithm: Hash algorithm to use (default: sha256)
Returns:
Hexadecimal checksum string
Raises:
FileNotFoundError: If file doesn't exist
ValueError: If algorithm is not supported
"""
if not file_path.exists():
raise FileNotFoundError(f"File not found: {file_path}")
if algorithm not in hashlib.algorithms_available:
raise ValueError(f"Unsupported hash algorithm: {algorithm}")
hash_obj = hashlib.new(algorithm)
try:
with open(file_path, 'rb') as f:
# Read file in chunks to handle large files
for chunk in iter(lambda: f.read(8192), b''):
hash_obj.update(chunk)
checksum = hash_obj.hexdigest()
filename = file_path.name
logger.debug(
f"Calculated {algorithm} checksum for {filename}: {checksum}"
)
return checksum
except IOError as e:
logger.error(f"Failed to read file {file_path}: {e}")
raise
def store_checksum(
self, file_path: Path, checksum: Optional[str] = None
) -> str:
"""Calculate and store checksum for a file.
Args:
file_path: Path to the file
checksum: Pre-calculated checksum (optional, will calculate
if not provided)
Returns:
The stored checksum
Raises:
FileNotFoundError: If file doesn't exist
"""
if checksum is None:
checksum = self.calculate_checksum(file_path)
# Use relative path as key for portability
key = str(file_path.resolve())
self.checksums[key] = checksum
self._save_checksums()
logger.info(f"Stored checksum for {file_path.name}")
return checksum
def verify_checksum(
self, file_path: Path, expected_checksum: Optional[str] = None
) -> bool:
"""Verify file integrity by comparing checksums.
Args:
file_path: Path to the file
expected_checksum: Expected checksum (optional, will look up
stored checksum)
Returns:
True if checksum matches, False otherwise
Raises:
FileNotFoundError: If file doesn't exist
"""
if not file_path.exists():
raise FileNotFoundError(f"File not found: {file_path}")
# Get expected checksum from storage if not provided
if expected_checksum is None:
key = str(file_path.resolve())
expected_checksum = self.checksums.get(key)
if expected_checksum is None:
filename = file_path.name
logger.warning(
"No stored checksum found for %s", filename
)
return False
# Calculate current checksum
try:
current_checksum = self.calculate_checksum(file_path)
if current_checksum == expected_checksum:
filename = file_path.name
logger.info("Checksum verification passed for %s", filename)
return True
else:
filename = file_path.name
logger.warning(
"Checksum mismatch for %s: "
"expected %s, got %s",
filename,
expected_checksum,
current_checksum
)
return False
except (IOError, OSError) as e:
logger.error("Failed to verify checksum for %s: %s", file_path, e)
return False
def remove_checksum(self, file_path: Path) -> bool:
"""Remove checksum for a file.
Args:
file_path: Path to the file
Returns:
True if checksum was removed, False if not found
"""
key = str(file_path.resolve())
if key in self.checksums:
del self.checksums[key]
self._save_checksums()
logger.info(f"Removed checksum for {file_path.name}")
return True
else:
logger.debug(f"No checksum found to remove for {file_path.name}")
return False
def has_checksum(self, file_path: Path) -> bool:
"""Check if a checksum exists for a file.
Args:
file_path: Path to the file
Returns:
True if checksum exists, False otherwise
"""
key = str(file_path.resolve())
return key in self.checksums
# Global instance
_integrity_manager: Optional[FileIntegrityManager] = None
def get_integrity_manager() -> FileIntegrityManager:
"""Get the global file integrity manager instance.
Returns:
FileIntegrityManager instance
"""
global _integrity_manager
if _integrity_manager is None:
_integrity_manager = FileIntegrityManager()
return _integrity_manager