Review comments and additinal tests
This commit is contained in:
@@ -5,6 +5,7 @@ import (
|
||||
|
||||
"github.com/CosmWasm/wasmd/x/wasm/internal/types"
|
||||
sdk "github.com/cosmos/cosmos-sdk/types"
|
||||
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
|
||||
// authexported "github.com/cosmos/cosmos-sdk/x/auth/exported"
|
||||
// "github.com/CosmWasm/wasmd/x/wasm/internal/types"
|
||||
)
|
||||
@@ -12,25 +13,33 @@ import (
|
||||
// InitGenesis sets supply information for genesis.
|
||||
//
|
||||
// CONTRACT: all types of accounts must have been already initialized/created
|
||||
func InitGenesis(ctx sdk.Context, keeper Keeper, data types.GenesisState) {
|
||||
for _, code := range data.Codes {
|
||||
func InitGenesis(ctx sdk.Context, keeper Keeper, data types.GenesisState) error {
|
||||
for i, code := range data.Codes {
|
||||
newId, err := keeper.Create(ctx, code.CodeInfo.Creator, code.CodesBytes, code.CodeInfo.Source, code.CodeInfo.Builder)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
return sdkerrors.Wrapf(err, "code number %d", i)
|
||||
|
||||
}
|
||||
newInfo := keeper.GetCodeInfo(ctx, newId)
|
||||
if !bytes.Equal(code.CodeInfo.CodeHash, newInfo.CodeHash) {
|
||||
panic("code hashes not same")
|
||||
return sdkerrors.Wrap(types.ErrInvalid, "code hashes not same")
|
||||
}
|
||||
}
|
||||
|
||||
for _, contract := range data.Contracts {
|
||||
keeper.importContract(ctx, contract.ContractAddress, &contract.ContractInfo, contract.ContractState)
|
||||
for i, contract := range data.Contracts {
|
||||
err := keeper.importContract(ctx, contract.ContractAddress, &contract.ContractInfo, contract.ContractState)
|
||||
if err != nil {
|
||||
return sdkerrors.Wrapf(err, "contract number %d", i)
|
||||
}
|
||||
}
|
||||
|
||||
for _, seq := range data.Sequences {
|
||||
keeper.importAutoIncrementID(ctx, seq.IDKey, seq.Value)
|
||||
for i, seq := range data.Sequences {
|
||||
err := keeper.importAutoIncrementID(ctx, seq.IDKey, seq.Value)
|
||||
if err != nil {
|
||||
return sdkerrors.Wrapf(err, "sequence number %d", i)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExportGenesis returns a GenesisState for a given context and keeper.
|
||||
|
||||
@@ -124,11 +124,42 @@ func TestFailFastImport(t *testing.T) {
|
||||
},
|
||||
expSuccess: true,
|
||||
},
|
||||
"happy path: code info with two contracts": {
|
||||
src: types.GenesisState{
|
||||
Codes: []types.Code{{
|
||||
CodeInfo: wasmTypes.CodeInfo{
|
||||
CodeHash: codeHash[:],
|
||||
Creator: anyAddress,
|
||||
},
|
||||
CodesBytes: wasmCode,
|
||||
}},
|
||||
Contracts: []types.Contract{
|
||||
{
|
||||
ContractAddress: addrFromUint64(1<<32 + 1),
|
||||
ContractInfo: wasmTypes.ContractInfo{
|
||||
CodeID: 1,
|
||||
Creator: anyAddress,
|
||||
Label: "any",
|
||||
Created: &types.AbsoluteTxPosition{BlockHeight: 1, TxIndex: 1},
|
||||
},
|
||||
}, {
|
||||
ContractAddress: addrFromUint64(2<<32 + 1),
|
||||
ContractInfo: wasmTypes.ContractInfo{
|
||||
CodeID: 1,
|
||||
Creator: anyAddress,
|
||||
Label: "any",
|
||||
Created: &types.AbsoluteTxPosition{BlockHeight: 1, TxIndex: 1},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expSuccess: true,
|
||||
},
|
||||
"prevent contracts that points to non existing codeID": {
|
||||
src: types.GenesisState{
|
||||
Contracts: []types.Contract{
|
||||
{
|
||||
ContractAddress: addrFromUint64(1<<32 + 1),
|
||||
ContractAddress: contractAddress(1, 1),
|
||||
ContractInfo: wasmTypes.ContractInfo{
|
||||
CodeID: 1,
|
||||
Creator: anyAddress,
|
||||
@@ -150,7 +181,7 @@ func TestFailFastImport(t *testing.T) {
|
||||
}},
|
||||
Contracts: []types.Contract{
|
||||
{
|
||||
ContractAddress: addrFromUint64(1<<32 + 1),
|
||||
ContractAddress: contractAddress(1, 1),
|
||||
ContractInfo: wasmTypes.ContractInfo{
|
||||
CodeID: 1,
|
||||
Creator: anyAddress,
|
||||
@@ -158,7 +189,7 @@ func TestFailFastImport(t *testing.T) {
|
||||
Created: &types.AbsoluteTxPosition{BlockHeight: 1, TxIndex: 1},
|
||||
},
|
||||
}, {
|
||||
ContractAddress: addrFromUint64(1<<32 + 1),
|
||||
ContractAddress: contractAddress(1, 1),
|
||||
ContractInfo: wasmTypes.ContractInfo{
|
||||
CodeID: 1,
|
||||
Creator: anyAddress,
|
||||
@@ -217,13 +248,12 @@ func TestFailFastImport(t *testing.T) {
|
||||
defer cleanup()
|
||||
|
||||
require.NoError(t, types.ValidateGenesis(spec.src))
|
||||
got := InitGenesis(ctx, keeper, spec.src)
|
||||
if spec.expSuccess {
|
||||
InitGenesis(ctx, keeper, spec.src)
|
||||
require.NoError(t, got)
|
||||
return
|
||||
}
|
||||
require.Panics(t, func() {
|
||||
InitGenesis(ctx, keeper, spec.src)
|
||||
})
|
||||
require.Error(t, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,10 +2,10 @@ package keeper
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/cosmos/cosmos-sdk/x/staking"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
wasm "github.com/CosmWasm/go-cosmwasm"
|
||||
wasmTypes "github.com/CosmWasm/go-cosmwasm/types"
|
||||
@@ -404,7 +404,7 @@ func (k Keeper) GetContractState(ctx sdk.Context, contractAddress sdk.AccAddress
|
||||
return prefixStore.Iterator(nil, nil)
|
||||
}
|
||||
|
||||
func (k Keeper) importContractState(ctx sdk.Context, contractAddress sdk.AccAddress, models []types.Model) {
|
||||
func (k Keeper) importContractState(ctx sdk.Context, contractAddress sdk.AccAddress, models []types.Model) error {
|
||||
prefixStoreKey := types.GetContractStorePrefixKey(contractAddress)
|
||||
prefixStore := prefix.NewStore(ctx.KVStore(k.storeKey), prefixStoreKey)
|
||||
for _, model := range models {
|
||||
@@ -412,10 +412,11 @@ func (k Keeper) importContractState(ctx sdk.Context, contractAddress sdk.AccAddr
|
||||
model.Value = []byte{}
|
||||
}
|
||||
if prefixStore.Has(model.Key) {
|
||||
panic(fmt.Sprintf("duplicate key: %x", model.Key))
|
||||
return sdkerrors.Wrapf(types.ErrDuplicate, "duplicate key: %x", model.Key)
|
||||
}
|
||||
prefixStore.Set(model.Key, model.Value)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (k Keeper) GetCodeInfo(ctx sdk.Context, codeID uint64) *types.CodeInfo {
|
||||
@@ -471,10 +472,15 @@ func consumeGas(ctx sdk.Context, gas uint64) {
|
||||
// generates a contract address from codeID + instanceID
|
||||
func (k Keeper) generateContractAddress(ctx sdk.Context, codeID uint64) sdk.AccAddress {
|
||||
instanceID := k.autoIncrementID(ctx, types.KeyLastInstanceID)
|
||||
return contractAddress(codeID, instanceID)
|
||||
}
|
||||
|
||||
func contractAddress(codeID, instanceID uint64) sdk.AccAddress {
|
||||
// NOTE: It is possible to get a duplicate address if either codeID or instanceID
|
||||
// overflow 32 bits. This is highly improbable, but something that could be refactored.
|
||||
contractID := codeID<<32 + instanceID
|
||||
return addrFromUint64(contractID)
|
||||
|
||||
}
|
||||
|
||||
func (k Keeper) GetNextCodeID(ctx sdk.Context) uint64 {
|
||||
@@ -510,24 +516,25 @@ func (k Keeper) peekAutoIncrementID(ctx sdk.Context, lastIDKey []byte) uint64 {
|
||||
return id
|
||||
}
|
||||
|
||||
func (k Keeper) importAutoIncrementID(ctx sdk.Context, lastIDKey []byte, val uint64) {
|
||||
func (k Keeper) importAutoIncrementID(ctx sdk.Context, lastIDKey []byte, val uint64) error {
|
||||
store := ctx.KVStore(k.storeKey)
|
||||
if store.Has(lastIDKey) {
|
||||
panic(fmt.Sprintf("duplicate autoincrement id: %s", string(lastIDKey)))
|
||||
return sdkerrors.Wrapf(types.ErrDuplicate, "autoincrement id: %s", string(lastIDKey))
|
||||
}
|
||||
bz := sdk.Uint64ToBigEndian(val)
|
||||
store.Set(lastIDKey, bz)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (k Keeper) importContract(ctx sdk.Context, address sdk.AccAddress, c *types.ContractInfo, state []types.Model) {
|
||||
func (k Keeper) importContract(ctx sdk.Context, address sdk.AccAddress, c *types.ContractInfo, state []types.Model) error {
|
||||
if !k.containsCodeInfo(ctx, c.CodeID) {
|
||||
panic(fmt.Sprintf("unknown code id: %d", c.CodeID))
|
||||
return errors.Wrapf(types.ErrNotFound, "code id: %d", c.CodeID)
|
||||
}
|
||||
if k.containsContractInfo(ctx, address) {
|
||||
panic(fmt.Sprintf("duplicate contract: %s", address))
|
||||
return errors.Wrapf(types.ErrDuplicate, "contract: %s", address)
|
||||
}
|
||||
k.setContractInfo(ctx, address, c)
|
||||
k.importContractState(ctx, address, state)
|
||||
return k.importContractState(ctx, address, state)
|
||||
}
|
||||
|
||||
func addrFromUint64(id uint64) sdk.AccAddress {
|
||||
|
||||
@@ -46,4 +46,7 @@ var (
|
||||
|
||||
// ErrInvalid error for content that is invalid in this context
|
||||
ErrInvalid = sdkErrors.Register(DefaultCodespace, 13, "invalid")
|
||||
|
||||
// ErrDuplicate error for content that exsists
|
||||
ErrDuplicate = sdkErrors.Register(DefaultCodespace, 14, "duplicate")
|
||||
)
|
||||
|
||||
107
x/wasm/internal/types/genesis_test.go
Normal file
107
x/wasm/internal/types/genesis_test.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tendermint/tendermint/libs/rand"
|
||||
)
|
||||
|
||||
func TestValidateGenesisState(t *testing.T) {
|
||||
specs := map[string]struct {
|
||||
srcMutator func(state GenesisState)
|
||||
expError bool
|
||||
}{
|
||||
"all good": {
|
||||
srcMutator: func(s GenesisState) {},
|
||||
},
|
||||
"codeinfo invalid": {
|
||||
srcMutator: func(s GenesisState) {
|
||||
s.Codes[0].CodeInfo.CodeHash = nil
|
||||
},
|
||||
expError: true,
|
||||
},
|
||||
"contract invalid": {
|
||||
srcMutator: func(s GenesisState) {
|
||||
s.Contracts[0].ContractAddress = nil
|
||||
},
|
||||
expError: true,
|
||||
},
|
||||
"sequence invalid": {
|
||||
srcMutator: func(s GenesisState) {
|
||||
s.Sequences[0].IDKey = nil
|
||||
},
|
||||
expError: true,
|
||||
},
|
||||
}
|
||||
for msg, spec := range specs {
|
||||
t.Run(msg, func(t *testing.T) {
|
||||
state := genesisFixture(spec.srcMutator)
|
||||
got := state.ValidateBasic()
|
||||
if spec.expError {
|
||||
require.Error(t, got)
|
||||
return
|
||||
}
|
||||
require.NoError(t, got)
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func genesisFixture(mutators ...func(state GenesisState)) GenesisState {
|
||||
const (
|
||||
numCodes = 2
|
||||
numContracts = 2
|
||||
numSequences = 2
|
||||
)
|
||||
|
||||
fixture := GenesisState{
|
||||
Codes: make([]Code, numCodes),
|
||||
Contracts: make([]Contract, numContracts),
|
||||
Sequences: make([]Sequence, numSequences),
|
||||
}
|
||||
for i := 0; i < numCodes; i++ {
|
||||
fixture.Codes[i] = codeFixture()
|
||||
}
|
||||
for i := 0; i < numContracts; i++ {
|
||||
fixture.Contracts[i] = contractFixture()
|
||||
}
|
||||
for i := 0; i < numSequences; i++ {
|
||||
fixture.Sequences[i] = Sequence{
|
||||
IDKey: rand.Bytes(5),
|
||||
Value: uint64(i),
|
||||
}
|
||||
}
|
||||
for _, m := range mutators {
|
||||
m(fixture)
|
||||
}
|
||||
return fixture
|
||||
}
|
||||
|
||||
func codeFixture() Code {
|
||||
wasmCode := rand.Bytes(100)
|
||||
codeHash := sha256.Sum256(wasmCode)
|
||||
anyAddress := make([]byte, 20)
|
||||
|
||||
return Code{
|
||||
CodeInfo: CodeInfo{
|
||||
CodeHash: codeHash[:],
|
||||
Creator: anyAddress,
|
||||
},
|
||||
CodesBytes: wasmCode,
|
||||
}
|
||||
}
|
||||
|
||||
func contractFixture() Contract {
|
||||
anyAddress := make([]byte, 20)
|
||||
return Contract{
|
||||
ContractAddress: anyAddress,
|
||||
ContractInfo: ContractInfo{
|
||||
CodeID: 1,
|
||||
Creator: anyAddress,
|
||||
Label: "any",
|
||||
Created: &AbsoluteTxPosition{BlockHeight: 1, TxIndex: 1},
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -97,6 +97,12 @@ func (c *ContractInfo) ValidateBasic() error {
|
||||
if err := validateLabel(c.Label); err != nil {
|
||||
return sdkerrors.Wrap(err, "label")
|
||||
}
|
||||
if err := c.Created.ValidateBasic(); err != nil {
|
||||
return sdkerrors.Wrap(err, "created")
|
||||
}
|
||||
if err := c.LastUpdated.ValidateBasic(); err != nil {
|
||||
return sdkerrors.Wrap(err, "last updated")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -119,6 +125,16 @@ func (a *AbsoluteTxPosition) LessThan(b *AbsoluteTxPosition) bool {
|
||||
return a.BlockHeight < b.BlockHeight || (a.BlockHeight == b.BlockHeight && a.TxIndex < b.TxIndex)
|
||||
}
|
||||
|
||||
func (a *AbsoluteTxPosition) ValidateBasic() error {
|
||||
if a == nil {
|
||||
return nil
|
||||
}
|
||||
if a.BlockHeight < 0 {
|
||||
return sdkerrors.Wrap(ErrInvalid, "height")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewCreatedAt gets a timestamp from the context
|
||||
func NewCreatedAt(ctx sdk.Context) *AbsoluteTxPosition {
|
||||
// we must safely handle nil gas meters
|
||||
|
||||
@@ -114,7 +114,9 @@ func (am AppModule) NewQuerierHandler() sdk.Querier {
|
||||
func (am AppModule) InitGenesis(ctx sdk.Context, data json.RawMessage) []abci.ValidatorUpdate {
|
||||
var genesisState GenesisState
|
||||
ModuleCdc.MustUnmarshalJSON(data, &genesisState)
|
||||
InitGenesis(ctx, am.keeper, genesisState)
|
||||
if err := InitGenesis(ctx, am.keeper, genesisState); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return []abci.ValidatorUpdate{}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user