package encryption import ( "bytes" "crypto/aes" "crypto/cipher" "encoding/base64" "lol.mleku.dev/chk" "lol.mleku.dev/errorf" "lukechampine.com/frand" ) // EncryptNip4 encrypts message with key using aes-256-cbc. key should be the shared secret generated by // ComputeSharedSecret. // // Returns: base64(encrypted_bytes) + "?iv=" + base64(initialization_vector). func EncryptNip4(msg, key []byte) (ct []byte, err error) { // block size is 16 bytes iv := make([]byte, 16) if _, err = frand.Read(iv); chk.E(err) { err = errorf.E("error creating initialization vector: %w", err) return } // automatically picks aes-256 based on key length (32 bytes) var block cipher.Block if block, err = aes.NewCipher(key); chk.E(err) { err = errorf.E("error creating block cipher: %w", err) return } mode := cipher.NewCBCEncrypter(block, iv) plaintext := []byte(msg) // add padding base := len(plaintext) // this will be a number between 1 and 16 (inclusive), never 0 bs := block.BlockSize() padding := bs - base%bs // encode the padding in all the padding bytes themselves padText := bytes.Repeat([]byte{byte(padding)}, padding) paddedMsgBytes := append(plaintext, padText...) ciphertext := make([]byte, len(paddedMsgBytes)) mode.CryptBlocks(ciphertext, paddedMsgBytes) return []byte(base64.StdEncoding.EncodeToString(ciphertext) + "?iv=" + base64.StdEncoding.EncodeToString(iv)), nil } // DecryptNip4 decrypts a content string using the shared secret key. The inverse operation to message -> // EncryptNip4(message, key). func DecryptNip4(content, key []byte) (msg []byte, err error) { parts := bytes.Split(content, []byte("?iv=")) if len(parts) < 2 { return nil, errorf.E( "error parsing encrypted message: no initialization vector", ) } ciphertext := make([]byte, base64.StdEncoding.EncodedLen(len(parts[0]))) if _, err = base64.StdEncoding.Decode(ciphertext, parts[0]); chk.E(err) { err = errorf.E("error decoding ciphertext from base64: %w", err) return } iv := make([]byte, base64.StdEncoding.EncodedLen(len(parts[1]))) if _, err = base64.StdEncoding.Decode(iv, parts[1]); chk.E(err) { err = errorf.E("error decoding iv from base64: %w", err) return } var block cipher.Block if block, err = aes.NewCipher(key); chk.E(err) { err = errorf.E("error creating block cipher: %w", err) return } mode := cipher.NewCBCDecrypter(block, iv) msg = make([]byte, len(ciphertext)) mode.CryptBlocks(msg, ciphertext) // remove padding var ( plaintextLen = len(msg) ) if plaintextLen > 0 { // the padding amount is encoded in the padding bytes themselves padding := int(msg[plaintextLen-1]) if padding > plaintextLen { err = errorf.E("invalid padding amount: %d", padding) return } msg = msg[0 : plaintextLen-padding] } return msg, nil }