validate and correct nip-04 and nip-44 encryption code

This commit is contained in:
2025-06-26 17:38:09 +01:00
parent 8dd3deb811
commit 6392dd8edf
2 changed files with 363 additions and 11 deletions

View File

@@ -33,6 +33,11 @@ func ComputeSharedSecret(pk, sk []byte) (sharedSecret []byte, err error) {
//
// Deprecated: upgrade to using Decrypt with the NIP-44 algorithm.
func EncryptNip4(msg []byte, key []byte) (ct []byte, err error) {
// validate key length for AES-256
if len(key) != 32 {
err = errorf.E("key must be 32 bytes for AES-256, got %d", len(key))
return
}
// block size is 16 bytes
iv := make([]byte, 16)
if _, err = frand.Read(iv); chk.E(err) {
@@ -66,21 +71,35 @@ func EncryptNip4(msg []byte, key []byte) (ct []byte, err error) {
//
// Deprecated: upgrade to using Decrypt with the NIP-44 algorithm.
func DecryptNip4(content, key []byte) (msg []byte, err error) {
// validate key length for AES-256
if len(key) != 32 {
err = errorf.E("key must be 32 bytes for AES-256, got %d", len(key))
return
}
parts := bytes.Split(content, []byte("?iv="))
if len(parts) < 2 {
return nil, errorf.E(
"error parsing encrypted message: no initialization vector")
}
ciphertext := make([]byte, base64.StdEncoding.EncodedLen(len(parts[0])))
iv := make([]byte, base64.StdEncoding.EncodedLen(len(parts[1])))
if _, err = base64.StdEncoding.Decode(ciphertext, parts[0]); chk.E(err) {
ciphertext := make([]byte, base64.StdEncoding.DecodedLen(len(parts[0])))
iv := make([]byte, base64.StdEncoding.DecodedLen(len(parts[1])))
var ctLen, ivLen int
if ctLen, err = base64.StdEncoding.Decode(ciphertext, parts[0]); chk.E(err) {
err = errorf.E("error decoding ciphertext from base64: %w", err)
return
}
if _, err = base64.StdEncoding.Decode(iv, parts[1]); chk.E(err) {
if ivLen, err = base64.StdEncoding.Decode(iv, parts[1]); chk.E(err) {
err = errorf.E("error decoding iv from base64: %w", err)
return
}
// trim buffers to actual decoded length
ciphertext = ciphertext[:ctLen]
iv = iv[:ivLen]
// validate IV length
if len(iv) != 16 {
err = errorf.E("initialization vector must be 16 bytes, got %d", len(iv))
return
}
var block cipher.Block
if block, err = aes.NewCipher(key); chk.E(err) {
err = errorf.E("error creating block cipher: %w", err)
@@ -89,18 +108,28 @@ func DecryptNip4(content, key []byte) (msg []byte, err error) {
mode := cipher.NewCBCDecrypter(block, iv)
msg = make([]byte, len(ciphertext))
mode.CryptBlocks(msg, ciphertext)
// remove padding
// remove padding using proper PKCS#7 validation
var (
plaintextLen = len(msg)
)
if plaintextLen > 0 {
if plaintextLen == 0 {
err = errorf.E("empty ciphertext")
return
}
// the padding amount is encoded in the padding bytes themselves
padding := int(msg[plaintextLen-1])
if padding > plaintextLen {
if padding == 0 || padding > 16 || padding > plaintextLen {
err = errorf.E("invalid padding amount: %d", padding)
return
}
msg = msg[0 : plaintextLen-padding]
// validate that all padding bytes have the correct value (PKCS#7)
for i := plaintextLen - padding; i < plaintextLen; i++ {
if msg[i] != byte(padding) {
err = errorf.E("invalid padding: byte at position %d should be %d, got %d",
i, padding, msg[i])
return
}
}
msg = msg[0 : plaintextLen-padding]
return msg, nil
}

323
encryption/nip4_test.go Normal file
View File

@@ -0,0 +1,323 @@
package encryption
import (
"bytes"
"crypto/rand"
"encoding/base64"
"encoding/hex"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"realy.lol/p256k"
)
func TestComputeSharedSecret(t *testing.T) {
// Test with valid keys
sk1, _ := hex.DecodeString("0000000000000000000000000000000000000000000000000000000000000001")
sk2, _ := hex.DecodeString("0000000000000000000000000000000000000000000000000000000000000002")
// Generate public key for sk2
signer := new(p256k.Signer)
err := signer.InitSec(sk2)
require.NoError(t, err)
pk2 := signer.Pub()
sharedSecret, err := ComputeSharedSecret(pk2, sk1)
require.NoError(t, err)
assert.Len(t, sharedSecret, 32, "shared secret should be 32 bytes")
// Test with invalid public key (wrong length)
invalidPk := []byte("invalid")
_, err = ComputeSharedSecret(invalidPk, sk1)
assert.Error(t, err, "should fail with invalid public key")
// Test with invalid public key (correct length but invalid point)
invalidPk32 := make([]byte, 32)
// Fill with invalid point data
for i := range invalidPk32 {
invalidPk32[i] = 0xFF
}
_, err = ComputeSharedSecret(invalidPk32, sk1)
assert.Error(t, err, "should fail with invalid public key point")
}
func TestEncryptNip4_ValidInputs(t *testing.T) {
key := make([]byte, 32)
_, err := rand.Read(key)
require.NoError(t, err)
testCases := []struct {
name string
message []byte
}{
{"empty message", []byte("")},
{"single byte", []byte("a")},
{"short message", []byte("hello")},
{"medium message", []byte("Hello, this is a test message for NIP-4 encryption!")},
{"long message", bytes.Repeat([]byte("x"), 1000)},
{"unicode message", []byte("🔒 Encrypted message with émojis and spëcial chars 中文")},
{"binary data", []byte{0x00, 0x01, 0x02, 0xFF, 0xFE, 0xFD}},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
encrypted, err := EncryptNip4(tc.message, key)
require.NoError(t, err, "encryption should succeed")
assert.NotEmpty(t, encrypted, "encrypted data should not be empty")
// Verify format: base64(ciphertext) + "?iv=" + base64(iv)
parts := bytes.Split(encrypted, []byte("?iv="))
assert.Len(t, parts, 2, "encrypted format should have ciphertext and IV parts")
// Verify both parts are valid base64
_, err = base64.StdEncoding.DecodeString(string(parts[0]))
assert.NoError(t, err, "ciphertext part should be valid base64")
_, err = base64.StdEncoding.DecodeString(string(parts[1]))
assert.NoError(t, err, "IV part should be valid base64")
// Test decryption
decrypted, err := DecryptNip4(encrypted, key)
require.NoError(t, err, "decryption should succeed")
assert.Equal(t, tc.message, decrypted, "decrypted message should match original")
})
}
}
func TestEncryptNip4_InvalidKey(t *testing.T) {
message := []byte("test message")
testCases := []struct {
name string
key []byte
}{
{"nil key", nil},
{"empty key", []byte{}},
{"short key", []byte("short")},
{"long key", make([]byte, 64)},
{"31 bytes key", make([]byte, 31)},
{"33 bytes key", make([]byte, 33)},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
_, err := EncryptNip4(message, tc.key)
assert.Error(t, err, "should fail with invalid key length")
assert.Contains(t, err.Error(), "key must be 32 bytes", "error should mention key length requirement")
})
}
}
func TestDecryptNip4_ValidInputs(t *testing.T) {
key := make([]byte, 32)
_, err := rand.Read(key)
require.NoError(t, err)
// Test with various message sizes
messages := [][]byte{
[]byte(""),
[]byte("a"),
[]byte("hello world"),
[]byte("This is a longer message to test padding"),
bytes.Repeat([]byte("x"), 100),
}
for i, msg := range messages {
t.Run(string(rune('A'+i)), func(t *testing.T) {
encrypted, err := EncryptNip4(msg, key)
require.NoError(t, err)
decrypted, err := DecryptNip4(encrypted, key)
require.NoError(t, err)
assert.Equal(t, msg, decrypted)
})
}
}
func TestDecryptNip4_InvalidKey(t *testing.T) {
// Create valid encrypted data first
validKey := make([]byte, 32)
_, err := rand.Read(validKey)
require.NoError(t, err)
encrypted, err := EncryptNip4([]byte("test"), validKey)
require.NoError(t, err)
testCases := []struct {
name string
key []byte
}{
{"nil key", nil},
{"empty key", []byte{}},
{"short key", []byte("short")},
{"long key", make([]byte, 64)},
{"31 bytes key", make([]byte, 31)},
{"33 bytes key", make([]byte, 33)},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
_, err := DecryptNip4(encrypted, tc.key)
assert.Error(t, err, "should fail with invalid key length")
assert.Contains(t, err.Error(), "key must be 32 bytes", "error should mention key length requirement")
})
}
}
func TestDecryptNip4_InvalidFormat(t *testing.T) {
key := make([]byte, 32)
_, err := rand.Read(key)
require.NoError(t, err)
testCases := []struct {
name string
encrypted []byte
errorMsg string
}{
{"no IV separator", []byte("dGVzdA=="), "no initialization vector"},
{"empty content", []byte(""), "no initialization vector"},
{"only separator", []byte("?iv="), "initialization vector must be 16 bytes"},
{"invalid base64 ciphertext", []byte("invalid_base64?iv=dGVzdA=="), "illegal base64 data"},
{"invalid base64 IV", []byte("dGVzdA==?iv=invalid_base64"), "illegal base64 data"},
{"wrong IV length", []byte("dGVzdA==?iv=dGVzdA=="), "initialization vector must be 16 bytes"},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
_, err := DecryptNip4(tc.encrypted, key)
assert.Error(t, err, "should fail with invalid format")
assert.Contains(t, err.Error(), tc.errorMsg, "error should contain expected message")
})
}
}
func TestDecryptNip4_InvalidPadding(t *testing.T) {
key := make([]byte, 32)
_, err := rand.Read(key)
require.NoError(t, err)
// Create a valid encrypted message first
encrypted, err := EncryptNip4([]byte("test"), key)
require.NoError(t, err)
// Decode to manipulate the ciphertext
parts := bytes.Split(encrypted, []byte("?iv="))
require.Len(t, parts, 2)
ciphertext, err := base64.StdEncoding.DecodeString(string(parts[0]))
require.NoError(t, err)
// Test cases with corrupted padding
testCases := []struct {
name string
modifier func([]byte) []byte
}{
{
"corrupt last byte (padding indicator)",
func(ct []byte) []byte {
ct[len(ct)-1] = 0xFF // invalid padding amount
return ct
},
},
{
"corrupt padding bytes",
func(ct []byte) []byte {
if len(ct) > 1 {
ct[len(ct)-2] = 0x00 // corrupt padding byte
}
return ct
},
},
{
"zero padding",
func(ct []byte) []byte {
ct[len(ct)-1] = 0x00 // zero padding is invalid
return ct
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
corruptedCt := tc.modifier(append([]byte(nil), ciphertext...))
corruptedB64 := base64.StdEncoding.EncodeToString(corruptedCt)
corruptedEncrypted := []byte(corruptedB64 + "?iv=" + string(parts[1]))
_, err := DecryptNip4(corruptedEncrypted, key)
assert.Error(t, err, "should fail with corrupted padding")
assert.Contains(t, err.Error(), "padding", "error should mention padding")
})
}
}
func TestNip4_RoundTrip(t *testing.T) {
// Test multiple round trips with different keys
for i := 0; i < 10; i++ {
key := make([]byte, 32)
_, err := rand.Read(key)
require.NoError(t, err)
message := make([]byte, i*10+1) // varying message sizes
_, err = rand.Read(message)
require.NoError(t, err)
encrypted, err := EncryptNip4(message, key)
require.NoError(t, err)
decrypted, err := DecryptNip4(encrypted, key)
require.NoError(t, err)
assert.Equal(t, message, decrypted, "round trip should preserve message")
}
}
func TestNip4_DifferentKeys(t *testing.T) {
key1 := make([]byte, 32)
key2 := make([]byte, 32)
_, err := rand.Read(key1)
require.NoError(t, err)
_, err = rand.Read(key2)
require.NoError(t, err)
message := []byte("secret message")
encrypted, err := EncryptNip4(message, key1)
require.NoError(t, err)
// Decryption with wrong key should fail or produce garbage
decrypted, err := DecryptNip4(encrypted, key2)
if err == nil {
// If decryption succeeds, result should be different from original
assert.NotEqual(t, message, decrypted, "decryption with wrong key should not produce original message")
}
// If decryption fails, that's also acceptable (padding validation might catch it)
}
func TestNip4_Deterministic(t *testing.T) {
key := make([]byte, 32)
_, err := rand.Read(key)
require.NoError(t, err)
message := []byte("test message")
// Encrypt the same message multiple times
encrypted1, err := EncryptNip4(message, key)
require.NoError(t, err)
encrypted2, err := EncryptNip4(message, key)
require.NoError(t, err)
// Results should be different due to random IV
assert.NotEqual(t, encrypted1, encrypted2, "encryption should not be deterministic")
// But both should decrypt to the same message
decrypted1, err := DecryptNip4(encrypted1, key)
require.NoError(t, err)
decrypted2, err := DecryptNip4(encrypted2, key)
require.NoError(t, err)
assert.Equal(t, message, decrypted1)
assert.Equal(t, message, decrypted2)
}