366 lines
14 KiB
Python
366 lines
14 KiB
Python
"""Unit tests for configuration encryption module.
|
||
|
||
Tests encryption/decryption of sensitive configuration values,
|
||
key management, and configuration dictionary encryption.
|
||
"""
|
||
|
||
import os
|
||
import stat
|
||
from pathlib import Path
|
||
from unittest.mock import MagicMock, patch
|
||
|
||
import pytest
|
||
from cryptography.fernet import Fernet
|
||
|
||
from src.infrastructure.security.config_encryption import (
|
||
ConfigEncryption,
|
||
get_config_encryption,
|
||
)
|
||
|
||
|
||
class TestConfigEncryptionInit:
|
||
"""Tests for ConfigEncryption initialization."""
|
||
|
||
def test_creates_key_file_if_not_exists(self, tmp_path: Path):
|
||
"""Key file is generated on init when it doesn't exist."""
|
||
key_file = tmp_path / "encryption.key"
|
||
enc = ConfigEncryption(key_file=key_file)
|
||
assert key_file.exists()
|
||
# Key should be valid Fernet key
|
||
key = key_file.read_bytes()
|
||
Fernet(key) # Raises if invalid
|
||
|
||
def test_uses_existing_key_file(self, tmp_path: Path):
|
||
"""Existing key file is reused without regeneration."""
|
||
key_file = tmp_path / "encryption.key"
|
||
original_key = Fernet.generate_key()
|
||
key_file.write_bytes(original_key)
|
||
|
||
enc = ConfigEncryption(key_file=key_file)
|
||
loaded_key = key_file.read_bytes()
|
||
assert loaded_key == original_key
|
||
|
||
def test_default_key_file_path(self):
|
||
"""Default key_file path points to data/encryption.key."""
|
||
with patch.object(ConfigEncryption, "_ensure_key_exists"):
|
||
enc = ConfigEncryption.__new__(ConfigEncryption)
|
||
enc._cipher = None
|
||
project_root = Path(__file__).parent.parent.parent
|
||
expected = project_root / "data" / "encryption.key"
|
||
# Just verify the constructor logic sets a reasonable default
|
||
enc2 = ConfigEncryption.__new__(ConfigEncryption)
|
||
enc2._cipher = None
|
||
enc2.key_file = expected
|
||
assert "encryption.key" in str(enc2.key_file)
|
||
|
||
def test_key_file_permissions_set_600(self, tmp_path: Path):
|
||
"""Generated key file should have owner-only permissions (0o600)."""
|
||
key_file = tmp_path / "encryption.key"
|
||
ConfigEncryption(key_file=key_file)
|
||
mode = oct(os.stat(key_file).st_mode & 0o777)
|
||
assert mode == "0o600"
|
||
|
||
|
||
class TestKeyManagement:
|
||
"""Tests for key generation, loading, and cipher creation."""
|
||
|
||
def test_generate_new_key_creates_valid_fernet_key(self, tmp_path: Path):
|
||
"""Generated key is a valid Fernet key."""
|
||
key_file = tmp_path / "encryption.key"
|
||
enc = ConfigEncryption(key_file=key_file)
|
||
key = key_file.read_bytes()
|
||
cipher = Fernet(key)
|
||
# Should be able to encrypt/decrypt
|
||
token = cipher.encrypt(b"test")
|
||
assert cipher.decrypt(token) == b"test"
|
||
|
||
def test_load_key_returns_bytes(self, tmp_path: Path):
|
||
"""_load_key returns raw bytes from key file."""
|
||
key_file = tmp_path / "encryption.key"
|
||
enc = ConfigEncryption(key_file=key_file)
|
||
key = enc._load_key()
|
||
assert isinstance(key, bytes)
|
||
assert len(key) > 0
|
||
|
||
def test_load_key_raises_if_file_missing(self, tmp_path: Path):
|
||
"""_load_key raises FileNotFoundError when key file is absent."""
|
||
key_file = tmp_path / "nonexistent.key"
|
||
enc = ConfigEncryption.__new__(ConfigEncryption)
|
||
enc._cipher = None
|
||
enc.key_file = key_file
|
||
with pytest.raises(FileNotFoundError):
|
||
enc._load_key()
|
||
|
||
def test_get_cipher_returns_fernet(self, tmp_path: Path):
|
||
"""_get_cipher returns a Fernet instance."""
|
||
key_file = tmp_path / "encryption.key"
|
||
enc = ConfigEncryption(key_file=key_file)
|
||
cipher = enc._get_cipher()
|
||
assert isinstance(cipher, Fernet)
|
||
|
||
def test_get_cipher_caches_instance(self, tmp_path: Path):
|
||
"""_get_cipher returns the same Fernet instance on repeated calls."""
|
||
key_file = tmp_path / "encryption.key"
|
||
enc = ConfigEncryption(key_file=key_file)
|
||
cipher1 = enc._get_cipher()
|
||
cipher2 = enc._get_cipher()
|
||
assert cipher1 is cipher2
|
||
|
||
|
||
class TestEncryptDecrypt:
|
||
"""Tests for encrypt_value and decrypt_value methods."""
|
||
|
||
@pytest.fixture
|
||
def encryption(self, tmp_path: Path) -> ConfigEncryption:
|
||
"""Create a ConfigEncryption instance with temp key file."""
|
||
return ConfigEncryption(key_file=tmp_path / "encryption.key")
|
||
|
||
def test_encrypt_decrypt_roundtrip(self, encryption: ConfigEncryption):
|
||
"""Encrypting then decrypting returns the original value."""
|
||
original = "my_secret_password"
|
||
encrypted = encryption.encrypt_value(original)
|
||
decrypted = encryption.decrypt_value(encrypted)
|
||
assert decrypted == original
|
||
|
||
def test_encrypt_value_returns_string(self, encryption: ConfigEncryption):
|
||
"""encrypt_value returns a base64 string, not bytes."""
|
||
result = encryption.encrypt_value("test")
|
||
assert isinstance(result, str)
|
||
|
||
def test_encrypt_value_raises_on_empty(self, encryption: ConfigEncryption):
|
||
"""encrypt_value raises ValueError for empty string."""
|
||
with pytest.raises(ValueError, match="Cannot encrypt empty value"):
|
||
encryption.encrypt_value("")
|
||
|
||
def test_decrypt_value_raises_on_empty(self, encryption: ConfigEncryption):
|
||
"""decrypt_value raises ValueError for empty string."""
|
||
with pytest.raises(ValueError, match="Cannot decrypt empty value"):
|
||
encryption.decrypt_value("")
|
||
|
||
def test_decrypt_with_wrong_key_fails(self, tmp_path: Path):
|
||
"""Decrypting with a different key raises an exception."""
|
||
enc1 = ConfigEncryption(key_file=tmp_path / "key1.key")
|
||
enc2 = ConfigEncryption(key_file=tmp_path / "key2.key")
|
||
encrypted = enc1.encrypt_value("secret")
|
||
with pytest.raises(Exception):
|
||
enc2.decrypt_value(encrypted)
|
||
|
||
def test_encrypt_produces_different_ciphertext_each_time(
|
||
self, encryption: ConfigEncryption
|
||
):
|
||
"""Two encryptions of same value produce different ciphertext (nonce)."""
|
||
val = "same_value"
|
||
enc1 = encryption.encrypt_value(val)
|
||
enc2 = encryption.encrypt_value(val)
|
||
assert enc1 != enc2
|
||
|
||
def test_unicode_value_roundtrip(self, encryption: ConfigEncryption):
|
||
"""Unicode strings survive encrypt/decrypt roundtrip."""
|
||
original = "Ünïcödé_тест_日本語"
|
||
encrypted = encryption.encrypt_value(original)
|
||
assert encryption.decrypt_value(encrypted) == original
|
||
|
||
def test_special_characters_roundtrip(self, encryption: ConfigEncryption):
|
||
"""Special characters survive encrypt/decrypt roundtrip."""
|
||
original = "p@$$w0rd!#%^&*(){}[]|\\:\";<>?/~`"
|
||
encrypted = encryption.encrypt_value(original)
|
||
assert encryption.decrypt_value(encrypted) == original
|
||
|
||
def test_large_value_roundtrip(self, encryption: ConfigEncryption):
|
||
"""Large string values can be encrypted and decrypted."""
|
||
original = "A" * 100_000
|
||
encrypted = encryption.encrypt_value(original)
|
||
assert encryption.decrypt_value(encrypted) == original
|
||
|
||
|
||
class TestConfigDictEncryption:
|
||
"""Tests for encrypt_config and decrypt_config methods."""
|
||
|
||
@pytest.fixture
|
||
def encryption(self, tmp_path: Path) -> ConfigEncryption:
|
||
"""Create a ConfigEncryption instance with temp key file."""
|
||
return ConfigEncryption(key_file=tmp_path / "encryption.key")
|
||
|
||
def test_encrypt_config_encrypts_sensitive_fields(
|
||
self, encryption: ConfigEncryption
|
||
):
|
||
"""Sensitive fields (password, token, etc.) are encrypted."""
|
||
config = {
|
||
"password": "secret123",
|
||
"api_key": "abcdef",
|
||
"username": "admin",
|
||
}
|
||
encrypted = encryption.encrypt_config(config)
|
||
|
||
# password and api_key should be wrapped
|
||
assert isinstance(encrypted["password"], dict)
|
||
assert encrypted["password"]["encrypted"] is True
|
||
assert isinstance(encrypted["api_key"], dict)
|
||
assert encrypted["api_key"]["encrypted"] is True
|
||
|
||
# username should not be encrypted
|
||
assert encrypted["username"] == "admin"
|
||
|
||
def test_encrypt_config_skips_non_string_sensitive_values(
|
||
self, encryption: ConfigEncryption
|
||
):
|
||
"""Non-string sensitive values (int, None) are left unencrypted."""
|
||
config = {
|
||
"password": 12345,
|
||
"token": None,
|
||
}
|
||
encrypted = encryption.encrypt_config(config)
|
||
assert encrypted["password"] == 12345
|
||
assert encrypted["token"] is None
|
||
|
||
def test_encrypt_config_skips_empty_string(
|
||
self, encryption: ConfigEncryption
|
||
):
|
||
"""Empty string sensitive values are left unencrypted."""
|
||
config = {"password": ""}
|
||
encrypted = encryption.encrypt_config(config)
|
||
assert encrypted["password"] == ""
|
||
|
||
def test_decrypt_config_decrypts_encrypted_fields(
|
||
self, encryption: ConfigEncryption
|
||
):
|
||
"""decrypt_config restores encrypted fields to plaintext."""
|
||
config = {
|
||
"password": "secret123",
|
||
"username": "admin",
|
||
}
|
||
encrypted = encryption.encrypt_config(config)
|
||
decrypted = encryption.decrypt_config(encrypted)
|
||
assert decrypted["password"] == "secret123"
|
||
assert decrypted["username"] == "admin"
|
||
|
||
def test_decrypt_config_passes_through_non_encrypted(
|
||
self, encryption: ConfigEncryption
|
||
):
|
||
"""Non-encrypted fields in config are passed through unchanged."""
|
||
config = {"host": "localhost", "port": 8080}
|
||
decrypted = encryption.decrypt_config(config)
|
||
assert decrypted == config
|
||
|
||
def test_decrypt_config_handles_decrypt_failure(
|
||
self, encryption: ConfigEncryption
|
||
):
|
||
"""Failed decryption sets field value to None."""
|
||
config = {
|
||
"password": {
|
||
"encrypted": True,
|
||
"value": "invalid_base64_data!!!"
|
||
}
|
||
}
|
||
decrypted = encryption.decrypt_config(config)
|
||
assert decrypted["password"] is None
|
||
|
||
def test_sensitive_field_detection_case_insensitive(
|
||
self, encryption: ConfigEncryption
|
||
):
|
||
"""Sensitive field detection works case-insensitively."""
|
||
config = {
|
||
"PASSWORD": "secret",
|
||
"Api_Key": "key123",
|
||
"JWT_SECRET": "jwt_val",
|
||
}
|
||
encrypted = encryption.encrypt_config(config)
|
||
for key in ("PASSWORD", "Api_Key", "JWT_SECRET"):
|
||
assert isinstance(encrypted[key], dict)
|
||
assert encrypted[key]["encrypted"] is True
|
||
|
||
def test_all_sensitive_field_patterns_detected(
|
||
self, encryption: ConfigEncryption
|
||
):
|
||
"""All defined sensitive field patterns trigger encryption."""
|
||
patterns = [
|
||
"password", "passwd", "secret", "key", "token",
|
||
"api_key", "apikey", "auth_token", "jwt_secret",
|
||
"master_password",
|
||
]
|
||
config = {p: f"value_{p}" for p in patterns}
|
||
encrypted = encryption.encrypt_config(config)
|
||
for p in patterns:
|
||
assert isinstance(encrypted[p], dict), (
|
||
f"Field '{p}' was not encrypted"
|
||
)
|
||
|
||
|
||
class TestKeyRotation:
|
||
"""Tests for key rotation functionality."""
|
||
|
||
def test_rotate_key_generates_new_key(self, tmp_path: Path):
|
||
"""rotate_key generates a new encryption key file."""
|
||
key_file = tmp_path / "encryption.key"
|
||
enc = ConfigEncryption(key_file=key_file)
|
||
old_key = key_file.read_bytes()
|
||
enc.rotate_key()
|
||
new_key = key_file.read_bytes()
|
||
assert old_key != new_key
|
||
|
||
def test_rotate_key_backs_up_old_key(self, tmp_path: Path):
|
||
"""rotate_key creates a .key.bak backup of the old key."""
|
||
key_file = tmp_path / "encryption.key"
|
||
enc = ConfigEncryption(key_file=key_file)
|
||
old_key = key_file.read_bytes()
|
||
enc.rotate_key()
|
||
backup = tmp_path / "encryption.key.bak"
|
||
assert backup.exists()
|
||
assert backup.read_bytes() == old_key
|
||
|
||
def test_rotate_key_resets_cipher_cache(self, tmp_path: Path):
|
||
"""rotate_key clears the cached cipher so new key is used."""
|
||
key_file = tmp_path / "encryption.key"
|
||
enc = ConfigEncryption(key_file=key_file)
|
||
_ = enc._get_cipher()
|
||
assert enc._cipher is not None
|
||
enc.rotate_key()
|
||
assert enc._cipher is None
|
||
|
||
def test_rotate_key_invalidates_old_encrypted_data(self, tmp_path: Path):
|
||
"""Data encrypted with old key cannot be decrypted after rotation."""
|
||
key_file = tmp_path / "encryption.key"
|
||
enc = ConfigEncryption(key_file=key_file)
|
||
encrypted = enc.encrypt_value("my_secret")
|
||
enc.rotate_key()
|
||
with pytest.raises(Exception):
|
||
enc.decrypt_value(encrypted)
|
||
|
||
def test_rotate_key_to_custom_path(self, tmp_path: Path):
|
||
"""rotate_key can target a custom key file path."""
|
||
key_file = tmp_path / "old.key"
|
||
new_key_file = tmp_path / "new.key"
|
||
enc = ConfigEncryption(key_file=key_file)
|
||
enc.rotate_key(new_key_file=new_key_file)
|
||
assert new_key_file.exists()
|
||
assert enc.key_file == new_key_file
|
||
|
||
|
||
class TestGlobalSingleton:
|
||
"""Tests for the get_config_encryption singleton function."""
|
||
|
||
def test_get_config_encryption_returns_instance(self):
|
||
"""get_config_encryption returns a ConfigEncryption instance."""
|
||
import src.infrastructure.security.config_encryption as mod
|
||
old = mod._config_encryption
|
||
try:
|
||
mod._config_encryption = None
|
||
with patch.object(ConfigEncryption, "__init__", return_value=None):
|
||
instance = get_config_encryption()
|
||
assert isinstance(instance, ConfigEncryption)
|
||
finally:
|
||
mod._config_encryption = old
|
||
|
||
def test_get_config_encryption_returns_same_instance(self):
|
||
"""Repeated calls return the same singleton instance."""
|
||
import src.infrastructure.security.config_encryption as mod
|
||
old = mod._config_encryption
|
||
try:
|
||
mod._config_encryption = None
|
||
with patch.object(ConfigEncryption, "__init__", return_value=None):
|
||
inst1 = get_config_encryption()
|
||
inst2 = get_config_encryption()
|
||
assert inst1 is inst2
|
||
finally:
|
||
mod._config_encryption = old
|