logging: fixes bug where unsampled logger is called from a sampled one (#1369)

We had a logging bug where an unsampled function (such as fd_write to
stdout/stderr) would end up with its logging "after" hook called, if it
was called from a sampled function.

For example, if a wasm function called fd_write with stdout, the before
hook on fd_write would skip, but the after would not, and accidentally
use its caller's parameters. This results in a panic due to incorrect
length.

This fixes the bug by ensuring we mute the logging context if there's
one in progress. It ensures the bug won't pop up again by adding test
data that matches the call pattern (from xpdf-wasm).

Thanks to @jerbob92 for helping isolate this!

Signed-off-by: Adrian Cole <adrian@tetrate.io>
This commit is contained in:
Crypt Keeper
2023-04-17 17:36:15 +01:00
committed by GitHub
parent b9e03dc691
commit 9263bef174
5 changed files with 40 additions and 11 deletions

View File

@@ -80,5 +80,6 @@ func Example_customListenerFactory() {
} }
// Output: // Output:
// wasi_snapshot_preview1.fd_write
// wasi_snapshot_preview1.random_get // wasi_snapshot_preview1.random_get
} }

View File

@@ -162,9 +162,12 @@ func (f *loggingListenerFactory) NewListener(fnd api.FunctionDefinition) experim
type logState struct { type logState struct {
w logging.Writer w logging.Writer
nestLevel int nestLevel int
unsampled bool
params []uint64 params []uint64
} }
var unsampledLogState = &logState{unsampled: true}
// loggingListener implements experimental.FunctionListener to log entrance and after // loggingListener implements experimental.FunctionListener to log entrance and after
// of each function call. // of each function call.
type loggingListener struct { type loggingListener struct {
@@ -178,28 +181,41 @@ type loggingListener struct {
// Before logs to stdout the module and function name, prefixed with '-->' and // Before logs to stdout the module and function name, prefixed with '-->' and
// indented based on the call nesting level. // indented based on the call nesting level.
func (l *loggingListener) Before(ctx context.Context, mod api.Module, _ api.FunctionDefinition, params []uint64) context.Context { func (l *loggingListener) Before(ctx context.Context, mod api.Module, _ api.FunctionDefinition, params []uint64) context.Context {
if s := l.pSampler; s != nil && !s(ctx, mod, params) { // First, see if this invocation is sampled.
return ctx sampled := true
if s := l.pSampler; s != nil {
sampled = s(ctx, mod, params)
} }
// Check to see if the calling function was logging.
var state *logState
var nestLevel int var nestLevel int
if ls := ctx.Value(logging.LoggerKey{}); ls != nil { if v := ctx.Value(logging.LoggerKey{}); v != nil {
nestLevel = ls.(*logState).nestLevel if !sampled { // override to mute this invocation.
return context.WithValue(ctx, logging.LoggerKey{}, unsampledLogState)
} }
state = v.(*logState)
nestLevel = state.nestLevel
} else if !sampled {
return ctx // lack of LoggerKey == not sampled.
}
// We're starting to log: increase the indentation level.
nestLevel++ nestLevel++
l.logIndented(ctx, mod, nestLevel, true, params, nil, nil) l.logIndented(ctx, mod, nestLevel, true, params, nil, nil)
ls := &logState{w: l.w, nestLevel: nestLevel} // We need to propagate this invocation's parameters to the after callback.
state = &logState{w: l.w, nestLevel: nestLevel}
if pLen := len(params); pLen > 0 { if pLen := len(params); pLen > 0 {
ls.params = make([]uint64, pLen) state.params = make([]uint64, pLen)
copy(ls.params, params) // safe copy copy(state.params, params) // safe copy
} else { // empty } else { // empty
ls.params = params state.params = params
} }
// Increase the next nesting level. // Overwrite the logging key with this invocation's state.
return context.WithValue(ctx, logging.LoggerKey{}, ls) return context.WithValue(ctx, logging.LoggerKey{}, state)
} }
// After logs to stdout the module and function name, prefixed with '<--' and // After logs to stdout the module and function name, prefixed with '<--' and
@@ -208,6 +224,9 @@ func (l *loggingListener) After(ctx context.Context, mod api.Module, _ api.Funct
// Note: We use the nest level directly even though it is the "next" nesting level. // Note: We use the nest level directly even though it is the "next" nesting level.
// This works because our indent of zero nesting is one tab. // This works because our indent of zero nesting is one tab.
if state, ok := ctx.Value(logging.LoggerKey{}).(*logState); ok { if state, ok := ctx.Value(logging.LoggerKey{}).(*logState); ok {
if state == unsampledLogState {
return
}
l.logIndented(ctx, mod, state.nestLevel, false, state.params, err, results) l.logIndented(ctx, mod, state.nestLevel, false, state.params, err, results)
} }
} }

Binary file not shown.

View File

@@ -2,12 +2,21 @@
(import "wasi_snapshot_preview1" "random_get" (import "wasi_snapshot_preview1" "random_get"
(func $wasi.random_get (param $buf i32) (param $buf_len i32) (result (;errno;) i32))) (func $wasi.random_get (param $buf i32) (param $buf_len i32) (result (;errno;) i32)))
(import "wasi_snapshot_preview1" "fd_write"
(func $wasi.fd_write (param $fd i32) (param $iovs i32) (param $iovs_len i32) (param $result.size i32) (result (;errno;) i32)))
(table 8 funcref) ;; Define a function table with a single element at index 3. (table 8 funcref) ;; Define a function table with a single element at index 3.
(elem (i32.const 3) $wasi.random_get) (elem (i32.const 3) $wasi.random_get)
(memory 1 1) ;; Memory is needed for WASI (memory 1 1) ;; Memory is needed for WASI
(func $wasi_rand (param $len i32) (func $wasi_rand (param $len i32)
;; call fd_write with an unsampled FD, inside a sampled function.
i32.const 1 ;; $fd = stdout
i32.const 0 i32.const 0 i32.const 0 ;; $iovs, $iovs_len $result.size = 0.
call $wasi.fd_write
drop ;; errno
i32.const 4 local.get 0 ;; buf, buf_len i32.const 4 local.get 0 ;; buf, buf_len
call $wasi.random_get call $wasi.random_get
drop ;; errno drop ;; errno

View File

@@ -269,7 +269,7 @@ type (
funcType *wasm.FunctionType funcType *wasm.FunctionType
// def is the api.Function for this function. Created during compilation. // def is the api.Function for this function. Created during compilation.
def api.FunctionDefinition def api.FunctionDefinition
// parent holds code from which this is crated. // parent holds code from which this is created.
parent *code parent *code
} }