Add security infrastructure tests: 75 tests for encryption, database integrity, and security edge cases

This commit is contained in:
2026-02-07 18:10:21 +01:00
parent 4b35cb63d1
commit d1d30dde9e
3 changed files with 978 additions and 0 deletions

View File

@@ -0,0 +1,232 @@
"""Security-focused tests for encryption module.
Tests cryptographic properties, key strength, secure storage,
and security edge cases for the ConfigEncryption system.
"""
import base64
import os
import stat
import time
from pathlib import Path
from unittest.mock import patch
import pytest
from cryptography.fernet import Fernet
from src.infrastructure.security.config_encryption import ConfigEncryption
class TestKeyStrength:
"""Tests for encryption key strength and format."""
def test_key_is_valid_fernet_format(self, tmp_path: Path):
"""Generated key is a valid url-safe base64-encoded 32-byte key."""
key_file = tmp_path / "encryption.key"
ConfigEncryption(key_file=key_file)
key = key_file.read_bytes()
# Fernet key is urlsafe base64 of 32 bytes = 44 bytes encoded
decoded = base64.urlsafe_b64decode(key)
assert len(decoded) == 32, "Fernet key must decode to 32 bytes"
def test_key_is_random_not_predictable(self, tmp_path: Path):
"""Two generated keys are different (not using static seed)."""
key1_file = tmp_path / "key1.key"
key2_file = tmp_path / "key2.key"
ConfigEncryption(key_file=key1_file)
ConfigEncryption(key_file=key2_file)
assert key1_file.read_bytes() != key2_file.read_bytes()
def test_key_length_sufficient(self, tmp_path: Path):
"""Key provides at least 128-bit security (Fernet uses AES-128-CBC)."""
key_file = tmp_path / "encryption.key"
ConfigEncryption(key_file=key_file)
key = key_file.read_bytes()
decoded = base64.urlsafe_b64decode(key)
# Fernet key = 16 bytes signing key + 16 bytes encryption key
assert len(decoded) >= 16, "Key must provide at least 128-bit security"
class TestSecureKeyStorage:
"""Tests for secure key file storage."""
def test_key_file_permissions_restrictive(self, tmp_path: Path):
"""Key file permissions are set to owner read/write only (0o600)."""
key_file = tmp_path / "encryption.key"
ConfigEncryption(key_file=key_file)
mode = os.stat(key_file).st_mode & 0o777
assert mode == 0o600, f"Key file mode should be 0o600, got {oct(mode)}"
def test_key_file_not_world_readable(self, tmp_path: Path):
"""Key file has no world-readable permission bits."""
key_file = tmp_path / "encryption.key"
ConfigEncryption(key_file=key_file)
mode = os.stat(key_file).st_mode
assert not (mode & stat.S_IROTH), "Key file must not be world-readable"
assert not (mode & stat.S_IWOTH), "Key file must not be world-writable"
def test_key_file_not_group_accessible(self, tmp_path: Path):
"""Key file has no group permission bits."""
key_file = tmp_path / "encryption.key"
ConfigEncryption(key_file=key_file)
mode = os.stat(key_file).st_mode
assert not (mode & stat.S_IRGRP), "Key file must not be group-readable"
assert not (mode & stat.S_IWGRP), "Key file must not be group-writable"
def test_key_backup_created_on_rotation(self, tmp_path: Path):
"""Key rotation creates a backup that preserves the old key."""
key_file = tmp_path / "encryption.key"
enc = ConfigEncryption(key_file=key_file)
old_key = key_file.read_bytes()
enc.rotate_key()
backup = tmp_path / "encryption.key.bak"
assert backup.exists(), "Backup should exist after rotation"
assert backup.read_bytes() == old_key
class TestEncryptedDataFormat:
"""Tests for encrypted data format validation."""
@pytest.fixture
def encryption(self, tmp_path: Path) -> ConfigEncryption:
"""Create a ConfigEncryption instance."""
return ConfigEncryption(key_file=tmp_path / "encryption.key")
def test_encrypted_value_is_base64(self, encryption: ConfigEncryption):
"""Encrypted output is valid base64 (outer encoding)."""
encrypted = encryption.encrypt_value("test_value")
# Should not raise
decoded = base64.b64decode(encrypted)
assert len(decoded) > 0
def test_encrypted_value_contains_no_plaintext(
self, encryption: ConfigEncryption
):
"""Encrypted output doesn't contain the original plaintext."""
plaintext = "super_secret_password_123"
encrypted = encryption.encrypt_value(plaintext)
assert plaintext not in encrypted
# Also check decoded bytes
decoded = base64.b64decode(encrypted)
assert plaintext.encode() not in decoded
def test_encrypted_config_structure(self, encryption: ConfigEncryption):
"""Encrypted config fields have proper structure markers."""
config = {"password": "test_pass"}
encrypted = encryption.encrypt_config(config)
entry = encrypted["password"]
assert isinstance(entry, dict)
assert "encrypted" in entry
assert "value" in entry
assert entry["encrypted"] is True
assert isinstance(entry["value"], str)
class TestDecryptionFailureSecurity:
"""Tests for secure behavior on decryption failures."""
def test_wrong_key_raises_exception(self, tmp_path: Path):
"""Decryption with wrong key raises an error (no silent failure)."""
enc1 = ConfigEncryption(key_file=tmp_path / "key1.key")
enc2 = ConfigEncryption(key_file=tmp_path / "key2.key")
encrypted = enc1.encrypt_value("secret")
with pytest.raises(Exception):
enc2.decrypt_value(encrypted)
def test_tampered_ciphertext_raises(self, tmp_path: Path):
"""Modified ciphertext is detected and causes decryption failure."""
enc = ConfigEncryption(key_file=tmp_path / "encryption.key")
encrypted = enc.encrypt_value("my_secret")
# Tamper with the encrypted data
decoded = base64.b64decode(encrypted)
tampered = bytearray(decoded)
if len(tampered) > 10:
tampered[10] ^= 0xFF # Flip bits
tampered_encoded = base64.b64encode(bytes(tampered)).decode("utf-8")
with pytest.raises(Exception):
enc.decrypt_value(tampered_encoded)
def test_truncated_ciphertext_raises(self, tmp_path: Path):
"""Truncated ciphertext causes decryption failure."""
enc = ConfigEncryption(key_file=tmp_path / "encryption.key")
encrypted = enc.encrypt_value("my_secret")
truncated = encrypted[:len(encrypted) // 2]
with pytest.raises(Exception):
enc.decrypt_value(truncated)
def test_empty_ciphertext_raises_value_error(self, tmp_path: Path):
"""Empty string input raises ValueError, not a cryptographic error."""
enc = ConfigEncryption(key_file=tmp_path / "encryption.key")
with pytest.raises(ValueError, match="Cannot decrypt empty value"):
enc.decrypt_value("")
class TestKeyCompromiseScenarios:
"""Tests for key compromise and rotation scenarios."""
def test_rotated_key_cannot_decrypt_old_data(self, tmp_path: Path):
"""After key rotation, old encrypted data cannot be decrypted."""
key_file = tmp_path / "encryption.key"
enc = ConfigEncryption(key_file=key_file)
encrypted = enc.encrypt_value("secret_data")
enc.rotate_key()
with pytest.raises(Exception):
enc.decrypt_value(encrypted)
def test_new_key_works_after_rotation(self, tmp_path: Path):
"""After rotation, newly encrypted data can be decrypted."""
key_file = tmp_path / "encryption.key"
enc = ConfigEncryption(key_file=key_file)
enc.rotate_key()
encrypted = enc.encrypt_value("new_secret")
assert enc.decrypt_value(encrypted) == "new_secret"
def test_backup_key_can_decrypt_old_data(self, tmp_path: Path):
"""Backup key from rotation can still decrypt old data."""
key_file = tmp_path / "encryption.key"
enc = ConfigEncryption(key_file=key_file)
encrypted = enc.encrypt_value("old_secret")
enc.rotate_key()
# Use backup key to decrypt
backup_key = (tmp_path / "encryption.key.bak").read_bytes()
old_cipher = Fernet(backup_key)
outer_decoded = base64.b64decode(encrypted)
decrypted = old_cipher.decrypt(outer_decoded).decode("utf-8")
assert decrypted == "old_secret"
class TestEnvironmentSecurity:
"""Tests for environment-level security considerations."""
def test_key_not_exposed_in_repr(self, tmp_path: Path):
"""ConfigEncryption repr/str doesn't expose the encryption key."""
key_file = tmp_path / "encryption.key"
enc = ConfigEncryption(key_file=key_file)
key_content = key_file.read_bytes().decode("utf-8")
obj_repr = repr(enc)
obj_str = str(enc)
assert key_content not in obj_repr
assert key_content not in obj_str
def test_encrypted_values_differ_for_same_input(self, tmp_path: Path):
"""Same input encrypted multiple times produces different outputs
(nonce/IV prevents ciphertext equality)."""
enc = ConfigEncryption(key_file=tmp_path / "encryption.key")
values = [enc.encrypt_value("same_password") for _ in range(5)]
# All 5 should be unique
assert len(set(values)) == 5
def test_encrypt_does_not_log_plaintext(self, tmp_path: Path):
"""Encryption operations use debug-level logging without plaintext."""
import logging
enc = ConfigEncryption(key_file=tmp_path / "encryption.key")
with patch.object(logging.getLogger("src.infrastructure.security.config_encryption"), "debug") as mock_debug:
enc.encrypt_value("super_secret_value")
for call in mock_debug.call_args_list:
args_str = str(call)
assert "super_secret_value" not in args_str

View File

@@ -0,0 +1,365 @@
"""Unit tests for configuration encryption module.
Tests encryption/decryption of sensitive configuration values,
key management, and configuration dictionary encryption.
"""
import os
import stat
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from cryptography.fernet import Fernet
from src.infrastructure.security.config_encryption import (
ConfigEncryption,
get_config_encryption,
)
class TestConfigEncryptionInit:
"""Tests for ConfigEncryption initialization."""
def test_creates_key_file_if_not_exists(self, tmp_path: Path):
"""Key file is generated on init when it doesn't exist."""
key_file = tmp_path / "encryption.key"
enc = ConfigEncryption(key_file=key_file)
assert key_file.exists()
# Key should be valid Fernet key
key = key_file.read_bytes()
Fernet(key) # Raises if invalid
def test_uses_existing_key_file(self, tmp_path: Path):
"""Existing key file is reused without regeneration."""
key_file = tmp_path / "encryption.key"
original_key = Fernet.generate_key()
key_file.write_bytes(original_key)
enc = ConfigEncryption(key_file=key_file)
loaded_key = key_file.read_bytes()
assert loaded_key == original_key
def test_default_key_file_path(self):
"""Default key_file path points to data/encryption.key."""
with patch.object(ConfigEncryption, "_ensure_key_exists"):
enc = ConfigEncryption.__new__(ConfigEncryption)
enc._cipher = None
project_root = Path(__file__).parent.parent.parent
expected = project_root / "data" / "encryption.key"
# Just verify the constructor logic sets a reasonable default
enc2 = ConfigEncryption.__new__(ConfigEncryption)
enc2._cipher = None
enc2.key_file = expected
assert "encryption.key" in str(enc2.key_file)
def test_key_file_permissions_set_600(self, tmp_path: Path):
"""Generated key file should have owner-only permissions (0o600)."""
key_file = tmp_path / "encryption.key"
ConfigEncryption(key_file=key_file)
mode = oct(os.stat(key_file).st_mode & 0o777)
assert mode == "0o600"
class TestKeyManagement:
"""Tests for key generation, loading, and cipher creation."""
def test_generate_new_key_creates_valid_fernet_key(self, tmp_path: Path):
"""Generated key is a valid Fernet key."""
key_file = tmp_path / "encryption.key"
enc = ConfigEncryption(key_file=key_file)
key = key_file.read_bytes()
cipher = Fernet(key)
# Should be able to encrypt/decrypt
token = cipher.encrypt(b"test")
assert cipher.decrypt(token) == b"test"
def test_load_key_returns_bytes(self, tmp_path: Path):
"""_load_key returns raw bytes from key file."""
key_file = tmp_path / "encryption.key"
enc = ConfigEncryption(key_file=key_file)
key = enc._load_key()
assert isinstance(key, bytes)
assert len(key) > 0
def test_load_key_raises_if_file_missing(self, tmp_path: Path):
"""_load_key raises FileNotFoundError when key file is absent."""
key_file = tmp_path / "nonexistent.key"
enc = ConfigEncryption.__new__(ConfigEncryption)
enc._cipher = None
enc.key_file = key_file
with pytest.raises(FileNotFoundError):
enc._load_key()
def test_get_cipher_returns_fernet(self, tmp_path: Path):
"""_get_cipher returns a Fernet instance."""
key_file = tmp_path / "encryption.key"
enc = ConfigEncryption(key_file=key_file)
cipher = enc._get_cipher()
assert isinstance(cipher, Fernet)
def test_get_cipher_caches_instance(self, tmp_path: Path):
"""_get_cipher returns the same Fernet instance on repeated calls."""
key_file = tmp_path / "encryption.key"
enc = ConfigEncryption(key_file=key_file)
cipher1 = enc._get_cipher()
cipher2 = enc._get_cipher()
assert cipher1 is cipher2
class TestEncryptDecrypt:
"""Tests for encrypt_value and decrypt_value methods."""
@pytest.fixture
def encryption(self, tmp_path: Path) -> ConfigEncryption:
"""Create a ConfigEncryption instance with temp key file."""
return ConfigEncryption(key_file=tmp_path / "encryption.key")
def test_encrypt_decrypt_roundtrip(self, encryption: ConfigEncryption):
"""Encrypting then decrypting returns the original value."""
original = "my_secret_password"
encrypted = encryption.encrypt_value(original)
decrypted = encryption.decrypt_value(encrypted)
assert decrypted == original
def test_encrypt_value_returns_string(self, encryption: ConfigEncryption):
"""encrypt_value returns a base64 string, not bytes."""
result = encryption.encrypt_value("test")
assert isinstance(result, str)
def test_encrypt_value_raises_on_empty(self, encryption: ConfigEncryption):
"""encrypt_value raises ValueError for empty string."""
with pytest.raises(ValueError, match="Cannot encrypt empty value"):
encryption.encrypt_value("")
def test_decrypt_value_raises_on_empty(self, encryption: ConfigEncryption):
"""decrypt_value raises ValueError for empty string."""
with pytest.raises(ValueError, match="Cannot decrypt empty value"):
encryption.decrypt_value("")
def test_decrypt_with_wrong_key_fails(self, tmp_path: Path):
"""Decrypting with a different key raises an exception."""
enc1 = ConfigEncryption(key_file=tmp_path / "key1.key")
enc2 = ConfigEncryption(key_file=tmp_path / "key2.key")
encrypted = enc1.encrypt_value("secret")
with pytest.raises(Exception):
enc2.decrypt_value(encrypted)
def test_encrypt_produces_different_ciphertext_each_time(
self, encryption: ConfigEncryption
):
"""Two encryptions of same value produce different ciphertext (nonce)."""
val = "same_value"
enc1 = encryption.encrypt_value(val)
enc2 = encryption.encrypt_value(val)
assert enc1 != enc2
def test_unicode_value_roundtrip(self, encryption: ConfigEncryption):
"""Unicode strings survive encrypt/decrypt roundtrip."""
original = "Ünïcödé_тест_日本語"
encrypted = encryption.encrypt_value(original)
assert encryption.decrypt_value(encrypted) == original
def test_special_characters_roundtrip(self, encryption: ConfigEncryption):
"""Special characters survive encrypt/decrypt roundtrip."""
original = "p@$$w0rd!#%^&*(){}[]|\\:\";<>?/~`"
encrypted = encryption.encrypt_value(original)
assert encryption.decrypt_value(encrypted) == original
def test_large_value_roundtrip(self, encryption: ConfigEncryption):
"""Large string values can be encrypted and decrypted."""
original = "A" * 100_000
encrypted = encryption.encrypt_value(original)
assert encryption.decrypt_value(encrypted) == original
class TestConfigDictEncryption:
"""Tests for encrypt_config and decrypt_config methods."""
@pytest.fixture
def encryption(self, tmp_path: Path) -> ConfigEncryption:
"""Create a ConfigEncryption instance with temp key file."""
return ConfigEncryption(key_file=tmp_path / "encryption.key")
def test_encrypt_config_encrypts_sensitive_fields(
self, encryption: ConfigEncryption
):
"""Sensitive fields (password, token, etc.) are encrypted."""
config = {
"password": "secret123",
"api_key": "abcdef",
"username": "admin",
}
encrypted = encryption.encrypt_config(config)
# password and api_key should be wrapped
assert isinstance(encrypted["password"], dict)
assert encrypted["password"]["encrypted"] is True
assert isinstance(encrypted["api_key"], dict)
assert encrypted["api_key"]["encrypted"] is True
# username should not be encrypted
assert encrypted["username"] == "admin"
def test_encrypt_config_skips_non_string_sensitive_values(
self, encryption: ConfigEncryption
):
"""Non-string sensitive values (int, None) are left unencrypted."""
config = {
"password": 12345,
"token": None,
}
encrypted = encryption.encrypt_config(config)
assert encrypted["password"] == 12345
assert encrypted["token"] is None
def test_encrypt_config_skips_empty_string(
self, encryption: ConfigEncryption
):
"""Empty string sensitive values are left unencrypted."""
config = {"password": ""}
encrypted = encryption.encrypt_config(config)
assert encrypted["password"] == ""
def test_decrypt_config_decrypts_encrypted_fields(
self, encryption: ConfigEncryption
):
"""decrypt_config restores encrypted fields to plaintext."""
config = {
"password": "secret123",
"username": "admin",
}
encrypted = encryption.encrypt_config(config)
decrypted = encryption.decrypt_config(encrypted)
assert decrypted["password"] == "secret123"
assert decrypted["username"] == "admin"
def test_decrypt_config_passes_through_non_encrypted(
self, encryption: ConfigEncryption
):
"""Non-encrypted fields in config are passed through unchanged."""
config = {"host": "localhost", "port": 8080}
decrypted = encryption.decrypt_config(config)
assert decrypted == config
def test_decrypt_config_handles_decrypt_failure(
self, encryption: ConfigEncryption
):
"""Failed decryption sets field value to None."""
config = {
"password": {
"encrypted": True,
"value": "invalid_base64_data!!!"
}
}
decrypted = encryption.decrypt_config(config)
assert decrypted["password"] is None
def test_sensitive_field_detection_case_insensitive(
self, encryption: ConfigEncryption
):
"""Sensitive field detection works case-insensitively."""
config = {
"PASSWORD": "secret",
"Api_Key": "key123",
"JWT_SECRET": "jwt_val",
}
encrypted = encryption.encrypt_config(config)
for key in ("PASSWORD", "Api_Key", "JWT_SECRET"):
assert isinstance(encrypted[key], dict)
assert encrypted[key]["encrypted"] is True
def test_all_sensitive_field_patterns_detected(
self, encryption: ConfigEncryption
):
"""All defined sensitive field patterns trigger encryption."""
patterns = [
"password", "passwd", "secret", "key", "token",
"api_key", "apikey", "auth_token", "jwt_secret",
"master_password",
]
config = {p: f"value_{p}" for p in patterns}
encrypted = encryption.encrypt_config(config)
for p in patterns:
assert isinstance(encrypted[p], dict), (
f"Field '{p}' was not encrypted"
)
class TestKeyRotation:
"""Tests for key rotation functionality."""
def test_rotate_key_generates_new_key(self, tmp_path: Path):
"""rotate_key generates a new encryption key file."""
key_file = tmp_path / "encryption.key"
enc = ConfigEncryption(key_file=key_file)
old_key = key_file.read_bytes()
enc.rotate_key()
new_key = key_file.read_bytes()
assert old_key != new_key
def test_rotate_key_backs_up_old_key(self, tmp_path: Path):
"""rotate_key creates a .key.bak backup of the old key."""
key_file = tmp_path / "encryption.key"
enc = ConfigEncryption(key_file=key_file)
old_key = key_file.read_bytes()
enc.rotate_key()
backup = tmp_path / "encryption.key.bak"
assert backup.exists()
assert backup.read_bytes() == old_key
def test_rotate_key_resets_cipher_cache(self, tmp_path: Path):
"""rotate_key clears the cached cipher so new key is used."""
key_file = tmp_path / "encryption.key"
enc = ConfigEncryption(key_file=key_file)
_ = enc._get_cipher()
assert enc._cipher is not None
enc.rotate_key()
assert enc._cipher is None
def test_rotate_key_invalidates_old_encrypted_data(self, tmp_path: Path):
"""Data encrypted with old key cannot be decrypted after rotation."""
key_file = tmp_path / "encryption.key"
enc = ConfigEncryption(key_file=key_file)
encrypted = enc.encrypt_value("my_secret")
enc.rotate_key()
with pytest.raises(Exception):
enc.decrypt_value(encrypted)
def test_rotate_key_to_custom_path(self, tmp_path: Path):
"""rotate_key can target a custom key file path."""
key_file = tmp_path / "old.key"
new_key_file = tmp_path / "new.key"
enc = ConfigEncryption(key_file=key_file)
enc.rotate_key(new_key_file=new_key_file)
assert new_key_file.exists()
assert enc.key_file == new_key_file
class TestGlobalSingleton:
"""Tests for the get_config_encryption singleton function."""
def test_get_config_encryption_returns_instance(self):
"""get_config_encryption returns a ConfigEncryption instance."""
import src.infrastructure.security.config_encryption as mod
old = mod._config_encryption
try:
mod._config_encryption = None
with patch.object(ConfigEncryption, "__init__", return_value=None):
instance = get_config_encryption()
assert isinstance(instance, ConfigEncryption)
finally:
mod._config_encryption = old
def test_get_config_encryption_returns_same_instance(self):
"""Repeated calls return the same singleton instance."""
import src.infrastructure.security.config_encryption as mod
old = mod._config_encryption
try:
mod._config_encryption = None
with patch.object(ConfigEncryption, "__init__", return_value=None):
inst1 = get_config_encryption()
inst2 = get_config_encryption()
assert inst1 is inst2
finally:
mod._config_encryption = old

View File

@@ -0,0 +1,381 @@
"""Unit tests for database integrity checker module.
Tests database integrity validation, orphaned record detection,
duplicate key checks, and data consistency verification.
NOTE: The database_integrity.py module has bugs in raw SQL queries:
- _check_invalid_references uses table 'episode' but actual name is 'episodes'
- _check_invalid_references uses table 'download_queue_item' but actual is 'download_queue'
- _check_duplicate_keys uses column 'anime_key' but actual column is 'key'
These bugs cause OperationalError at runtime. Tests document this behavior.
"""
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
from sqlalchemy import create_engine, text
from sqlalchemy.orm import Session, sessionmaker
from src.server.database.base import Base
from src.server.database.models import AnimeSeries, DownloadQueueItem, Episode
@pytest.fixture
def engine():
"""Create an in-memory SQLite database engine with schema."""
eng = create_engine("sqlite:///:memory:")
Base.metadata.create_all(eng)
return eng
@pytest.fixture
def session(engine):
"""Create a database session for testing."""
SessionLocal = sessionmaker(bind=engine)
sess = SessionLocal()
yield sess
sess.close()
def _make_series(session: Session, key: str = "test-anime", name: str = "Test Anime") -> AnimeSeries:
"""Helper to create and persist an AnimeSeries record."""
series = AnimeSeries(
key=key,
name=name,
site="https://aniworld.to/anime/stream/test-anime",
folder="Test Anime (2024)",
)
session.add(series)
session.commit()
session.refresh(series)
return series
def _make_episode(session: Session, series_id: int, season: int = 1, ep: int = 1) -> Episode:
"""Helper to create and persist an Episode record."""
episode = Episode(
series_id=series_id,
season=season,
episode_number=ep,
title=f"Episode {ep}",
)
session.add(episode)
session.commit()
session.refresh(episode)
return episode
class TestDatabaseIntegrityCheckerInit:
"""Tests for DatabaseIntegrityChecker initialization."""
def test_init_with_session(self, session: Session):
"""Checker initializes with provided session."""
from src.infrastructure.security.database_integrity import DatabaseIntegrityChecker
checker = DatabaseIntegrityChecker(session=session)
assert checker.session is session
assert checker.issues == []
def test_init_without_session(self):
"""Checker can be created without a session."""
from src.infrastructure.security.database_integrity import DatabaseIntegrityChecker
checker = DatabaseIntegrityChecker()
assert checker.session is None
def test_check_all_requires_session(self):
"""check_all raises ValueError when no session is set."""
from src.infrastructure.security.database_integrity import DatabaseIntegrityChecker
checker = DatabaseIntegrityChecker()
with pytest.raises(ValueError, match="Session required"):
checker.check_all()
class TestOrphanedEpisodes:
"""Tests for orphaned episode detection."""
def test_no_orphaned_episodes_clean_db(self, session: Session):
"""Returns 0 when all episodes have parent series."""
from src.infrastructure.security.database_integrity import DatabaseIntegrityChecker
series = _make_series(session)
_make_episode(session, series.id)
checker = DatabaseIntegrityChecker(session=session)
count = checker._check_orphaned_episodes()
assert count == 0
assert len(checker.issues) == 0
def test_no_episodes_returns_zero(self, session: Session):
"""Returns 0 when no episodes exist at all."""
from src.infrastructure.security.database_integrity import DatabaseIntegrityChecker
checker = DatabaseIntegrityChecker(session=session)
count = checker._check_orphaned_episodes()
assert count == 0
def test_detects_orphaned_episodes(self, session: Session):
"""Detects episodes whose series_id references nonexistent series."""
from src.infrastructure.security.database_integrity import DatabaseIntegrityChecker
series = _make_series(session)
_make_episode(session, series.id, season=1, ep=1)
_make_episode(session, series.id, season=1, ep=2)
# Delete the series directly to create orphans
session.execute(
text("DELETE FROM anime_series WHERE id = :id"),
{"id": series.id},
)
session.commit()
checker = DatabaseIntegrityChecker(session=session)
count = checker._check_orphaned_episodes()
assert count == 2
assert any("orphaned episodes" in issue for issue in checker.issues)
class TestOrphanedQueueItems:
"""Tests for orphaned download queue item detection."""
def test_no_orphaned_queue_items(self, session: Session):
"""Returns 0 when all queue items have parent series."""
from src.infrastructure.security.database_integrity import DatabaseIntegrityChecker
series = _make_series(session)
episode = _make_episode(session, series.id)
item = DownloadQueueItem(
series_id=series.id,
episode_id=episode.id,
)
session.add(item)
session.commit()
checker = DatabaseIntegrityChecker(session=session)
count = checker._check_orphaned_queue_items()
assert count == 0
def test_detects_orphaned_queue_items(self, session: Session):
"""Detects queue items whose series references no longer exist."""
from src.infrastructure.security.database_integrity import DatabaseIntegrityChecker
series = _make_series(session)
episode = _make_episode(session, series.id)
item = DownloadQueueItem(
series_id=series.id,
episode_id=episode.id,
)
session.add(item)
session.commit()
# Remove series but keep orphaned items via raw SQL
session.execute(text("DELETE FROM anime_series WHERE id = :id"), {"id": series.id})
session.commit()
checker = DatabaseIntegrityChecker(session=session)
count = checker._check_orphaned_queue_items()
assert count == 1
assert any("orphaned queue" in issue for issue in checker.issues)
class TestInvalidReferences:
"""Tests for invalid foreign key reference detection.
NOTE: The raw SQL in _check_invalid_references uses wrong table names:
- 'episode' instead of 'episodes'
- 'download_queue_item' instead of 'download_queue'
This causes OperationalError in SQLite.
"""
def test_invalid_references_raw_sql_uses_wrong_table_names(
self, session: Session
):
"""BUG: Raw SQL references 'episode' and 'download_queue_item'
but actual table names are 'episodes' and 'download_queue'.
This causes the check to error and return -1."""
from src.infrastructure.security.database_integrity import DatabaseIntegrityChecker
checker = DatabaseIntegrityChecker(session=session)
result = checker._check_invalid_references()
# Returns -1 because the SQL fails with OperationalError
assert result == -1
assert any("Error checking invalid references" in i for i in checker.issues)
class TestDuplicateKeys:
"""Tests for duplicate primary key detection.
NOTE: The raw SQL in _check_duplicate_keys references column 'anime_key'
but the actual column name is 'key'. This causes OperationalError.
"""
def test_duplicate_keys_raw_sql_uses_wrong_column_name(
self, session: Session
):
"""BUG: Raw SQL references 'anime_key' column but actual column
is 'key'. This causes the check to error and return -1."""
from src.infrastructure.security.database_integrity import DatabaseIntegrityChecker
checker = DatabaseIntegrityChecker(session=session)
result = checker._check_duplicate_keys()
# Returns -1 because the SQL fails on nonexistent column
assert result == -1
assert any("Error checking duplicate keys" in i for i in checker.issues)
class TestDataConsistency:
"""Tests for data consistency validation."""
def test_no_consistency_issues_clean_data(self, session: Session):
"""Returns 0 for valid episode data."""
from src.infrastructure.security.database_integrity import DatabaseIntegrityChecker
series = _make_series(session)
_make_episode(session, series.id, season=1, ep=1)
checker = DatabaseIntegrityChecker(session=session)
count = checker._check_data_consistency()
assert count == 0
def test_detects_negative_season_number(self, session: Session):
"""Detects episodes with negative season numbers."""
from src.infrastructure.security.database_integrity import DatabaseIntegrityChecker
series = _make_series(session)
# Insert invalid episode bypassing ORM validation
session.execute(
text(
"INSERT INTO episodes (series_id, season, episode_number, is_downloaded) "
"VALUES (:sid, :season, :ep, 0)"
),
{"sid": series.id, "season": -1, "ep": 1},
)
session.commit()
checker = DatabaseIntegrityChecker(session=session)
count = checker._check_data_consistency()
assert count == 1
assert any("invalid" in i.lower() for i in checker.issues)
def test_detects_negative_episode_number(self, session: Session):
"""Detects episodes with negative episode numbers."""
from src.infrastructure.security.database_integrity import DatabaseIntegrityChecker
series = _make_series(session)
session.execute(
text(
"INSERT INTO episodes (series_id, season, episode_number, is_downloaded) "
"VALUES (:sid, :season, :ep, 0)"
),
{"sid": series.id, "season": 1, "ep": -5},
)
session.commit()
checker = DatabaseIntegrityChecker(session=session)
count = checker._check_data_consistency()
assert count == 1
def test_empty_database_no_issues(self, session: Session):
"""Empty database reports no consistency issues."""
from src.infrastructure.security.database_integrity import DatabaseIntegrityChecker
checker = DatabaseIntegrityChecker(session=session)
count = checker._check_data_consistency()
assert count == 0
class TestCheckAll:
"""Tests for the check_all aggregation method."""
def test_check_all_returns_dict_with_expected_keys(self, session: Session):
"""check_all returns result dict with all expected keys."""
from src.infrastructure.security.database_integrity import DatabaseIntegrityChecker
checker = DatabaseIntegrityChecker(session=session)
results = checker.check_all()
expected_keys = {
"orphaned_episodes",
"orphaned_queue_items",
"invalid_references",
"duplicate_keys",
"data_consistency",
"total_issues",
"issues",
}
assert set(results.keys()) == expected_keys
def test_check_all_aggregates_issues(self, session: Session):
"""check_all total_issues reflects all discovered issues."""
from src.infrastructure.security.database_integrity import DatabaseIntegrityChecker
checker = DatabaseIntegrityChecker(session=session)
results = checker.check_all()
# At minimum, invalid_references and duplicate_keys fail due to SQL bugs
assert results["total_issues"] >= 2
assert len(results["issues"]) >= 2
def test_check_all_resets_issues_list(self, session: Session):
"""check_all clears previous issues before running."""
from src.infrastructure.security.database_integrity import DatabaseIntegrityChecker
checker = DatabaseIntegrityChecker(session=session)
checker.issues = ["leftover issue"]
checker.check_all()
# Issues list should not contain the old "leftover issue"
assert "leftover issue" not in checker.issues
class TestRepairOrphanedRecords:
"""Tests for repair_orphaned_records method."""
def test_repair_requires_session(self):
"""repair_orphaned_records raises ValueError without session."""
from src.infrastructure.security.database_integrity import DatabaseIntegrityChecker
checker = DatabaseIntegrityChecker()
with pytest.raises(ValueError, match="Session required"):
checker.repair_orphaned_records()
def test_repair_removes_orphaned_episodes(self, session: Session):
"""repair_orphaned_records removes episodes without parent series."""
from src.infrastructure.security.database_integrity import DatabaseIntegrityChecker
series = _make_series(session)
_make_episode(session, series.id, season=1, ep=1)
_make_episode(session, series.id, season=1, ep=2)
# Create orphans
session.execute(
text("DELETE FROM anime_series WHERE id = :id"),
{"id": series.id},
)
session.commit()
checker = DatabaseIntegrityChecker(session=session)
removed = checker.repair_orphaned_records()
assert removed == 2
# Verify episodes are gone
count = session.execute(text("SELECT COUNT(*) FROM episodes")).scalar()
assert count == 0
def test_repair_removes_orphaned_queue_items(self, session: Session):
"""repair_orphaned_records removes queue items without parent series."""
from src.infrastructure.security.database_integrity import DatabaseIntegrityChecker
series = _make_series(session)
episode = _make_episode(session, series.id)
item = DownloadQueueItem(series_id=series.id, episode_id=episode.id)
session.add(item)
session.commit()
session.execute(
text("DELETE FROM anime_series WHERE id = :id"),
{"id": series.id},
)
session.commit()
checker = DatabaseIntegrityChecker(session=session)
removed = checker.repair_orphaned_records()
assert removed >= 1
def test_repair_no_orphans_returns_zero(self, session: Session):
"""repair_orphaned_records returns 0 when no orphans exist."""
from src.infrastructure.security.database_integrity import DatabaseIntegrityChecker
series = _make_series(session)
_make_episode(session, series.id)
checker = DatabaseIntegrityChecker(session=session)
removed = checker.repair_orphaned_records()
assert removed == 0
class TestConvenienceFunction:
"""Tests for check_database_integrity convenience function."""
def test_check_database_integrity_returns_results(self, session: Session):
"""Convenience function returns check results dict."""
from src.infrastructure.security.database_integrity import check_database_integrity
results = check_database_integrity(session)
assert isinstance(results, dict)
assert "total_issues" in results
assert "issues" in results