Review comments and additinal tests

This commit is contained in:
Alex Peters
2020-06-30 20:37:21 +02:00
parent a20e568bff
commit f7b4acf47c
7 changed files with 199 additions and 25 deletions

View File

@@ -5,6 +5,7 @@ import (
"github.com/CosmWasm/wasmd/x/wasm/internal/types" "github.com/CosmWasm/wasmd/x/wasm/internal/types"
sdk "github.com/cosmos/cosmos-sdk/types" sdk "github.com/cosmos/cosmos-sdk/types"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
// authexported "github.com/cosmos/cosmos-sdk/x/auth/exported" // authexported "github.com/cosmos/cosmos-sdk/x/auth/exported"
// "github.com/CosmWasm/wasmd/x/wasm/internal/types" // "github.com/CosmWasm/wasmd/x/wasm/internal/types"
) )
@@ -12,25 +13,33 @@ import (
// InitGenesis sets supply information for genesis. // InitGenesis sets supply information for genesis.
// //
// CONTRACT: all types of accounts must have been already initialized/created // CONTRACT: all types of accounts must have been already initialized/created
func InitGenesis(ctx sdk.Context, keeper Keeper, data types.GenesisState) { func InitGenesis(ctx sdk.Context, keeper Keeper, data types.GenesisState) error {
for _, code := range data.Codes { for i, code := range data.Codes {
newId, err := keeper.Create(ctx, code.CodeInfo.Creator, code.CodesBytes, code.CodeInfo.Source, code.CodeInfo.Builder) newId, err := keeper.Create(ctx, code.CodeInfo.Creator, code.CodesBytes, code.CodeInfo.Source, code.CodeInfo.Builder)
if err != nil { if err != nil {
panic(err) return sdkerrors.Wrapf(err, "code number %d", i)
} }
newInfo := keeper.GetCodeInfo(ctx, newId) newInfo := keeper.GetCodeInfo(ctx, newId)
if !bytes.Equal(code.CodeInfo.CodeHash, newInfo.CodeHash) { if !bytes.Equal(code.CodeInfo.CodeHash, newInfo.CodeHash) {
panic("code hashes not same") return sdkerrors.Wrap(types.ErrInvalid, "code hashes not same")
} }
} }
for _, contract := range data.Contracts { for i, contract := range data.Contracts {
keeper.importContract(ctx, contract.ContractAddress, &contract.ContractInfo, contract.ContractState) err := keeper.importContract(ctx, contract.ContractAddress, &contract.ContractInfo, contract.ContractState)
if err != nil {
return sdkerrors.Wrapf(err, "contract number %d", i)
}
} }
for _, seq := range data.Sequences { for i, seq := range data.Sequences {
keeper.importAutoIncrementID(ctx, seq.IDKey, seq.Value) err := keeper.importAutoIncrementID(ctx, seq.IDKey, seq.Value)
if err != nil {
return sdkerrors.Wrapf(err, "sequence number %d", i)
}
} }
return nil
} }
// ExportGenesis returns a GenesisState for a given context and keeper. // ExportGenesis returns a GenesisState for a given context and keeper.

View File

@@ -124,11 +124,42 @@ func TestFailFastImport(t *testing.T) {
}, },
expSuccess: true, expSuccess: true,
}, },
"happy path: code info with two contracts": {
src: types.GenesisState{
Codes: []types.Code{{
CodeInfo: wasmTypes.CodeInfo{
CodeHash: codeHash[:],
Creator: anyAddress,
},
CodesBytes: wasmCode,
}},
Contracts: []types.Contract{
{
ContractAddress: addrFromUint64(1<<32 + 1),
ContractInfo: wasmTypes.ContractInfo{
CodeID: 1,
Creator: anyAddress,
Label: "any",
Created: &types.AbsoluteTxPosition{BlockHeight: 1, TxIndex: 1},
},
}, {
ContractAddress: addrFromUint64(2<<32 + 1),
ContractInfo: wasmTypes.ContractInfo{
CodeID: 1,
Creator: anyAddress,
Label: "any",
Created: &types.AbsoluteTxPosition{BlockHeight: 1, TxIndex: 1},
},
},
},
},
expSuccess: true,
},
"prevent contracts that points to non existing codeID": { "prevent contracts that points to non existing codeID": {
src: types.GenesisState{ src: types.GenesisState{
Contracts: []types.Contract{ Contracts: []types.Contract{
{ {
ContractAddress: addrFromUint64(1<<32 + 1), ContractAddress: contractAddress(1, 1),
ContractInfo: wasmTypes.ContractInfo{ ContractInfo: wasmTypes.ContractInfo{
CodeID: 1, CodeID: 1,
Creator: anyAddress, Creator: anyAddress,
@@ -150,7 +181,7 @@ func TestFailFastImport(t *testing.T) {
}}, }},
Contracts: []types.Contract{ Contracts: []types.Contract{
{ {
ContractAddress: addrFromUint64(1<<32 + 1), ContractAddress: contractAddress(1, 1),
ContractInfo: wasmTypes.ContractInfo{ ContractInfo: wasmTypes.ContractInfo{
CodeID: 1, CodeID: 1,
Creator: anyAddress, Creator: anyAddress,
@@ -158,7 +189,7 @@ func TestFailFastImport(t *testing.T) {
Created: &types.AbsoluteTxPosition{BlockHeight: 1, TxIndex: 1}, Created: &types.AbsoluteTxPosition{BlockHeight: 1, TxIndex: 1},
}, },
}, { }, {
ContractAddress: addrFromUint64(1<<32 + 1), ContractAddress: contractAddress(1, 1),
ContractInfo: wasmTypes.ContractInfo{ ContractInfo: wasmTypes.ContractInfo{
CodeID: 1, CodeID: 1,
Creator: anyAddress, Creator: anyAddress,
@@ -217,13 +248,12 @@ func TestFailFastImport(t *testing.T) {
defer cleanup() defer cleanup()
require.NoError(t, types.ValidateGenesis(spec.src)) require.NoError(t, types.ValidateGenesis(spec.src))
got := InitGenesis(ctx, keeper, spec.src)
if spec.expSuccess { if spec.expSuccess {
InitGenesis(ctx, keeper, spec.src) require.NoError(t, got)
return return
} }
require.Panics(t, func() { require.Error(t, got)
InitGenesis(ctx, keeper, spec.src)
})
}) })
} }
} }

View File

@@ -2,10 +2,10 @@ package keeper
import ( import (
"encoding/binary" "encoding/binary"
"fmt"
"path/filepath" "path/filepath"
"github.com/cosmos/cosmos-sdk/x/staking" "github.com/cosmos/cosmos-sdk/x/staking"
"github.com/pkg/errors"
wasm "github.com/CosmWasm/go-cosmwasm" wasm "github.com/CosmWasm/go-cosmwasm"
wasmTypes "github.com/CosmWasm/go-cosmwasm/types" wasmTypes "github.com/CosmWasm/go-cosmwasm/types"
@@ -404,7 +404,7 @@ func (k Keeper) GetContractState(ctx sdk.Context, contractAddress sdk.AccAddress
return prefixStore.Iterator(nil, nil) return prefixStore.Iterator(nil, nil)
} }
func (k Keeper) importContractState(ctx sdk.Context, contractAddress sdk.AccAddress, models []types.Model) { func (k Keeper) importContractState(ctx sdk.Context, contractAddress sdk.AccAddress, models []types.Model) error {
prefixStoreKey := types.GetContractStorePrefixKey(contractAddress) prefixStoreKey := types.GetContractStorePrefixKey(contractAddress)
prefixStore := prefix.NewStore(ctx.KVStore(k.storeKey), prefixStoreKey) prefixStore := prefix.NewStore(ctx.KVStore(k.storeKey), prefixStoreKey)
for _, model := range models { for _, model := range models {
@@ -412,10 +412,11 @@ func (k Keeper) importContractState(ctx sdk.Context, contractAddress sdk.AccAddr
model.Value = []byte{} model.Value = []byte{}
} }
if prefixStore.Has(model.Key) { if prefixStore.Has(model.Key) {
panic(fmt.Sprintf("duplicate key: %x", model.Key)) return sdkerrors.Wrapf(types.ErrDuplicate, "duplicate key: %x", model.Key)
} }
prefixStore.Set(model.Key, model.Value) prefixStore.Set(model.Key, model.Value)
} }
return nil
} }
func (k Keeper) GetCodeInfo(ctx sdk.Context, codeID uint64) *types.CodeInfo { func (k Keeper) GetCodeInfo(ctx sdk.Context, codeID uint64) *types.CodeInfo {
@@ -471,10 +472,15 @@ func consumeGas(ctx sdk.Context, gas uint64) {
// generates a contract address from codeID + instanceID // generates a contract address from codeID + instanceID
func (k Keeper) generateContractAddress(ctx sdk.Context, codeID uint64) sdk.AccAddress { func (k Keeper) generateContractAddress(ctx sdk.Context, codeID uint64) sdk.AccAddress {
instanceID := k.autoIncrementID(ctx, types.KeyLastInstanceID) instanceID := k.autoIncrementID(ctx, types.KeyLastInstanceID)
return contractAddress(codeID, instanceID)
}
func contractAddress(codeID, instanceID uint64) sdk.AccAddress {
// NOTE: It is possible to get a duplicate address if either codeID or instanceID // NOTE: It is possible to get a duplicate address if either codeID or instanceID
// overflow 32 bits. This is highly improbable, but something that could be refactored. // overflow 32 bits. This is highly improbable, but something that could be refactored.
contractID := codeID<<32 + instanceID contractID := codeID<<32 + instanceID
return addrFromUint64(contractID) return addrFromUint64(contractID)
} }
func (k Keeper) GetNextCodeID(ctx sdk.Context) uint64 { func (k Keeper) GetNextCodeID(ctx sdk.Context) uint64 {
@@ -510,24 +516,25 @@ func (k Keeper) peekAutoIncrementID(ctx sdk.Context, lastIDKey []byte) uint64 {
return id return id
} }
func (k Keeper) importAutoIncrementID(ctx sdk.Context, lastIDKey []byte, val uint64) { func (k Keeper) importAutoIncrementID(ctx sdk.Context, lastIDKey []byte, val uint64) error {
store := ctx.KVStore(k.storeKey) store := ctx.KVStore(k.storeKey)
if store.Has(lastIDKey) { if store.Has(lastIDKey) {
panic(fmt.Sprintf("duplicate autoincrement id: %s", string(lastIDKey))) return sdkerrors.Wrapf(types.ErrDuplicate, "autoincrement id: %s", string(lastIDKey))
} }
bz := sdk.Uint64ToBigEndian(val) bz := sdk.Uint64ToBigEndian(val)
store.Set(lastIDKey, bz) store.Set(lastIDKey, bz)
return nil
} }
func (k Keeper) importContract(ctx sdk.Context, address sdk.AccAddress, c *types.ContractInfo, state []types.Model) { func (k Keeper) importContract(ctx sdk.Context, address sdk.AccAddress, c *types.ContractInfo, state []types.Model) error {
if !k.containsCodeInfo(ctx, c.CodeID) { if !k.containsCodeInfo(ctx, c.CodeID) {
panic(fmt.Sprintf("unknown code id: %d", c.CodeID)) return errors.Wrapf(types.ErrNotFound, "code id: %d", c.CodeID)
} }
if k.containsContractInfo(ctx, address) { if k.containsContractInfo(ctx, address) {
panic(fmt.Sprintf("duplicate contract: %s", address)) return errors.Wrapf(types.ErrDuplicate, "contract: %s", address)
} }
k.setContractInfo(ctx, address, c) k.setContractInfo(ctx, address, c)
k.importContractState(ctx, address, state) return k.importContractState(ctx, address, state)
} }
func addrFromUint64(id uint64) sdk.AccAddress { func addrFromUint64(id uint64) sdk.AccAddress {

View File

@@ -46,4 +46,7 @@ var (
// ErrInvalid error for content that is invalid in this context // ErrInvalid error for content that is invalid in this context
ErrInvalid = sdkErrors.Register(DefaultCodespace, 13, "invalid") ErrInvalid = sdkErrors.Register(DefaultCodespace, 13, "invalid")
// ErrDuplicate error for content that exsists
ErrDuplicate = sdkErrors.Register(DefaultCodespace, 14, "duplicate")
) )

View File

@@ -0,0 +1,107 @@
package types
import (
"crypto/sha256"
"testing"
"github.com/stretchr/testify/require"
"github.com/tendermint/tendermint/libs/rand"
)
func TestValidateGenesisState(t *testing.T) {
specs := map[string]struct {
srcMutator func(state GenesisState)
expError bool
}{
"all good": {
srcMutator: func(s GenesisState) {},
},
"codeinfo invalid": {
srcMutator: func(s GenesisState) {
s.Codes[0].CodeInfo.CodeHash = nil
},
expError: true,
},
"contract invalid": {
srcMutator: func(s GenesisState) {
s.Contracts[0].ContractAddress = nil
},
expError: true,
},
"sequence invalid": {
srcMutator: func(s GenesisState) {
s.Sequences[0].IDKey = nil
},
expError: true,
},
}
for msg, spec := range specs {
t.Run(msg, func(t *testing.T) {
state := genesisFixture(spec.srcMutator)
got := state.ValidateBasic()
if spec.expError {
require.Error(t, got)
return
}
require.NoError(t, got)
})
}
}
func genesisFixture(mutators ...func(state GenesisState)) GenesisState {
const (
numCodes = 2
numContracts = 2
numSequences = 2
)
fixture := GenesisState{
Codes: make([]Code, numCodes),
Contracts: make([]Contract, numContracts),
Sequences: make([]Sequence, numSequences),
}
for i := 0; i < numCodes; i++ {
fixture.Codes[i] = codeFixture()
}
for i := 0; i < numContracts; i++ {
fixture.Contracts[i] = contractFixture()
}
for i := 0; i < numSequences; i++ {
fixture.Sequences[i] = Sequence{
IDKey: rand.Bytes(5),
Value: uint64(i),
}
}
for _, m := range mutators {
m(fixture)
}
return fixture
}
func codeFixture() Code {
wasmCode := rand.Bytes(100)
codeHash := sha256.Sum256(wasmCode)
anyAddress := make([]byte, 20)
return Code{
CodeInfo: CodeInfo{
CodeHash: codeHash[:],
Creator: anyAddress,
},
CodesBytes: wasmCode,
}
}
func contractFixture() Contract {
anyAddress := make([]byte, 20)
return Contract{
ContractAddress: anyAddress,
ContractInfo: ContractInfo{
CodeID: 1,
Creator: anyAddress,
Label: "any",
Created: &AbsoluteTxPosition{BlockHeight: 1, TxIndex: 1},
},
}
}

View File

@@ -97,6 +97,12 @@ func (c *ContractInfo) ValidateBasic() error {
if err := validateLabel(c.Label); err != nil { if err := validateLabel(c.Label); err != nil {
return sdkerrors.Wrap(err, "label") return sdkerrors.Wrap(err, "label")
} }
if err := c.Created.ValidateBasic(); err != nil {
return sdkerrors.Wrap(err, "created")
}
if err := c.LastUpdated.ValidateBasic(); err != nil {
return sdkerrors.Wrap(err, "last updated")
}
return nil return nil
} }
@@ -119,6 +125,16 @@ func (a *AbsoluteTxPosition) LessThan(b *AbsoluteTxPosition) bool {
return a.BlockHeight < b.BlockHeight || (a.BlockHeight == b.BlockHeight && a.TxIndex < b.TxIndex) return a.BlockHeight < b.BlockHeight || (a.BlockHeight == b.BlockHeight && a.TxIndex < b.TxIndex)
} }
func (a *AbsoluteTxPosition) ValidateBasic() error {
if a == nil {
return nil
}
if a.BlockHeight < 0 {
return sdkerrors.Wrap(ErrInvalid, "height")
}
return nil
}
// NewCreatedAt gets a timestamp from the context // NewCreatedAt gets a timestamp from the context
func NewCreatedAt(ctx sdk.Context) *AbsoluteTxPosition { func NewCreatedAt(ctx sdk.Context) *AbsoluteTxPosition {
// we must safely handle nil gas meters // we must safely handle nil gas meters

View File

@@ -114,7 +114,9 @@ func (am AppModule) NewQuerierHandler() sdk.Querier {
func (am AppModule) InitGenesis(ctx sdk.Context, data json.RawMessage) []abci.ValidatorUpdate { func (am AppModule) InitGenesis(ctx sdk.Context, data json.RawMessage) []abci.ValidatorUpdate {
var genesisState GenesisState var genesisState GenesisState
ModuleCdc.MustUnmarshalJSON(data, &genesisState) ModuleCdc.MustUnmarshalJSON(data, &genesisState)
InitGenesis(ctx, am.keeper, genesisState) if err := InitGenesis(ctx, am.keeper, genesisState); err != nil {
panic(err)
}
return []abci.ValidatorUpdate{} return []abci.ValidatorUpdate{}
} }