validate and correct nip-04 and nip-44 encryption code
This commit is contained in:
@@ -33,6 +33,11 @@ func ComputeSharedSecret(pk, sk []byte) (sharedSecret []byte, err error) {
|
|||||||
//
|
//
|
||||||
// Deprecated: upgrade to using Decrypt with the NIP-44 algorithm.
|
// Deprecated: upgrade to using Decrypt with the NIP-44 algorithm.
|
||||||
func EncryptNip4(msg []byte, key []byte) (ct []byte, err error) {
|
func EncryptNip4(msg []byte, key []byte) (ct []byte, err error) {
|
||||||
|
// validate key length for AES-256
|
||||||
|
if len(key) != 32 {
|
||||||
|
err = errorf.E("key must be 32 bytes for AES-256, got %d", len(key))
|
||||||
|
return
|
||||||
|
}
|
||||||
// block size is 16 bytes
|
// block size is 16 bytes
|
||||||
iv := make([]byte, 16)
|
iv := make([]byte, 16)
|
||||||
if _, err = frand.Read(iv); chk.E(err) {
|
if _, err = frand.Read(iv); chk.E(err) {
|
||||||
@@ -66,21 +71,35 @@ func EncryptNip4(msg []byte, key []byte) (ct []byte, err error) {
|
|||||||
//
|
//
|
||||||
// Deprecated: upgrade to using Decrypt with the NIP-44 algorithm.
|
// Deprecated: upgrade to using Decrypt with the NIP-44 algorithm.
|
||||||
func DecryptNip4(content, key []byte) (msg []byte, err error) {
|
func DecryptNip4(content, key []byte) (msg []byte, err error) {
|
||||||
|
// validate key length for AES-256
|
||||||
|
if len(key) != 32 {
|
||||||
|
err = errorf.E("key must be 32 bytes for AES-256, got %d", len(key))
|
||||||
|
return
|
||||||
|
}
|
||||||
parts := bytes.Split(content, []byte("?iv="))
|
parts := bytes.Split(content, []byte("?iv="))
|
||||||
if len(parts) < 2 {
|
if len(parts) < 2 {
|
||||||
return nil, errorf.E(
|
return nil, errorf.E(
|
||||||
"error parsing encrypted message: no initialization vector")
|
"error parsing encrypted message: no initialization vector")
|
||||||
}
|
}
|
||||||
ciphertext := make([]byte, base64.StdEncoding.EncodedLen(len(parts[0])))
|
ciphertext := make([]byte, base64.StdEncoding.DecodedLen(len(parts[0])))
|
||||||
iv := make([]byte, base64.StdEncoding.EncodedLen(len(parts[1])))
|
iv := make([]byte, base64.StdEncoding.DecodedLen(len(parts[1])))
|
||||||
if _, err = base64.StdEncoding.Decode(ciphertext, parts[0]); chk.E(err) {
|
var ctLen, ivLen int
|
||||||
|
if ctLen, err = base64.StdEncoding.Decode(ciphertext, parts[0]); chk.E(err) {
|
||||||
err = errorf.E("error decoding ciphertext from base64: %w", err)
|
err = errorf.E("error decoding ciphertext from base64: %w", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if _, err = base64.StdEncoding.Decode(iv, parts[1]); chk.E(err) {
|
if ivLen, err = base64.StdEncoding.Decode(iv, parts[1]); chk.E(err) {
|
||||||
err = errorf.E("error decoding iv from base64: %w", err)
|
err = errorf.E("error decoding iv from base64: %w", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// trim buffers to actual decoded length
|
||||||
|
ciphertext = ciphertext[:ctLen]
|
||||||
|
iv = iv[:ivLen]
|
||||||
|
// validate IV length
|
||||||
|
if len(iv) != 16 {
|
||||||
|
err = errorf.E("initialization vector must be 16 bytes, got %d", len(iv))
|
||||||
|
return
|
||||||
|
}
|
||||||
var block cipher.Block
|
var block cipher.Block
|
||||||
if block, err = aes.NewCipher(key); chk.E(err) {
|
if block, err = aes.NewCipher(key); chk.E(err) {
|
||||||
err = errorf.E("error creating block cipher: %w", err)
|
err = errorf.E("error creating block cipher: %w", err)
|
||||||
@@ -89,18 +108,28 @@ func DecryptNip4(content, key []byte) (msg []byte, err error) {
|
|||||||
mode := cipher.NewCBCDecrypter(block, iv)
|
mode := cipher.NewCBCDecrypter(block, iv)
|
||||||
msg = make([]byte, len(ciphertext))
|
msg = make([]byte, len(ciphertext))
|
||||||
mode.CryptBlocks(msg, ciphertext)
|
mode.CryptBlocks(msg, ciphertext)
|
||||||
// remove padding
|
// remove padding using proper PKCS#7 validation
|
||||||
var (
|
var (
|
||||||
plaintextLen = len(msg)
|
plaintextLen = len(msg)
|
||||||
)
|
)
|
||||||
if plaintextLen > 0 {
|
if plaintextLen == 0 {
|
||||||
// the padding amount is encoded in the padding bytes themselves
|
err = errorf.E("empty ciphertext")
|
||||||
padding := int(msg[plaintextLen-1])
|
return
|
||||||
if padding > plaintextLen {
|
}
|
||||||
err = errorf.E("invalid padding amount: %d", padding)
|
// the padding amount is encoded in the padding bytes themselves
|
||||||
|
padding := int(msg[plaintextLen-1])
|
||||||
|
if padding == 0 || padding > 16 || padding > plaintextLen {
|
||||||
|
err = errorf.E("invalid padding amount: %d", padding)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// validate that all padding bytes have the correct value (PKCS#7)
|
||||||
|
for i := plaintextLen - padding; i < plaintextLen; i++ {
|
||||||
|
if msg[i] != byte(padding) {
|
||||||
|
err = errorf.E("invalid padding: byte at position %d should be %d, got %d",
|
||||||
|
i, padding, msg[i])
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
msg = msg[0 : plaintextLen-padding]
|
|
||||||
}
|
}
|
||||||
|
msg = msg[0 : plaintextLen-padding]
|
||||||
return msg, nil
|
return msg, nil
|
||||||
}
|
}
|
||||||
|
|||||||
323
encryption/nip4_test.go
Normal file
323
encryption/nip4_test.go
Normal file
@@ -0,0 +1,323 @@
|
|||||||
|
package encryption
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/hex"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"realy.lol/p256k"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestComputeSharedSecret(t *testing.T) {
|
||||||
|
// Test with valid keys
|
||||||
|
sk1, _ := hex.DecodeString("0000000000000000000000000000000000000000000000000000000000000001")
|
||||||
|
sk2, _ := hex.DecodeString("0000000000000000000000000000000000000000000000000000000000000002")
|
||||||
|
|
||||||
|
// Generate public key for sk2
|
||||||
|
signer := new(p256k.Signer)
|
||||||
|
err := signer.InitSec(sk2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
pk2 := signer.Pub()
|
||||||
|
|
||||||
|
sharedSecret, err := ComputeSharedSecret(pk2, sk1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, sharedSecret, 32, "shared secret should be 32 bytes")
|
||||||
|
|
||||||
|
// Test with invalid public key (wrong length)
|
||||||
|
invalidPk := []byte("invalid")
|
||||||
|
_, err = ComputeSharedSecret(invalidPk, sk1)
|
||||||
|
assert.Error(t, err, "should fail with invalid public key")
|
||||||
|
|
||||||
|
// Test with invalid public key (correct length but invalid point)
|
||||||
|
invalidPk32 := make([]byte, 32)
|
||||||
|
// Fill with invalid point data
|
||||||
|
for i := range invalidPk32 {
|
||||||
|
invalidPk32[i] = 0xFF
|
||||||
|
}
|
||||||
|
_, err = ComputeSharedSecret(invalidPk32, sk1)
|
||||||
|
assert.Error(t, err, "should fail with invalid public key point")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncryptNip4_ValidInputs(t *testing.T) {
|
||||||
|
key := make([]byte, 32)
|
||||||
|
_, err := rand.Read(key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
message []byte
|
||||||
|
}{
|
||||||
|
{"empty message", []byte("")},
|
||||||
|
{"single byte", []byte("a")},
|
||||||
|
{"short message", []byte("hello")},
|
||||||
|
{"medium message", []byte("Hello, this is a test message for NIP-4 encryption!")},
|
||||||
|
{"long message", bytes.Repeat([]byte("x"), 1000)},
|
||||||
|
{"unicode message", []byte("🔒 Encrypted message with émojis and spëcial chars 中文")},
|
||||||
|
{"binary data", []byte{0x00, 0x01, 0x02, 0xFF, 0xFE, 0xFD}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
encrypted, err := EncryptNip4(tc.message, key)
|
||||||
|
require.NoError(t, err, "encryption should succeed")
|
||||||
|
assert.NotEmpty(t, encrypted, "encrypted data should not be empty")
|
||||||
|
|
||||||
|
// Verify format: base64(ciphertext) + "?iv=" + base64(iv)
|
||||||
|
parts := bytes.Split(encrypted, []byte("?iv="))
|
||||||
|
assert.Len(t, parts, 2, "encrypted format should have ciphertext and IV parts")
|
||||||
|
|
||||||
|
// Verify both parts are valid base64
|
||||||
|
_, err = base64.StdEncoding.DecodeString(string(parts[0]))
|
||||||
|
assert.NoError(t, err, "ciphertext part should be valid base64")
|
||||||
|
_, err = base64.StdEncoding.DecodeString(string(parts[1]))
|
||||||
|
assert.NoError(t, err, "IV part should be valid base64")
|
||||||
|
|
||||||
|
// Test decryption
|
||||||
|
decrypted, err := DecryptNip4(encrypted, key)
|
||||||
|
require.NoError(t, err, "decryption should succeed")
|
||||||
|
assert.Equal(t, tc.message, decrypted, "decrypted message should match original")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncryptNip4_InvalidKey(t *testing.T) {
|
||||||
|
message := []byte("test message")
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
key []byte
|
||||||
|
}{
|
||||||
|
{"nil key", nil},
|
||||||
|
{"empty key", []byte{}},
|
||||||
|
{"short key", []byte("short")},
|
||||||
|
{"long key", make([]byte, 64)},
|
||||||
|
{"31 bytes key", make([]byte, 31)},
|
||||||
|
{"33 bytes key", make([]byte, 33)},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
_, err := EncryptNip4(message, tc.key)
|
||||||
|
assert.Error(t, err, "should fail with invalid key length")
|
||||||
|
assert.Contains(t, err.Error(), "key must be 32 bytes", "error should mention key length requirement")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecryptNip4_ValidInputs(t *testing.T) {
|
||||||
|
key := make([]byte, 32)
|
||||||
|
_, err := rand.Read(key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Test with various message sizes
|
||||||
|
messages := [][]byte{
|
||||||
|
[]byte(""),
|
||||||
|
[]byte("a"),
|
||||||
|
[]byte("hello world"),
|
||||||
|
[]byte("This is a longer message to test padding"),
|
||||||
|
bytes.Repeat([]byte("x"), 100),
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, msg := range messages {
|
||||||
|
t.Run(string(rune('A'+i)), func(t *testing.T) {
|
||||||
|
encrypted, err := EncryptNip4(msg, key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
decrypted, err := DecryptNip4(encrypted, key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, msg, decrypted)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecryptNip4_InvalidKey(t *testing.T) {
|
||||||
|
// Create valid encrypted data first
|
||||||
|
validKey := make([]byte, 32)
|
||||||
|
_, err := rand.Read(validKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
encrypted, err := EncryptNip4([]byte("test"), validKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
key []byte
|
||||||
|
}{
|
||||||
|
{"nil key", nil},
|
||||||
|
{"empty key", []byte{}},
|
||||||
|
{"short key", []byte("short")},
|
||||||
|
{"long key", make([]byte, 64)},
|
||||||
|
{"31 bytes key", make([]byte, 31)},
|
||||||
|
{"33 bytes key", make([]byte, 33)},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
_, err := DecryptNip4(encrypted, tc.key)
|
||||||
|
assert.Error(t, err, "should fail with invalid key length")
|
||||||
|
assert.Contains(t, err.Error(), "key must be 32 bytes", "error should mention key length requirement")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecryptNip4_InvalidFormat(t *testing.T) {
|
||||||
|
key := make([]byte, 32)
|
||||||
|
_, err := rand.Read(key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
encrypted []byte
|
||||||
|
errorMsg string
|
||||||
|
}{
|
||||||
|
{"no IV separator", []byte("dGVzdA=="), "no initialization vector"},
|
||||||
|
{"empty content", []byte(""), "no initialization vector"},
|
||||||
|
{"only separator", []byte("?iv="), "initialization vector must be 16 bytes"},
|
||||||
|
{"invalid base64 ciphertext", []byte("invalid_base64?iv=dGVzdA=="), "illegal base64 data"},
|
||||||
|
{"invalid base64 IV", []byte("dGVzdA==?iv=invalid_base64"), "illegal base64 data"},
|
||||||
|
{"wrong IV length", []byte("dGVzdA==?iv=dGVzdA=="), "initialization vector must be 16 bytes"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
_, err := DecryptNip4(tc.encrypted, key)
|
||||||
|
assert.Error(t, err, "should fail with invalid format")
|
||||||
|
assert.Contains(t, err.Error(), tc.errorMsg, "error should contain expected message")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecryptNip4_InvalidPadding(t *testing.T) {
|
||||||
|
key := make([]byte, 32)
|
||||||
|
_, err := rand.Read(key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Create a valid encrypted message first
|
||||||
|
encrypted, err := EncryptNip4([]byte("test"), key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Decode to manipulate the ciphertext
|
||||||
|
parts := bytes.Split(encrypted, []byte("?iv="))
|
||||||
|
require.Len(t, parts, 2)
|
||||||
|
|
||||||
|
ciphertext, err := base64.StdEncoding.DecodeString(string(parts[0]))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Test cases with corrupted padding
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
modifier func([]byte) []byte
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"corrupt last byte (padding indicator)",
|
||||||
|
func(ct []byte) []byte {
|
||||||
|
ct[len(ct)-1] = 0xFF // invalid padding amount
|
||||||
|
return ct
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"corrupt padding bytes",
|
||||||
|
func(ct []byte) []byte {
|
||||||
|
if len(ct) > 1 {
|
||||||
|
ct[len(ct)-2] = 0x00 // corrupt padding byte
|
||||||
|
}
|
||||||
|
return ct
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"zero padding",
|
||||||
|
func(ct []byte) []byte {
|
||||||
|
ct[len(ct)-1] = 0x00 // zero padding is invalid
|
||||||
|
return ct
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
corruptedCt := tc.modifier(append([]byte(nil), ciphertext...))
|
||||||
|
corruptedB64 := base64.StdEncoding.EncodeToString(corruptedCt)
|
||||||
|
corruptedEncrypted := []byte(corruptedB64 + "?iv=" + string(parts[1]))
|
||||||
|
|
||||||
|
_, err := DecryptNip4(corruptedEncrypted, key)
|
||||||
|
assert.Error(t, err, "should fail with corrupted padding")
|
||||||
|
assert.Contains(t, err.Error(), "padding", "error should mention padding")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNip4_RoundTrip(t *testing.T) {
|
||||||
|
// Test multiple round trips with different keys
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
key := make([]byte, 32)
|
||||||
|
_, err := rand.Read(key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
message := make([]byte, i*10+1) // varying message sizes
|
||||||
|
_, err = rand.Read(message)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
encrypted, err := EncryptNip4(message, key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
decrypted, err := DecryptNip4(encrypted, key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, message, decrypted, "round trip should preserve message")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNip4_DifferentKeys(t *testing.T) {
|
||||||
|
key1 := make([]byte, 32)
|
||||||
|
key2 := make([]byte, 32)
|
||||||
|
_, err := rand.Read(key1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = rand.Read(key2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
message := []byte("secret message")
|
||||||
|
|
||||||
|
encrypted, err := EncryptNip4(message, key1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Decryption with wrong key should fail or produce garbage
|
||||||
|
decrypted, err := DecryptNip4(encrypted, key2)
|
||||||
|
if err == nil {
|
||||||
|
// If decryption succeeds, result should be different from original
|
||||||
|
assert.NotEqual(t, message, decrypted, "decryption with wrong key should not produce original message")
|
||||||
|
}
|
||||||
|
// If decryption fails, that's also acceptable (padding validation might catch it)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNip4_Deterministic(t *testing.T) {
|
||||||
|
key := make([]byte, 32)
|
||||||
|
_, err := rand.Read(key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
message := []byte("test message")
|
||||||
|
|
||||||
|
// Encrypt the same message multiple times
|
||||||
|
encrypted1, err := EncryptNip4(message, key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
encrypted2, err := EncryptNip4(message, key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Results should be different due to random IV
|
||||||
|
assert.NotEqual(t, encrypted1, encrypted2, "encryption should not be deterministic")
|
||||||
|
|
||||||
|
// But both should decrypt to the same message
|
||||||
|
decrypted1, err := DecryptNip4(encrypted1, key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
decrypted2, err := DecryptNip4(encrypted2, key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, message, decrypted1)
|
||||||
|
assert.Equal(t, message, decrypted2)
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user