Aniworld/src/infrastructure/security/database_integrity.py
2025-10-23 19:00:49 +02:00

331 lines
10 KiB
Python

"""Database integrity verification utilities.
This module provides database integrity checks including:
- Foreign key constraint validation
- Orphaned record detection
- Data consistency checks
"""
import logging
from typing import Any, Dict, List, Optional
from sqlalchemy import select, text
from sqlalchemy.orm import Session
from src.server.database.models import AnimeSeries, DownloadQueueItem, Episode
logger = logging.getLogger(__name__)
class DatabaseIntegrityChecker:
"""Checks database integrity and consistency."""
def __init__(self, session: Optional[Session] = None):
"""Initialize the database integrity checker.
Args:
session: SQLAlchemy session for database access
"""
self.session = session
self.issues: List[str] = []
def check_all(self) -> Dict[str, Any]:
"""Run all integrity checks.
Returns:
Dictionary with check results and issues found
"""
if self.session is None:
raise ValueError("Session required for integrity checks")
self.issues = []
results = {
"orphaned_episodes": self._check_orphaned_episodes(),
"orphaned_queue_items": self._check_orphaned_queue_items(),
"invalid_references": self._check_invalid_references(),
"duplicate_keys": self._check_duplicate_keys(),
"data_consistency": self._check_data_consistency(),
"total_issues": len(self.issues),
"issues": self.issues,
}
return results
def _check_orphaned_episodes(self) -> int:
"""Check for episodes without parent series.
Returns:
Number of orphaned episodes found
"""
try:
# Find episodes with non-existent series_id
stmt = select(Episode).outerjoin(
AnimeSeries, Episode.series_id == AnimeSeries.id
).where(AnimeSeries.id.is_(None))
orphaned = self.session.execute(stmt).scalars().all()
if orphaned:
count = len(orphaned)
msg = f"Found {count} orphaned episodes without parent series"
self.issues.append(msg)
logger.warning(msg)
return count
logger.info("No orphaned episodes found")
return 0
except Exception as e:
msg = f"Error checking orphaned episodes: {e}"
self.issues.append(msg)
logger.error(msg)
return -1
def _check_orphaned_queue_items(self) -> int:
"""Check for queue items without parent series.
Returns:
Number of orphaned queue items found
"""
try:
# Find queue items with non-existent series_id
stmt = select(DownloadQueueItem).outerjoin(
AnimeSeries,
DownloadQueueItem.series_id == AnimeSeries.id
).where(AnimeSeries.id.is_(None))
orphaned = self.session.execute(stmt).scalars().all()
if orphaned:
count = len(orphaned)
msg = (
f"Found {count} orphaned queue items "
f"without parent series"
)
self.issues.append(msg)
logger.warning(msg)
return count
logger.info("No orphaned queue items found")
return 0
except Exception as e:
msg = f"Error checking orphaned queue items: {e}"
self.issues.append(msg)
logger.error(msg)
return -1
def _check_invalid_references(self) -> int:
"""Check for invalid foreign key references.
Returns:
Number of invalid references found
"""
issues_found = 0
try:
# Check Episode.series_id references
stmt = text("""
SELECT COUNT(*) as count
FROM episode e
LEFT JOIN anime_series s ON e.series_id = s.id
WHERE e.series_id IS NOT NULL AND s.id IS NULL
""")
result = self.session.execute(stmt).fetchone()
if result and result[0] > 0:
msg = f"Found {result[0]} episodes with invalid series_id"
self.issues.append(msg)
logger.warning(msg)
issues_found += result[0]
# Check DownloadQueueItem.series_id references
stmt = text("""
SELECT COUNT(*) as count
FROM download_queue_item d
LEFT JOIN anime_series s ON d.series_id = s.id
WHERE d.series_id IS NOT NULL AND s.id IS NULL
""")
result = self.session.execute(stmt).fetchone()
if result and result[0] > 0:
msg = (
f"Found {result[0]} queue items with invalid series_id"
)
self.issues.append(msg)
logger.warning(msg)
issues_found += result[0]
if issues_found == 0:
logger.info("No invalid foreign key references found")
return issues_found
except Exception as e:
msg = f"Error checking invalid references: {e}"
self.issues.append(msg)
logger.error(msg)
return -1
def _check_duplicate_keys(self) -> int:
"""Check for duplicate primary keys.
Returns:
Number of duplicate key issues found
"""
issues_found = 0
try:
# Check for duplicate anime series keys
stmt = text("""
SELECT anime_key, COUNT(*) as count
FROM anime_series
GROUP BY anime_key
HAVING COUNT(*) > 1
""")
duplicates = self.session.execute(stmt).fetchall()
if duplicates:
for row in duplicates:
msg = (
f"Duplicate anime_key found: {row[0]} "
f"({row[1]} times)"
)
self.issues.append(msg)
logger.warning(msg)
issues_found += 1
if issues_found == 0:
logger.info("No duplicate keys found")
return issues_found
except Exception as e:
msg = f"Error checking duplicate keys: {e}"
self.issues.append(msg)
logger.error(msg)
return -1
def _check_data_consistency(self) -> int:
"""Check for data consistency issues.
Returns:
Number of consistency issues found
"""
issues_found = 0
try:
# Check for invalid season/episode numbers
stmt = select(Episode).where(
(Episode.season < 0) | (Episode.episode_number < 0)
)
invalid_episodes = self.session.execute(stmt).scalars().all()
if invalid_episodes:
count = len(invalid_episodes)
msg = (
f"Found {count} episodes with invalid "
f"season/episode numbers"
)
self.issues.append(msg)
logger.warning(msg)
issues_found += count
# Check for invalid progress percentages
stmt = select(DownloadQueueItem).where(
(DownloadQueueItem.progress < 0) |
(DownloadQueueItem.progress > 100)
)
invalid_progress = self.session.execute(stmt).scalars().all()
if invalid_progress:
count = len(invalid_progress)
msg = (
f"Found {count} queue items with invalid progress "
f"percentages"
)
self.issues.append(msg)
logger.warning(msg)
issues_found += count
# Check for queue items with invalid status
valid_statuses = {'pending', 'downloading', 'completed', 'failed'}
stmt = select(DownloadQueueItem).where(
~DownloadQueueItem.status.in_(valid_statuses)
)
invalid_status = self.session.execute(stmt).scalars().all()
if invalid_status:
count = len(invalid_status)
msg = f"Found {count} queue items with invalid status"
self.issues.append(msg)
logger.warning(msg)
issues_found += count
if issues_found == 0:
logger.info("No data consistency issues found")
return issues_found
except Exception as e:
msg = f"Error checking data consistency: {e}"
self.issues.append(msg)
logger.error(msg)
return -1
def repair_orphaned_records(self) -> int:
"""Remove orphaned records from database.
Returns:
Number of records removed
"""
if self.session is None:
raise ValueError("Session required for repair operations")
removed = 0
try:
# Remove orphaned episodes
stmt = select(Episode).outerjoin(
AnimeSeries, Episode.series_id == AnimeSeries.id
).where(AnimeSeries.id.is_(None))
orphaned_episodes = self.session.execute(stmt).scalars().all()
for episode in orphaned_episodes:
self.session.delete(episode)
removed += 1
# Remove orphaned queue items
stmt = select(DownloadQueueItem).outerjoin(
AnimeSeries,
DownloadQueueItem.series_id == AnimeSeries.id
).where(AnimeSeries.id.is_(None))
orphaned_queue = self.session.execute(stmt).scalars().all()
for item in orphaned_queue:
self.session.delete(item)
removed += 1
self.session.commit()
logger.info(f"Removed {removed} orphaned records")
return removed
except Exception as e:
self.session.rollback()
logger.error(f"Error removing orphaned records: {e}")
raise
def check_database_integrity(session: Session) -> Dict[str, Any]:
"""Convenience function to check database integrity.
Args:
session: SQLAlchemy session
Returns:
Dictionary with check results
"""
checker = DatabaseIntegrityChecker(session)
return checker.check_all()