Add security infrastructure tests: 75 tests for encryption, database integrity, and security edge cases
This commit is contained in:
232
tests/security/test_encryption_security.py
Normal file
232
tests/security/test_encryption_security.py
Normal file
@@ -0,0 +1,232 @@
|
||||
"""Security-focused tests for encryption module.
|
||||
|
||||
Tests cryptographic properties, key strength, secure storage,
|
||||
and security edge cases for the ConfigEncryption system.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import os
|
||||
import stat
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
from src.infrastructure.security.config_encryption import ConfigEncryption
|
||||
|
||||
|
||||
class TestKeyStrength:
|
||||
"""Tests for encryption key strength and format."""
|
||||
|
||||
def test_key_is_valid_fernet_format(self, tmp_path: Path):
|
||||
"""Generated key is a valid url-safe base64-encoded 32-byte key."""
|
||||
key_file = tmp_path / "encryption.key"
|
||||
ConfigEncryption(key_file=key_file)
|
||||
key = key_file.read_bytes()
|
||||
# Fernet key is urlsafe base64 of 32 bytes = 44 bytes encoded
|
||||
decoded = base64.urlsafe_b64decode(key)
|
||||
assert len(decoded) == 32, "Fernet key must decode to 32 bytes"
|
||||
|
||||
def test_key_is_random_not_predictable(self, tmp_path: Path):
|
||||
"""Two generated keys are different (not using static seed)."""
|
||||
key1_file = tmp_path / "key1.key"
|
||||
key2_file = tmp_path / "key2.key"
|
||||
ConfigEncryption(key_file=key1_file)
|
||||
ConfigEncryption(key_file=key2_file)
|
||||
assert key1_file.read_bytes() != key2_file.read_bytes()
|
||||
|
||||
def test_key_length_sufficient(self, tmp_path: Path):
|
||||
"""Key provides at least 128-bit security (Fernet uses AES-128-CBC)."""
|
||||
key_file = tmp_path / "encryption.key"
|
||||
ConfigEncryption(key_file=key_file)
|
||||
key = key_file.read_bytes()
|
||||
decoded = base64.urlsafe_b64decode(key)
|
||||
# Fernet key = 16 bytes signing key + 16 bytes encryption key
|
||||
assert len(decoded) >= 16, "Key must provide at least 128-bit security"
|
||||
|
||||
|
||||
class TestSecureKeyStorage:
|
||||
"""Tests for secure key file storage."""
|
||||
|
||||
def test_key_file_permissions_restrictive(self, tmp_path: Path):
|
||||
"""Key file permissions are set to owner read/write only (0o600)."""
|
||||
key_file = tmp_path / "encryption.key"
|
||||
ConfigEncryption(key_file=key_file)
|
||||
mode = os.stat(key_file).st_mode & 0o777
|
||||
assert mode == 0o600, f"Key file mode should be 0o600, got {oct(mode)}"
|
||||
|
||||
def test_key_file_not_world_readable(self, tmp_path: Path):
|
||||
"""Key file has no world-readable permission bits."""
|
||||
key_file = tmp_path / "encryption.key"
|
||||
ConfigEncryption(key_file=key_file)
|
||||
mode = os.stat(key_file).st_mode
|
||||
assert not (mode & stat.S_IROTH), "Key file must not be world-readable"
|
||||
assert not (mode & stat.S_IWOTH), "Key file must not be world-writable"
|
||||
|
||||
def test_key_file_not_group_accessible(self, tmp_path: Path):
|
||||
"""Key file has no group permission bits."""
|
||||
key_file = tmp_path / "encryption.key"
|
||||
ConfigEncryption(key_file=key_file)
|
||||
mode = os.stat(key_file).st_mode
|
||||
assert not (mode & stat.S_IRGRP), "Key file must not be group-readable"
|
||||
assert not (mode & stat.S_IWGRP), "Key file must not be group-writable"
|
||||
|
||||
def test_key_backup_created_on_rotation(self, tmp_path: Path):
|
||||
"""Key rotation creates a backup that preserves the old key."""
|
||||
key_file = tmp_path / "encryption.key"
|
||||
enc = ConfigEncryption(key_file=key_file)
|
||||
old_key = key_file.read_bytes()
|
||||
enc.rotate_key()
|
||||
backup = tmp_path / "encryption.key.bak"
|
||||
assert backup.exists(), "Backup should exist after rotation"
|
||||
assert backup.read_bytes() == old_key
|
||||
|
||||
|
||||
class TestEncryptedDataFormat:
|
||||
"""Tests for encrypted data format validation."""
|
||||
|
||||
@pytest.fixture
|
||||
def encryption(self, tmp_path: Path) -> ConfigEncryption:
|
||||
"""Create a ConfigEncryption instance."""
|
||||
return ConfigEncryption(key_file=tmp_path / "encryption.key")
|
||||
|
||||
def test_encrypted_value_is_base64(self, encryption: ConfigEncryption):
|
||||
"""Encrypted output is valid base64 (outer encoding)."""
|
||||
encrypted = encryption.encrypt_value("test_value")
|
||||
# Should not raise
|
||||
decoded = base64.b64decode(encrypted)
|
||||
assert len(decoded) > 0
|
||||
|
||||
def test_encrypted_value_contains_no_plaintext(
|
||||
self, encryption: ConfigEncryption
|
||||
):
|
||||
"""Encrypted output doesn't contain the original plaintext."""
|
||||
plaintext = "super_secret_password_123"
|
||||
encrypted = encryption.encrypt_value(plaintext)
|
||||
assert plaintext not in encrypted
|
||||
# Also check decoded bytes
|
||||
decoded = base64.b64decode(encrypted)
|
||||
assert plaintext.encode() not in decoded
|
||||
|
||||
def test_encrypted_config_structure(self, encryption: ConfigEncryption):
|
||||
"""Encrypted config fields have proper structure markers."""
|
||||
config = {"password": "test_pass"}
|
||||
encrypted = encryption.encrypt_config(config)
|
||||
entry = encrypted["password"]
|
||||
assert isinstance(entry, dict)
|
||||
assert "encrypted" in entry
|
||||
assert "value" in entry
|
||||
assert entry["encrypted"] is True
|
||||
assert isinstance(entry["value"], str)
|
||||
|
||||
|
||||
class TestDecryptionFailureSecurity:
|
||||
"""Tests for secure behavior on decryption failures."""
|
||||
|
||||
def test_wrong_key_raises_exception(self, tmp_path: Path):
|
||||
"""Decryption with wrong key raises an error (no silent failure)."""
|
||||
enc1 = ConfigEncryption(key_file=tmp_path / "key1.key")
|
||||
enc2 = ConfigEncryption(key_file=tmp_path / "key2.key")
|
||||
encrypted = enc1.encrypt_value("secret")
|
||||
with pytest.raises(Exception):
|
||||
enc2.decrypt_value(encrypted)
|
||||
|
||||
def test_tampered_ciphertext_raises(self, tmp_path: Path):
|
||||
"""Modified ciphertext is detected and causes decryption failure."""
|
||||
enc = ConfigEncryption(key_file=tmp_path / "encryption.key")
|
||||
encrypted = enc.encrypt_value("my_secret")
|
||||
|
||||
# Tamper with the encrypted data
|
||||
decoded = base64.b64decode(encrypted)
|
||||
tampered = bytearray(decoded)
|
||||
if len(tampered) > 10:
|
||||
tampered[10] ^= 0xFF # Flip bits
|
||||
tampered_encoded = base64.b64encode(bytes(tampered)).decode("utf-8")
|
||||
|
||||
with pytest.raises(Exception):
|
||||
enc.decrypt_value(tampered_encoded)
|
||||
|
||||
def test_truncated_ciphertext_raises(self, tmp_path: Path):
|
||||
"""Truncated ciphertext causes decryption failure."""
|
||||
enc = ConfigEncryption(key_file=tmp_path / "encryption.key")
|
||||
encrypted = enc.encrypt_value("my_secret")
|
||||
truncated = encrypted[:len(encrypted) // 2]
|
||||
with pytest.raises(Exception):
|
||||
enc.decrypt_value(truncated)
|
||||
|
||||
def test_empty_ciphertext_raises_value_error(self, tmp_path: Path):
|
||||
"""Empty string input raises ValueError, not a cryptographic error."""
|
||||
enc = ConfigEncryption(key_file=tmp_path / "encryption.key")
|
||||
with pytest.raises(ValueError, match="Cannot decrypt empty value"):
|
||||
enc.decrypt_value("")
|
||||
|
||||
|
||||
class TestKeyCompromiseScenarios:
|
||||
"""Tests for key compromise and rotation scenarios."""
|
||||
|
||||
def test_rotated_key_cannot_decrypt_old_data(self, tmp_path: Path):
|
||||
"""After key rotation, old encrypted data cannot be decrypted."""
|
||||
key_file = tmp_path / "encryption.key"
|
||||
enc = ConfigEncryption(key_file=key_file)
|
||||
encrypted = enc.encrypt_value("secret_data")
|
||||
enc.rotate_key()
|
||||
with pytest.raises(Exception):
|
||||
enc.decrypt_value(encrypted)
|
||||
|
||||
def test_new_key_works_after_rotation(self, tmp_path: Path):
|
||||
"""After rotation, newly encrypted data can be decrypted."""
|
||||
key_file = tmp_path / "encryption.key"
|
||||
enc = ConfigEncryption(key_file=key_file)
|
||||
enc.rotate_key()
|
||||
encrypted = enc.encrypt_value("new_secret")
|
||||
assert enc.decrypt_value(encrypted) == "new_secret"
|
||||
|
||||
def test_backup_key_can_decrypt_old_data(self, tmp_path: Path):
|
||||
"""Backup key from rotation can still decrypt old data."""
|
||||
key_file = tmp_path / "encryption.key"
|
||||
enc = ConfigEncryption(key_file=key_file)
|
||||
encrypted = enc.encrypt_value("old_secret")
|
||||
enc.rotate_key()
|
||||
|
||||
# Use backup key to decrypt
|
||||
backup_key = (tmp_path / "encryption.key.bak").read_bytes()
|
||||
old_cipher = Fernet(backup_key)
|
||||
outer_decoded = base64.b64decode(encrypted)
|
||||
decrypted = old_cipher.decrypt(outer_decoded).decode("utf-8")
|
||||
assert decrypted == "old_secret"
|
||||
|
||||
|
||||
class TestEnvironmentSecurity:
|
||||
"""Tests for environment-level security considerations."""
|
||||
|
||||
def test_key_not_exposed_in_repr(self, tmp_path: Path):
|
||||
"""ConfigEncryption repr/str doesn't expose the encryption key."""
|
||||
key_file = tmp_path / "encryption.key"
|
||||
enc = ConfigEncryption(key_file=key_file)
|
||||
key_content = key_file.read_bytes().decode("utf-8")
|
||||
obj_repr = repr(enc)
|
||||
obj_str = str(enc)
|
||||
assert key_content not in obj_repr
|
||||
assert key_content not in obj_str
|
||||
|
||||
def test_encrypted_values_differ_for_same_input(self, tmp_path: Path):
|
||||
"""Same input encrypted multiple times produces different outputs
|
||||
(nonce/IV prevents ciphertext equality)."""
|
||||
enc = ConfigEncryption(key_file=tmp_path / "encryption.key")
|
||||
values = [enc.encrypt_value("same_password") for _ in range(5)]
|
||||
# All 5 should be unique
|
||||
assert len(set(values)) == 5
|
||||
|
||||
def test_encrypt_does_not_log_plaintext(self, tmp_path: Path):
|
||||
"""Encryption operations use debug-level logging without plaintext."""
|
||||
import logging
|
||||
|
||||
enc = ConfigEncryption(key_file=tmp_path / "encryption.key")
|
||||
|
||||
with patch.object(logging.getLogger("src.infrastructure.security.config_encryption"), "debug") as mock_debug:
|
||||
enc.encrypt_value("super_secret_value")
|
||||
for call in mock_debug.call_args_list:
|
||||
args_str = str(call)
|
||||
assert "super_secret_value" not in args_str
|
||||
365
tests/unit/test_config_encryption.py
Normal file
365
tests/unit/test_config_encryption.py
Normal file
@@ -0,0 +1,365 @@
|
||||
"""Unit tests for configuration encryption module.
|
||||
|
||||
Tests encryption/decryption of sensitive configuration values,
|
||||
key management, and configuration dictionary encryption.
|
||||
"""
|
||||
|
||||
import os
|
||||
import stat
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
from src.infrastructure.security.config_encryption import (
|
||||
ConfigEncryption,
|
||||
get_config_encryption,
|
||||
)
|
||||
|
||||
|
||||
class TestConfigEncryptionInit:
|
||||
"""Tests for ConfigEncryption initialization."""
|
||||
|
||||
def test_creates_key_file_if_not_exists(self, tmp_path: Path):
|
||||
"""Key file is generated on init when it doesn't exist."""
|
||||
key_file = tmp_path / "encryption.key"
|
||||
enc = ConfigEncryption(key_file=key_file)
|
||||
assert key_file.exists()
|
||||
# Key should be valid Fernet key
|
||||
key = key_file.read_bytes()
|
||||
Fernet(key) # Raises if invalid
|
||||
|
||||
def test_uses_existing_key_file(self, tmp_path: Path):
|
||||
"""Existing key file is reused without regeneration."""
|
||||
key_file = tmp_path / "encryption.key"
|
||||
original_key = Fernet.generate_key()
|
||||
key_file.write_bytes(original_key)
|
||||
|
||||
enc = ConfigEncryption(key_file=key_file)
|
||||
loaded_key = key_file.read_bytes()
|
||||
assert loaded_key == original_key
|
||||
|
||||
def test_default_key_file_path(self):
|
||||
"""Default key_file path points to data/encryption.key."""
|
||||
with patch.object(ConfigEncryption, "_ensure_key_exists"):
|
||||
enc = ConfigEncryption.__new__(ConfigEncryption)
|
||||
enc._cipher = None
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
expected = project_root / "data" / "encryption.key"
|
||||
# Just verify the constructor logic sets a reasonable default
|
||||
enc2 = ConfigEncryption.__new__(ConfigEncryption)
|
||||
enc2._cipher = None
|
||||
enc2.key_file = expected
|
||||
assert "encryption.key" in str(enc2.key_file)
|
||||
|
||||
def test_key_file_permissions_set_600(self, tmp_path: Path):
|
||||
"""Generated key file should have owner-only permissions (0o600)."""
|
||||
key_file = tmp_path / "encryption.key"
|
||||
ConfigEncryption(key_file=key_file)
|
||||
mode = oct(os.stat(key_file).st_mode & 0o777)
|
||||
assert mode == "0o600"
|
||||
|
||||
|
||||
class TestKeyManagement:
|
||||
"""Tests for key generation, loading, and cipher creation."""
|
||||
|
||||
def test_generate_new_key_creates_valid_fernet_key(self, tmp_path: Path):
|
||||
"""Generated key is a valid Fernet key."""
|
||||
key_file = tmp_path / "encryption.key"
|
||||
enc = ConfigEncryption(key_file=key_file)
|
||||
key = key_file.read_bytes()
|
||||
cipher = Fernet(key)
|
||||
# Should be able to encrypt/decrypt
|
||||
token = cipher.encrypt(b"test")
|
||||
assert cipher.decrypt(token) == b"test"
|
||||
|
||||
def test_load_key_returns_bytes(self, tmp_path: Path):
|
||||
"""_load_key returns raw bytes from key file."""
|
||||
key_file = tmp_path / "encryption.key"
|
||||
enc = ConfigEncryption(key_file=key_file)
|
||||
key = enc._load_key()
|
||||
assert isinstance(key, bytes)
|
||||
assert len(key) > 0
|
||||
|
||||
def test_load_key_raises_if_file_missing(self, tmp_path: Path):
|
||||
"""_load_key raises FileNotFoundError when key file is absent."""
|
||||
key_file = tmp_path / "nonexistent.key"
|
||||
enc = ConfigEncryption.__new__(ConfigEncryption)
|
||||
enc._cipher = None
|
||||
enc.key_file = key_file
|
||||
with pytest.raises(FileNotFoundError):
|
||||
enc._load_key()
|
||||
|
||||
def test_get_cipher_returns_fernet(self, tmp_path: Path):
|
||||
"""_get_cipher returns a Fernet instance."""
|
||||
key_file = tmp_path / "encryption.key"
|
||||
enc = ConfigEncryption(key_file=key_file)
|
||||
cipher = enc._get_cipher()
|
||||
assert isinstance(cipher, Fernet)
|
||||
|
||||
def test_get_cipher_caches_instance(self, tmp_path: Path):
|
||||
"""_get_cipher returns the same Fernet instance on repeated calls."""
|
||||
key_file = tmp_path / "encryption.key"
|
||||
enc = ConfigEncryption(key_file=key_file)
|
||||
cipher1 = enc._get_cipher()
|
||||
cipher2 = enc._get_cipher()
|
||||
assert cipher1 is cipher2
|
||||
|
||||
|
||||
class TestEncryptDecrypt:
|
||||
"""Tests for encrypt_value and decrypt_value methods."""
|
||||
|
||||
@pytest.fixture
|
||||
def encryption(self, tmp_path: Path) -> ConfigEncryption:
|
||||
"""Create a ConfigEncryption instance with temp key file."""
|
||||
return ConfigEncryption(key_file=tmp_path / "encryption.key")
|
||||
|
||||
def test_encrypt_decrypt_roundtrip(self, encryption: ConfigEncryption):
|
||||
"""Encrypting then decrypting returns the original value."""
|
||||
original = "my_secret_password"
|
||||
encrypted = encryption.encrypt_value(original)
|
||||
decrypted = encryption.decrypt_value(encrypted)
|
||||
assert decrypted == original
|
||||
|
||||
def test_encrypt_value_returns_string(self, encryption: ConfigEncryption):
|
||||
"""encrypt_value returns a base64 string, not bytes."""
|
||||
result = encryption.encrypt_value("test")
|
||||
assert isinstance(result, str)
|
||||
|
||||
def test_encrypt_value_raises_on_empty(self, encryption: ConfigEncryption):
|
||||
"""encrypt_value raises ValueError for empty string."""
|
||||
with pytest.raises(ValueError, match="Cannot encrypt empty value"):
|
||||
encryption.encrypt_value("")
|
||||
|
||||
def test_decrypt_value_raises_on_empty(self, encryption: ConfigEncryption):
|
||||
"""decrypt_value raises ValueError for empty string."""
|
||||
with pytest.raises(ValueError, match="Cannot decrypt empty value"):
|
||||
encryption.decrypt_value("")
|
||||
|
||||
def test_decrypt_with_wrong_key_fails(self, tmp_path: Path):
|
||||
"""Decrypting with a different key raises an exception."""
|
||||
enc1 = ConfigEncryption(key_file=tmp_path / "key1.key")
|
||||
enc2 = ConfigEncryption(key_file=tmp_path / "key2.key")
|
||||
encrypted = enc1.encrypt_value("secret")
|
||||
with pytest.raises(Exception):
|
||||
enc2.decrypt_value(encrypted)
|
||||
|
||||
def test_encrypt_produces_different_ciphertext_each_time(
|
||||
self, encryption: ConfigEncryption
|
||||
):
|
||||
"""Two encryptions of same value produce different ciphertext (nonce)."""
|
||||
val = "same_value"
|
||||
enc1 = encryption.encrypt_value(val)
|
||||
enc2 = encryption.encrypt_value(val)
|
||||
assert enc1 != enc2
|
||||
|
||||
def test_unicode_value_roundtrip(self, encryption: ConfigEncryption):
|
||||
"""Unicode strings survive encrypt/decrypt roundtrip."""
|
||||
original = "Ünïcödé_тест_日本語"
|
||||
encrypted = encryption.encrypt_value(original)
|
||||
assert encryption.decrypt_value(encrypted) == original
|
||||
|
||||
def test_special_characters_roundtrip(self, encryption: ConfigEncryption):
|
||||
"""Special characters survive encrypt/decrypt roundtrip."""
|
||||
original = "p@$$w0rd!#%^&*(){}[]|\\:\";<>?/~`"
|
||||
encrypted = encryption.encrypt_value(original)
|
||||
assert encryption.decrypt_value(encrypted) == original
|
||||
|
||||
def test_large_value_roundtrip(self, encryption: ConfigEncryption):
|
||||
"""Large string values can be encrypted and decrypted."""
|
||||
original = "A" * 100_000
|
||||
encrypted = encryption.encrypt_value(original)
|
||||
assert encryption.decrypt_value(encrypted) == original
|
||||
|
||||
|
||||
class TestConfigDictEncryption:
|
||||
"""Tests for encrypt_config and decrypt_config methods."""
|
||||
|
||||
@pytest.fixture
|
||||
def encryption(self, tmp_path: Path) -> ConfigEncryption:
|
||||
"""Create a ConfigEncryption instance with temp key file."""
|
||||
return ConfigEncryption(key_file=tmp_path / "encryption.key")
|
||||
|
||||
def test_encrypt_config_encrypts_sensitive_fields(
|
||||
self, encryption: ConfigEncryption
|
||||
):
|
||||
"""Sensitive fields (password, token, etc.) are encrypted."""
|
||||
config = {
|
||||
"password": "secret123",
|
||||
"api_key": "abcdef",
|
||||
"username": "admin",
|
||||
}
|
||||
encrypted = encryption.encrypt_config(config)
|
||||
|
||||
# password and api_key should be wrapped
|
||||
assert isinstance(encrypted["password"], dict)
|
||||
assert encrypted["password"]["encrypted"] is True
|
||||
assert isinstance(encrypted["api_key"], dict)
|
||||
assert encrypted["api_key"]["encrypted"] is True
|
||||
|
||||
# username should not be encrypted
|
||||
assert encrypted["username"] == "admin"
|
||||
|
||||
def test_encrypt_config_skips_non_string_sensitive_values(
|
||||
self, encryption: ConfigEncryption
|
||||
):
|
||||
"""Non-string sensitive values (int, None) are left unencrypted."""
|
||||
config = {
|
||||
"password": 12345,
|
||||
"token": None,
|
||||
}
|
||||
encrypted = encryption.encrypt_config(config)
|
||||
assert encrypted["password"] == 12345
|
||||
assert encrypted["token"] is None
|
||||
|
||||
def test_encrypt_config_skips_empty_string(
|
||||
self, encryption: ConfigEncryption
|
||||
):
|
||||
"""Empty string sensitive values are left unencrypted."""
|
||||
config = {"password": ""}
|
||||
encrypted = encryption.encrypt_config(config)
|
||||
assert encrypted["password"] == ""
|
||||
|
||||
def test_decrypt_config_decrypts_encrypted_fields(
|
||||
self, encryption: ConfigEncryption
|
||||
):
|
||||
"""decrypt_config restores encrypted fields to plaintext."""
|
||||
config = {
|
||||
"password": "secret123",
|
||||
"username": "admin",
|
||||
}
|
||||
encrypted = encryption.encrypt_config(config)
|
||||
decrypted = encryption.decrypt_config(encrypted)
|
||||
assert decrypted["password"] == "secret123"
|
||||
assert decrypted["username"] == "admin"
|
||||
|
||||
def test_decrypt_config_passes_through_non_encrypted(
|
||||
self, encryption: ConfigEncryption
|
||||
):
|
||||
"""Non-encrypted fields in config are passed through unchanged."""
|
||||
config = {"host": "localhost", "port": 8080}
|
||||
decrypted = encryption.decrypt_config(config)
|
||||
assert decrypted == config
|
||||
|
||||
def test_decrypt_config_handles_decrypt_failure(
|
||||
self, encryption: ConfigEncryption
|
||||
):
|
||||
"""Failed decryption sets field value to None."""
|
||||
config = {
|
||||
"password": {
|
||||
"encrypted": True,
|
||||
"value": "invalid_base64_data!!!"
|
||||
}
|
||||
}
|
||||
decrypted = encryption.decrypt_config(config)
|
||||
assert decrypted["password"] is None
|
||||
|
||||
def test_sensitive_field_detection_case_insensitive(
|
||||
self, encryption: ConfigEncryption
|
||||
):
|
||||
"""Sensitive field detection works case-insensitively."""
|
||||
config = {
|
||||
"PASSWORD": "secret",
|
||||
"Api_Key": "key123",
|
||||
"JWT_SECRET": "jwt_val",
|
||||
}
|
||||
encrypted = encryption.encrypt_config(config)
|
||||
for key in ("PASSWORD", "Api_Key", "JWT_SECRET"):
|
||||
assert isinstance(encrypted[key], dict)
|
||||
assert encrypted[key]["encrypted"] is True
|
||||
|
||||
def test_all_sensitive_field_patterns_detected(
|
||||
self, encryption: ConfigEncryption
|
||||
):
|
||||
"""All defined sensitive field patterns trigger encryption."""
|
||||
patterns = [
|
||||
"password", "passwd", "secret", "key", "token",
|
||||
"api_key", "apikey", "auth_token", "jwt_secret",
|
||||
"master_password",
|
||||
]
|
||||
config = {p: f"value_{p}" for p in patterns}
|
||||
encrypted = encryption.encrypt_config(config)
|
||||
for p in patterns:
|
||||
assert isinstance(encrypted[p], dict), (
|
||||
f"Field '{p}' was not encrypted"
|
||||
)
|
||||
|
||||
|
||||
class TestKeyRotation:
|
||||
"""Tests for key rotation functionality."""
|
||||
|
||||
def test_rotate_key_generates_new_key(self, tmp_path: Path):
|
||||
"""rotate_key generates a new encryption key file."""
|
||||
key_file = tmp_path / "encryption.key"
|
||||
enc = ConfigEncryption(key_file=key_file)
|
||||
old_key = key_file.read_bytes()
|
||||
enc.rotate_key()
|
||||
new_key = key_file.read_bytes()
|
||||
assert old_key != new_key
|
||||
|
||||
def test_rotate_key_backs_up_old_key(self, tmp_path: Path):
|
||||
"""rotate_key creates a .key.bak backup of the old key."""
|
||||
key_file = tmp_path / "encryption.key"
|
||||
enc = ConfigEncryption(key_file=key_file)
|
||||
old_key = key_file.read_bytes()
|
||||
enc.rotate_key()
|
||||
backup = tmp_path / "encryption.key.bak"
|
||||
assert backup.exists()
|
||||
assert backup.read_bytes() == old_key
|
||||
|
||||
def test_rotate_key_resets_cipher_cache(self, tmp_path: Path):
|
||||
"""rotate_key clears the cached cipher so new key is used."""
|
||||
key_file = tmp_path / "encryption.key"
|
||||
enc = ConfigEncryption(key_file=key_file)
|
||||
_ = enc._get_cipher()
|
||||
assert enc._cipher is not None
|
||||
enc.rotate_key()
|
||||
assert enc._cipher is None
|
||||
|
||||
def test_rotate_key_invalidates_old_encrypted_data(self, tmp_path: Path):
|
||||
"""Data encrypted with old key cannot be decrypted after rotation."""
|
||||
key_file = tmp_path / "encryption.key"
|
||||
enc = ConfigEncryption(key_file=key_file)
|
||||
encrypted = enc.encrypt_value("my_secret")
|
||||
enc.rotate_key()
|
||||
with pytest.raises(Exception):
|
||||
enc.decrypt_value(encrypted)
|
||||
|
||||
def test_rotate_key_to_custom_path(self, tmp_path: Path):
|
||||
"""rotate_key can target a custom key file path."""
|
||||
key_file = tmp_path / "old.key"
|
||||
new_key_file = tmp_path / "new.key"
|
||||
enc = ConfigEncryption(key_file=key_file)
|
||||
enc.rotate_key(new_key_file=new_key_file)
|
||||
assert new_key_file.exists()
|
||||
assert enc.key_file == new_key_file
|
||||
|
||||
|
||||
class TestGlobalSingleton:
|
||||
"""Tests for the get_config_encryption singleton function."""
|
||||
|
||||
def test_get_config_encryption_returns_instance(self):
|
||||
"""get_config_encryption returns a ConfigEncryption instance."""
|
||||
import src.infrastructure.security.config_encryption as mod
|
||||
old = mod._config_encryption
|
||||
try:
|
||||
mod._config_encryption = None
|
||||
with patch.object(ConfigEncryption, "__init__", return_value=None):
|
||||
instance = get_config_encryption()
|
||||
assert isinstance(instance, ConfigEncryption)
|
||||
finally:
|
||||
mod._config_encryption = old
|
||||
|
||||
def test_get_config_encryption_returns_same_instance(self):
|
||||
"""Repeated calls return the same singleton instance."""
|
||||
import src.infrastructure.security.config_encryption as mod
|
||||
old = mod._config_encryption
|
||||
try:
|
||||
mod._config_encryption = None
|
||||
with patch.object(ConfigEncryption, "__init__", return_value=None):
|
||||
inst1 = get_config_encryption()
|
||||
inst2 = get_config_encryption()
|
||||
assert inst1 is inst2
|
||||
finally:
|
||||
mod._config_encryption = old
|
||||
381
tests/unit/test_database_integrity.py
Normal file
381
tests/unit/test_database_integrity.py
Normal file
@@ -0,0 +1,381 @@
|
||||
"""Unit tests for database integrity checker module.
|
||||
|
||||
Tests database integrity validation, orphaned record detection,
|
||||
duplicate key checks, and data consistency verification.
|
||||
|
||||
NOTE: The database_integrity.py module has bugs in raw SQL queries:
|
||||
- _check_invalid_references uses table 'episode' but actual name is 'episodes'
|
||||
- _check_invalid_references uses table 'download_queue_item' but actual is 'download_queue'
|
||||
- _check_duplicate_keys uses column 'anime_key' but actual column is 'key'
|
||||
These bugs cause OperationalError at runtime. Tests document this behavior.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from src.server.database.base import Base
|
||||
from src.server.database.models import AnimeSeries, DownloadQueueItem, Episode
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def engine():
|
||||
"""Create an in-memory SQLite database engine with schema."""
|
||||
eng = create_engine("sqlite:///:memory:")
|
||||
Base.metadata.create_all(eng)
|
||||
return eng
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session(engine):
|
||||
"""Create a database session for testing."""
|
||||
SessionLocal = sessionmaker(bind=engine)
|
||||
sess = SessionLocal()
|
||||
yield sess
|
||||
sess.close()
|
||||
|
||||
|
||||
def _make_series(session: Session, key: str = "test-anime", name: str = "Test Anime") -> AnimeSeries:
|
||||
"""Helper to create and persist an AnimeSeries record."""
|
||||
series = AnimeSeries(
|
||||
key=key,
|
||||
name=name,
|
||||
site="https://aniworld.to/anime/stream/test-anime",
|
||||
folder="Test Anime (2024)",
|
||||
)
|
||||
session.add(series)
|
||||
session.commit()
|
||||
session.refresh(series)
|
||||
return series
|
||||
|
||||
|
||||
def _make_episode(session: Session, series_id: int, season: int = 1, ep: int = 1) -> Episode:
|
||||
"""Helper to create and persist an Episode record."""
|
||||
episode = Episode(
|
||||
series_id=series_id,
|
||||
season=season,
|
||||
episode_number=ep,
|
||||
title=f"Episode {ep}",
|
||||
)
|
||||
session.add(episode)
|
||||
session.commit()
|
||||
session.refresh(episode)
|
||||
return episode
|
||||
|
||||
|
||||
class TestDatabaseIntegrityCheckerInit:
|
||||
"""Tests for DatabaseIntegrityChecker initialization."""
|
||||
|
||||
def test_init_with_session(self, session: Session):
|
||||
"""Checker initializes with provided session."""
|
||||
from src.infrastructure.security.database_integrity import DatabaseIntegrityChecker
|
||||
checker = DatabaseIntegrityChecker(session=session)
|
||||
assert checker.session is session
|
||||
assert checker.issues == []
|
||||
|
||||
def test_init_without_session(self):
|
||||
"""Checker can be created without a session."""
|
||||
from src.infrastructure.security.database_integrity import DatabaseIntegrityChecker
|
||||
checker = DatabaseIntegrityChecker()
|
||||
assert checker.session is None
|
||||
|
||||
def test_check_all_requires_session(self):
|
||||
"""check_all raises ValueError when no session is set."""
|
||||
from src.infrastructure.security.database_integrity import DatabaseIntegrityChecker
|
||||
checker = DatabaseIntegrityChecker()
|
||||
with pytest.raises(ValueError, match="Session required"):
|
||||
checker.check_all()
|
||||
|
||||
|
||||
class TestOrphanedEpisodes:
|
||||
"""Tests for orphaned episode detection."""
|
||||
|
||||
def test_no_orphaned_episodes_clean_db(self, session: Session):
|
||||
"""Returns 0 when all episodes have parent series."""
|
||||
from src.infrastructure.security.database_integrity import DatabaseIntegrityChecker
|
||||
series = _make_series(session)
|
||||
_make_episode(session, series.id)
|
||||
checker = DatabaseIntegrityChecker(session=session)
|
||||
count = checker._check_orphaned_episodes()
|
||||
assert count == 0
|
||||
assert len(checker.issues) == 0
|
||||
|
||||
def test_no_episodes_returns_zero(self, session: Session):
|
||||
"""Returns 0 when no episodes exist at all."""
|
||||
from src.infrastructure.security.database_integrity import DatabaseIntegrityChecker
|
||||
checker = DatabaseIntegrityChecker(session=session)
|
||||
count = checker._check_orphaned_episodes()
|
||||
assert count == 0
|
||||
|
||||
def test_detects_orphaned_episodes(self, session: Session):
|
||||
"""Detects episodes whose series_id references nonexistent series."""
|
||||
from src.infrastructure.security.database_integrity import DatabaseIntegrityChecker
|
||||
series = _make_series(session)
|
||||
_make_episode(session, series.id, season=1, ep=1)
|
||||
_make_episode(session, series.id, season=1, ep=2)
|
||||
|
||||
# Delete the series directly to create orphans
|
||||
session.execute(
|
||||
text("DELETE FROM anime_series WHERE id = :id"),
|
||||
{"id": series.id},
|
||||
)
|
||||
session.commit()
|
||||
|
||||
checker = DatabaseIntegrityChecker(session=session)
|
||||
count = checker._check_orphaned_episodes()
|
||||
assert count == 2
|
||||
assert any("orphaned episodes" in issue for issue in checker.issues)
|
||||
|
||||
|
||||
class TestOrphanedQueueItems:
|
||||
"""Tests for orphaned download queue item detection."""
|
||||
|
||||
def test_no_orphaned_queue_items(self, session: Session):
|
||||
"""Returns 0 when all queue items have parent series."""
|
||||
from src.infrastructure.security.database_integrity import DatabaseIntegrityChecker
|
||||
series = _make_series(session)
|
||||
episode = _make_episode(session, series.id)
|
||||
item = DownloadQueueItem(
|
||||
series_id=series.id,
|
||||
episode_id=episode.id,
|
||||
)
|
||||
session.add(item)
|
||||
session.commit()
|
||||
|
||||
checker = DatabaseIntegrityChecker(session=session)
|
||||
count = checker._check_orphaned_queue_items()
|
||||
assert count == 0
|
||||
|
||||
def test_detects_orphaned_queue_items(self, session: Session):
|
||||
"""Detects queue items whose series references no longer exist."""
|
||||
from src.infrastructure.security.database_integrity import DatabaseIntegrityChecker
|
||||
series = _make_series(session)
|
||||
episode = _make_episode(session, series.id)
|
||||
item = DownloadQueueItem(
|
||||
series_id=series.id,
|
||||
episode_id=episode.id,
|
||||
)
|
||||
session.add(item)
|
||||
session.commit()
|
||||
|
||||
# Remove series but keep orphaned items via raw SQL
|
||||
session.execute(text("DELETE FROM anime_series WHERE id = :id"), {"id": series.id})
|
||||
session.commit()
|
||||
|
||||
checker = DatabaseIntegrityChecker(session=session)
|
||||
count = checker._check_orphaned_queue_items()
|
||||
assert count == 1
|
||||
assert any("orphaned queue" in issue for issue in checker.issues)
|
||||
|
||||
|
||||
class TestInvalidReferences:
|
||||
"""Tests for invalid foreign key reference detection.
|
||||
|
||||
NOTE: The raw SQL in _check_invalid_references uses wrong table names:
|
||||
- 'episode' instead of 'episodes'
|
||||
- 'download_queue_item' instead of 'download_queue'
|
||||
This causes OperationalError in SQLite.
|
||||
"""
|
||||
|
||||
def test_invalid_references_raw_sql_uses_wrong_table_names(
|
||||
self, session: Session
|
||||
):
|
||||
"""BUG: Raw SQL references 'episode' and 'download_queue_item'
|
||||
but actual table names are 'episodes' and 'download_queue'.
|
||||
This causes the check to error and return -1."""
|
||||
from src.infrastructure.security.database_integrity import DatabaseIntegrityChecker
|
||||
checker = DatabaseIntegrityChecker(session=session)
|
||||
result = checker._check_invalid_references()
|
||||
# Returns -1 because the SQL fails with OperationalError
|
||||
assert result == -1
|
||||
assert any("Error checking invalid references" in i for i in checker.issues)
|
||||
|
||||
|
||||
class TestDuplicateKeys:
|
||||
"""Tests for duplicate primary key detection.
|
||||
|
||||
NOTE: The raw SQL in _check_duplicate_keys references column 'anime_key'
|
||||
but the actual column name is 'key'. This causes OperationalError.
|
||||
"""
|
||||
|
||||
def test_duplicate_keys_raw_sql_uses_wrong_column_name(
|
||||
self, session: Session
|
||||
):
|
||||
"""BUG: Raw SQL references 'anime_key' column but actual column
|
||||
is 'key'. This causes the check to error and return -1."""
|
||||
from src.infrastructure.security.database_integrity import DatabaseIntegrityChecker
|
||||
checker = DatabaseIntegrityChecker(session=session)
|
||||
result = checker._check_duplicate_keys()
|
||||
# Returns -1 because the SQL fails on nonexistent column
|
||||
assert result == -1
|
||||
assert any("Error checking duplicate keys" in i for i in checker.issues)
|
||||
|
||||
|
||||
class TestDataConsistency:
|
||||
"""Tests for data consistency validation."""
|
||||
|
||||
def test_no_consistency_issues_clean_data(self, session: Session):
|
||||
"""Returns 0 for valid episode data."""
|
||||
from src.infrastructure.security.database_integrity import DatabaseIntegrityChecker
|
||||
series = _make_series(session)
|
||||
_make_episode(session, series.id, season=1, ep=1)
|
||||
checker = DatabaseIntegrityChecker(session=session)
|
||||
count = checker._check_data_consistency()
|
||||
assert count == 0
|
||||
|
||||
def test_detects_negative_season_number(self, session: Session):
|
||||
"""Detects episodes with negative season numbers."""
|
||||
from src.infrastructure.security.database_integrity import DatabaseIntegrityChecker
|
||||
series = _make_series(session)
|
||||
# Insert invalid episode bypassing ORM validation
|
||||
session.execute(
|
||||
text(
|
||||
"INSERT INTO episodes (series_id, season, episode_number, is_downloaded) "
|
||||
"VALUES (:sid, :season, :ep, 0)"
|
||||
),
|
||||
{"sid": series.id, "season": -1, "ep": 1},
|
||||
)
|
||||
session.commit()
|
||||
|
||||
checker = DatabaseIntegrityChecker(session=session)
|
||||
count = checker._check_data_consistency()
|
||||
assert count == 1
|
||||
assert any("invalid" in i.lower() for i in checker.issues)
|
||||
|
||||
def test_detects_negative_episode_number(self, session: Session):
|
||||
"""Detects episodes with negative episode numbers."""
|
||||
from src.infrastructure.security.database_integrity import DatabaseIntegrityChecker
|
||||
series = _make_series(session)
|
||||
session.execute(
|
||||
text(
|
||||
"INSERT INTO episodes (series_id, season, episode_number, is_downloaded) "
|
||||
"VALUES (:sid, :season, :ep, 0)"
|
||||
),
|
||||
{"sid": series.id, "season": 1, "ep": -5},
|
||||
)
|
||||
session.commit()
|
||||
|
||||
checker = DatabaseIntegrityChecker(session=session)
|
||||
count = checker._check_data_consistency()
|
||||
assert count == 1
|
||||
|
||||
def test_empty_database_no_issues(self, session: Session):
|
||||
"""Empty database reports no consistency issues."""
|
||||
from src.infrastructure.security.database_integrity import DatabaseIntegrityChecker
|
||||
checker = DatabaseIntegrityChecker(session=session)
|
||||
count = checker._check_data_consistency()
|
||||
assert count == 0
|
||||
|
||||
|
||||
class TestCheckAll:
|
||||
"""Tests for the check_all aggregation method."""
|
||||
|
||||
def test_check_all_returns_dict_with_expected_keys(self, session: Session):
|
||||
"""check_all returns result dict with all expected keys."""
|
||||
from src.infrastructure.security.database_integrity import DatabaseIntegrityChecker
|
||||
checker = DatabaseIntegrityChecker(session=session)
|
||||
results = checker.check_all()
|
||||
expected_keys = {
|
||||
"orphaned_episodes",
|
||||
"orphaned_queue_items",
|
||||
"invalid_references",
|
||||
"duplicate_keys",
|
||||
"data_consistency",
|
||||
"total_issues",
|
||||
"issues",
|
||||
}
|
||||
assert set(results.keys()) == expected_keys
|
||||
|
||||
def test_check_all_aggregates_issues(self, session: Session):
|
||||
"""check_all total_issues reflects all discovered issues."""
|
||||
from src.infrastructure.security.database_integrity import DatabaseIntegrityChecker
|
||||
checker = DatabaseIntegrityChecker(session=session)
|
||||
results = checker.check_all()
|
||||
# At minimum, invalid_references and duplicate_keys fail due to SQL bugs
|
||||
assert results["total_issues"] >= 2
|
||||
assert len(results["issues"]) >= 2
|
||||
|
||||
def test_check_all_resets_issues_list(self, session: Session):
|
||||
"""check_all clears previous issues before running."""
|
||||
from src.infrastructure.security.database_integrity import DatabaseIntegrityChecker
|
||||
checker = DatabaseIntegrityChecker(session=session)
|
||||
checker.issues = ["leftover issue"]
|
||||
checker.check_all()
|
||||
# Issues list should not contain the old "leftover issue"
|
||||
assert "leftover issue" not in checker.issues
|
||||
|
||||
|
||||
class TestRepairOrphanedRecords:
|
||||
"""Tests for repair_orphaned_records method."""
|
||||
|
||||
def test_repair_requires_session(self):
|
||||
"""repair_orphaned_records raises ValueError without session."""
|
||||
from src.infrastructure.security.database_integrity import DatabaseIntegrityChecker
|
||||
checker = DatabaseIntegrityChecker()
|
||||
with pytest.raises(ValueError, match="Session required"):
|
||||
checker.repair_orphaned_records()
|
||||
|
||||
def test_repair_removes_orphaned_episodes(self, session: Session):
|
||||
"""repair_orphaned_records removes episodes without parent series."""
|
||||
from src.infrastructure.security.database_integrity import DatabaseIntegrityChecker
|
||||
series = _make_series(session)
|
||||
_make_episode(session, series.id, season=1, ep=1)
|
||||
_make_episode(session, series.id, season=1, ep=2)
|
||||
|
||||
# Create orphans
|
||||
session.execute(
|
||||
text("DELETE FROM anime_series WHERE id = :id"),
|
||||
{"id": series.id},
|
||||
)
|
||||
session.commit()
|
||||
|
||||
checker = DatabaseIntegrityChecker(session=session)
|
||||
removed = checker.repair_orphaned_records()
|
||||
assert removed == 2
|
||||
|
||||
# Verify episodes are gone
|
||||
count = session.execute(text("SELECT COUNT(*) FROM episodes")).scalar()
|
||||
assert count == 0
|
||||
|
||||
def test_repair_removes_orphaned_queue_items(self, session: Session):
|
||||
"""repair_orphaned_records removes queue items without parent series."""
|
||||
from src.infrastructure.security.database_integrity import DatabaseIntegrityChecker
|
||||
series = _make_series(session)
|
||||
episode = _make_episode(session, series.id)
|
||||
item = DownloadQueueItem(series_id=series.id, episode_id=episode.id)
|
||||
session.add(item)
|
||||
session.commit()
|
||||
|
||||
session.execute(
|
||||
text("DELETE FROM anime_series WHERE id = :id"),
|
||||
{"id": series.id},
|
||||
)
|
||||
session.commit()
|
||||
|
||||
checker = DatabaseIntegrityChecker(session=session)
|
||||
removed = checker.repair_orphaned_records()
|
||||
assert removed >= 1
|
||||
|
||||
def test_repair_no_orphans_returns_zero(self, session: Session):
|
||||
"""repair_orphaned_records returns 0 when no orphans exist."""
|
||||
from src.infrastructure.security.database_integrity import DatabaseIntegrityChecker
|
||||
series = _make_series(session)
|
||||
_make_episode(session, series.id)
|
||||
|
||||
checker = DatabaseIntegrityChecker(session=session)
|
||||
removed = checker.repair_orphaned_records()
|
||||
assert removed == 0
|
||||
|
||||
|
||||
class TestConvenienceFunction:
|
||||
"""Tests for check_database_integrity convenience function."""
|
||||
|
||||
def test_check_database_integrity_returns_results(self, session: Session):
|
||||
"""Convenience function returns check results dict."""
|
||||
from src.infrastructure.security.database_integrity import check_database_integrity
|
||||
results = check_database_integrity(session)
|
||||
assert isinstance(results, dict)
|
||||
assert "total_issues" in results
|
||||
assert "issues" in results
|
||||
Reference in New Issue
Block a user