diff --git a/x/wasm/internal/keeper/genesis.go b/x/wasm/internal/keeper/genesis.go index b23ec7ce..04a4220e 100644 --- a/x/wasm/internal/keeper/genesis.go +++ b/x/wasm/internal/keeper/genesis.go @@ -5,6 +5,7 @@ import ( "github.com/CosmWasm/wasmd/x/wasm/internal/types" sdk "github.com/cosmos/cosmos-sdk/types" + sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" // authexported "github.com/cosmos/cosmos-sdk/x/auth/exported" // "github.com/CosmWasm/wasmd/x/wasm/internal/types" ) @@ -12,25 +13,33 @@ import ( // InitGenesis sets supply information for genesis. // // CONTRACT: all types of accounts must have been already initialized/created -func InitGenesis(ctx sdk.Context, keeper Keeper, data types.GenesisState) { - for _, code := range data.Codes { +func InitGenesis(ctx sdk.Context, keeper Keeper, data types.GenesisState) error { + for i, code := range data.Codes { newId, err := keeper.Create(ctx, code.CodeInfo.Creator, code.CodesBytes, code.CodeInfo.Source, code.CodeInfo.Builder) if err != nil { - panic(err) + return sdkerrors.Wrapf(err, "code number %d", i) + } newInfo := keeper.GetCodeInfo(ctx, newId) if !bytes.Equal(code.CodeInfo.CodeHash, newInfo.CodeHash) { - panic("code hashes not same") + return sdkerrors.Wrap(types.ErrInvalid, "code hashes not same") } } - for _, contract := range data.Contracts { - keeper.importContract(ctx, contract.ContractAddress, &contract.ContractInfo, contract.ContractState) + for i, contract := range data.Contracts { + err := keeper.importContract(ctx, contract.ContractAddress, &contract.ContractInfo, contract.ContractState) + if err != nil { + return sdkerrors.Wrapf(err, "contract number %d", i) + } } - for _, seq := range data.Sequences { - keeper.importAutoIncrementID(ctx, seq.IDKey, seq.Value) + for i, seq := range data.Sequences { + err := keeper.importAutoIncrementID(ctx, seq.IDKey, seq.Value) + if err != nil { + return sdkerrors.Wrapf(err, "sequence number %d", i) + } } + return nil } // ExportGenesis returns a GenesisState for a given context and keeper. diff --git a/x/wasm/internal/keeper/genesis_test.go b/x/wasm/internal/keeper/genesis_test.go index 74813514..0fedc631 100644 --- a/x/wasm/internal/keeper/genesis_test.go +++ b/x/wasm/internal/keeper/genesis_test.go @@ -124,11 +124,42 @@ func TestFailFastImport(t *testing.T) { }, expSuccess: true, }, + "happy path: code info with two contracts": { + src: types.GenesisState{ + Codes: []types.Code{{ + CodeInfo: wasmTypes.CodeInfo{ + CodeHash: codeHash[:], + Creator: anyAddress, + }, + CodesBytes: wasmCode, + }}, + Contracts: []types.Contract{ + { + ContractAddress: addrFromUint64(1<<32 + 1), + ContractInfo: wasmTypes.ContractInfo{ + CodeID: 1, + Creator: anyAddress, + Label: "any", + Created: &types.AbsoluteTxPosition{BlockHeight: 1, TxIndex: 1}, + }, + }, { + ContractAddress: addrFromUint64(2<<32 + 1), + ContractInfo: wasmTypes.ContractInfo{ + CodeID: 1, + Creator: anyAddress, + Label: "any", + Created: &types.AbsoluteTxPosition{BlockHeight: 1, TxIndex: 1}, + }, + }, + }, + }, + expSuccess: true, + }, "prevent contracts that points to non existing codeID": { src: types.GenesisState{ Contracts: []types.Contract{ { - ContractAddress: addrFromUint64(1<<32 + 1), + ContractAddress: contractAddress(1, 1), ContractInfo: wasmTypes.ContractInfo{ CodeID: 1, Creator: anyAddress, @@ -150,7 +181,7 @@ func TestFailFastImport(t *testing.T) { }}, Contracts: []types.Contract{ { - ContractAddress: addrFromUint64(1<<32 + 1), + ContractAddress: contractAddress(1, 1), ContractInfo: wasmTypes.ContractInfo{ CodeID: 1, Creator: anyAddress, @@ -158,7 +189,7 @@ func TestFailFastImport(t *testing.T) { Created: &types.AbsoluteTxPosition{BlockHeight: 1, TxIndex: 1}, }, }, { - ContractAddress: addrFromUint64(1<<32 + 1), + ContractAddress: contractAddress(1, 1), ContractInfo: wasmTypes.ContractInfo{ CodeID: 1, Creator: anyAddress, @@ -217,13 +248,12 @@ func TestFailFastImport(t *testing.T) { defer cleanup() require.NoError(t, types.ValidateGenesis(spec.src)) + got := InitGenesis(ctx, keeper, spec.src) if spec.expSuccess { - InitGenesis(ctx, keeper, spec.src) + require.NoError(t, got) return } - require.Panics(t, func() { - InitGenesis(ctx, keeper, spec.src) - }) + require.Error(t, got) }) } } diff --git a/x/wasm/internal/keeper/keeper.go b/x/wasm/internal/keeper/keeper.go index 74aa88ea..cdf9701f 100644 --- a/x/wasm/internal/keeper/keeper.go +++ b/x/wasm/internal/keeper/keeper.go @@ -2,10 +2,10 @@ package keeper import ( "encoding/binary" - "fmt" "path/filepath" "github.com/cosmos/cosmos-sdk/x/staking" + "github.com/pkg/errors" wasm "github.com/CosmWasm/go-cosmwasm" wasmTypes "github.com/CosmWasm/go-cosmwasm/types" @@ -404,7 +404,7 @@ func (k Keeper) GetContractState(ctx sdk.Context, contractAddress sdk.AccAddress return prefixStore.Iterator(nil, nil) } -func (k Keeper) importContractState(ctx sdk.Context, contractAddress sdk.AccAddress, models []types.Model) { +func (k Keeper) importContractState(ctx sdk.Context, contractAddress sdk.AccAddress, models []types.Model) error { prefixStoreKey := types.GetContractStorePrefixKey(contractAddress) prefixStore := prefix.NewStore(ctx.KVStore(k.storeKey), prefixStoreKey) for _, model := range models { @@ -412,10 +412,11 @@ func (k Keeper) importContractState(ctx sdk.Context, contractAddress sdk.AccAddr model.Value = []byte{} } if prefixStore.Has(model.Key) { - panic(fmt.Sprintf("duplicate key: %x", model.Key)) + return sdkerrors.Wrapf(types.ErrDuplicate, "duplicate key: %x", model.Key) } prefixStore.Set(model.Key, model.Value) } + return nil } func (k Keeper) GetCodeInfo(ctx sdk.Context, codeID uint64) *types.CodeInfo { @@ -471,10 +472,15 @@ func consumeGas(ctx sdk.Context, gas uint64) { // generates a contract address from codeID + instanceID func (k Keeper) generateContractAddress(ctx sdk.Context, codeID uint64) sdk.AccAddress { instanceID := k.autoIncrementID(ctx, types.KeyLastInstanceID) + return contractAddress(codeID, instanceID) +} + +func contractAddress(codeID, instanceID uint64) sdk.AccAddress { // NOTE: It is possible to get a duplicate address if either codeID or instanceID // overflow 32 bits. This is highly improbable, but something that could be refactored. contractID := codeID<<32 + instanceID return addrFromUint64(contractID) + } func (k Keeper) GetNextCodeID(ctx sdk.Context) uint64 { @@ -510,24 +516,25 @@ func (k Keeper) peekAutoIncrementID(ctx sdk.Context, lastIDKey []byte) uint64 { return id } -func (k Keeper) importAutoIncrementID(ctx sdk.Context, lastIDKey []byte, val uint64) { +func (k Keeper) importAutoIncrementID(ctx sdk.Context, lastIDKey []byte, val uint64) error { store := ctx.KVStore(k.storeKey) if store.Has(lastIDKey) { - panic(fmt.Sprintf("duplicate autoincrement id: %s", string(lastIDKey))) + return sdkerrors.Wrapf(types.ErrDuplicate, "autoincrement id: %s", string(lastIDKey)) } bz := sdk.Uint64ToBigEndian(val) store.Set(lastIDKey, bz) + return nil } -func (k Keeper) importContract(ctx sdk.Context, address sdk.AccAddress, c *types.ContractInfo, state []types.Model) { +func (k Keeper) importContract(ctx sdk.Context, address sdk.AccAddress, c *types.ContractInfo, state []types.Model) error { if !k.containsCodeInfo(ctx, c.CodeID) { - panic(fmt.Sprintf("unknown code id: %d", c.CodeID)) + return errors.Wrapf(types.ErrNotFound, "code id: %d", c.CodeID) } if k.containsContractInfo(ctx, address) { - panic(fmt.Sprintf("duplicate contract: %s", address)) + return errors.Wrapf(types.ErrDuplicate, "contract: %s", address) } k.setContractInfo(ctx, address, c) - k.importContractState(ctx, address, state) + return k.importContractState(ctx, address, state) } func addrFromUint64(id uint64) sdk.AccAddress { diff --git a/x/wasm/internal/types/errors.go b/x/wasm/internal/types/errors.go index 3ab5b5c4..ad70bdab 100644 --- a/x/wasm/internal/types/errors.go +++ b/x/wasm/internal/types/errors.go @@ -46,4 +46,7 @@ var ( // ErrInvalid error for content that is invalid in this context ErrInvalid = sdkErrors.Register(DefaultCodespace, 13, "invalid") + + // ErrDuplicate error for content that exsists + ErrDuplicate = sdkErrors.Register(DefaultCodespace, 14, "duplicate") ) diff --git a/x/wasm/internal/types/genesis_test.go b/x/wasm/internal/types/genesis_test.go new file mode 100644 index 00000000..5d84fc9d --- /dev/null +++ b/x/wasm/internal/types/genesis_test.go @@ -0,0 +1,107 @@ +package types + +import ( + "crypto/sha256" + "testing" + + "github.com/stretchr/testify/require" + "github.com/tendermint/tendermint/libs/rand" +) + +func TestValidateGenesisState(t *testing.T) { + specs := map[string]struct { + srcMutator func(state GenesisState) + expError bool + }{ + "all good": { + srcMutator: func(s GenesisState) {}, + }, + "codeinfo invalid": { + srcMutator: func(s GenesisState) { + s.Codes[0].CodeInfo.CodeHash = nil + }, + expError: true, + }, + "contract invalid": { + srcMutator: func(s GenesisState) { + s.Contracts[0].ContractAddress = nil + }, + expError: true, + }, + "sequence invalid": { + srcMutator: func(s GenesisState) { + s.Sequences[0].IDKey = nil + }, + expError: true, + }, + } + for msg, spec := range specs { + t.Run(msg, func(t *testing.T) { + state := genesisFixture(spec.srcMutator) + got := state.ValidateBasic() + if spec.expError { + require.Error(t, got) + return + } + require.NoError(t, got) + }) + } + +} + +func genesisFixture(mutators ...func(state GenesisState)) GenesisState { + const ( + numCodes = 2 + numContracts = 2 + numSequences = 2 + ) + + fixture := GenesisState{ + Codes: make([]Code, numCodes), + Contracts: make([]Contract, numContracts), + Sequences: make([]Sequence, numSequences), + } + for i := 0; i < numCodes; i++ { + fixture.Codes[i] = codeFixture() + } + for i := 0; i < numContracts; i++ { + fixture.Contracts[i] = contractFixture() + } + for i := 0; i < numSequences; i++ { + fixture.Sequences[i] = Sequence{ + IDKey: rand.Bytes(5), + Value: uint64(i), + } + } + for _, m := range mutators { + m(fixture) + } + return fixture +} + +func codeFixture() Code { + wasmCode := rand.Bytes(100) + codeHash := sha256.Sum256(wasmCode) + anyAddress := make([]byte, 20) + + return Code{ + CodeInfo: CodeInfo{ + CodeHash: codeHash[:], + Creator: anyAddress, + }, + CodesBytes: wasmCode, + } +} + +func contractFixture() Contract { + anyAddress := make([]byte, 20) + return Contract{ + ContractAddress: anyAddress, + ContractInfo: ContractInfo{ + CodeID: 1, + Creator: anyAddress, + Label: "any", + Created: &AbsoluteTxPosition{BlockHeight: 1, TxIndex: 1}, + }, + } +} diff --git a/x/wasm/internal/types/types.go b/x/wasm/internal/types/types.go index 1f6cbe10..896cf8e9 100644 --- a/x/wasm/internal/types/types.go +++ b/x/wasm/internal/types/types.go @@ -97,6 +97,12 @@ func (c *ContractInfo) ValidateBasic() error { if err := validateLabel(c.Label); err != nil { return sdkerrors.Wrap(err, "label") } + if err := c.Created.ValidateBasic(); err != nil { + return sdkerrors.Wrap(err, "created") + } + if err := c.LastUpdated.ValidateBasic(); err != nil { + return sdkerrors.Wrap(err, "last updated") + } return nil } @@ -119,6 +125,16 @@ func (a *AbsoluteTxPosition) LessThan(b *AbsoluteTxPosition) bool { return a.BlockHeight < b.BlockHeight || (a.BlockHeight == b.BlockHeight && a.TxIndex < b.TxIndex) } +func (a *AbsoluteTxPosition) ValidateBasic() error { + if a == nil { + return nil + } + if a.BlockHeight < 0 { + return sdkerrors.Wrap(ErrInvalid, "height") + } + return nil +} + // NewCreatedAt gets a timestamp from the context func NewCreatedAt(ctx sdk.Context) *AbsoluteTxPosition { // we must safely handle nil gas meters diff --git a/x/wasm/module.go b/x/wasm/module.go index e0215165..89643903 100644 --- a/x/wasm/module.go +++ b/x/wasm/module.go @@ -114,7 +114,9 @@ func (am AppModule) NewQuerierHandler() sdk.Querier { func (am AppModule) InitGenesis(ctx sdk.Context, data json.RawMessage) []abci.ValidatorUpdate { var genesisState GenesisState ModuleCdc.MustUnmarshalJSON(data, &genesisState) - InitGenesis(ctx, am.keeper, genesisState) + if err := InitGenesis(ctx, am.keeper, genesisState); err != nil { + panic(err) + } return []abci.ValidatorUpdate{} }