diff --git a/app/test_support.go b/app/test_support.go new file mode 100644 index 00000000..6c6c0167 --- /dev/null +++ b/app/test_support.go @@ -0,0 +1,40 @@ +package app + +import ( + "github.com/cosmos/cosmos-sdk/baseapp" + authkeeper "github.com/cosmos/cosmos-sdk/x/auth/keeper" + bankkeeper "github.com/cosmos/cosmos-sdk/x/bank/keeper" + capabilitykeeper "github.com/cosmos/cosmos-sdk/x/capability/keeper" + stakingkeeper "github.com/cosmos/cosmos-sdk/x/staking/keeper" + ibckeeper "github.com/cosmos/ibc-go/v7/modules/core/keeper" + + "github.com/CosmWasm/wasmd/x/wasm" +) + +func (app *WasmApp) GetIBCKeeper() *ibckeeper.Keeper { + return app.IBCKeeper +} + +func (app *WasmApp) GetScopedIBCKeeper() capabilitykeeper.ScopedKeeper { + return app.ScopedIBCKeeper +} + +func (app *WasmApp) GetBaseApp() *baseapp.BaseApp { + return app.BaseApp +} + +func (app *WasmApp) GetBankKeeper() bankkeeper.Keeper { + return app.BankKeeper +} + +func (app *WasmApp) GetStakingKeeper() *stakingkeeper.Keeper { + return app.StakingKeeper +} + +func (app *WasmApp) GetAccountKeeper() authkeeper.AccountKeeper { + return app.AccountKeeper +} + +func (app *WasmApp) GetWasmKeeper() wasm.Keeper { + return app.WasmKeeper +} diff --git a/tests/e2e/gov_test.go b/tests/e2e/gov_test.go index 5ce967da..1db5eca9 100644 --- a/tests/e2e/gov_test.go +++ b/tests/e2e/gov_test.go @@ -12,6 +12,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/CosmWasm/wasmd/app" "github.com/CosmWasm/wasmd/tests/e2e" "github.com/CosmWasm/wasmd/x/wasm/ibctesting" ) @@ -41,7 +42,8 @@ func TestGovVoteByContract(t *testing.T) { e2e.MustExecViaReflectContract(t, chain, contractAddr, delegateMsg) signer := chain.SenderAccount.GetAddress().String() - govKeeper, accountKeeper := chain.App.GovKeeper, chain.App.AccountKeeper + app := chain.App.(*app.WasmApp) + govKeeper, accountKeeper := app.GovKeeper, app.AccountKeeper communityPoolBalance := chain.Balance(accountKeeper.GetModuleAccount(chain.GetContext(), distributiontypes.ModuleName).GetAddress(), sdk.DefaultBondDenom) require.False(t, communityPoolBalance.IsZero()) diff --git a/tests/e2e/ibc_fees_test.go b/tests/e2e/ibc_fees_test.go index 30fcb5b9..8c8ca4fa 100644 --- a/tests/e2e/ibc_fees_test.go +++ b/tests/e2e/ibc_fees_test.go @@ -52,7 +52,8 @@ func TestIBCFeesTransfer(t *testing.T) { } // with an ics-20 transfer channel setup between both chains coord.Setup(path) - require.True(t, chainA.App.IBCFeeKeeper.IsFeeEnabled(chainA.GetContext(), ibctransfertypes.PortID, path.EndpointA.ChannelID)) + appA := chainA.App.(*app.WasmApp) + require.True(t, appA.IBCFeeKeeper.IsFeeEnabled(chainA.GetContext(), ibctransfertypes.PortID, path.EndpointA.ChannelID)) // and with a payee registered on both chains _, err := chainA.SendMsgs(ibcfee.NewMsgRegisterPayee(ibctransfertypes.PortID, path.EndpointA.ChannelID, actorChainA.String(), payee.String())) require.NoError(t, err) @@ -66,7 +67,7 @@ func TestIBCFeesTransfer(t *testing.T) { feeMsg := ibcfee.NewMsgPayPacketFee(ibcPackageFee, ibctransfertypes.PortID, path.EndpointA.ChannelID, actorChainA.String(), nil) _, err = chainA.SendMsgs(feeMsg, ibcPayloadMsg) require.NoError(t, err) - pendingIncentivisedPackages := chainA.App.IBCFeeKeeper.GetIdentifiedPacketFeesForChannel(chainA.GetContext(), ibctransfertypes.PortID, path.EndpointA.ChannelID) + pendingIncentivisedPackages := appA.IBCFeeKeeper.GetIdentifiedPacketFeesForChannel(chainA.GetContext(), ibctransfertypes.PortID, path.EndpointA.ChannelID) assert.Len(t, pendingIncentivisedPackages, 1) // and packages relayed @@ -91,7 +92,8 @@ func TestIBCFeesTransfer(t *testing.T) { feeMsg = ibcfee.NewMsgPayPacketFee(ibcPackageFee, ibctransfertypes.PortID, path.EndpointB.ChannelID, actorChainB.String(), nil) _, err = chainB.SendMsgs(feeMsg, ibcPayloadMsg) require.NoError(t, err) - pendingIncentivisedPackages = chainB.App.IBCFeeKeeper.GetIdentifiedPacketFeesForChannel(chainB.GetContext(), ibctransfertypes.PortID, path.EndpointB.ChannelID) + appB := chainB.App.(*app.WasmApp) + pendingIncentivisedPackages = appB.IBCFeeKeeper.GetIdentifiedPacketFeesForChannel(chainB.GetContext(), ibctransfertypes.PortID, path.EndpointB.ChannelID) assert.Len(t, pendingIncentivisedPackages, 1) // when packages relayed @@ -145,8 +147,10 @@ func TestIBCFeesWasm(t *testing.T) { } // with an ics-29 fee enabled channel setup between both chains coord.Setup(path) - require.True(t, chainA.App.IBCFeeKeeper.IsFeeEnabled(chainA.GetContext(), ibcContractPortID, path.EndpointA.ChannelID)) - require.True(t, chainB.App.IBCFeeKeeper.IsFeeEnabled(chainB.GetContext(), ibctransfertypes.PortID, path.EndpointB.ChannelID)) + appA := chainA.App.(*app.WasmApp) + appB := chainB.App.(*app.WasmApp) + require.True(t, appA.IBCFeeKeeper.IsFeeEnabled(chainA.GetContext(), ibcContractPortID, path.EndpointA.ChannelID)) + require.True(t, appB.IBCFeeKeeper.IsFeeEnabled(chainB.GetContext(), ibctransfertypes.PortID, path.EndpointB.ChannelID)) // and with a payee registered for A -> B _, err := chainA.SendMsgs(ibcfee.NewMsgRegisterPayee(ibcContractPortID, path.EndpointA.ChannelID, actorChainA.String(), payee.String())) require.NoError(t, err) @@ -165,7 +169,7 @@ func TestIBCFeesWasm(t *testing.T) { feeMsg := ibcfee.NewMsgPayPacketFee(ibcPackageFee, ibcContractPortID, path.EndpointA.ChannelID, actorChainA.String(), nil) _, err = chainA.SendMsgs(feeMsg, &execMsg) require.NoError(t, err) - pendingIncentivisedPackages := chainA.App.IBCFeeKeeper.GetIdentifiedPacketFeesForChannel(chainA.GetContext(), ibcContractPortID, path.EndpointA.ChannelID) + pendingIncentivisedPackages := appA.IBCFeeKeeper.GetIdentifiedPacketFeesForChannel(chainA.GetContext(), ibcContractPortID, path.EndpointA.ChannelID) assert.Len(t, pendingIncentivisedPackages, 1) // and packages relayed @@ -173,13 +177,13 @@ func TestIBCFeesWasm(t *testing.T) { // then // on chain A - gotCW20Balance, err := chainA.App.WasmKeeper.QuerySmart(chainA.GetContext(), cw20ContractAddr, []byte(fmt.Sprintf(`{"balance":{"address": %q}}`, actorChainA.String()))) + gotCW20Balance, err := appA.WasmKeeper.QuerySmart(chainA.GetContext(), cw20ContractAddr, []byte(fmt.Sprintf(`{"balance":{"address": %q}}`, actorChainA.String()))) require.NoError(t, err) assert.JSONEq(t, `{"balance":"99999900"}`, string(gotCW20Balance)) payeeBalance := chainA.AllBalances(payee) assert.Equal(t, sdk.NewCoin(sdk.DefaultBondDenom, sdk.NewInt(2)).String(), payeeBalance.String()) // and on chain B - pendingIncentivisedPackages = chainA.App.IBCFeeKeeper.GetIdentifiedPacketFeesForChannel(chainA.GetContext(), ibcContractPortID, path.EndpointA.ChannelID) + pendingIncentivisedPackages = appA.IBCFeeKeeper.GetIdentifiedPacketFeesForChannel(chainA.GetContext(), ibcContractPortID, path.EndpointA.ChannelID) assert.Len(t, pendingIncentivisedPackages, 0) expBalance := ibctransfertypes.GetTransferCoin(path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, "cw20:"+cw20ContractAddr.String(), sdk.NewInt(100)) gotBalance := chainB.Balance(actorChainB, expBalance.Denom) @@ -197,7 +201,7 @@ func TestIBCFeesWasm(t *testing.T) { feeMsg = ibcfee.NewMsgPayPacketFee(ibcPackageFee, ibctransfertypes.PortID, path.EndpointB.ChannelID, actorChainB.String(), nil) _, err = chainB.SendMsgs(feeMsg, ibcPayloadMsg) require.NoError(t, err) - pendingIncentivisedPackages = chainB.App.IBCFeeKeeper.GetIdentifiedPacketFeesForChannel(chainB.GetContext(), ibctransfertypes.PortID, path.EndpointB.ChannelID) + pendingIncentivisedPackages = appB.IBCFeeKeeper.GetIdentifiedPacketFeesForChannel(chainB.GetContext(), ibctransfertypes.PortID, path.EndpointB.ChannelID) assert.Len(t, pendingIncentivisedPackages, 1) // when packages relayed @@ -205,7 +209,7 @@ func TestIBCFeesWasm(t *testing.T) { // then // on chain A - gotCW20Balance, err = chainA.App.WasmKeeper.QuerySmart(chainA.GetContext(), cw20ContractAddr, []byte(fmt.Sprintf(`{"balance":{"address": %q}}`, actorChainA.String()))) + gotCW20Balance, err = appA.WasmKeeper.QuerySmart(chainA.GetContext(), cw20ContractAddr, []byte(fmt.Sprintf(`{"balance":{"address": %q}}`, actorChainA.String()))) require.NoError(t, err) assert.JSONEq(t, `{"balance":"100000000"}`, string(gotCW20Balance)) // and on chain B diff --git a/tests/e2e/ica_test.go b/tests/e2e/ica_test.go index 3ac2fdb8..4234b023 100644 --- a/tests/e2e/ica_test.go +++ b/tests/e2e/ica_test.go @@ -5,6 +5,8 @@ import ( "testing" "time" + "github.com/CosmWasm/wasmd/app" + sdk "github.com/cosmos/cosmos-sdk/types" "github.com/cosmos/cosmos-sdk/types/address" banktypes "github.com/cosmos/cosmos-sdk/x/bank/types" @@ -30,7 +32,8 @@ func TestICA(t *testing.T) { coord := wasmibctesting.NewCoordinator(t, 2) hostChain := coord.GetChain(ibctesting.GetChainID(1)) hostParams := hosttypes.NewParams(true, []string{sdk.MsgTypeURL(&banktypes.MsgSend{})}) - hostChain.App.ICAHostKeeper.SetParams(hostChain.GetContext(), hostParams) + hostApp := hostChain.App.(*app.WasmApp) + hostApp.ICAHostKeeper.SetParams(hostChain.GetContext(), hostParams) controllerChain := coord.GetChain(ibctesting.GetChainID(2)) @@ -58,7 +61,8 @@ func TestICA(t *testing.T) { coord.CreateChannels(path) // assert ICA exists on controller - icaRsp, err := controllerChain.App.ICAControllerKeeper.InterchainAccount(sdk.WrapSDKContext(controllerChain.GetContext()), &icacontrollertypes.QueryInterchainAccountRequest{ + contApp := controllerChain.App.(*app.WasmApp) + icaRsp, err := contApp.ICAControllerKeeper.InterchainAccount(sdk.WrapSDKContext(controllerChain.GetContext()), &icacontrollertypes.QueryInterchainAccountRequest{ Owner: ownerAddr.String(), ConnectionId: path.EndpointA.ConnectionID, }) diff --git a/x/wasm/ibc_integration_test.go b/x/wasm/ibc_integration_test.go index 9a3e383e..97030f44 100644 --- a/x/wasm/ibc_integration_test.go +++ b/x/wasm/ibc_integration_test.go @@ -3,6 +3,8 @@ package wasm_test import ( "testing" + "github.com/CosmWasm/wasmd/app" + sdk "github.com/cosmos/cosmos-sdk/types" "github.com/CosmWasm/wasmd/x/wasm/types" @@ -54,7 +56,8 @@ func TestOnChanOpenInitVersion(t *testing.T) { chainA = coordinator.GetChain(wasmibctesting.GetChainID(1)) chainB = coordinator.GetChain(wasmibctesting.GetChainID(2)) myContractAddr = chainA.SeedNewContractInstance() - contractInfo = chainA.App.WasmKeeper.GetContractInfo(chainA.GetContext(), myContractAddr) + appA = chainA.App.(*app.WasmApp) + contractInfo = appA.WasmKeeper.GetContractInfo(chainA.GetContext(), myContractAddr) ) path := wasmibctesting.NewPath(chainA, chainB) diff --git a/x/wasm/ibctesting/chain.go b/x/wasm/ibctesting/chain.go index 0862de17..cfa540ce 100644 --- a/x/wasm/ibctesting/chain.go +++ b/x/wasm/ibctesting/chain.go @@ -5,6 +5,13 @@ import ( "testing" "time" + "github.com/cosmos/cosmos-sdk/baseapp" + storetypes "github.com/cosmos/cosmos-sdk/store/types" + authkeeper "github.com/cosmos/cosmos-sdk/x/auth/keeper" + bankkeeper "github.com/cosmos/cosmos-sdk/x/bank/keeper" + stakingkeeper "github.com/cosmos/cosmos-sdk/x/staking/keeper" + ibckeeper "github.com/cosmos/ibc-go/v7/modules/core/keeper" + errorsmod "cosmossdk.io/errors" "github.com/cosmos/cosmos-sdk/client" @@ -48,6 +55,24 @@ type SenderAccount struct { SenderAccount authtypes.AccountI } +// ChainApp Abstract chain app definition used for testing +type ChainApp interface { + abci.Application + AppCodec() codec.Codec + NewContext(isCheckTx bool, header tmproto.Header) sdk.Context + LastBlockHeight() int64 + LastCommitID() storetypes.CommitID + GetBaseApp() *baseapp.BaseApp + + TxConfig() client.TxConfig + GetScopedIBCKeeper() capabilitykeeper.ScopedKeeper + GetIBCKeeper() *ibckeeper.Keeper + GetBankKeeper() bankkeeper.Keeper + GetStakingKeeper() *stakingkeeper.Keeper + GetAccountKeeper() authkeeper.AccountKeeper + GetWasmKeeper() wasm.Keeper +} + // TestChain is a testing struct that wraps a simapp with the last TM Header, the current ABCI // header and the validators of the TestChain. It also contains a field called ChainID. This // is the clientID that *other* chains use to refer to this TestChain. The SenderAccount @@ -57,7 +82,7 @@ type TestChain struct { t *testing.T Coordinator *Coordinator - App *app.WasmApp + App ChainApp ChainID string LastHeader *ibctm.Header // header for last block height committed CurrentHeader tmproto.Header // header for current block height @@ -88,9 +113,23 @@ type PacketAck struct { Ack []byte } +// ChainAppFactory abstract factory method that usually implemented by app.SetupWithGenesisValSet +type ChainAppFactory func(t *testing.T, valSet *tmtypes.ValidatorSet, genAccs []authtypes.GenesisAccount, chainID string, opts []wasm.Option, balances ...banktypes.Balance) ChainApp + +// DefaultWasmAppFactory instantiates and sets up the default wasmd app +func DefaultWasmAppFactory(t *testing.T, valSet *tmtypes.ValidatorSet, genAccs []authtypes.GenesisAccount, chainID string, opts []wasm.Option, balances ...banktypes.Balance) ChainApp { + return app.SetupWithGenesisValSet(t, valSet, genAccs, chainID, opts, balances...) +} + +// NewDefaultTestChain initializes a new test chain with a default of 4 validators +// Use this function if the tests do not need custom control over the validator set +func NewDefaultTestChain(t *testing.T, coord *Coordinator, chainID string, opts ...wasm.Option) *TestChain { + return NewTestChain(t, coord, DefaultWasmAppFactory, chainID, opts...) +} + // NewTestChain initializes a new test chain with a default of 4 validators // Use this function if the tests do not need custom control over the validator set -func NewTestChain(t *testing.T, coord *Coordinator, chainID string, opts ...wasm.Option) *TestChain { +func NewTestChain(t *testing.T, coord *Coordinator, appFactory ChainAppFactory, chainID string, opts ...wasm.Option) *TestChain { // generate validators private/public key var ( validatorsPerChain = 4 @@ -111,7 +150,7 @@ func NewTestChain(t *testing.T, coord *Coordinator, chainID string, opts ...wasm // or, if equal, by address lexical order valSet := tmtypes.NewValidatorSet(validators) - return NewTestChainWithValSet(t, coord, chainID, valSet, signersByAddress, opts...) + return NewTestChainWithValSet(t, coord, appFactory, chainID, valSet, signersByAddress, opts...) } // NewTestChainWithValSet initializes a new TestChain instance with the given validator set @@ -129,7 +168,7 @@ func NewTestChain(t *testing.T, coord *Coordinator, chainID string, opts ...wasm // // CONTRACT: Validator array must be provided in the order expected by Tendermint. // i.e. sorted first by power and then lexicographically by address. -func NewTestChainWithValSet(t *testing.T, coord *Coordinator, chainID string, valSet *tmtypes.ValidatorSet, signers map[string]tmtypes.PrivValidator, opts ...wasm.Option) *TestChain { +func NewTestChainWithValSet(t *testing.T, coord *Coordinator, appFactory ChainAppFactory, chainID string, valSet *tmtypes.ValidatorSet, signers map[string]tmtypes.PrivValidator, opts ...wasm.Option) *TestChain { genAccs := []authtypes.GenesisAccount{} genBals := []banktypes.Balance{} senderAccs := []SenderAccount{} @@ -158,7 +197,7 @@ func NewTestChainWithValSet(t *testing.T, coord *Coordinator, chainID string, va senderAccs = append(senderAccs, senderAcc) } - wasmApp := app.SetupWithGenesisValSet(t, valSet, genAccs, chainID, opts, genBals...) + wasmApp := appFactory(t, valSet, genAccs, chainID, opts, genBals...) // create current header and call begin block header := tmproto.Header{ @@ -176,7 +215,7 @@ func NewTestChainWithValSet(t *testing.T, coord *Coordinator, chainID string, va ChainID: chainID, App: wasmApp, CurrentHeader: header, - QueryServer: wasmApp.IBCKeeper, + QueryServer: wasmApp.GetIBCKeeper(), TxConfig: txConfig, Codec: wasmApp.AppCodec(), Vals: valSet, @@ -194,7 +233,7 @@ func NewTestChainWithValSet(t *testing.T, coord *Coordinator, chainID string, va // GetContext returns the current context for the application. func (chain *TestChain) GetContext() sdk.Context { - return chain.App.BaseApp.NewContext(false, chain.CurrentHeader) + return chain.App.NewContext(false, chain.CurrentHeader) } // QueryProof performs an abci query with the given key and returns the proto encoded merkle proof @@ -322,7 +361,7 @@ func (chain *TestChain) SendMsgs(msgs ...sdk.Msg) (*sdk.Result, error) { _, r, err := app.SignAndDeliverWithoutCommit( chain.t, chain.TxConfig, - chain.App.BaseApp, + chain.App.GetBaseApp(), chain.GetContext().BlockHeader(), msgs, chain.ChainID, @@ -351,7 +390,7 @@ func (chain *TestChain) SendMsgs(msgs ...sdk.Msg) (*sdk.Result, error) { } func (chain *TestChain) CaptureIBCEvents(r *sdk.Result) { - toSend := getSendPackets(r.Events) + toSend := GetSendPackets(r.Events) if len(toSend) > 0 { // Keep a queue on the chain that we can relay in tests chain.PendingSendPackets = append(chain.PendingSendPackets, toSend...) @@ -361,7 +400,7 @@ func (chain *TestChain) CaptureIBCEvents(r *sdk.Result) { // GetClientState retrieves the client state for the provided clientID. The client is // expected to exist otherwise testing will fail. func (chain *TestChain) GetClientState(clientID string) exported.ClientState { - clientState, found := chain.App.IBCKeeper.ClientKeeper.GetClientState(chain.GetContext(), clientID) + clientState, found := chain.App.GetIBCKeeper().ClientKeeper.GetClientState(chain.GetContext(), clientID) require.True(chain.t, found) return clientState @@ -370,13 +409,13 @@ func (chain *TestChain) GetClientState(clientID string) exported.ClientState { // GetConsensusState retrieves the consensus state for the provided clientID and height. // It will return a success boolean depending on if consensus state exists or not. func (chain *TestChain) GetConsensusState(clientID string, height exported.Height) (exported.ConsensusState, bool) { - return chain.App.IBCKeeper.ClientKeeper.GetClientConsensusState(chain.GetContext(), clientID, height) + return chain.App.GetIBCKeeper().ClientKeeper.GetClientConsensusState(chain.GetContext(), clientID, height) } // GetValsAtHeight will return the validator set of the chain at a given height. It will return // a success boolean depending on if the validator set exists or not at that height. func (chain *TestChain) GetValsAtHeight(height int64) (*tmtypes.ValidatorSet, bool) { - histInfo, ok := chain.App.StakingKeeper.GetHistoricalInfo(chain.GetContext(), height) + histInfo, ok := chain.App.GetStakingKeeper().GetHistoricalInfo(chain.GetContext(), height) if !ok { return nil, false } @@ -393,7 +432,7 @@ func (chain *TestChain) GetValsAtHeight(height int64) (*tmtypes.ValidatorSet, bo // GetAcknowledgement retrieves an acknowledgement for the provided packet. If the // acknowledgement does not exist then testing will fail. func (chain *TestChain) GetAcknowledgement(packet exported.PacketI) []byte { - ack, found := chain.App.IBCKeeper.ChannelKeeper.GetPacketAcknowledgement(chain.GetContext(), packet.GetDestPort(), packet.GetDestChannel(), packet.GetSequence()) + ack, found := chain.App.GetIBCKeeper().ChannelKeeper.GetPacketAcknowledgement(chain.GetContext(), packet.GetDestPort(), packet.GetDestChannel(), packet.GetSequence()) require.True(chain.t, found) return ack @@ -401,7 +440,7 @@ func (chain *TestChain) GetAcknowledgement(packet exported.PacketI) []byte { // GetPrefix returns the prefix for used by a chain in connection creation func (chain *TestChain) GetPrefix() commitmenttypes.MerklePrefix { - return commitmenttypes.NewMerklePrefix(chain.App.IBCKeeper.ConnectionKeeper.GetCommitmentPrefix().Bytes()) + return commitmenttypes.NewMerklePrefix(chain.App.GetIBCKeeper().ConnectionKeeper.GetCommitmentPrefix().Bytes()) } // ConstructUpdateTMClientHeader will construct a valid 07-tendermint Header to update the @@ -547,10 +586,10 @@ func MakeBlockID(hash []byte, partSetSize uint32, partSetHash []byte) tmtypes.Bl // Other applications must bind to the port in InitGenesis or modify this code. func (chain *TestChain) CreatePortCapability(scopedKeeper capabilitykeeper.ScopedKeeper, portID string) { // check if the portId is already binded, if not bind it - _, ok := chain.App.ScopedIBCKeeper.GetCapability(chain.GetContext(), host.PortPath(portID)) + _, ok := chain.App.GetScopedIBCKeeper().GetCapability(chain.GetContext(), host.PortPath(portID)) if !ok { // create capability using the IBC capability keeper - portCap, err := chain.App.ScopedIBCKeeper.NewCapability(chain.GetContext(), host.PortPath(portID)) + portCap, err := chain.App.GetScopedIBCKeeper().NewCapability(chain.GetContext(), host.PortPath(portID)) require.NoError(chain.t, err) // claim capability using the scopedKeeper @@ -564,7 +603,7 @@ func (chain *TestChain) CreatePortCapability(scopedKeeper capabilitykeeper.Scope // GetPortCapability returns the port capability for the given portID. The capability must // exist, otherwise testing will fail. func (chain *TestChain) GetPortCapability(portID string) *capabilitytypes.Capability { - portCap, ok := chain.App.ScopedIBCKeeper.GetCapability(chain.GetContext(), host.PortPath(portID)) + portCap, ok := chain.App.GetScopedIBCKeeper().GetCapability(chain.GetContext(), host.PortPath(portID)) require.True(chain.t, ok) return portCap @@ -576,9 +615,9 @@ func (chain *TestChain) GetPortCapability(portID string) *capabilitytypes.Capabi func (chain *TestChain) CreateChannelCapability(scopedKeeper capabilitykeeper.ScopedKeeper, portID, channelID string) { capName := host.ChannelCapabilityPath(portID, channelID) // check if the portId is already binded, if not bind it - _, ok := chain.App.ScopedIBCKeeper.GetCapability(chain.GetContext(), capName) + _, ok := chain.App.GetScopedIBCKeeper().GetCapability(chain.GetContext(), capName) if !ok { - portCap, err := chain.App.ScopedIBCKeeper.NewCapability(chain.GetContext(), capName) + portCap, err := chain.App.GetScopedIBCKeeper().NewCapability(chain.GetContext(), capName) require.NoError(chain.t, err) err = scopedKeeper.ClaimCapability(chain.GetContext(), portCap, capName) require.NoError(chain.t, err) @@ -590,7 +629,7 @@ func (chain *TestChain) CreateChannelCapability(scopedKeeper capabilitykeeper.Sc // GetChannelCapability returns the channel capability for the given portID and channelID. // The capability must exist, otherwise testing will fail. func (chain *TestChain) GetChannelCapability(portID, channelID string) *capabilitytypes.Capability { - chanCap, ok := chain.App.ScopedIBCKeeper.GetCapability(chain.GetContext(), host.ChannelCapabilityPath(portID, channelID)) + chanCap, ok := chain.App.GetScopedIBCKeeper().GetCapability(chain.GetContext(), host.ChannelCapabilityPath(portID, channelID)) require.True(chain.t, ok) return chanCap @@ -603,9 +642,9 @@ func (chain *TestChain) GetTimeoutHeight() clienttypes.Height { } func (chain *TestChain) Balance(acc sdk.AccAddress, denom string) sdk.Coin { - return chain.App.BankKeeper.GetBalance(chain.GetContext(), acc, denom) + return chain.App.GetBankKeeper().GetBalance(chain.GetContext(), acc, denom) } func (chain *TestChain) AllBalances(acc sdk.AccAddress) sdk.Coins { - return chain.App.BankKeeper.GetAllBalances(chain.GetContext(), acc) + return chain.App.GetBankKeeper().GetAllBalances(chain.GetContext(), acc) } diff --git a/x/wasm/ibctesting/coordinator.go b/x/wasm/ibctesting/coordinator.go index 557db830..056bb8c9 100644 --- a/x/wasm/ibctesting/coordinator.go +++ b/x/wasm/ibctesting/coordinator.go @@ -42,7 +42,7 @@ func NewCoordinator(t *testing.T, n int, opts ...[]wasmkeeper.Option) *Coordinat if len(opts) > (i - 1) { x = opts[i-1] } - chains[chainID] = NewTestChain(t, coord, chainID, x...) + chains[chainID] = NewDefaultTestChain(t, coord, chainID, x...) } coord.Chains = chains diff --git a/x/wasm/ibctesting/endpoint.go b/x/wasm/ibctesting/endpoint.go index 4eb47baf..726aca22 100644 --- a/x/wasm/ibctesting/endpoint.go +++ b/x/wasm/ibctesting/endpoint.go @@ -181,7 +181,7 @@ func (endpoint *Endpoint) UpgradeChain() error { } // update chain - baseapp.SetChainID(newChainID)(endpoint.Chain.App.BaseApp) + baseapp.SetChainID(newChainID)(endpoint.Chain.App.GetBaseApp()) endpoint.Chain.ChainID = newChainID endpoint.Chain.CurrentHeader.ChainID = newChainID endpoint.Chain.NextBlock() // commit changes @@ -446,7 +446,7 @@ func (endpoint *Endpoint) SendPacket( channelCap := endpoint.Chain.GetChannelCapability(endpoint.ChannelConfig.PortID, endpoint.ChannelID) // no need to send message, acting as a module - sequence, err := endpoint.Chain.App.IBCKeeper.ChannelKeeper.SendPacket(endpoint.Chain.GetContext(), channelCap, endpoint.ChannelConfig.PortID, endpoint.ChannelID, timeoutHeight, timeoutTimestamp, data) + sequence, err := endpoint.Chain.App.GetIBCKeeper().ChannelKeeper.SendPacket(endpoint.Chain.GetContext(), channelCap, endpoint.ChannelConfig.PortID, endpoint.ChannelID, timeoutHeight, timeoutTimestamp, data) if err != nil { return 0, err } @@ -501,7 +501,7 @@ func (endpoint *Endpoint) WriteAcknowledgement(ack exported.Acknowledgement, pac channelCap := endpoint.Chain.GetChannelCapability(packet.GetDestPort(), packet.GetDestChannel()) // no need to send message, acting as a handler - err := endpoint.Chain.App.IBCKeeper.ChannelKeeper.WriteAcknowledgement(endpoint.Chain.GetContext(), channelCap, packet, ack) + err := endpoint.Chain.App.GetIBCKeeper().ChannelKeeper.WriteAcknowledgement(endpoint.Chain.GetContext(), channelCap, packet, ack) if err != nil { return err } @@ -538,7 +538,7 @@ func (endpoint *Endpoint) TimeoutPacket(packet channeltypes.Packet) error { } proof, proofHeight := endpoint.Counterparty.QueryProof(packetKey) - nextSeqRecv, found := endpoint.Counterparty.Chain.App.IBCKeeper.ChannelKeeper.GetNextSequenceRecv(endpoint.Counterparty.Chain.GetContext(), endpoint.ChannelConfig.PortID, endpoint.ChannelID) + nextSeqRecv, found := endpoint.Counterparty.Chain.App.GetIBCKeeper().ChannelKeeper.GetNextSequenceRecv(endpoint.Counterparty.Chain.GetContext(), endpoint.ChannelConfig.PortID, endpoint.ChannelID) require.True(endpoint.Chain.t, found) timeoutMsg := channeltypes.NewMsgTimeout( @@ -568,7 +568,7 @@ func (endpoint *Endpoint) TimeoutOnClose(packet channeltypes.Packet) error { channelKey := host.ChannelKey(packet.GetDestPort(), packet.GetDestChannel()) proofClosed, _ := endpoint.Counterparty.QueryProof(channelKey) - nextSeqRecv, found := endpoint.Counterparty.Chain.App.IBCKeeper.ChannelKeeper.GetNextSequenceRecv(endpoint.Counterparty.Chain.GetContext(), endpoint.ChannelConfig.PortID, endpoint.ChannelID) + nextSeqRecv, found := endpoint.Counterparty.Chain.App.GetIBCKeeper().ChannelKeeper.GetNextSequenceRecv(endpoint.Counterparty.Chain.GetContext(), endpoint.ChannelConfig.PortID, endpoint.ChannelID) require.True(endpoint.Chain.t, found) timeoutOnCloseMsg := channeltypes.NewMsgTimeoutOnClose( @@ -584,7 +584,7 @@ func (endpoint *Endpoint) SetChannelClosed() error { channel := endpoint.GetChannel() channel.State = channeltypes.CLOSED - endpoint.Chain.App.IBCKeeper.ChannelKeeper.SetChannel(endpoint.Chain.GetContext(), endpoint.ChannelConfig.PortID, endpoint.ChannelID, channel) + endpoint.Chain.App.GetIBCKeeper().ChannelKeeper.SetChannel(endpoint.Chain.GetContext(), endpoint.ChannelConfig.PortID, endpoint.ChannelID, channel) endpoint.Chain.Coordinator.CommitBlock(endpoint.Chain) @@ -599,7 +599,7 @@ func (endpoint *Endpoint) GetClientState() exported.ClientState { // SetClientState sets the client state for this endpoint. func (endpoint *Endpoint) SetClientState(clientState exported.ClientState) { - endpoint.Chain.App.IBCKeeper.ClientKeeper.SetClientState(endpoint.Chain.GetContext(), endpoint.ClientID, clientState) + endpoint.Chain.App.GetIBCKeeper().ClientKeeper.SetClientState(endpoint.Chain.GetContext(), endpoint.ClientID, clientState) } // GetConsensusState retrieves the Consensus State for this endpoint at the provided height. @@ -613,13 +613,13 @@ func (endpoint *Endpoint) GetConsensusState(height exported.Height) exported.Con // SetConsensusState sets the consensus state for this endpoint. func (endpoint *Endpoint) SetConsensusState(consensusState exported.ConsensusState, height exported.Height) { - endpoint.Chain.App.IBCKeeper.ClientKeeper.SetClientConsensusState(endpoint.Chain.GetContext(), endpoint.ClientID, height, consensusState) + endpoint.Chain.App.GetIBCKeeper().ClientKeeper.SetClientConsensusState(endpoint.Chain.GetContext(), endpoint.ClientID, height, consensusState) } // GetConnection retrieves an IBC Connection for the endpoint. The // connection is expected to exist otherwise testing will fail. func (endpoint *Endpoint) GetConnection() connectiontypes.ConnectionEnd { - connection, found := endpoint.Chain.App.IBCKeeper.ConnectionKeeper.GetConnection(endpoint.Chain.GetContext(), endpoint.ConnectionID) + connection, found := endpoint.Chain.App.GetIBCKeeper().ConnectionKeeper.GetConnection(endpoint.Chain.GetContext(), endpoint.ConnectionID) require.True(endpoint.Chain.t, found) return connection @@ -627,13 +627,13 @@ func (endpoint *Endpoint) GetConnection() connectiontypes.ConnectionEnd { // SetConnection sets the connection for this endpoint. func (endpoint *Endpoint) SetConnection(connection connectiontypes.ConnectionEnd) { - endpoint.Chain.App.IBCKeeper.ConnectionKeeper.SetConnection(endpoint.Chain.GetContext(), endpoint.ConnectionID, connection) + endpoint.Chain.App.GetIBCKeeper().ConnectionKeeper.SetConnection(endpoint.Chain.GetContext(), endpoint.ConnectionID, connection) } // GetChannel retrieves an IBC Channel for the endpoint. The channel // is expected to exist otherwise testing will fail. func (endpoint *Endpoint) GetChannel() channeltypes.Channel { - channel, found := endpoint.Chain.App.IBCKeeper.ChannelKeeper.GetChannel(endpoint.Chain.GetContext(), endpoint.ChannelConfig.PortID, endpoint.ChannelID) + channel, found := endpoint.Chain.App.GetIBCKeeper().ChannelKeeper.GetChannel(endpoint.Chain.GetContext(), endpoint.ChannelConfig.PortID, endpoint.ChannelID) require.True(endpoint.Chain.t, found) return channel @@ -641,7 +641,7 @@ func (endpoint *Endpoint) GetChannel() channeltypes.Channel { // SetChannel sets the channel for this endpoint. func (endpoint *Endpoint) SetChannel(channel channeltypes.Channel) { - endpoint.Chain.App.IBCKeeper.ChannelKeeper.SetChannel(endpoint.Chain.GetContext(), endpoint.ChannelConfig.PortID, endpoint.ChannelID, channel) + endpoint.Chain.App.GetIBCKeeper().ChannelKeeper.SetChannel(endpoint.Chain.GetContext(), endpoint.ChannelConfig.PortID, endpoint.ChannelID, channel) } // QueryClientStateProof performs and abci query for a client stat associated diff --git a/x/wasm/ibctesting/event_utils.go b/x/wasm/ibctesting/event_utils.go index a3769014..3ca72f8b 100644 --- a/x/wasm/ibctesting/event_utils.go +++ b/x/wasm/ibctesting/event_utils.go @@ -13,11 +13,11 @@ import ( channeltypes "github.com/cosmos/ibc-go/v7/modules/core/04-channel/types" ) -func getSendPackets(evts []abci.Event) []channeltypes.Packet { +func GetSendPackets(evts []abci.Event) []channeltypes.Packet { var res []channeltypes.Packet for _, evt := range evts { if evt.Type == channeltypes.EventTypeSendPacket { - packet := parsePacketFromEvent(evt) + packet := ParsePacketFromEvent(evt) res = append(res, packet) } } @@ -32,7 +32,7 @@ func getSendPackets(evts []abci.Event) []channeltypes.Packet { // } //} -func parsePacketFromEvent(evt abci.Event) channeltypes.Packet { +func ParsePacketFromEvent(evt abci.Event) channeltypes.Packet { return channeltypes.Packet{ Sequence: getUintField(evt, channeltypes.AttributeKeySequence), SourcePort: getField(evt, channeltypes.AttributeKeySrcPort), diff --git a/x/wasm/ibctesting/faucet.go b/x/wasm/ibctesting/faucet.go index 4de2c4e0..19b9a947 100644 --- a/x/wasm/ibctesting/faucet.go +++ b/x/wasm/ibctesting/faucet.go @@ -28,12 +28,12 @@ func (chain *TestChain) SendNonDefaultSenderMsgs(senderPrivKey cryptotypes.PrivK chain.Coordinator.UpdateTimeForChain(chain) addr := sdk.AccAddress(senderPrivKey.PubKey().Address().Bytes()) - account := chain.App.AccountKeeper.GetAccount(chain.GetContext(), addr) + account := chain.App.GetAccountKeeper().GetAccount(chain.GetContext(), addr) require.NotNil(chain.t, account) _, r, err := app.SignAndDeliverWithoutCommit( chain.t, chain.TxConfig, - chain.App.BaseApp, + chain.App.GetBaseApp(), chain.GetContext().BlockHeader(), msgs, chain.ChainID, diff --git a/x/wasm/ibctesting/path.go b/x/wasm/ibctesting/path.go index c7f1af8f..5c39e361 100644 --- a/x/wasm/ibctesting/path.go +++ b/x/wasm/ibctesting/path.go @@ -41,7 +41,7 @@ func (path *Path) SetChannelOrdered() { // if EndpointA does not contain a packet commitment for that packet. An error is returned // if a relay step fails or the packet commitment does not exist on either endpoint. func (path *Path) RelayPacket(packet channeltypes.Packet, _ []byte) error { - pc := path.EndpointA.Chain.App.IBCKeeper.ChannelKeeper.GetPacketCommitment(path.EndpointA.Chain.GetContext(), packet.GetSourcePort(), packet.GetSourceChannel(), packet.GetSequence()) + pc := path.EndpointA.Chain.App.GetIBCKeeper().ChannelKeeper.GetPacketCommitment(path.EndpointA.Chain.GetContext(), packet.GetSourcePort(), packet.GetSourceChannel(), packet.GetSequence()) if bytes.Equal(pc, channeltypes.CommitPacket(path.EndpointA.Chain.App.AppCodec(), packet)) { // packet found, relay from A to B @@ -64,7 +64,7 @@ func (path *Path) RelayPacket(packet channeltypes.Packet, _ []byte) error { return err } - pc = path.EndpointB.Chain.App.IBCKeeper.ChannelKeeper.GetPacketCommitment(path.EndpointB.Chain.GetContext(), packet.GetSourcePort(), packet.GetSourceChannel(), packet.GetSequence()) + pc = path.EndpointB.Chain.App.GetIBCKeeper().ChannelKeeper.GetPacketCommitment(path.EndpointB.Chain.GetContext(), packet.GetSourcePort(), packet.GetSourceChannel(), packet.GetSequence()) if bytes.Equal(pc, channeltypes.CommitPacket(path.EndpointB.Chain.App.AppCodec(), packet)) { // packet found, relay B to A diff --git a/x/wasm/ibctesting/wasm.go b/x/wasm/ibctesting/wasm.go index 53bd7dc5..1e6a300d 100644 --- a/x/wasm/ibctesting/wasm.go +++ b/x/wasm/ibctesting/wasm.go @@ -125,5 +125,5 @@ func (chain *TestChain) SmartQuery(contractAddr string, queryMsg interface{}, re // ContractInfo is a helper function to returns the ContractInfo for the given contract address func (chain *TestChain) ContractInfo(contractAddr sdk.AccAddress) *types.ContractInfo { - return chain.App.WasmKeeper.GetContractInfo(chain.GetContext(), contractAddr) + return chain.App.GetWasmKeeper().GetContractInfo(chain.GetContext(), contractAddr) } diff --git a/x/wasm/relay_pingpong_test.go b/x/wasm/relay_pingpong_test.go index ff52de3d..4ae022b1 100644 --- a/x/wasm/relay_pingpong_test.go +++ b/x/wasm/relay_pingpong_test.go @@ -5,6 +5,8 @@ import ( "fmt" "testing" + app2 "github.com/CosmWasm/wasmd/app" + ibctransfertypes "github.com/cosmos/ibc-go/v7/modules/apps/transfer/types" ibctesting "github.com/cosmos/ibc-go/v7/testing" @@ -306,7 +308,8 @@ func (p player) incrementCounter(key []byte, store wasmvm.KVStore) uint64 { } func (p player) QueryState(key []byte) uint64 { - raw := p.chain.App.WasmKeeper.QueryRaw(p.chain.GetContext(), p.contractAddr, key) + app := p.chain.App.(*app2.WasmApp) + raw := app.WasmKeeper.QueryRaw(p.chain.GetContext(), p.contractAddr, key) return sdk.BigEndianToUint64(raw) } diff --git a/x/wasm/relay_test.go b/x/wasm/relay_test.go index 1017522d..81698f5a 100644 --- a/x/wasm/relay_test.go +++ b/x/wasm/relay_test.go @@ -6,6 +6,8 @@ import ( "testing" "time" + "github.com/CosmWasm/wasmd/app" + errorsmod "cosmossdk.io/errors" "cosmossdk.io/math" wasmvm "github.com/CosmWasm/wasmvm" @@ -20,7 +22,7 @@ import ( wasmibctesting "github.com/CosmWasm/wasmd/x/wasm/ibctesting" wasmkeeper "github.com/CosmWasm/wasmd/x/wasm/keeper" - wasmtesting "github.com/CosmWasm/wasmd/x/wasm/keeper/wasmtesting" + "github.com/CosmWasm/wasmd/x/wasm/keeper/wasmtesting" "github.com/CosmWasm/wasmd/x/wasm/types" ) @@ -210,7 +212,7 @@ func TestContractCanInitiateIBCTransferMsg(t *testing.T) { require.Equal(t, 0, len(chainB.PendingSendPackets)) // and dest chain balance contains voucher - bankKeeperB := chainB.App.BankKeeper + bankKeeperB := chainB.App.(*app.WasmApp).BankKeeper expBalance := ibctransfertypes.GetTransferCoin(path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, coinToSendToB.Denom, coinToSendToB.Amount) gotBalance := chainB.Balance(chainB.SenderAccount.GetAddress(), expBalance.Denom) assert.Equal(t, expBalance, gotBalance, "got total balance: %s", bankKeeperB.GetAllBalances(chainB.GetContext(), chainB.SenderAccount.GetAddress())) @@ -285,7 +287,7 @@ func TestContractCanEmulateIBCTransferMessage(t *testing.T) { require.Equal(t, 0, len(chainB.PendingSendPackets)) // and dest chain balance contains voucher - bankKeeperB := chainB.App.BankKeeper + bankKeeperB := chainB.App.(*app.WasmApp).BankKeeper expBalance := ibctransfertypes.GetTransferCoin(path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, coinToSendToB.Denom, coinToSendToB.Amount) gotBalance := chainB.Balance(chainB.SenderAccount.GetAddress(), expBalance.Denom) assert.Equal(t, expBalance, gotBalance, "got total balance: %s", bankKeeperB.GetAllBalances(chainB.GetContext(), chainB.SenderAccount.GetAddress())) @@ -712,7 +714,7 @@ func (c *ackReceiverContract) IBCPacketReceive(_ wasmvm.Checksum, _ wasmvmtypes. // call original ibctransfer keeper to not copy all code into this ibcPacket := toIBCPacket(packet) ctx := c.chain.GetContext() // HACK: please note that this is not reverted after checkTX - err := c.chain.App.TransferKeeper.OnRecvPacket(ctx, ibcPacket, src) + err := c.chain.App.(*app.WasmApp).TransferKeeper.OnRecvPacket(ctx, ibcPacket, src) if err != nil { return nil, 0, errorsmod.Wrap(err, "within our smart contract") } @@ -737,7 +739,7 @@ func (c *ackReceiverContract) IBCPacketAck(_ wasmvm.Checksum, _ wasmvmtypes.Env, // call original ibctransfer keeper to not copy all code into this ctx := c.chain.GetContext() // HACK: please note that this is not reverted after checkTX ibcPacket := toIBCPacket(msg.OriginalPacket) - err := c.chain.App.TransferKeeper.OnAcknowledgementPacket(ctx, ibcPacket, data, ack) + err := c.chain.App.(*app.WasmApp).TransferKeeper.OnAcknowledgementPacket(ctx, ibcPacket, data, ack) if err != nil { return nil, 0, errorsmod.Wrap(err, "within our smart contract") }