Files
Aniworld/tests/unit/test_config_encryption.py

366 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Unit tests for configuration encryption module.
Tests encryption/decryption of sensitive configuration values,
key management, and configuration dictionary encryption.
"""
import os
import stat
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from cryptography.fernet import Fernet
from src.infrastructure.security.config_encryption import (
ConfigEncryption,
get_config_encryption,
)
class TestConfigEncryptionInit:
"""Tests for ConfigEncryption initialization."""
def test_creates_key_file_if_not_exists(self, tmp_path: Path):
"""Key file is generated on init when it doesn't exist."""
key_file = tmp_path / "encryption.key"
enc = ConfigEncryption(key_file=key_file)
assert key_file.exists()
# Key should be valid Fernet key
key = key_file.read_bytes()
Fernet(key) # Raises if invalid
def test_uses_existing_key_file(self, tmp_path: Path):
"""Existing key file is reused without regeneration."""
key_file = tmp_path / "encryption.key"
original_key = Fernet.generate_key()
key_file.write_bytes(original_key)
enc = ConfigEncryption(key_file=key_file)
loaded_key = key_file.read_bytes()
assert loaded_key == original_key
def test_default_key_file_path(self):
"""Default key_file path points to data/encryption.key."""
with patch.object(ConfigEncryption, "_ensure_key_exists"):
enc = ConfigEncryption.__new__(ConfigEncryption)
enc._cipher = None
project_root = Path(__file__).parent.parent.parent
expected = project_root / "data" / "encryption.key"
# Just verify the constructor logic sets a reasonable default
enc2 = ConfigEncryption.__new__(ConfigEncryption)
enc2._cipher = None
enc2.key_file = expected
assert "encryption.key" in str(enc2.key_file)
def test_key_file_permissions_set_600(self, tmp_path: Path):
"""Generated key file should have owner-only permissions (0o600)."""
key_file = tmp_path / "encryption.key"
ConfigEncryption(key_file=key_file)
mode = oct(os.stat(key_file).st_mode & 0o777)
assert mode == "0o600"
class TestKeyManagement:
"""Tests for key generation, loading, and cipher creation."""
def test_generate_new_key_creates_valid_fernet_key(self, tmp_path: Path):
"""Generated key is a valid Fernet key."""
key_file = tmp_path / "encryption.key"
enc = ConfigEncryption(key_file=key_file)
key = key_file.read_bytes()
cipher = Fernet(key)
# Should be able to encrypt/decrypt
token = cipher.encrypt(b"test")
assert cipher.decrypt(token) == b"test"
def test_load_key_returns_bytes(self, tmp_path: Path):
"""_load_key returns raw bytes from key file."""
key_file = tmp_path / "encryption.key"
enc = ConfigEncryption(key_file=key_file)
key = enc._load_key()
assert isinstance(key, bytes)
assert len(key) > 0
def test_load_key_raises_if_file_missing(self, tmp_path: Path):
"""_load_key raises FileNotFoundError when key file is absent."""
key_file = tmp_path / "nonexistent.key"
enc = ConfigEncryption.__new__(ConfigEncryption)
enc._cipher = None
enc.key_file = key_file
with pytest.raises(FileNotFoundError):
enc._load_key()
def test_get_cipher_returns_fernet(self, tmp_path: Path):
"""_get_cipher returns a Fernet instance."""
key_file = tmp_path / "encryption.key"
enc = ConfigEncryption(key_file=key_file)
cipher = enc._get_cipher()
assert isinstance(cipher, Fernet)
def test_get_cipher_caches_instance(self, tmp_path: Path):
"""_get_cipher returns the same Fernet instance on repeated calls."""
key_file = tmp_path / "encryption.key"
enc = ConfigEncryption(key_file=key_file)
cipher1 = enc._get_cipher()
cipher2 = enc._get_cipher()
assert cipher1 is cipher2
class TestEncryptDecrypt:
"""Tests for encrypt_value and decrypt_value methods."""
@pytest.fixture
def encryption(self, tmp_path: Path) -> ConfigEncryption:
"""Create a ConfigEncryption instance with temp key file."""
return ConfigEncryption(key_file=tmp_path / "encryption.key")
def test_encrypt_decrypt_roundtrip(self, encryption: ConfigEncryption):
"""Encrypting then decrypting returns the original value."""
original = "my_secret_password"
encrypted = encryption.encrypt_value(original)
decrypted = encryption.decrypt_value(encrypted)
assert decrypted == original
def test_encrypt_value_returns_string(self, encryption: ConfigEncryption):
"""encrypt_value returns a base64 string, not bytes."""
result = encryption.encrypt_value("test")
assert isinstance(result, str)
def test_encrypt_value_raises_on_empty(self, encryption: ConfigEncryption):
"""encrypt_value raises ValueError for empty string."""
with pytest.raises(ValueError, match="Cannot encrypt empty value"):
encryption.encrypt_value("")
def test_decrypt_value_raises_on_empty(self, encryption: ConfigEncryption):
"""decrypt_value raises ValueError for empty string."""
with pytest.raises(ValueError, match="Cannot decrypt empty value"):
encryption.decrypt_value("")
def test_decrypt_with_wrong_key_fails(self, tmp_path: Path):
"""Decrypting with a different key raises an exception."""
enc1 = ConfigEncryption(key_file=tmp_path / "key1.key")
enc2 = ConfigEncryption(key_file=tmp_path / "key2.key")
encrypted = enc1.encrypt_value("secret")
with pytest.raises(Exception):
enc2.decrypt_value(encrypted)
def test_encrypt_produces_different_ciphertext_each_time(
self, encryption: ConfigEncryption
):
"""Two encryptions of same value produce different ciphertext (nonce)."""
val = "same_value"
enc1 = encryption.encrypt_value(val)
enc2 = encryption.encrypt_value(val)
assert enc1 != enc2
def test_unicode_value_roundtrip(self, encryption: ConfigEncryption):
"""Unicode strings survive encrypt/decrypt roundtrip."""
original = "Ünïcödé_тест_日本語"
encrypted = encryption.encrypt_value(original)
assert encryption.decrypt_value(encrypted) == original
def test_special_characters_roundtrip(self, encryption: ConfigEncryption):
"""Special characters survive encrypt/decrypt roundtrip."""
original = "p@$$w0rd!#%^&*(){}[]|\\:\";<>?/~`"
encrypted = encryption.encrypt_value(original)
assert encryption.decrypt_value(encrypted) == original
def test_large_value_roundtrip(self, encryption: ConfigEncryption):
"""Large string values can be encrypted and decrypted."""
original = "A" * 100_000
encrypted = encryption.encrypt_value(original)
assert encryption.decrypt_value(encrypted) == original
class TestConfigDictEncryption:
"""Tests for encrypt_config and decrypt_config methods."""
@pytest.fixture
def encryption(self, tmp_path: Path) -> ConfigEncryption:
"""Create a ConfigEncryption instance with temp key file."""
return ConfigEncryption(key_file=tmp_path / "encryption.key")
def test_encrypt_config_encrypts_sensitive_fields(
self, encryption: ConfigEncryption
):
"""Sensitive fields (password, token, etc.) are encrypted."""
config = {
"password": "secret123",
"api_key": "abcdef",
"username": "admin",
}
encrypted = encryption.encrypt_config(config)
# password and api_key should be wrapped
assert isinstance(encrypted["password"], dict)
assert encrypted["password"]["encrypted"] is True
assert isinstance(encrypted["api_key"], dict)
assert encrypted["api_key"]["encrypted"] is True
# username should not be encrypted
assert encrypted["username"] == "admin"
def test_encrypt_config_skips_non_string_sensitive_values(
self, encryption: ConfigEncryption
):
"""Non-string sensitive values (int, None) are left unencrypted."""
config = {
"password": 12345,
"token": None,
}
encrypted = encryption.encrypt_config(config)
assert encrypted["password"] == 12345
assert encrypted["token"] is None
def test_encrypt_config_skips_empty_string(
self, encryption: ConfigEncryption
):
"""Empty string sensitive values are left unencrypted."""
config = {"password": ""}
encrypted = encryption.encrypt_config(config)
assert encrypted["password"] == ""
def test_decrypt_config_decrypts_encrypted_fields(
self, encryption: ConfigEncryption
):
"""decrypt_config restores encrypted fields to plaintext."""
config = {
"password": "secret123",
"username": "admin",
}
encrypted = encryption.encrypt_config(config)
decrypted = encryption.decrypt_config(encrypted)
assert decrypted["password"] == "secret123"
assert decrypted["username"] == "admin"
def test_decrypt_config_passes_through_non_encrypted(
self, encryption: ConfigEncryption
):
"""Non-encrypted fields in config are passed through unchanged."""
config = {"host": "localhost", "port": 8080}
decrypted = encryption.decrypt_config(config)
assert decrypted == config
def test_decrypt_config_handles_decrypt_failure(
self, encryption: ConfigEncryption
):
"""Failed decryption sets field value to None."""
config = {
"password": {
"encrypted": True,
"value": "invalid_base64_data!!!"
}
}
decrypted = encryption.decrypt_config(config)
assert decrypted["password"] is None
def test_sensitive_field_detection_case_insensitive(
self, encryption: ConfigEncryption
):
"""Sensitive field detection works case-insensitively."""
config = {
"PASSWORD": "secret",
"Api_Key": "key123",
"JWT_SECRET": "jwt_val",
}
encrypted = encryption.encrypt_config(config)
for key in ("PASSWORD", "Api_Key", "JWT_SECRET"):
assert isinstance(encrypted[key], dict)
assert encrypted[key]["encrypted"] is True
def test_all_sensitive_field_patterns_detected(
self, encryption: ConfigEncryption
):
"""All defined sensitive field patterns trigger encryption."""
patterns = [
"password", "passwd", "secret", "key", "token",
"api_key", "apikey", "auth_token", "jwt_secret",
"master_password",
]
config = {p: f"value_{p}" for p in patterns}
encrypted = encryption.encrypt_config(config)
for p in patterns:
assert isinstance(encrypted[p], dict), (
f"Field '{p}' was not encrypted"
)
class TestKeyRotation:
"""Tests for key rotation functionality."""
def test_rotate_key_generates_new_key(self, tmp_path: Path):
"""rotate_key generates a new encryption key file."""
key_file = tmp_path / "encryption.key"
enc = ConfigEncryption(key_file=key_file)
old_key = key_file.read_bytes()
enc.rotate_key()
new_key = key_file.read_bytes()
assert old_key != new_key
def test_rotate_key_backs_up_old_key(self, tmp_path: Path):
"""rotate_key creates a .key.bak backup of the old key."""
key_file = tmp_path / "encryption.key"
enc = ConfigEncryption(key_file=key_file)
old_key = key_file.read_bytes()
enc.rotate_key()
backup = tmp_path / "encryption.key.bak"
assert backup.exists()
assert backup.read_bytes() == old_key
def test_rotate_key_resets_cipher_cache(self, tmp_path: Path):
"""rotate_key clears the cached cipher so new key is used."""
key_file = tmp_path / "encryption.key"
enc = ConfigEncryption(key_file=key_file)
_ = enc._get_cipher()
assert enc._cipher is not None
enc.rotate_key()
assert enc._cipher is None
def test_rotate_key_invalidates_old_encrypted_data(self, tmp_path: Path):
"""Data encrypted with old key cannot be decrypted after rotation."""
key_file = tmp_path / "encryption.key"
enc = ConfigEncryption(key_file=key_file)
encrypted = enc.encrypt_value("my_secret")
enc.rotate_key()
with pytest.raises(Exception):
enc.decrypt_value(encrypted)
def test_rotate_key_to_custom_path(self, tmp_path: Path):
"""rotate_key can target a custom key file path."""
key_file = tmp_path / "old.key"
new_key_file = tmp_path / "new.key"
enc = ConfigEncryption(key_file=key_file)
enc.rotate_key(new_key_file=new_key_file)
assert new_key_file.exists()
assert enc.key_file == new_key_file
class TestGlobalSingleton:
"""Tests for the get_config_encryption singleton function."""
def test_get_config_encryption_returns_instance(self):
"""get_config_encryption returns a ConfigEncryption instance."""
import src.infrastructure.security.config_encryption as mod
old = mod._config_encryption
try:
mod._config_encryption = None
with patch.object(ConfigEncryption, "__init__", return_value=None):
instance = get_config_encryption()
assert isinstance(instance, ConfigEncryption)
finally:
mod._config_encryption = old
def test_get_config_encryption_returns_same_instance(self):
"""Repeated calls return the same singleton instance."""
import src.infrastructure.security.config_encryption as mod
old = mod._config_encryption
try:
mod._config_encryption = None
with patch.object(ConfigEncryption, "__init__", return_value=None):
inst1 = get_config_encryption()
inst2 = get_config_encryption()
assert inst1 is inst2
finally:
mod._config_encryption = old