diff --git a/encryption/nip4.go b/encryption/nip4.go index 1401af1..8834957 100644 --- a/encryption/nip4.go +++ b/encryption/nip4.go @@ -33,6 +33,11 @@ func ComputeSharedSecret(pk, sk []byte) (sharedSecret []byte, err error) { // // Deprecated: upgrade to using Decrypt with the NIP-44 algorithm. func EncryptNip4(msg []byte, key []byte) (ct []byte, err error) { + // validate key length for AES-256 + if len(key) != 32 { + err = errorf.E("key must be 32 bytes for AES-256, got %d", len(key)) + return + } // block size is 16 bytes iv := make([]byte, 16) if _, err = frand.Read(iv); chk.E(err) { @@ -66,21 +71,35 @@ func EncryptNip4(msg []byte, key []byte) (ct []byte, err error) { // // Deprecated: upgrade to using Decrypt with the NIP-44 algorithm. func DecryptNip4(content, key []byte) (msg []byte, err error) { + // validate key length for AES-256 + if len(key) != 32 { + err = errorf.E("key must be 32 bytes for AES-256, got %d", len(key)) + return + } parts := bytes.Split(content, []byte("?iv=")) if len(parts) < 2 { return nil, errorf.E( "error parsing encrypted message: no initialization vector") } - ciphertext := make([]byte, base64.StdEncoding.EncodedLen(len(parts[0]))) - iv := make([]byte, base64.StdEncoding.EncodedLen(len(parts[1]))) - if _, err = base64.StdEncoding.Decode(ciphertext, parts[0]); chk.E(err) { + ciphertext := make([]byte, base64.StdEncoding.DecodedLen(len(parts[0]))) + iv := make([]byte, base64.StdEncoding.DecodedLen(len(parts[1]))) + var ctLen, ivLen int + if ctLen, err = base64.StdEncoding.Decode(ciphertext, parts[0]); chk.E(err) { err = errorf.E("error decoding ciphertext from base64: %w", err) return } - if _, err = base64.StdEncoding.Decode(iv, parts[1]); chk.E(err) { + if ivLen, err = base64.StdEncoding.Decode(iv, parts[1]); chk.E(err) { err = errorf.E("error decoding iv from base64: %w", err) return } + // trim buffers to actual decoded length + ciphertext = ciphertext[:ctLen] + iv = iv[:ivLen] + // validate IV length + if len(iv) != 16 { + err = errorf.E("initialization vector must be 16 bytes, got %d", len(iv)) + return + } var block cipher.Block if block, err = aes.NewCipher(key); chk.E(err) { err = errorf.E("error creating block cipher: %w", err) @@ -89,18 +108,28 @@ func DecryptNip4(content, key []byte) (msg []byte, err error) { mode := cipher.NewCBCDecrypter(block, iv) msg = make([]byte, len(ciphertext)) mode.CryptBlocks(msg, ciphertext) - // remove padding + // remove padding using proper PKCS#7 validation var ( plaintextLen = len(msg) ) - if plaintextLen > 0 { - // the padding amount is encoded in the padding bytes themselves - padding := int(msg[plaintextLen-1]) - if padding > plaintextLen { - err = errorf.E("invalid padding amount: %d", padding) + if plaintextLen == 0 { + err = errorf.E("empty ciphertext") + return + } + // the padding amount is encoded in the padding bytes themselves + padding := int(msg[plaintextLen-1]) + if padding == 0 || padding > 16 || padding > plaintextLen { + err = errorf.E("invalid padding amount: %d", padding) + return + } + // validate that all padding bytes have the correct value (PKCS#7) + for i := plaintextLen - padding; i < plaintextLen; i++ { + if msg[i] != byte(padding) { + err = errorf.E("invalid padding: byte at position %d should be %d, got %d", + i, padding, msg[i]) return } - msg = msg[0 : plaintextLen-padding] } + msg = msg[0 : plaintextLen-padding] return msg, nil } diff --git a/encryption/nip4_test.go b/encryption/nip4_test.go new file mode 100644 index 0000000..3e7dbbe --- /dev/null +++ b/encryption/nip4_test.go @@ -0,0 +1,323 @@ +package encryption + +import ( + "bytes" + "crypto/rand" + "encoding/base64" + "encoding/hex" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "realy.lol/p256k" +) + +func TestComputeSharedSecret(t *testing.T) { + // Test with valid keys + sk1, _ := hex.DecodeString("0000000000000000000000000000000000000000000000000000000000000001") + sk2, _ := hex.DecodeString("0000000000000000000000000000000000000000000000000000000000000002") + + // Generate public key for sk2 + signer := new(p256k.Signer) + err := signer.InitSec(sk2) + require.NoError(t, err) + pk2 := signer.Pub() + + sharedSecret, err := ComputeSharedSecret(pk2, sk1) + require.NoError(t, err) + assert.Len(t, sharedSecret, 32, "shared secret should be 32 bytes") + + // Test with invalid public key (wrong length) + invalidPk := []byte("invalid") + _, err = ComputeSharedSecret(invalidPk, sk1) + assert.Error(t, err, "should fail with invalid public key") + + // Test with invalid public key (correct length but invalid point) + invalidPk32 := make([]byte, 32) + // Fill with invalid point data + for i := range invalidPk32 { + invalidPk32[i] = 0xFF + } + _, err = ComputeSharedSecret(invalidPk32, sk1) + assert.Error(t, err, "should fail with invalid public key point") +} + +func TestEncryptNip4_ValidInputs(t *testing.T) { + key := make([]byte, 32) + _, err := rand.Read(key) + require.NoError(t, err) + + testCases := []struct { + name string + message []byte + }{ + {"empty message", []byte("")}, + {"single byte", []byte("a")}, + {"short message", []byte("hello")}, + {"medium message", []byte("Hello, this is a test message for NIP-4 encryption!")}, + {"long message", bytes.Repeat([]byte("x"), 1000)}, + {"unicode message", []byte("🔒 Encrypted message with émojis and spëcial chars 中文")}, + {"binary data", []byte{0x00, 0x01, 0x02, 0xFF, 0xFE, 0xFD}}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + encrypted, err := EncryptNip4(tc.message, key) + require.NoError(t, err, "encryption should succeed") + assert.NotEmpty(t, encrypted, "encrypted data should not be empty") + + // Verify format: base64(ciphertext) + "?iv=" + base64(iv) + parts := bytes.Split(encrypted, []byte("?iv=")) + assert.Len(t, parts, 2, "encrypted format should have ciphertext and IV parts") + + // Verify both parts are valid base64 + _, err = base64.StdEncoding.DecodeString(string(parts[0])) + assert.NoError(t, err, "ciphertext part should be valid base64") + _, err = base64.StdEncoding.DecodeString(string(parts[1])) + assert.NoError(t, err, "IV part should be valid base64") + + // Test decryption + decrypted, err := DecryptNip4(encrypted, key) + require.NoError(t, err, "decryption should succeed") + assert.Equal(t, tc.message, decrypted, "decrypted message should match original") + }) + } +} + +func TestEncryptNip4_InvalidKey(t *testing.T) { + message := []byte("test message") + + testCases := []struct { + name string + key []byte + }{ + {"nil key", nil}, + {"empty key", []byte{}}, + {"short key", []byte("short")}, + {"long key", make([]byte, 64)}, + {"31 bytes key", make([]byte, 31)}, + {"33 bytes key", make([]byte, 33)}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + _, err := EncryptNip4(message, tc.key) + assert.Error(t, err, "should fail with invalid key length") + assert.Contains(t, err.Error(), "key must be 32 bytes", "error should mention key length requirement") + }) + } +} + +func TestDecryptNip4_ValidInputs(t *testing.T) { + key := make([]byte, 32) + _, err := rand.Read(key) + require.NoError(t, err) + + // Test with various message sizes + messages := [][]byte{ + []byte(""), + []byte("a"), + []byte("hello world"), + []byte("This is a longer message to test padding"), + bytes.Repeat([]byte("x"), 100), + } + + for i, msg := range messages { + t.Run(string(rune('A'+i)), func(t *testing.T) { + encrypted, err := EncryptNip4(msg, key) + require.NoError(t, err) + + decrypted, err := DecryptNip4(encrypted, key) + require.NoError(t, err) + assert.Equal(t, msg, decrypted) + }) + } +} + +func TestDecryptNip4_InvalidKey(t *testing.T) { + // Create valid encrypted data first + validKey := make([]byte, 32) + _, err := rand.Read(validKey) + require.NoError(t, err) + + encrypted, err := EncryptNip4([]byte("test"), validKey) + require.NoError(t, err) + + testCases := []struct { + name string + key []byte + }{ + {"nil key", nil}, + {"empty key", []byte{}}, + {"short key", []byte("short")}, + {"long key", make([]byte, 64)}, + {"31 bytes key", make([]byte, 31)}, + {"33 bytes key", make([]byte, 33)}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + _, err := DecryptNip4(encrypted, tc.key) + assert.Error(t, err, "should fail with invalid key length") + assert.Contains(t, err.Error(), "key must be 32 bytes", "error should mention key length requirement") + }) + } +} + +func TestDecryptNip4_InvalidFormat(t *testing.T) { + key := make([]byte, 32) + _, err := rand.Read(key) + require.NoError(t, err) + + testCases := []struct { + name string + encrypted []byte + errorMsg string + }{ + {"no IV separator", []byte("dGVzdA=="), "no initialization vector"}, + {"empty content", []byte(""), "no initialization vector"}, + {"only separator", []byte("?iv="), "initialization vector must be 16 bytes"}, + {"invalid base64 ciphertext", []byte("invalid_base64?iv=dGVzdA=="), "illegal base64 data"}, + {"invalid base64 IV", []byte("dGVzdA==?iv=invalid_base64"), "illegal base64 data"}, + {"wrong IV length", []byte("dGVzdA==?iv=dGVzdA=="), "initialization vector must be 16 bytes"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + _, err := DecryptNip4(tc.encrypted, key) + assert.Error(t, err, "should fail with invalid format") + assert.Contains(t, err.Error(), tc.errorMsg, "error should contain expected message") + }) + } +} + +func TestDecryptNip4_InvalidPadding(t *testing.T) { + key := make([]byte, 32) + _, err := rand.Read(key) + require.NoError(t, err) + + // Create a valid encrypted message first + encrypted, err := EncryptNip4([]byte("test"), key) + require.NoError(t, err) + + // Decode to manipulate the ciphertext + parts := bytes.Split(encrypted, []byte("?iv=")) + require.Len(t, parts, 2) + + ciphertext, err := base64.StdEncoding.DecodeString(string(parts[0])) + require.NoError(t, err) + + // Test cases with corrupted padding + testCases := []struct { + name string + modifier func([]byte) []byte + }{ + { + "corrupt last byte (padding indicator)", + func(ct []byte) []byte { + ct[len(ct)-1] = 0xFF // invalid padding amount + return ct + }, + }, + { + "corrupt padding bytes", + func(ct []byte) []byte { + if len(ct) > 1 { + ct[len(ct)-2] = 0x00 // corrupt padding byte + } + return ct + }, + }, + { + "zero padding", + func(ct []byte) []byte { + ct[len(ct)-1] = 0x00 // zero padding is invalid + return ct + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + corruptedCt := tc.modifier(append([]byte(nil), ciphertext...)) + corruptedB64 := base64.StdEncoding.EncodeToString(corruptedCt) + corruptedEncrypted := []byte(corruptedB64 + "?iv=" + string(parts[1])) + + _, err := DecryptNip4(corruptedEncrypted, key) + assert.Error(t, err, "should fail with corrupted padding") + assert.Contains(t, err.Error(), "padding", "error should mention padding") + }) + } +} + +func TestNip4_RoundTrip(t *testing.T) { + // Test multiple round trips with different keys + for i := 0; i < 10; i++ { + key := make([]byte, 32) + _, err := rand.Read(key) + require.NoError(t, err) + + message := make([]byte, i*10+1) // varying message sizes + _, err = rand.Read(message) + require.NoError(t, err) + + encrypted, err := EncryptNip4(message, key) + require.NoError(t, err) + + decrypted, err := DecryptNip4(encrypted, key) + require.NoError(t, err) + + assert.Equal(t, message, decrypted, "round trip should preserve message") + } +} + +func TestNip4_DifferentKeys(t *testing.T) { + key1 := make([]byte, 32) + key2 := make([]byte, 32) + _, err := rand.Read(key1) + require.NoError(t, err) + _, err = rand.Read(key2) + require.NoError(t, err) + + message := []byte("secret message") + + encrypted, err := EncryptNip4(message, key1) + require.NoError(t, err) + + // Decryption with wrong key should fail or produce garbage + decrypted, err := DecryptNip4(encrypted, key2) + if err == nil { + // If decryption succeeds, result should be different from original + assert.NotEqual(t, message, decrypted, "decryption with wrong key should not produce original message") + } + // If decryption fails, that's also acceptable (padding validation might catch it) +} + +func TestNip4_Deterministic(t *testing.T) { + key := make([]byte, 32) + _, err := rand.Read(key) + require.NoError(t, err) + + message := []byte("test message") + + // Encrypt the same message multiple times + encrypted1, err := EncryptNip4(message, key) + require.NoError(t, err) + + encrypted2, err := EncryptNip4(message, key) + require.NoError(t, err) + + // Results should be different due to random IV + assert.NotEqual(t, encrypted1, encrypted2, "encryption should not be deterministic") + + // But both should decrypt to the same message + decrypted1, err := DecryptNip4(encrypted1, key) + require.NoError(t, err) + + decrypted2, err := DecryptNip4(encrypted2, key) + require.NoError(t, err) + + assert.Equal(t, message, decrypted1) + assert.Equal(t, message, decrypted2) +}