cleanup
This commit is contained in:
20
src/infrastructure/security/__init__.py
Normal file
20
src/infrastructure/security/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""Security utilities for the Aniworld application.
|
||||
|
||||
This module provides security-related utilities including:
|
||||
- File integrity verification with checksums
|
||||
- Database integrity checks
|
||||
- Configuration encryption
|
||||
"""
|
||||
|
||||
from .config_encryption import ConfigEncryption, get_config_encryption
|
||||
from .database_integrity import DatabaseIntegrityChecker, check_database_integrity
|
||||
from .file_integrity import FileIntegrityManager, get_integrity_manager
|
||||
|
||||
__all__ = [
|
||||
"FileIntegrityManager",
|
||||
"get_integrity_manager",
|
||||
"DatabaseIntegrityChecker",
|
||||
"check_database_integrity",
|
||||
"ConfigEncryption",
|
||||
"get_config_encryption",
|
||||
]
|
||||
274
src/infrastructure/security/config_encryption.py
Normal file
274
src/infrastructure/security/config_encryption.py
Normal file
@@ -0,0 +1,274 @@
|
||||
"""Configuration encryption utilities.
|
||||
|
||||
This module provides encryption/decryption for sensitive configuration
|
||||
values such as passwords, API keys, and tokens.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConfigEncryption:
|
||||
"""Handles encryption/decryption of sensitive configuration values."""
|
||||
|
||||
def __init__(self, key_file: Optional[Path] = None):
|
||||
"""Initialize the configuration encryption.
|
||||
|
||||
Args:
|
||||
key_file: Path to store encryption key.
|
||||
Defaults to data/encryption.key
|
||||
"""
|
||||
if key_file is None:
|
||||
project_root = Path(__file__).parent.parent.parent.parent
|
||||
key_file = project_root / "data" / "encryption.key"
|
||||
|
||||
self.key_file = Path(key_file)
|
||||
self._cipher: Optional[Fernet] = None
|
||||
self._ensure_key_exists()
|
||||
|
||||
def _ensure_key_exists(self) -> None:
|
||||
"""Ensure encryption key exists or create one."""
|
||||
if not self.key_file.exists():
|
||||
logger.info(f"Creating new encryption key at {self.key_file}")
|
||||
self._generate_new_key()
|
||||
else:
|
||||
logger.info(f"Using existing encryption key from {self.key_file}")
|
||||
|
||||
def _generate_new_key(self) -> None:
|
||||
"""Generate and store a new encryption key."""
|
||||
try:
|
||||
self.key_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Generate a secure random key
|
||||
key = Fernet.generate_key()
|
||||
|
||||
# Write key with restrictive permissions (owner read/write only)
|
||||
self.key_file.write_bytes(key)
|
||||
os.chmod(self.key_file, 0o600)
|
||||
|
||||
logger.info("Generated new encryption key")
|
||||
|
||||
except IOError as e:
|
||||
logger.error(f"Failed to generate encryption key: {e}")
|
||||
raise
|
||||
|
||||
def _load_key(self) -> bytes:
|
||||
"""Load encryption key from file.
|
||||
|
||||
Returns:
|
||||
Encryption key bytes
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If key file doesn't exist
|
||||
"""
|
||||
if not self.key_file.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Encryption key not found: {self.key_file}"
|
||||
)
|
||||
|
||||
try:
|
||||
key = self.key_file.read_bytes()
|
||||
return key
|
||||
except IOError as e:
|
||||
logger.error(f"Failed to load encryption key: {e}")
|
||||
raise
|
||||
|
||||
def _get_cipher(self) -> Fernet:
|
||||
"""Get or create Fernet cipher instance.
|
||||
|
||||
Returns:
|
||||
Fernet cipher instance
|
||||
"""
|
||||
if self._cipher is None:
|
||||
key = self._load_key()
|
||||
self._cipher = Fernet(key)
|
||||
return self._cipher
|
||||
|
||||
def encrypt_value(self, value: str) -> str:
|
||||
"""Encrypt a configuration value.
|
||||
|
||||
Args:
|
||||
value: Plain text value to encrypt
|
||||
|
||||
Returns:
|
||||
Base64-encoded encrypted value
|
||||
|
||||
Raises:
|
||||
ValueError: If value is empty
|
||||
"""
|
||||
if not value:
|
||||
raise ValueError("Cannot encrypt empty value")
|
||||
|
||||
try:
|
||||
cipher = self._get_cipher()
|
||||
encrypted_bytes = cipher.encrypt(value.encode('utf-8'))
|
||||
|
||||
# Return as base64 string for easy storage
|
||||
encrypted_str = base64.b64encode(encrypted_bytes).decode('utf-8')
|
||||
|
||||
logger.debug("Encrypted configuration value")
|
||||
return encrypted_str
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to encrypt value: {e}")
|
||||
raise
|
||||
|
||||
def decrypt_value(self, encrypted_value: str) -> str:
|
||||
"""Decrypt a configuration value.
|
||||
|
||||
Args:
|
||||
encrypted_value: Base64-encoded encrypted value
|
||||
|
||||
Returns:
|
||||
Decrypted plain text value
|
||||
|
||||
Raises:
|
||||
ValueError: If encrypted value is invalid
|
||||
"""
|
||||
if not encrypted_value:
|
||||
raise ValueError("Cannot decrypt empty value")
|
||||
|
||||
try:
|
||||
cipher = self._get_cipher()
|
||||
|
||||
# Decode from base64
|
||||
encrypted_bytes = base64.b64decode(encrypted_value.encode('utf-8'))
|
||||
|
||||
# Decrypt
|
||||
decrypted_bytes = cipher.decrypt(encrypted_bytes)
|
||||
decrypted_str = decrypted_bytes.decode('utf-8')
|
||||
|
||||
logger.debug("Decrypted configuration value")
|
||||
return decrypted_str
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to decrypt value: {e}")
|
||||
raise
|
||||
|
||||
def encrypt_config(self, config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Encrypt sensitive fields in configuration dictionary.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary
|
||||
|
||||
Returns:
|
||||
Dictionary with encrypted sensitive fields
|
||||
"""
|
||||
# List of sensitive field names to encrypt
|
||||
sensitive_fields = {
|
||||
'password',
|
||||
'passwd',
|
||||
'secret',
|
||||
'key',
|
||||
'token',
|
||||
'api_key',
|
||||
'apikey',
|
||||
'auth_token',
|
||||
'jwt_secret',
|
||||
'master_password',
|
||||
}
|
||||
|
||||
encrypted_config = {}
|
||||
|
||||
for key, value in config.items():
|
||||
key_lower = key.lower()
|
||||
|
||||
# Check if field name suggests sensitive data
|
||||
is_sensitive = any(
|
||||
field in key_lower for field in sensitive_fields
|
||||
)
|
||||
|
||||
if is_sensitive and isinstance(value, str) and value:
|
||||
try:
|
||||
encrypted_config[key] = {
|
||||
'encrypted': True,
|
||||
'value': self.encrypt_value(value)
|
||||
}
|
||||
logger.debug(f"Encrypted config field: {key}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to encrypt {key}: {e}")
|
||||
encrypted_config[key] = value
|
||||
else:
|
||||
encrypted_config[key] = value
|
||||
|
||||
return encrypted_config
|
||||
|
||||
def decrypt_config(self, config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Decrypt sensitive fields in configuration dictionary.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary with encrypted fields
|
||||
|
||||
Returns:
|
||||
Dictionary with decrypted values
|
||||
"""
|
||||
decrypted_config = {}
|
||||
|
||||
for key, value in config.items():
|
||||
# Check if this is an encrypted field
|
||||
if (
|
||||
isinstance(value, dict) and
|
||||
value.get('encrypted') is True and
|
||||
'value' in value
|
||||
):
|
||||
try:
|
||||
decrypted_config[key] = self.decrypt_value(
|
||||
value['value']
|
||||
)
|
||||
logger.debug(f"Decrypted config field: {key}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to decrypt {key}: {e}")
|
||||
decrypted_config[key] = None
|
||||
else:
|
||||
decrypted_config[key] = value
|
||||
|
||||
return decrypted_config
|
||||
|
||||
def rotate_key(self, new_key_file: Optional[Path] = None) -> None:
|
||||
"""Rotate encryption key.
|
||||
|
||||
**Warning**: This will invalidate all previously encrypted data.
|
||||
|
||||
Args:
|
||||
new_key_file: Path for new key file (optional)
|
||||
"""
|
||||
logger.warning(
|
||||
"Rotating encryption key - all encrypted data will "
|
||||
"need re-encryption"
|
||||
)
|
||||
|
||||
# Backup old key if it exists
|
||||
if self.key_file.exists():
|
||||
backup_path = self.key_file.with_suffix('.key.bak')
|
||||
self.key_file.rename(backup_path)
|
||||
logger.info(f"Backed up old key to {backup_path}")
|
||||
|
||||
# Generate new key
|
||||
if new_key_file:
|
||||
self.key_file = new_key_file
|
||||
|
||||
self._generate_new_key()
|
||||
self._cipher = None # Reset cipher to use new key
|
||||
|
||||
|
||||
# Global instance
|
||||
_config_encryption: Optional[ConfigEncryption] = None
|
||||
|
||||
|
||||
def get_config_encryption() -> ConfigEncryption:
|
||||
"""Get the global configuration encryption instance.
|
||||
|
||||
Returns:
|
||||
ConfigEncryption instance
|
||||
"""
|
||||
global _config_encryption
|
||||
if _config_encryption is None:
|
||||
_config_encryption = ConfigEncryption()
|
||||
return _config_encryption
|
||||
330
src/infrastructure/security/database_integrity.py
Normal file
330
src/infrastructure/security/database_integrity.py
Normal file
@@ -0,0 +1,330 @@
|
||||
"""Database integrity verification utilities.
|
||||
|
||||
This module provides database integrity checks including:
|
||||
- Foreign key constraint validation
|
||||
- Orphaned record detection
|
||||
- Data consistency checks
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.server.database.models import AnimeSeries, DownloadQueueItem, Episode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatabaseIntegrityChecker:
|
||||
"""Checks database integrity and consistency."""
|
||||
|
||||
def __init__(self, session: Optional[Session] = None):
|
||||
"""Initialize the database integrity checker.
|
||||
|
||||
Args:
|
||||
session: SQLAlchemy session for database access
|
||||
"""
|
||||
self.session = session
|
||||
self.issues: List[str] = []
|
||||
|
||||
def check_all(self) -> Dict[str, Any]:
|
||||
"""Run all integrity checks.
|
||||
|
||||
Returns:
|
||||
Dictionary with check results and issues found
|
||||
"""
|
||||
if self.session is None:
|
||||
raise ValueError("Session required for integrity checks")
|
||||
|
||||
self.issues = []
|
||||
results = {
|
||||
"orphaned_episodes": self._check_orphaned_episodes(),
|
||||
"orphaned_queue_items": self._check_orphaned_queue_items(),
|
||||
"invalid_references": self._check_invalid_references(),
|
||||
"duplicate_keys": self._check_duplicate_keys(),
|
||||
"data_consistency": self._check_data_consistency(),
|
||||
"total_issues": len(self.issues),
|
||||
"issues": self.issues,
|
||||
}
|
||||
|
||||
return results
|
||||
|
||||
def _check_orphaned_episodes(self) -> int:
|
||||
"""Check for episodes without parent series.
|
||||
|
||||
Returns:
|
||||
Number of orphaned episodes found
|
||||
"""
|
||||
try:
|
||||
# Find episodes with non-existent series_id
|
||||
stmt = select(Episode).outerjoin(
|
||||
AnimeSeries, Episode.series_id == AnimeSeries.id
|
||||
).where(AnimeSeries.id.is_(None))
|
||||
|
||||
orphaned = self.session.execute(stmt).scalars().all()
|
||||
|
||||
if orphaned:
|
||||
count = len(orphaned)
|
||||
msg = f"Found {count} orphaned episodes without parent series"
|
||||
self.issues.append(msg)
|
||||
logger.warning(msg)
|
||||
return count
|
||||
|
||||
logger.info("No orphaned episodes found")
|
||||
return 0
|
||||
|
||||
except Exception as e:
|
||||
msg = f"Error checking orphaned episodes: {e}"
|
||||
self.issues.append(msg)
|
||||
logger.error(msg)
|
||||
return -1
|
||||
|
||||
def _check_orphaned_queue_items(self) -> int:
|
||||
"""Check for queue items without parent series.
|
||||
|
||||
Returns:
|
||||
Number of orphaned queue items found
|
||||
"""
|
||||
try:
|
||||
# Find queue items with non-existent series_id
|
||||
stmt = select(DownloadQueueItem).outerjoin(
|
||||
AnimeSeries,
|
||||
DownloadQueueItem.series_id == AnimeSeries.id
|
||||
).where(AnimeSeries.id.is_(None))
|
||||
|
||||
orphaned = self.session.execute(stmt).scalars().all()
|
||||
|
||||
if orphaned:
|
||||
count = len(orphaned)
|
||||
msg = (
|
||||
f"Found {count} orphaned queue items "
|
||||
f"without parent series"
|
||||
)
|
||||
self.issues.append(msg)
|
||||
logger.warning(msg)
|
||||
return count
|
||||
|
||||
logger.info("No orphaned queue items found")
|
||||
return 0
|
||||
|
||||
except Exception as e:
|
||||
msg = f"Error checking orphaned queue items: {e}"
|
||||
self.issues.append(msg)
|
||||
logger.error(msg)
|
||||
return -1
|
||||
|
||||
def _check_invalid_references(self) -> int:
|
||||
"""Check for invalid foreign key references.
|
||||
|
||||
Returns:
|
||||
Number of invalid references found
|
||||
"""
|
||||
issues_found = 0
|
||||
|
||||
try:
|
||||
# Check Episode.series_id references
|
||||
stmt = text("""
|
||||
SELECT COUNT(*) as count
|
||||
FROM episode e
|
||||
LEFT JOIN anime_series s ON e.series_id = s.id
|
||||
WHERE e.series_id IS NOT NULL AND s.id IS NULL
|
||||
""")
|
||||
result = self.session.execute(stmt).fetchone()
|
||||
if result and result[0] > 0:
|
||||
msg = f"Found {result[0]} episodes with invalid series_id"
|
||||
self.issues.append(msg)
|
||||
logger.warning(msg)
|
||||
issues_found += result[0]
|
||||
|
||||
# Check DownloadQueueItem.series_id references
|
||||
stmt = text("""
|
||||
SELECT COUNT(*) as count
|
||||
FROM download_queue_item d
|
||||
LEFT JOIN anime_series s ON d.series_id = s.id
|
||||
WHERE d.series_id IS NOT NULL AND s.id IS NULL
|
||||
""")
|
||||
result = self.session.execute(stmt).fetchone()
|
||||
if result and result[0] > 0:
|
||||
msg = (
|
||||
f"Found {result[0]} queue items with invalid series_id"
|
||||
)
|
||||
self.issues.append(msg)
|
||||
logger.warning(msg)
|
||||
issues_found += result[0]
|
||||
|
||||
if issues_found == 0:
|
||||
logger.info("No invalid foreign key references found")
|
||||
|
||||
return issues_found
|
||||
|
||||
except Exception as e:
|
||||
msg = f"Error checking invalid references: {e}"
|
||||
self.issues.append(msg)
|
||||
logger.error(msg)
|
||||
return -1
|
||||
|
||||
def _check_duplicate_keys(self) -> int:
|
||||
"""Check for duplicate primary keys.
|
||||
|
||||
Returns:
|
||||
Number of duplicate key issues found
|
||||
"""
|
||||
issues_found = 0
|
||||
|
||||
try:
|
||||
# Check for duplicate anime series keys
|
||||
stmt = text("""
|
||||
SELECT anime_key, COUNT(*) as count
|
||||
FROM anime_series
|
||||
GROUP BY anime_key
|
||||
HAVING COUNT(*) > 1
|
||||
""")
|
||||
duplicates = self.session.execute(stmt).fetchall()
|
||||
|
||||
if duplicates:
|
||||
for row in duplicates:
|
||||
msg = (
|
||||
f"Duplicate anime_key found: {row[0]} "
|
||||
f"({row[1]} times)"
|
||||
)
|
||||
self.issues.append(msg)
|
||||
logger.warning(msg)
|
||||
issues_found += 1
|
||||
|
||||
if issues_found == 0:
|
||||
logger.info("No duplicate keys found")
|
||||
|
||||
return issues_found
|
||||
|
||||
except Exception as e:
|
||||
msg = f"Error checking duplicate keys: {e}"
|
||||
self.issues.append(msg)
|
||||
logger.error(msg)
|
||||
return -1
|
||||
|
||||
def _check_data_consistency(self) -> int:
|
||||
"""Check for data consistency issues.
|
||||
|
||||
Returns:
|
||||
Number of consistency issues found
|
||||
"""
|
||||
issues_found = 0
|
||||
|
||||
try:
|
||||
# Check for invalid season/episode numbers
|
||||
stmt = select(Episode).where(
|
||||
(Episode.season < 0) | (Episode.episode_number < 0)
|
||||
)
|
||||
invalid_episodes = self.session.execute(stmt).scalars().all()
|
||||
|
||||
if invalid_episodes:
|
||||
count = len(invalid_episodes)
|
||||
msg = (
|
||||
f"Found {count} episodes with invalid "
|
||||
f"season/episode numbers"
|
||||
)
|
||||
self.issues.append(msg)
|
||||
logger.warning(msg)
|
||||
issues_found += count
|
||||
|
||||
# Check for invalid progress percentages
|
||||
stmt = select(DownloadQueueItem).where(
|
||||
(DownloadQueueItem.progress < 0) |
|
||||
(DownloadQueueItem.progress > 100)
|
||||
)
|
||||
invalid_progress = self.session.execute(stmt).scalars().all()
|
||||
|
||||
if invalid_progress:
|
||||
count = len(invalid_progress)
|
||||
msg = (
|
||||
f"Found {count} queue items with invalid progress "
|
||||
f"percentages"
|
||||
)
|
||||
self.issues.append(msg)
|
||||
logger.warning(msg)
|
||||
issues_found += count
|
||||
|
||||
# Check for queue items with invalid status
|
||||
valid_statuses = {'pending', 'downloading', 'completed', 'failed'}
|
||||
stmt = select(DownloadQueueItem).where(
|
||||
~DownloadQueueItem.status.in_(valid_statuses)
|
||||
)
|
||||
invalid_status = self.session.execute(stmt).scalars().all()
|
||||
|
||||
if invalid_status:
|
||||
count = len(invalid_status)
|
||||
msg = f"Found {count} queue items with invalid status"
|
||||
self.issues.append(msg)
|
||||
logger.warning(msg)
|
||||
issues_found += count
|
||||
|
||||
if issues_found == 0:
|
||||
logger.info("No data consistency issues found")
|
||||
|
||||
return issues_found
|
||||
|
||||
except Exception as e:
|
||||
msg = f"Error checking data consistency: {e}"
|
||||
self.issues.append(msg)
|
||||
logger.error(msg)
|
||||
return -1
|
||||
|
||||
def repair_orphaned_records(self) -> int:
|
||||
"""Remove orphaned records from database.
|
||||
|
||||
Returns:
|
||||
Number of records removed
|
||||
"""
|
||||
if self.session is None:
|
||||
raise ValueError("Session required for repair operations")
|
||||
|
||||
removed = 0
|
||||
|
||||
try:
|
||||
# Remove orphaned episodes
|
||||
stmt = select(Episode).outerjoin(
|
||||
AnimeSeries, Episode.series_id == AnimeSeries.id
|
||||
).where(AnimeSeries.id.is_(None))
|
||||
|
||||
orphaned_episodes = self.session.execute(stmt).scalars().all()
|
||||
|
||||
for episode in orphaned_episodes:
|
||||
self.session.delete(episode)
|
||||
removed += 1
|
||||
|
||||
# Remove orphaned queue items
|
||||
stmt = select(DownloadQueueItem).outerjoin(
|
||||
AnimeSeries,
|
||||
DownloadQueueItem.series_id == AnimeSeries.id
|
||||
).where(AnimeSeries.id.is_(None))
|
||||
|
||||
orphaned_queue = self.session.execute(stmt).scalars().all()
|
||||
|
||||
for item in orphaned_queue:
|
||||
self.session.delete(item)
|
||||
removed += 1
|
||||
|
||||
self.session.commit()
|
||||
logger.info(f"Removed {removed} orphaned records")
|
||||
|
||||
return removed
|
||||
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
logger.error(f"Error removing orphaned records: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def check_database_integrity(session: Session) -> Dict[str, Any]:
|
||||
"""Convenience function to check database integrity.
|
||||
|
||||
Args:
|
||||
session: SQLAlchemy session
|
||||
|
||||
Returns:
|
||||
Dictionary with check results
|
||||
"""
|
||||
checker = DatabaseIntegrityChecker(session)
|
||||
return checker.check_all()
|
||||
232
src/infrastructure/security/file_integrity.py
Normal file
232
src/infrastructure/security/file_integrity.py
Normal file
@@ -0,0 +1,232 @@
|
||||
"""File integrity verification utilities.
|
||||
|
||||
This module provides checksum calculation and verification for
|
||||
downloaded files. Supports SHA256 hashing for file integrity validation.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileIntegrityManager:
|
||||
"""Manages file integrity checksums and verification."""
|
||||
|
||||
def __init__(self, checksum_file: Optional[Path] = None):
|
||||
"""Initialize the file integrity manager.
|
||||
|
||||
Args:
|
||||
checksum_file: Path to store checksums.
|
||||
Defaults to data/checksums.json
|
||||
"""
|
||||
if checksum_file is None:
|
||||
project_root = Path(__file__).parent.parent.parent.parent
|
||||
checksum_file = project_root / "data" / "checksums.json"
|
||||
|
||||
self.checksum_file = Path(checksum_file)
|
||||
self.checksums: Dict[str, str] = {}
|
||||
self._load_checksums()
|
||||
|
||||
def _load_checksums(self) -> None:
|
||||
"""Load checksums from file."""
|
||||
if self.checksum_file.exists():
|
||||
try:
|
||||
with open(self.checksum_file, 'r', encoding='utf-8') as f:
|
||||
self.checksums = json.load(f)
|
||||
count = len(self.checksums)
|
||||
logger.info(
|
||||
f"Loaded {count} checksums from {self.checksum_file}"
|
||||
)
|
||||
except (json.JSONDecodeError, IOError) as e:
|
||||
logger.error(f"Failed to load checksums: {e}")
|
||||
self.checksums = {}
|
||||
else:
|
||||
logger.info(f"Checksum file does not exist: {self.checksum_file}")
|
||||
self.checksums = {}
|
||||
|
||||
def _save_checksums(self) -> None:
|
||||
"""Save checksums to file."""
|
||||
try:
|
||||
self.checksum_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(self.checksum_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(self.checksums, f, indent=2)
|
||||
count = len(self.checksums)
|
||||
logger.debug(
|
||||
f"Saved {count} checksums to {self.checksum_file}"
|
||||
)
|
||||
except IOError as e:
|
||||
logger.error(f"Failed to save checksums: {e}")
|
||||
|
||||
def calculate_checksum(
|
||||
self, file_path: Path, algorithm: str = "sha256"
|
||||
) -> str:
|
||||
"""Calculate checksum for a file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
algorithm: Hash algorithm to use (default: sha256)
|
||||
|
||||
Returns:
|
||||
Hexadecimal checksum string
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If file doesn't exist
|
||||
ValueError: If algorithm is not supported
|
||||
"""
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
if algorithm not in hashlib.algorithms_available:
|
||||
raise ValueError(f"Unsupported hash algorithm: {algorithm}")
|
||||
|
||||
hash_obj = hashlib.new(algorithm)
|
||||
|
||||
try:
|
||||
with open(file_path, 'rb') as f:
|
||||
# Read file in chunks to handle large files
|
||||
for chunk in iter(lambda: f.read(8192), b''):
|
||||
hash_obj.update(chunk)
|
||||
|
||||
checksum = hash_obj.hexdigest()
|
||||
filename = file_path.name
|
||||
logger.debug(
|
||||
f"Calculated {algorithm} checksum for {filename}: {checksum}"
|
||||
)
|
||||
return checksum
|
||||
|
||||
except IOError as e:
|
||||
logger.error(f"Failed to read file {file_path}: {e}")
|
||||
raise
|
||||
|
||||
def store_checksum(
|
||||
self, file_path: Path, checksum: Optional[str] = None
|
||||
) -> str:
|
||||
"""Calculate and store checksum for a file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
checksum: Pre-calculated checksum (optional, will calculate
|
||||
if not provided)
|
||||
|
||||
Returns:
|
||||
The stored checksum
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If file doesn't exist
|
||||
"""
|
||||
if checksum is None:
|
||||
checksum = self.calculate_checksum(file_path)
|
||||
|
||||
# Use relative path as key for portability
|
||||
key = str(file_path.resolve())
|
||||
self.checksums[key] = checksum
|
||||
self._save_checksums()
|
||||
|
||||
logger.info(f"Stored checksum for {file_path.name}")
|
||||
return checksum
|
||||
|
||||
def verify_checksum(
|
||||
self, file_path: Path, expected_checksum: Optional[str] = None
|
||||
) -> bool:
|
||||
"""Verify file integrity by comparing checksums.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
expected_checksum: Expected checksum (optional, will look up
|
||||
stored checksum)
|
||||
|
||||
Returns:
|
||||
True if checksum matches, False otherwise
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If file doesn't exist
|
||||
"""
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
# Get expected checksum from storage if not provided
|
||||
if expected_checksum is None:
|
||||
key = str(file_path.resolve())
|
||||
expected_checksum = self.checksums.get(key)
|
||||
|
||||
if expected_checksum is None:
|
||||
filename = file_path.name
|
||||
logger.warning(
|
||||
"No stored checksum found for %s", filename
|
||||
)
|
||||
return False
|
||||
|
||||
# Calculate current checksum
|
||||
try:
|
||||
current_checksum = self.calculate_checksum(file_path)
|
||||
|
||||
if current_checksum == expected_checksum:
|
||||
filename = file_path.name
|
||||
logger.info("Checksum verification passed for %s", filename)
|
||||
return True
|
||||
else:
|
||||
filename = file_path.name
|
||||
logger.warning(
|
||||
"Checksum mismatch for %s: "
|
||||
"expected %s, got %s",
|
||||
filename,
|
||||
expected_checksum,
|
||||
current_checksum
|
||||
)
|
||||
return False
|
||||
|
||||
except (IOError, OSError) as e:
|
||||
logger.error("Failed to verify checksum for %s: %s", file_path, e)
|
||||
return False
|
||||
|
||||
def remove_checksum(self, file_path: Path) -> bool:
|
||||
"""Remove checksum for a file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
|
||||
Returns:
|
||||
True if checksum was removed, False if not found
|
||||
"""
|
||||
key = str(file_path.resolve())
|
||||
|
||||
if key in self.checksums:
|
||||
del self.checksums[key]
|
||||
self._save_checksums()
|
||||
logger.info(f"Removed checksum for {file_path.name}")
|
||||
return True
|
||||
else:
|
||||
logger.debug(f"No checksum found to remove for {file_path.name}")
|
||||
return False
|
||||
|
||||
def has_checksum(self, file_path: Path) -> bool:
|
||||
"""Check if a checksum exists for a file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file
|
||||
|
||||
Returns:
|
||||
True if checksum exists, False otherwise
|
||||
"""
|
||||
key = str(file_path.resolve())
|
||||
return key in self.checksums
|
||||
|
||||
|
||||
# Global instance
|
||||
_integrity_manager: Optional[FileIntegrityManager] = None
|
||||
|
||||
|
||||
def get_integrity_manager() -> FileIntegrityManager:
|
||||
"""Get the global file integrity manager instance.
|
||||
|
||||
Returns:
|
||||
FileIntegrityManager instance
|
||||
"""
|
||||
global _integrity_manager
|
||||
if _integrity_manager is None:
|
||||
_integrity_manager = FileIntegrityManager()
|
||||
return _integrity_manager
|
||||
Reference in New Issue
Block a user