From 51f04f5f605955e81d06e5456c9fcd7be2a1870a Mon Sep 17 00:00:00 2001 From: mleku Date: Tue, 2 Sep 2025 20:32:53 +0100 Subject: [PATCH] implemented event and req --- app/config/config.go | 5 +- app/handle-auth.go | 1 + app/handle-close.go | 1 + app/handle-event.go | 66 ++++ app/handle-message.go | 44 +-- app/handle-req.go | 120 ++++++++ app/handle-websocket.go | 12 +- app/listener.go | 16 +- app/main.go | 6 +- app/ok.go | 118 +++++++ app/publisher.go | 222 ++++++++++++++ app/server.go | 5 + go.mod | 18 ++ go.sum | 37 +++ main.go | 9 +- pkg/database/database.go | 13 +- pkg/database/get-indexes-for-event.go | 11 +- pkg/database/get-indexes-from-filter.go | 44 +-- pkg/database/query-events.go | 2 +- pkg/database/query-for-serials.go | 21 +- pkg/database/save-event.go | 13 +- .../envelopes/reqenvelope/reqenvelope.go | 9 +- pkg/encoders/event/binary.go | 4 +- pkg/encoders/event/canonical.go | 6 +- pkg/encoders/filter/filter.go | 45 +++ pkg/encoders/filter/filters.go | 13 +- pkg/encoders/kind/kind.go | 8 +- pkg/encoders/kind/kinds.go | 2 +- pkg/encoders/reason/reason.go | 52 ++++ pkg/encoders/tag/tag.go | 22 +- pkg/encoders/tag/tags.go | 26 +- pkg/interfaces/go.mod | 25 ++ pkg/interfaces/go.sum | 33 ++ pkg/interfaces/publisher/publisher.go | 14 + pkg/interfaces/typer/typer.go | 10 + pkg/protocol/publish/publisher.go | 38 +++ pkg/utils/atomic/.codecov.yml | 19 ++ pkg/utils/atomic/CHANGELOG.md | 130 ++++++++ pkg/utils/atomic/LICENSE | 19 ++ pkg/utils/atomic/Makefile | 79 +++++ pkg/utils/atomic/README.md | 33 ++ pkg/utils/atomic/assert_test.go | 45 +++ pkg/utils/atomic/bool.go | 88 ++++++ pkg/utils/atomic/bool_ext.go | 53 ++++ pkg/utils/atomic/bool_test.go | 150 +++++++++ pkg/utils/atomic/bytes.go | 59 ++++ pkg/utils/atomic/bytes_ext.go | 56 ++++ pkg/utils/atomic/bytes_test.go | 252 +++++++++++++++ pkg/utils/atomic/doc.go | 23 ++ pkg/utils/atomic/duration.go | 89 ++++++ pkg/utils/atomic/duration_ext.go | 40 +++ pkg/utils/atomic/duration_test.go | 73 +++++ pkg/utils/atomic/error.go | 72 +++++ pkg/utils/atomic/error_ext.go | 39 +++ pkg/utils/atomic/error_test.go | 136 +++++++++ pkg/utils/atomic/example_test.go | 42 +++ pkg/utils/atomic/float32.go | 77 +++++ pkg/utils/atomic/float32_ext.go | 76 +++++ pkg/utils/atomic/float32_test.go | 73 +++++ pkg/utils/atomic/float64.go | 77 +++++ pkg/utils/atomic/float64_ext.go | 76 +++++ pkg/utils/atomic/float64_test.go | 73 +++++ pkg/utils/atomic/gen.go | 27 ++ pkg/utils/atomic/int32.go | 109 +++++++ pkg/utils/atomic/int32_test.go | 82 +++++ pkg/utils/atomic/int64.go | 109 +++++++ pkg/utils/atomic/int64_test.go | 82 +++++ .../atomic/internal/gen-atomicint/main.go | 116 +++++++ .../internal/gen-atomicint/wrapper.tmpl | 117 +++++++ .../atomic/internal/gen-atomicwrapper/main.go | 203 ++++++++++++ .../internal/gen-atomicwrapper/wrapper.tmpl | 120 ++++++++ pkg/utils/atomic/nocmp.go | 35 +++ pkg/utils/atomic/nocmp_test.go | 164 ++++++++++ pkg/utils/atomic/pointer_test.go | 100 ++++++ pkg/utils/atomic/stress_test.go | 289 ++++++++++++++++++ pkg/utils/atomic/string.go | 72 +++++ pkg/utils/atomic/string_ext.go | 54 ++++ pkg/utils/atomic/string_test.go | 170 +++++++++++ pkg/utils/atomic/time.go | 55 ++++ pkg/utils/atomic/time_ext.go | 36 +++ pkg/utils/atomic/time_test.go | 86 ++++++ pkg/utils/atomic/tools/tools.go | 30 ++ pkg/utils/atomic/uint32.go | 109 +++++++ pkg/utils/atomic/uint32_test.go | 77 +++++ pkg/utils/atomic/uint64.go | 109 +++++++ pkg/utils/atomic/uint64_test.go | 77 +++++ pkg/utils/atomic/uintptr.go | 109 +++++++ pkg/utils/atomic/uintptr_test.go | 80 +++++ pkg/utils/atomic/unsafe_pointer.go | 65 ++++ pkg/utils/atomic/unsafe_pointer_test.go | 83 +++++ pkg/utils/atomic/value.go | 31 ++ pkg/utils/atomic/value_test.go | 40 +++ pkg/utils/go.mod | 12 + pkg/utils/go.sum | 41 +++ pkg/utils/interrupt/README.md | 2 + pkg/utils/interrupt/main.go | 153 ++++++++++ pkg/utils/interrupt/restart.go | 26 ++ pkg/utils/interrupt/restart_darwin.go | 20 ++ pkg/utils/interrupt/restart_windows.go | 20 ++ pkg/utils/interrupt/sigterm.go | 12 + pkg/utils/normalize/normalize.go | 19 ++ pkg/utils/qu/README.adoc | 60 ++++ pkg/utils/qu/qu.go | 245 +++++++++++++++ pprof.go | 7 +- 104 files changed, 6368 insertions(+), 125 deletions(-) create mode 100644 app/handle-auth.go create mode 100644 app/handle-close.go create mode 100644 app/handle-event.go create mode 100644 app/handle-req.go create mode 100644 app/ok.go create mode 100644 app/publisher.go create mode 100644 pkg/encoders/reason/reason.go create mode 100644 pkg/interfaces/go.sum create mode 100644 pkg/interfaces/publisher/publisher.go create mode 100644 pkg/interfaces/typer/typer.go create mode 100644 pkg/protocol/publish/publisher.go create mode 100644 pkg/utils/atomic/.codecov.yml create mode 100644 pkg/utils/atomic/CHANGELOG.md create mode 100644 pkg/utils/atomic/LICENSE create mode 100644 pkg/utils/atomic/Makefile create mode 100644 pkg/utils/atomic/README.md create mode 100644 pkg/utils/atomic/assert_test.go create mode 100644 pkg/utils/atomic/bool.go create mode 100644 pkg/utils/atomic/bool_ext.go create mode 100644 pkg/utils/atomic/bool_test.go create mode 100644 pkg/utils/atomic/bytes.go create mode 100644 pkg/utils/atomic/bytes_ext.go create mode 100644 pkg/utils/atomic/bytes_test.go create mode 100644 pkg/utils/atomic/doc.go create mode 100644 pkg/utils/atomic/duration.go create mode 100644 pkg/utils/atomic/duration_ext.go create mode 100644 pkg/utils/atomic/duration_test.go create mode 100644 pkg/utils/atomic/error.go create mode 100644 pkg/utils/atomic/error_ext.go create mode 100644 pkg/utils/atomic/error_test.go create mode 100644 pkg/utils/atomic/example_test.go create mode 100644 pkg/utils/atomic/float32.go create mode 100644 pkg/utils/atomic/float32_ext.go create mode 100644 pkg/utils/atomic/float32_test.go create mode 100644 pkg/utils/atomic/float64.go create mode 100644 pkg/utils/atomic/float64_ext.go create mode 100644 pkg/utils/atomic/float64_test.go create mode 100644 pkg/utils/atomic/gen.go create mode 100644 pkg/utils/atomic/int32.go create mode 100644 pkg/utils/atomic/int32_test.go create mode 100644 pkg/utils/atomic/int64.go create mode 100644 pkg/utils/atomic/int64_test.go create mode 100644 pkg/utils/atomic/internal/gen-atomicint/main.go create mode 100644 pkg/utils/atomic/internal/gen-atomicint/wrapper.tmpl create mode 100644 pkg/utils/atomic/internal/gen-atomicwrapper/main.go create mode 100644 pkg/utils/atomic/internal/gen-atomicwrapper/wrapper.tmpl create mode 100644 pkg/utils/atomic/nocmp.go create mode 100644 pkg/utils/atomic/nocmp_test.go create mode 100644 pkg/utils/atomic/pointer_test.go create mode 100644 pkg/utils/atomic/stress_test.go create mode 100644 pkg/utils/atomic/string.go create mode 100644 pkg/utils/atomic/string_ext.go create mode 100644 pkg/utils/atomic/string_test.go create mode 100644 pkg/utils/atomic/time.go create mode 100644 pkg/utils/atomic/time_ext.go create mode 100644 pkg/utils/atomic/time_test.go create mode 100644 pkg/utils/atomic/tools/tools.go create mode 100644 pkg/utils/atomic/uint32.go create mode 100644 pkg/utils/atomic/uint32_test.go create mode 100644 pkg/utils/atomic/uint64.go create mode 100644 pkg/utils/atomic/uint64_test.go create mode 100644 pkg/utils/atomic/uintptr.go create mode 100644 pkg/utils/atomic/uintptr_test.go create mode 100644 pkg/utils/atomic/unsafe_pointer.go create mode 100644 pkg/utils/atomic/unsafe_pointer_test.go create mode 100644 pkg/utils/atomic/value.go create mode 100644 pkg/utils/atomic/value_test.go create mode 100644 pkg/utils/interrupt/README.md create mode 100644 pkg/utils/interrupt/main.go create mode 100644 pkg/utils/interrupt/restart.go create mode 100644 pkg/utils/interrupt/restart_darwin.go create mode 100644 pkg/utils/interrupt/restart_windows.go create mode 100644 pkg/utils/interrupt/sigterm.go create mode 100644 pkg/utils/qu/README.adoc create mode 100644 pkg/utils/qu/qu.go diff --git a/app/config/config.go b/app/config/config.go index 65cc93c..3ca17af 100644 --- a/app/config/config.go +++ b/app/config/config.go @@ -16,7 +16,6 @@ import ( "go-simpler.org/env" lol "lol.mleku.dev" "lol.mleku.dev/chk" - "lol.mleku.dev/log" "next.orly.dev/pkg/version" ) @@ -28,7 +27,8 @@ type C struct { DataDir string `env:"ORLY_DATA_DIR" usage:"storage location for the event store" default:"~/.local/share/ORLY"` Listen string `env:"ORLY_LISTEN" default:"0.0.0.0" usage:"network listen address"` Port int `env:"ORLY_PORT" default:"3334" usage:"port to listen on"` - LogLevel string `env:"ORLY_LOG_LEVEL" default:"info" usage:"debug level: fatal error warn info debug trace"` + LogLevel string `env:"ORLY_LOG_LEVEL" default:"info" usage:"relay log level: fatal error warn info debug trace"` + DBLogLevel string `env:"ORLY_DB_LOG_LEVEL" default:"info" usage:"database log level: fatal error warn info debug trace"` Pprof string `env:"ORLY_PPROF" usage:"enable pprof in modes: cpu,memory,allocation"` IPWhitelist []string `env:"ORLY_IP_WHITELIST" usage:"comma-separated list of IP addresses to allow access from, matches on prefixes to allow private subnets, eg 10.0.0 = 10.0.0.0/8"` } @@ -71,7 +71,6 @@ func New() (cfg *C, err error) { os.Exit(0) } lol.SetLogLevel(cfg.LogLevel) - log.I.S(cfg.IPWhitelist) return } diff --git a/app/handle-auth.go b/app/handle-auth.go new file mode 100644 index 0000000..4879f7a --- /dev/null +++ b/app/handle-auth.go @@ -0,0 +1 @@ +package app diff --git a/app/handle-close.go b/app/handle-close.go new file mode 100644 index 0000000..4879f7a --- /dev/null +++ b/app/handle-close.go @@ -0,0 +1 @@ +package app diff --git a/app/handle-event.go b/app/handle-event.go new file mode 100644 index 0000000..d57d294 --- /dev/null +++ b/app/handle-event.go @@ -0,0 +1,66 @@ +package app + +import ( + "context" + "fmt" + + "encoders.orly/envelopes/eventenvelope" + "lol.mleku.dev/chk" + "lol.mleku.dev/log" + utils "utils.orly" +) + +func (l *Listener) HandleEvent(c context.Context, msg []byte) ( + err error, +) { + // decode the envelope + env := eventenvelope.NewSubmission() + if msg, err = env.Unmarshal(msg); chk.E(err) { + return + } + if len(msg) > 0 { + log.I.F("extra '%s'", msg) + } + // check the event ID is correct + calculatedId := env.E.GetIDBytes() + if !utils.FastEqual(calculatedId, env.E.ID) { + if err = Ok.Invalid( + l, env, "event id is computed incorrectly, "+ + "event has ID %0x, but when computed it is %0x", + env.E.ID, calculatedId, + ); chk.E(err) { + return + } + return + } + // verify the signature + var ok bool + if ok, err = env.Verify(); chk.T(err) { + if err = Ok.Error( + l, env, fmt.Sprintf( + "failed to verify signature: %s", + err.Error(), + ), + ); chk.E(err) { + return + } + } else if !ok { + if err = Ok.Invalid( + l, env, + "signature is invalid", + ); chk.E(err) { + return + } + return + } + // store the event + if _, _, err = l.SaveEvent(c, env.E, false, nil); chk.E(err) { + return + } + // Send a success response after storing + if err = Ok.Ok(l, env, ""); chk.E(err) { + return + } + log.D.F("saved event %0x", env.E.ID) + return +} diff --git a/app/handle-message.go b/app/handle-message.go index 1cd1a85..fca7099 100644 --- a/app/handle-message.go +++ b/app/handle-message.go @@ -7,12 +7,14 @@ import ( "encoders.orly/envelopes/authenvelope" "encoders.orly/envelopes/closeenvelope" "encoders.orly/envelopes/eventenvelope" + "encoders.orly/envelopes/noticeenvelope" "encoders.orly/envelopes/reqenvelope" "lol.mleku.dev/chk" + "lol.mleku.dev/errorf" "lol.mleku.dev/log" ) -func (s *Server) HandleMessage(msg []byte, remote string) { +func (l *Listener) HandleMessage(msg []byte, remote string) { log.D.C( func() string { return fmt.Sprintf( @@ -20,36 +22,36 @@ func (s *Server) HandleMessage(msg []byte, remote string) { ) }, ) - var notice []byte var err error var t string var rem []byte - if t, rem, err = envelopes.Identify(msg); chk.E(err) { - notice = []byte(err.Error()) + if t, rem, err = envelopes.Identify(msg); !chk.E(err) { + switch t { + case eventenvelope.L: + log.D.F("eventenvelope: %s", rem) + err = l.HandleEvent(l.ctx, rem) + case reqenvelope.L: + log.D.F("reqenvelope: %s", rem) + err = l.HandleReq(l.ctx, rem) + case closeenvelope.L: + log.D.F("closeenvelope: %s", rem) + case authenvelope.L: + log.D.F("authenvelope: %s", rem) + default: + err = errorf.E("unknown envelope type %s\n%s", t, rem) + } } - switch t { - case eventenvelope.L: - log.D.F("eventenvelope: %s", rem) - case reqenvelope.L: - log.D.F("reqenvelope: %s", rem) - case closeenvelope.L: - log.D.F("closeenvelope: %s", rem) - case authenvelope.L: - log.D.F("authenvelope: %s", rem) - default: - notice = []byte(fmt.Sprintf("unknown envelope type %s\n%s", t, rem)) - } - if len(notice) > 0 { + if err != nil { log.D.C( func() string { return fmt.Sprintf( - "notice->%s %s", remote, notice, + "notice->%s %s", remote, err, ) }, ) - // if err = noticeenvelope.NewFrom(notice).Write(a.Listener); chk.E(err) { - // return - // } + if err = noticeenvelope.NewFrom(err.Error()).Write(l); chk.E(err) { + return + } } } diff --git a/app/handle-req.go b/app/handle-req.go new file mode 100644 index 0000000..4d6a413 --- /dev/null +++ b/app/handle-req.go @@ -0,0 +1,120 @@ +package app + +import ( + "context" + "errors" + + "encoders.orly/envelopes/closedenvelope" + "encoders.orly/envelopes/eoseenvelope" + "encoders.orly/envelopes/eventenvelope" + "encoders.orly/envelopes/reqenvelope" + "encoders.orly/event" + "encoders.orly/filter" + "encoders.orly/tag" + "github.com/dgraph-io/badger/v4" + "lol.mleku.dev/chk" + "lol.mleku.dev/log" + "utils.orly/normalize" + "utils.orly/pointers" +) + +func (l *Listener) HandleReq(c context.Context, msg []byte) ( + err error, +) { + var rem []byte + env := reqenvelope.New() + if rem, err = env.Unmarshal(msg); chk.E(err) { + return normalize.Error.Errorf(err.Error()) + } + if len(rem) > 0 { + log.I.F("extra '%s'", rem) + } + var events event.S + for _, f := range *env.Filters { + if pointers.Present(f.Limit) { + if *f.Limit == 0 { + continue + } + } + if events, err = l.QueryEvents(c, f); chk.E(err) { + if errors.Is(err, badger.ErrDBClosed) { + return + } + err = nil + } + } + // write out the events to the socket + seen := make(map[string]struct{}) + for _, ev := range events { + // track the IDs we've sent + seen[string(ev.ID)] = struct{}{} + var res *eventenvelope.Result + if res, err = eventenvelope.NewResultWith( + env.Subscription, ev, + ); chk.E(err) { + return + } + if err = res.Write(l); chk.E(err) { + return + } + } + // write the EOSE to signal to the client that all events found have been + // sent. + if err = eoseenvelope.NewFrom(env.Subscription). + Write(l); chk.E(err) { + return + } + // if the query was for just Ids, we know there can't be any more results, + // so cancel the subscription. + cancel := true + var subbedFilters filter.S + for _, f := range *env.Filters { + if f.Ids.Len() < 1 { + cancel = false + subbedFilters = append(subbedFilters, f) + } else { + // remove the IDs that we already sent + var notFounds [][]byte + for _, ev := range events { + if _, ok := seen[string(ev.ID)]; ok { + continue + } + notFounds = append(notFounds, ev.ID) + } + // if all were found, don't add to subbedFilters + if len(notFounds) == 0 { + continue + } + // rewrite the filter Ids to remove the ones we already sent + f.Ids = tag.NewFromBytesSlice(notFounds...) + // add the filter to the list of filters we're subscribing to + subbedFilters = append(subbedFilters, f) + } + // also, if we received the limit number of events, subscription ded + if pointers.Present(f.Limit) { + if len(events) < int(*f.Limit) { + cancel = false + } + } + } + receiver := make(event.C, 32) + // if the subscription should be cancelled, do so + if !cancel { + l.publishers.Receive( + &W{ + Conn: l.conn, + remote: l.remote, + Id: string(env.Subscription), + Receiver: receiver, + Filters: env.Filters, + }, + ) + } else { + if err = closedenvelope.NewFrom( + env.Subscription, nil, + ).Write(l); chk.E(err) { + return + } + } + return +} diff --git a/app/handle-websocket.go b/app/handle-websocket.go index 73d02fa..0cbe9d8 100644 --- a/app/handle-websocket.go +++ b/app/handle-websocket.go @@ -9,6 +9,7 @@ import ( "github.com/coder/websocket" "lol.mleku.dev/chk" "lol.mleku.dev/log" + "protocol.orly/publish" ) const ( @@ -28,7 +29,7 @@ const ( func (s *Server) HandleWebsocket(w http.ResponseWriter, r *http.Request) { remote := GetRemoteFromReq(r) - log.D.F("handling websocket connection from %s", remote) + log.T.F("handling websocket connection from %s", remote) if len(s.Config.IPWhitelist) > 0 { for _, ip := range s.Config.IPWhitelist { log.T.F("checking IP whitelist: %s", ip) @@ -52,6 +53,13 @@ whitelist: return } defer conn.CloseNow() + listener := &Listener{ + ctx: s.Ctx, + Server: s, + conn: conn, + remote: remote, + } + listener.publishers = publish.New(NewPublisher()) go s.Pinger(s.Ctx, conn, time.NewTicker(time.Second*10), cancel) for { select { @@ -85,7 +93,7 @@ whitelist: } continue } - go s.HandleMessage(msg, remote) + go listener.HandleMessage(msg, remote) } } diff --git a/app/listener.go b/app/listener.go index 39a7111..ed80907 100644 --- a/app/listener.go +++ b/app/listener.go @@ -1,9 +1,23 @@ package app import ( + "context" + "github.com/coder/websocket" + "lol.mleku.dev/chk" ) type Listener struct { - conn *websocket.Conn + *Server + conn *websocket.Conn + ctx context.Context + remote string +} + +func (l *Listener) Write(p []byte) (n int, err error) { + if err = l.conn.Write(l.ctx, websocket.MessageText, p); chk.E(err) { + return + } + n = len(p) + return } diff --git a/app/main.go b/app/main.go index 6335313..05ff2ec 100644 --- a/app/main.go +++ b/app/main.go @@ -5,12 +5,15 @@ import ( "fmt" "net/http" + database "database.orly" "lol.mleku.dev/chk" "lol.mleku.dev/log" "next.orly.dev/app/config" ) -func Run(ctx context.Context, cfg *config.C) (quit chan struct{}) { +func Run( + ctx context.Context, cfg *config.C, db *database.D, +) (quit chan struct{}) { // shutdown handler go func() { select { @@ -23,6 +26,7 @@ func Run(ctx context.Context, cfg *config.C) (quit chan struct{}) { l := &Server{ Ctx: ctx, Config: cfg, + D: db, } addr := fmt.Sprintf("%s:%d", cfg.Listen, cfg.Port) log.I.F("starting listener on http://%s", addr) diff --git a/app/ok.go b/app/ok.go new file mode 100644 index 0000000..f7ab81f --- /dev/null +++ b/app/ok.go @@ -0,0 +1,118 @@ +package app + +import ( + "encoders.orly/envelopes/eventenvelope" + "encoders.orly/envelopes/okenvelope" + "encoders.orly/reason" +) + +// OK represents a function that processes events or operations, using provided +// parameters to generate formatted messages and return errors if any issues +// occur during processing. +type OK func( + l *Listener, env *eventenvelope.Submission, format string, params ...any, +) (err error) + +// OKs provides a collection of handler functions for managing different types +// of operational outcomes, each corresponding to specific error or status +// conditions such as authentication requirements, rate limiting, and invalid +// inputs. +type OKs struct { + Ok OK + AuthRequired OK + PoW OK + Duplicate OK + Blocked OK + RateLimited OK + Invalid OK + Error OK + Unsupported OK + Restricted OK +} + +// Ok provides a collection of handler functions for managing different types of +// operational outcomes, each corresponding to specific error or status +// conditions such as authentication requirements, rate limiting, and invalid +// inputs. +var Ok = OKs{ + Ok: func( + l *Listener, env *eventenvelope.Submission, format string, + params ...any, + ) (err error) { + return okenvelope.NewFrom( + env.Id(), true, nil, + ).Write(l) + }, + AuthRequired: func( + l *Listener, env *eventenvelope.Submission, format string, + params ...any, + ) (err error) { + return okenvelope.NewFrom( + env.Id(), false, reason.AuthRequired.F(format, params...), + ).Write(l) + }, + PoW: func( + l *Listener, env *eventenvelope.Submission, format string, + params ...any, + ) (err error) { + return okenvelope.NewFrom( + env.Id(), false, reason.PoW.F(format, params...), + ).Write(l) + }, + Duplicate: func( + l *Listener, env *eventenvelope.Submission, format string, + params ...any, + ) (err error) { + return okenvelope.NewFrom( + env.Id(), false, reason.Duplicate.F(format, params...), + ).Write(l) + }, + Blocked: func( + l *Listener, env *eventenvelope.Submission, format string, + params ...any, + ) (err error) { + return okenvelope.NewFrom( + env.Id(), false, reason.Blocked.F(format, params...), + ).Write(l) + }, + RateLimited: func( + l *Listener, env *eventenvelope.Submission, format string, + params ...any, + ) (err error) { + return okenvelope.NewFrom( + env.Id(), false, reason.RateLimited.F(format, params...), + ).Write(l) + }, + Invalid: func( + l *Listener, env *eventenvelope.Submission, format string, + params ...any, + ) (err error) { + return okenvelope.NewFrom( + env.Id(), false, reason.Invalid.F(format, params...), + ).Write(l) + }, + Error: func( + l *Listener, env *eventenvelope.Submission, format string, + params ...any, + ) (err error) { + return okenvelope.NewFrom( + env.Id(), false, reason.Error.F(format, params...), + ).Write(l) + }, + Unsupported: func( + l *Listener, env *eventenvelope.Submission, format string, + params ...any, + ) (err error) { + return okenvelope.NewFrom( + env.Id(), false, reason.Unsupported.F(format, params...), + ).Write(l) + }, + Restricted: func( + l *Listener, env *eventenvelope.Submission, format string, + params ...any, + ) (err error) { + return okenvelope.NewFrom( + env.Id(), false, reason.Restricted.F(format, params...), + ).Write(l) + }, +} diff --git a/app/publisher.go b/app/publisher.go new file mode 100644 index 0000000..d1631eb --- /dev/null +++ b/app/publisher.go @@ -0,0 +1,222 @@ +package app + +import ( + "context" + "fmt" + "sync" + + "encoders.orly/envelopes/eventenvelope" + "encoders.orly/event" + "encoders.orly/filter" + "github.com/coder/websocket" + "interfaces.orly/publisher" + "interfaces.orly/typer" + "lol.mleku.dev/chk" + "lol.mleku.dev/log" +) + +const Type = "socketapi" + +type Subscription struct { + remote string + *filter.S +} + +// Map is a map of filters associated with a collection of ws.Listener +// connections. +type Map map[*websocket.Conn]map[string]Subscription + +type W struct { + *websocket.Conn + + remote string + + // If Cancel is true, this is a close command. + Cancel bool + + // Id is the subscription Id. If Cancel is true, cancel the named + // subscription, otherwise, cancel the publisher for the socket. + Id string + + // The Receiver holds the event channel for receiving notifications or data + // relevant to this WebSocket connection. + Receiver event.C + + // Filters holds a collection of filters used to match or process events + // associated with this WebSocket connection. It is used to determine which + // notifications or data should be received by the subscriber. + Filters *filter.S +} + +func (w *W) Type() (typeName string) { return Type } + +// P is a structure that manages subscriptions and associated filters for +// websocket listeners. It uses a mutex to synchronize access to a map storing +// subscriber connections and their filter configurations. +type P struct { + c context.Context + // Mx is the mutex for the Map. + Mx sync.Mutex + // Map is the map of subscribers and subscriptions from the websocket api. + Map +} + +var _ publisher.I = &P{} + +func NewPublisher() (publisher *P) { + return &P{ + Map: make(Map), + } +} + +func (p *P) Type() (typeName string) { return Type } + +// Receive handles incoming messages to manage websocket listener subscriptions +// and associated filters. +// +// # Parameters +// +// - msg (publisher.Message): The incoming message to process; expected to be of +// type *W to trigger subscription management actions. +// +// # Expected behaviour +// +// - Checks if the message is of type *W. +// +// - If Cancel is true, removes a subscriber by ID or the entire listener. +// +// - Otherwise, adds the subscription to the map under a mutex lock. +// +// - Logs actions related to subscription creation or removal. +func (p *P) Receive(msg typer.T) { + if m, ok := msg.(*W); ok { + if m.Cancel { + if m.Id == "" { + p.removeSubscriber(m.Conn) + log.T.F("removed listener %s", m.remote) + } else { + p.removeSubscriberId(m.Conn, m.Id) + log.T.C( + func() string { + return fmt.Sprintf( + "removed subscription %s for %s", m.Id, + m.remote, + ) + }, + ) + } + return + } + p.Mx.Lock() + defer p.Mx.Unlock() + if subs, ok := p.Map[m.Conn]; !ok { + subs = make(map[string]Subscription) + subs[m.Id] = Subscription{S: m.Filters, remote: m.remote} + p.Map[m.Conn] = subs + log.T.C( + func() string { + return fmt.Sprintf( + "created new subscription for %s, %s", + m.remote, + m.Filters.Marshal(nil), + ) + }, + ) + } else { + subs[m.Id] = Subscription{S: m.Filters, remote: m.remote} + log.T.C( + func() string { + return fmt.Sprintf( + "added subscription %s for %s", m.Id, + m.remote, + ) + }, + ) + } + } +} + +// Deliver processes and distributes an event to all matching subscribers based on their filter configurations. +// +// # Parameters +// +// - ev (*event.E): The event to be delivered to subscribed clients. +// +// # Expected behaviour +// +// Delivers the event to all subscribers whose filters match the event. It +// applies authentication checks if required by the server and skips delivery +// for unauthenticated users when events are privileged. +func (p *P) Deliver(ev *event.E) { + var err error + p.Mx.Lock() + defer p.Mx.Unlock() + log.T.C( + func() string { + return fmt.Sprintf( + "delivering event %0x to websocket subscribers %d", ev.ID, + len(p.Map), + ) + }, + ) + for w, subs := range p.Map { + log.T.C( + func() string { + return fmt.Sprintf( + "%v %s", subs, + ) + }, + ) + for id, subscriber := range subs { + if !subscriber.Match(ev) { + continue + } + // if p.Server.AuthRequired() { + // if !auth.CheckPrivilege(w.AuthedPubkey(), ev) { + // continue + // } + // } + var res *eventenvelope.Result + if res, err = eventenvelope.NewResultWith(id, ev); chk.E(err) { + continue + } + if err = w.Write( + p.c, websocket.MessageText, res.Marshal(nil), + ); chk.E(err) { + continue + } + log.T.C( + func() string { + return fmt.Sprintf( + "dispatched event %0x to subscription %s, %s", + ev.ID, id, subscriber.remote, + ) + }, + ) + } + } +} + +// removeSubscriberId removes a specific subscription from a subscriber +// websocket. +func (p *P) removeSubscriberId(ws *websocket.Conn, id string) { + p.Mx.Lock() + var subs map[string]Subscription + var ok bool + if subs, ok = p.Map[ws]; ok { + delete(p.Map[ws], id) + _ = subs + if len(subs) == 0 { + delete(p.Map, ws) + } + } + p.Mx.Unlock() +} + +// removeSubscriber removes a websocket from the P collection. +func (p *P) removeSubscriber(ws *websocket.Conn) { + p.Mx.Lock() + clear(p.Map[ws]) + delete(p.Map, ws) + p.Mx.Unlock() +} diff --git a/app/server.go b/app/server.go index 029ec99..2db5c7c 100644 --- a/app/server.go +++ b/app/server.go @@ -4,14 +4,19 @@ import ( "context" "net/http" + "database.orly" "lol.mleku.dev/log" "next.orly.dev/app/config" + "protocol.orly/publish" ) type Server struct { mux *http.ServeMux Config *config.C Ctx context.Context + remote string + *database.D + publishers *publish.S } func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { diff --git a/go.mod b/go.mod index c5b8dcb..e1dfd65 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module next.orly.dev go 1.25.0 require ( + database.orly v0.0.0-00010101000000-000000000000 encoders.orly v0.0.0-00010101000000-000000000000 github.com/adrg/xdg v0.5.3 github.com/coder/websocket v1.8.13 @@ -16,22 +17,39 @@ require ( require ( crypto.orly v0.0.0-00010101000000-000000000000 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dgraph-io/badger/v4 v4.8.0 // indirect + github.com/dgraph-io/ristretto/v2 v2.2.0 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect github.com/fatih/color v1.18.0 // indirect github.com/felixge/fgprof v0.9.3 // indirect + github.com/go-logr/logr v1.4.3 // indirect + github.com/go-logr/stdr v1.2.2 // indirect + github.com/google/flatbuffers v25.2.10+incompatible // indirect github.com/google/pprof v0.0.0-20211214055906-6f57359322fd // indirect + github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0 // indirect + github.com/klauspost/compress v1.18.0 // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/templexxx/cpu v0.0.1 // indirect github.com/templexxx/xhex v0.0.0-20200614015412-aed53437177b // indirect + go.opentelemetry.io/auto/sdk v1.1.0 // indirect + go.opentelemetry.io/otel v1.37.0 // indirect + go.opentelemetry.io/otel/metric v1.37.0 // indirect + go.opentelemetry.io/otel/trace v1.37.0 // indirect + go.uber.org/atomic v1.11.0 // indirect golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b // indirect + golang.org/x/net v0.41.0 // indirect golang.org/x/sys v0.35.0 // indirect + google.golang.org/protobuf v1.36.6 // indirect interfaces.orly v0.0.0-00010101000000-000000000000 // indirect ) replace ( crypto.orly => ./pkg/crypto + database.orly => ./pkg/database encoders.orly => ./pkg/encoders interfaces.orly => ./pkg/interfaces next.orly.dev => ../../ diff --git a/go.sum b/go.sum index 5248dd2..e87b3e6 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/adrg/xdg v0.5.3 h1:xRnxJXne7+oWDatRhR1JLnvuccuIeCoBu2rtuLqQB78= github.com/adrg/xdg v0.5.3/go.mod h1:nlTsY+NNiCBGCK2tpm09vRqfVzrc2fLmXGpBLF0zlTQ= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= @@ -8,13 +10,34 @@ github.com/coder/websocket v1.8.13/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3C github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgraph-io/badger/v4 v4.8.0 h1:JYph1ChBijCw8SLeybvPINizbDKWZ5n/GYbz2yhN/bs= +github.com/dgraph-io/badger/v4 v4.8.0/go.mod h1:U6on6e8k/RTbUWxqKR0MvugJuVmkxSNc79ap4917h4w= +github.com/dgraph-io/ristretto/v2 v2.2.0 h1:bkY3XzJcXoMuELV8F+vS8kzNgicwQFAaGINAEJdWGOM= +github.com/dgraph-io/ristretto/v2 v2.2.0/go.mod h1:RZrm63UmcBAaYWC1DotLYBmTvgkrs0+XhBd7Npn7/zI= +github.com/dgryski/go-farm v0.0.0-20240924180020-3414d57e47da h1:aIftn67I1fkbMa512G+w+Pxci9hJPB8oMnkcP3iZF38= +github.com/dgryski/go-farm v0.0.0-20240924180020-3414d57e47da/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= github.com/felixge/fgprof v0.9.3 h1:VvyZxILNuCiUCSXtPtYmmtGvb65nqXh2QFWc0Wpf2/g= github.com/felixge/fgprof v0.9.3/go.mod h1:RdbpDgzqYVh/T9fPELJyV7EYJuHB55UTEULNun8eiPw= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/google/flatbuffers v25.2.10+incompatible h1:F3vclr7C3HpB1k9mxCGRMXq6FdUalZ6H/pNX4FP1v0Q= +github.com/google/flatbuffers v25.2.10+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/pprof v0.0.0-20211214055906-6f57359322fd h1:1FjCyPC+syAzJ5/2S8fqdZK1R22vvA0J7JZKcuOIQ7Y= github.com/google/pprof v0.0.0-20211214055906-6f57359322fd/go.mod h1:KgnwoLYCZ8IQu3XUZ8Nc/bM9CCZFOyjUNOSygVozoDg= github.com/ianlancetaylor/demangle v0.0.0-20210905161508-09a460cdf81d/go.mod h1:aYm2/VgdVmcIU8iMfdMvDMsRAQjcfZSKFby6HOFvi/w= +github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0 h1:iQTw/8FWTuc7uiaSepXwyf3o52HaUYcV+Tu66S3F5GA= +github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0/go.mod h1:1NbS8ALrpOvjt0rHPNLyCIeMtbizbir8U//inJ+zuB8= +github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= +github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= @@ -37,12 +60,26 @@ github.com/templexxx/xhex v0.0.0-20200614015412-aed53437177b h1:XeDLE6c9mzHpdv3W github.com/templexxx/xhex v0.0.0-20200614015412-aed53437177b/go.mod h1:7rwmCH0wC2fQvNEvPZ3sKXukhyCTyiaZ5VTZMQYpZKQ= go-simpler.org/env v0.12.0 h1:kt/lBts0J1kjWJAnB740goNdvwNxt5emhYngL0Fzufs= go-simpler.org/env v0.12.0/go.mod h1:cc/5Md9JCUM7LVLtN0HYjPTDcI3Q8TDaPlNTAlDU+WI= +go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= +go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= +go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ= +go.opentelemetry.io/otel v1.37.0/go.mod h1:ehE/umFRLnuLa/vSccNq9oS1ErUlkkK71gMcN34UG8I= +go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/WgbsdpcPoZE= +go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E= +go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4= +go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b h1:DXr+pvt3nC887026GRP39Ej11UATqWDmWuS99x26cD0= golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b/go.mod h1:4QTo5u+SEIbbKW1RacMZq1YEfOBqeXa19JeshGi+zc4= +golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw= +golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA= golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= +google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/main.go b/main.go index 566da79..421a498 100644 --- a/main.go +++ b/main.go @@ -6,6 +6,7 @@ import ( "os" "os/signal" + database "database.orly" "lol.mleku.dev/chk" "lol.mleku.dev/log" "next.orly.dev/app" @@ -21,7 +22,13 @@ func main() { log.I.F("starting %s %s", cfg.AppName, version.V) startProfiler(cfg.Pprof) ctx, cancel := context.WithCancel(context.Background()) - quit := app.Run(ctx, cfg) + var db *database.D + if db, err = database.New( + ctx, cancel, cfg.DataDir, cfg.DBLogLevel, + ); chk.E(err) { + os.Exit(1) + } + quit := app.Run(ctx, cfg, db) sigs := make(chan os.Signal, 1) signal.Notify(sigs, os.Interrupt) for { diff --git a/pkg/database/database.go b/pkg/database/database.go index ff5ec43..22ee3a5 100644 --- a/pkg/database/database.go +++ b/pkg/database/database.go @@ -2,6 +2,7 @@ package database import ( "context" + "errors" "os" "path/filepath" "time" @@ -43,7 +44,8 @@ func New( return } - // Also ensure the directory exists using apputil.EnsureDir for any potential subdirectories + // Also ensure the directory exists using apputil.EnsureDir for any + // potential subdirectories dummyFile := filepath.Join(dataDir, "dummy.sst") if err = apputil.EnsureDir(dummyFile); chk.E(err) { return @@ -87,8 +89,8 @@ func New( func (d *D) Path() string { return d.dataDir } func (d *D) Wipe() (err error) { - // TODO implement me - panic("implement me") + err = errors.New("not implemented") + return } func (d *D) SetLogLevel(level string) { @@ -98,8 +100,8 @@ func (d *D) SetLogLevel(level string) { func (d *D) EventIdsBySerial(start uint64, count int) ( evs []uint64, err error, ) { - // TODO implement me - panic("implement me") + err = errors.New("not implemented") + return } // Init initializes the database with the given path. @@ -118,6 +120,7 @@ func (d *D) Sync() (err error) { // Close releases resources and closes the database. func (d *D) Close() (err error) { + log.D.F("%s: closing database", d.dataDir) if d.seq != nil { if err = d.seq.Release(); chk.E(err) { return diff --git a/pkg/database/get-indexes-for-event.go b/pkg/database/get-indexes-for-event.go index c343c86..bcbcaa1 100644 --- a/pkg/database/get-indexes-for-event.go +++ b/pkg/database/get-indexes-for-event.go @@ -76,20 +76,21 @@ func GetIndexesForEvent(ev *event.E, serial uint64) ( } // Process tags for tag-related indexes if ev.Tags != nil && ev.Tags.Len() > 0 { - for _, tag := range ev.Tags.ToSliceOfTags() { + for _, t := range *ev.Tags { // only index tags with a value field and the key is a single character - if tag.Len() >= 2 { + if t.Len() >= 2 { // Get the key and value from the tag - keyBytes := tag.Key() + keyBytes := t.Key() // require single-letter key if len(keyBytes) != 1 { continue } // if the key is not a-zA-Z skip - if (keyBytes[0] < 'a' || keyBytes[0] > 'z') && (keyBytes[0] < 'A' || keyBytes[0] > 'Z') { + if (keyBytes[0] < 'a' || keyBytes[0] > 'z') && + (keyBytes[0] < 'A' || keyBytes[0] > 'Z') { continue } - valueBytes := tag.Value() + valueBytes := t.Value() // Create tag key and value key := new(Letter) key.Set(keyBytes[0]) diff --git a/pkg/database/get-indexes-from-filter.go b/pkg/database/get-indexes-from-filter.go index e13cd6f..ef1c8e7 100644 --- a/pkg/database/get-indexes-from-filter.go +++ b/pkg/database/get-indexes-from-filter.go @@ -82,7 +82,7 @@ func GetIndexesFromFilter(f *filter.F) (idxs []Range, err error) { // If there is any Ids in the filter, none of the other fields matter. It // should be an error, but convention just ignores it. if f.Ids.Len() > 0 { - for _, id := range f.Ids.ToSliceOfBytes() { + for _, id := range f.Ids.T { if err = func() (err error) { var i *types2.IdHash if i, err = CreateIdHashFromData(id); chk.E(err) { @@ -123,7 +123,7 @@ func GetIndexesFromFilter(f *filter.F) (idxs []Range, err error) { if f.Tags != nil && f.Tags.Len() > 0 { // sort the tags so they are in iteration order (reverse) - tmp := f.Tags.ToSliceOfTags() + tmp := *f.Tags sort.Slice( tmp, func(i, j int) bool { return bytes.Compare(tmp[i].Key(), tmp[j].Key()) > 0 @@ -134,17 +134,17 @@ func GetIndexesFromFilter(f *filter.F) (idxs []Range, err error) { // TagKindPubkey tkp if f.Kinds != nil && f.Kinds.Len() > 0 && f.Authors != nil && f.Authors.Len() > 0 && f.Tags != nil && f.Tags.Len() > 0 { for _, k := range f.Kinds.ToUint16() { - for _, author := range f.Authors.ToSliceOfBytes() { - for _, tag := range f.Tags.ToSliceOfTags() { + for _, author := range f.Authors.T { + for _, t := range *f.Tags { // accept single-letter keys like "e" or filter-style keys like "#e" - if tag.Len() >= 2 && (len(tag.Key()) == 1 || (len(tag.Key()) == 2 && tag.Key()[0] == '#')) { + if t.Len() >= 2 && (len(t.Key()) == 1 || (len(t.Key()) == 2 && t.Key()[0] == '#')) { kind := new(types2.Uint16) kind.Set(k) var p *types2.PubHash if p, err = CreatePubHashFromData(author); chk.E(err) { return } - keyBytes := tag.Key() + keyBytes := t.Key() key := new(types2.Letter) // If the tag key starts with '#', use the second character as the key if len(keyBytes) == 2 && keyBytes[0] == '#' { @@ -152,7 +152,7 @@ func GetIndexesFromFilter(f *filter.F) (idxs []Range, err error) { } else { key.Set(keyBytes[0]) } - for _, valueBytes := range tag.ToSliceOfBytes()[1:] { + for _, valueBytes := range t.T[1:] { valueHash := new(types2.Ident) valueHash.FromIdent(valueBytes) start, end := new(bytes.Buffer), new(bytes.Buffer) @@ -184,11 +184,11 @@ func GetIndexesFromFilter(f *filter.F) (idxs []Range, err error) { // TagKind tkc if f.Kinds != nil && f.Kinds.Len() > 0 && f.Tags != nil && f.Tags.Len() > 0 { for _, k := range f.Kinds.ToUint16() { - for _, tag := range f.Tags.ToSliceOfTags() { - if tag.Len() >= 2 && (len(tag.Key()) == 1 || (len(tag.Key()) == 2 && tag.Key()[0] == '#')) { + for _, t := range *f.Tags { + if t.Len() >= 2 && (len(t.Key()) == 1 || (len(t.Key()) == 2 && t.Key()[0] == '#')) { kind := new(types2.Uint16) kind.Set(k) - keyBytes := tag.Key() + keyBytes := t.Key() key := new(types2.Letter) // If the tag key starts with '#', use the second character as the key if len(keyBytes) == 2 && keyBytes[0] == '#' { @@ -196,7 +196,7 @@ func GetIndexesFromFilter(f *filter.F) (idxs []Range, err error) { } else { key.Set(keyBytes[0]) } - for _, valueBytes := range tag.ToSliceOfBytes()[1:] { + for _, valueBytes := range t.T[1:] { valueHash := new(types2.Ident) valueHash.FromIdent(valueBytes) start, end := new(bytes.Buffer), new(bytes.Buffer) @@ -226,14 +226,14 @@ func GetIndexesFromFilter(f *filter.F) (idxs []Range, err error) { // TagPubkey tpc if f.Authors != nil && f.Authors.Len() > 0 && f.Tags != nil && f.Tags.Len() > 0 { - for _, author := range f.Authors.ToSliceOfBytes() { - for _, tag := range f.Tags.ToSliceOfTags() { - if tag.Len() >= 2 && (len(tag.Key()) == 1 || (len(tag.Key()) == 2 && tag.Key()[0] == '#')) { + for _, author := range f.Authors.T { + for _, t := range *f.Tags { + if t.Len() >= 2 && (len(t.Key()) == 1 || (len(t.Key()) == 2 && t.Key()[0] == '#')) { var p *types2.PubHash if p, err = CreatePubHashFromData(author); chk.E(err) { return } - keyBytes := tag.Key() + keyBytes := t.Key() key := new(types2.Letter) // If the tag key starts with '#', use the second character as the key if len(keyBytes) == 2 && keyBytes[0] == '#' { @@ -241,7 +241,7 @@ func GetIndexesFromFilter(f *filter.F) (idxs []Range, err error) { } else { key.Set(keyBytes[0]) } - for _, valueBytes := range tag.ToSliceOfBytes()[1:] { + for _, valueBytes := range t.T[1:] { valueHash := new(types2.Ident) valueHash.FromIdent(valueBytes) start, end := new(bytes.Buffer), new(bytes.Buffer) @@ -269,9 +269,9 @@ func GetIndexesFromFilter(f *filter.F) (idxs []Range, err error) { // Tag tc- if f.Tags != nil && f.Tags.Len() > 0 && (f.Authors == nil || f.Authors.Len() == 0) && (f.Kinds == nil || f.Kinds.Len() == 0) { - for _, tag := range f.Tags.ToSliceOfTags() { - if tag.Len() >= 2 && (len(tag.Key()) == 1 || (len(tag.Key()) == 2 && tag.Key()[0] == '#')) { - keyBytes := tag.Key() + for _, t := range *f.Tags { + if t.Len() >= 2 && (len(t.Key()) == 1 || (len(t.Key()) == 2 && t.Key()[0] == '#')) { + keyBytes := t.Key() key := new(types2.Letter) // If the tag key starts with '#', use the second character as the key if len(keyBytes) == 2 && keyBytes[0] == '#' { @@ -279,7 +279,7 @@ func GetIndexesFromFilter(f *filter.F) (idxs []Range, err error) { } else { key.Set(keyBytes[0]) } - for _, valueBytes := range tag.ToSliceOfBytes()[1:] { + for _, valueBytes := range t.T[1:] { valueHash := new(types2.Ident) valueHash.FromIdent(valueBytes) start, end := new(bytes.Buffer), new(bytes.Buffer) @@ -303,7 +303,7 @@ func GetIndexesFromFilter(f *filter.F) (idxs []Range, err error) { // KindPubkey kpc if f.Kinds != nil && f.Kinds.Len() > 0 && f.Authors != nil && f.Authors.Len() > 0 { for _, k := range f.Kinds.ToUint16() { - for _, author := range f.Authors.ToSliceOfBytes() { + for _, author := range f.Authors.T { kind := new(types2.Uint16) kind.Set(k) p := new(types2.PubHash) @@ -350,7 +350,7 @@ func GetIndexesFromFilter(f *filter.F) (idxs []Range, err error) { // Pubkey pc- if f.Authors != nil && f.Authors.Len() > 0 { - for _, author := range f.Authors.ToSliceOfBytes() { + for _, author := range f.Authors.T { p := new(types2.PubHash) if err = p.FromPubkey(author); chk.E(err) { return diff --git a/pkg/database/query-events.go b/pkg/database/query-events.go index d4ef295..e588df6 100644 --- a/pkg/database/query-events.go +++ b/pkg/database/query-events.go @@ -42,7 +42,7 @@ func (d *D) QueryEvents(c context.Context, f *filter.F) ( var expDeletes types.Uint40s var expEvs event.S if f.Ids != nil && f.Ids.Len() > 0 { - for _, idx := range f.Ids.ToSliceOfBytes() { + for _, idx := range f.Ids.T { // we know there is only Ids in this, so run the ID query and fetch. var ser *types.Uint40 if ser, err = d.GetSerialById(idx); chk.E(err) { diff --git a/pkg/database/query-for-serials.go b/pkg/database/query-for-serials.go index 8a66f13..708e1fd 100644 --- a/pkg/database/query-for-serials.go +++ b/pkg/database/query-for-serials.go @@ -17,7 +17,7 @@ func (d *D) QueryForSerials(c context.Context, f *filter.F) ( var founds []*types.Uint40 var idPkTs []*store.IdPkTs if f.Ids != nil && f.Ids.Len() > 0 { - for _, id := range f.Ids.ToSliceOfBytes() { + for _, id := range f.Ids.T { var ser *types.Uint40 if ser, err = d.GetSerialById(id); chk.E(err) { return @@ -29,25 +29,6 @@ func (d *D) QueryForSerials(c context.Context, f *filter.F) ( return } idPkTs = append(idPkTs, tmp...) - - // // fetch the events full id indexes so we can sort them - // for _, ser := range founds { - // // scan for the IdPkTs - // var fidpk *store.IdPkTs - // if fidpk, err = d.GetFullIdPubkeyBySerial(ser); chk.E(err) { - // return - // } - // if fidpk == nil { - // continue - // } - // idPkTs = append(idPkTs, fidpk) - // // sort by timestamp - // sort.Slice( - // idPkTs, func(i, j int) bool { - // return idPkTs[i].Ts > idPkTs[j].Ts - // }, - // ) - // } } else { if idPkTs, err = d.QueryForIds(c, f); chk.E(err) { return diff --git a/pkg/database/save-event.go b/pkg/database/save-event.go index de9028f..68b7d83 100644 --- a/pkg/database/save-event.go +++ b/pkg/database/save-event.go @@ -16,6 +16,7 @@ import ( "interfaces.orly/store" "lol.mleku.dev/chk" "lol.mleku.dev/errorf" + "lol.mleku.dev/log" ) // SaveEvent saves an event to the database, generating all the necessary indexes. @@ -71,16 +72,6 @@ func (d *D) SaveEvent( return } idPkTss = append(idPkTss, tmp...) - // for _, ser := range sers { - // var fidpk *store.IdPkTs - // if fidpk, err = d.GetFullIdPubkeyBySerial(ser); chk.E(err) { - // return - // } - // if fidpk == nil { - // continue - // } - // idPkTss = append(idPkTss, fidpk) - // } // sort by timestamp, so the first is the newest sort.Slice( idPkTss, func(i, j int) bool { @@ -177,6 +168,6 @@ func (d *D) SaveEvent( return }, ) - // log.T.F("total data written: %d bytes keys %d bytes values", kc, vc) + log.T.F("total data written: %d bytes keys %d bytes values", kc, vc) return } diff --git a/pkg/encoders/envelopes/reqenvelope/reqenvelope.go b/pkg/encoders/envelopes/reqenvelope/reqenvelope.go index b2d4ab1..fe6b046 100644 --- a/pkg/encoders/envelopes/reqenvelope/reqenvelope.go +++ b/pkg/encoders/envelopes/reqenvelope/reqenvelope.go @@ -21,7 +21,7 @@ const L = "REQ" // newly received events after it returns an eoseenvelope.T. type T struct { Subscription []byte - Filters filter.S + Filters *filter.S } var _ codec.Envelope = (*T)(nil) @@ -32,14 +32,14 @@ func New() *T { return new(T) } // NewFrom creates a new reqenvelope.T with a provided subscription.Id and // filters.T. -func NewFrom(id []byte, ff filter.S) *T { +func NewFrom(id []byte, ff *filter.S) *T { return &T{ Subscription: id, Filters: ff, } } -func NewWithId[V string | []byte](id V, ff filter.S) (sub *T) { +func NewWithId[V string | []byte](id V, ff *filter.S) (sub *T) { return &T{ Subscription: []byte(id), Filters: ff, @@ -69,7 +69,7 @@ func (en *T) Marshal(dst []byte) (b []byte) { o = append(o, '"') o = append(o, en.Subscription...) o = append(o, '"') - for _, f := range en.Filters { + for _, f := range *en.Filters { o = append(o, ',') o = f.Marshal(o) } @@ -90,6 +90,7 @@ func (en *T) Unmarshal(b []byte) (r []byte, err error) { if r, err = text.Comma(r); chk.E(err) { return } + en.Filters = new(filter.S) if r, err = en.Filters.Unmarshal(r); chk.E(err) { return } diff --git a/pkg/encoders/event/binary.go b/pkg/encoders/event/binary.go index 446715f..8cbcd0b 100644 --- a/pkg/encoders/event/binary.go +++ b/pkg/encoders/event/binary.go @@ -30,9 +30,9 @@ func (ev *E) MarshalBinary(w io.Writer) { varint.Encode(w, uint64(ev.CreatedAt)) varint.Encode(w, uint64(ev.Kind)) varint.Encode(w, uint64(ev.Tags.Len())) - for _, x := range ev.Tags.ToSliceOfTags() { + for _, x := range *ev.Tags { varint.Encode(w, uint64(x.Len())) - for _, y := range x.ToSliceOfBytes() { + for _, y := range x.T { varint.Encode(w, uint64(len(y))) _, _ = w.Write(y) } diff --git a/pkg/encoders/event/canonical.go b/pkg/encoders/event/canonical.go index 4bc225e..598d898 100644 --- a/pkg/encoders/event/canonical.go +++ b/pkg/encoders/event/canonical.go @@ -5,6 +5,7 @@ import ( "encoders.orly/hex" "encoders.orly/ints" "encoders.orly/text" + "lol.mleku.dev/log" ) // ToCanonical converts the event to the canonical encoding used to derive the @@ -14,14 +15,15 @@ func (ev *E) ToCanonical(dst []byte) (b []byte) { b = append(b, "[0,\""...) b = hex.EncAppend(b, ev.Pubkey) b = append(b, "\","...) - b = ints.New(ev.CreatedAt).Marshal(nil) + b = ints.New(ev.CreatedAt).Marshal(b) b = append(b, ',') - b = ints.New(ev.Kind).Marshal(nil) + b = ints.New(ev.Kind).Marshal(b) b = append(b, ',') b = ev.Tags.Marshal(b) b = append(b, ',') b = text.AppendQuote(b, ev.Content, text.NostrEscape) b = append(b, ']') + log.D.F("canonical: %s", b) return } diff --git a/pkg/encoders/filter/filter.go b/pkg/encoders/filter/filter.go index c8c724d..1502fd0 100644 --- a/pkg/encoders/filter/filter.go +++ b/pkg/encoders/filter/filter.go @@ -6,6 +6,7 @@ import ( "crypto.orly/ec/schnorr" "crypto.orly/sha256" + "encoders.orly/event" "encoders.orly/ints" "encoders.orly/kind" "encoders.orly/tag" @@ -100,6 +101,50 @@ func (f *F) Sort() { } } +// MatchesIgnoringTimestampConstraints checks a filter against an event and +// determines if the event matches the filter, ignoring timestamp constraints.. +func (f *F) MatchesIgnoringTimestampConstraints(ev *event.E) bool { + if ev == nil { + return false + } + if f.Ids.Len() > 0 && !f.Ids.Contains(ev.ID) { + return false + } + if f.Kinds.Len() > 0 && !f.Kinds.Contains(ev.Kind) { + return false + } + if f.Authors.Len() > 0 && !f.Authors.Contains(ev.Pubkey) { + return false + } + if f.Tags.Len() > 0 { + for _, v := range *f.Tags { + if v.Len() < 2 { + continue + } + key := v.Key() + values := v.T[1:] + if !ev.Tags.ContainsAny(key, values) { + return false + } + } + } + return true +} + +// Matches checks a filter against an event and determines if the event matches the filter. +func (f *F) Matches(ev *event.E) (match bool) { + if !f.MatchesIgnoringTimestampConstraints(ev) { + return + } + if f.Since.Int() != 0 && ev.CreatedAt < f.Since.I64() { + return + } + if f.Until.Int() != 0 && ev.CreatedAt > f.Until.I64() { + return + } + return true +} + // Marshal a filter into raw JSON bytes, minified. The field ordering and sort // of fields is canonicalized so that a hash can identify the same filter. func (f *F) Marshal(dst []byte) (b []byte) { diff --git a/pkg/encoders/filter/filters.go b/pkg/encoders/filter/filters.go index 036d097..b8e28c2 100644 --- a/pkg/encoders/filter/filters.go +++ b/pkg/encoders/filter/filters.go @@ -1,11 +1,22 @@ package filter import ( + "encoders.orly/event" "lol.mleku.dev/errorf" ) type S []*F +// Match checks if a set of filters.T matches on an event.F. +func (s *S) Match(event *event.E) bool { + for _, f := range *s { + if f.Matches(event) { + return true + } + } + return false +} + // Marshal encodes a slice of filters as a JSON array of objects. // It appends the result to dst and returns the resulting slice. func (s S) Marshal(dst []byte) (b []byte) { @@ -43,7 +54,7 @@ func (s *S) Unmarshal(b []byte) (r []byte, err error) { if len(r) == 0 { return } - f := new(F) + f := New() var rem []byte if rem, err = f.Unmarshal(r); err != nil { return diff --git a/pkg/encoders/kind/kind.go b/pkg/encoders/kind/kind.go index 20f0bb7..3f52722 100644 --- a/pkg/encoders/kind/kind.go +++ b/pkg/encoders/kind/kind.go @@ -59,11 +59,11 @@ func (k *K) ToU64() uint64 { func (k *K) Name() string { return GetString(k.K) } // Equal checks if -func (k *K) Equal(k2 *K) bool { - if k == nil || k2 == nil { +func (k *K) Equal(k2 uint16) bool { + if k == nil { return false } - return k.K == k2.K + return k.K == k2 } var Privileged = []*K{ @@ -80,7 +80,7 @@ var Privileged = []*K{ // the pubkeys in the event and p tags of the event are party to. func (k *K) IsPrivileged() (is bool) { for i := range Privileged { - if k.Equal(Privileged[i]) { + if k.Equal(Privileged[i].K) { return true } } diff --git a/pkg/encoders/kind/kinds.go b/pkg/encoders/kind/kinds.go index ed644fc..756eb73 100644 --- a/pkg/encoders/kind/kinds.go +++ b/pkg/encoders/kind/kinds.go @@ -67,7 +67,7 @@ func (k *S) Clone() (c *S) { // Even if a custom number is found, this codebase does not have the logic to // deal with the kind so such a search is pointless and for which reason static // typing always wins. No mistakes possible with known quantities. -func (k *S) Contains(s *K) bool { +func (k *S) Contains(s uint16) bool { for i := range k.K { if k.K[i].Equal(s) { return true diff --git a/pkg/encoders/reason/reason.go b/pkg/encoders/reason/reason.go new file mode 100644 index 0000000..efe6e52 --- /dev/null +++ b/pkg/encoders/reason/reason.go @@ -0,0 +1,52 @@ +package reason + +import ( + "bytes" + "fmt" + + "lol.mleku.dev/log" +) + +// R is the machine-readable prefix before the colon in an OK or CLOSED envelope message. +// Below are the most common kinds that are mentioned in NIP-01. +type R []byte + +var ( + AuthRequired = R("auth-required") + PoW = R("pow") + Duplicate = R("duplicate") + Blocked = R("blocked") + RateLimited = R("rate-limited") + Invalid = R("invalid") + Error = R("error") + Unsupported = R("unsupported") + Restricted = R("restricted") +) + +// S returns the R as a string +func (r R) S() string { return string(r) } + +// B returns the R as a byte slice. +func (r R) B() []byte { return r } + +// IsPrefix returns whether a text contains the same R prefix. +func (r R) IsPrefix(reason []byte) bool { + return bytes.HasPrefix( + reason, r.B(), + ) +} + +// F allows creation of a full R text with a printf style format. +func (r R) F(format string, params ...any) (o []byte) { + log.D.F(format, params...) + return Msg(r, format, params...) +} + +// Msg constructs a properly formatted message with a machine-readable prefix +// for OK and CLOSED envelopes. +func Msg(prefix R, format string, params ...any) (o []byte) { + if len(prefix) < 1 { + prefix = Error + } + return []byte(fmt.Sprintf(prefix.S()+": "+format, params...)) +} diff --git a/pkg/encoders/tag/tag.go b/pkg/encoders/tag/tag.go index 7bde30d..ceb32ce 100644 --- a/pkg/encoders/tag/tag.go +++ b/pkg/encoders/tag/tag.go @@ -8,6 +8,7 @@ import ( "encoders.orly/text" "lol.mleku.dev/errorf" + utils "utils.orly" "utils.orly/bufpool" ) @@ -67,21 +68,28 @@ func (t *T) Less(i, j int) bool { func (t *T) Swap(i, j int) { t.T[i], t.T[j] = t.T[j], t.T[i] } -func (t *T) ToSliceOfBytes() (b [][]byte) { - return t.T +// Contains returns true if the provided element is found in the tag slice. +func (t *T) Contains(s []byte) (b bool) { + for i := range t.T { + if utils.FastEqual(t.T[i], s) { + return true + } + } + return false } // Marshal encodes a tag.T as standard minified JSON array of strings. func (t *T) Marshal(dst []byte) (b []byte) { - dst = append(dst, '[') + b = dst + b = append(b, '[') for i, s := range t.T { - dst = text.AppendQuote(dst, s, text.NostrEscape) + b = text.AppendQuote(b, s, text.NostrEscape) if i < len(t.T)-1 { - dst = append(dst, ',') + b = append(b, ',') } } - dst = append(dst, ']') - return dst + b = append(b, ']') + return } // MarshalJSON encodes a tag.T as standard minified JSON array of strings. diff --git a/pkg/encoders/tag/tags.go b/pkg/encoders/tag/tags.go index 7d81c44..4e716d4 100644 --- a/pkg/encoders/tag/tags.go +++ b/pkg/encoders/tag/tags.go @@ -44,17 +44,26 @@ func (s *S) Append(t ...*T) { *s = append(*s, t...) } -func (s *S) ToSliceOfTags() (t []T) { - if s == nil { - return +// ContainsAny returns true if any of the values given in `values` matches any +// of the tag elements. +func (s *S) ContainsAny(tagName []byte, values [][]byte) bool { + if len(tagName) < 1 { + return false } - for _, tt := range *s { - if tt == nil { + for _, v := range *s { + if v.Len() < 2 { continue } - t = append(t, *tt) + if !utils.FastEqual(v.Key(), tagName) { + continue + } + for _, candidate := range values { + if bytes.HasPrefix(v.Value(), candidate) { + return true + } + } } - return + return false } // MarshalJSON encodes a tags.T appended to a provided byte slice in JSON form. @@ -74,7 +83,8 @@ func (s *S) MarshalJSON() (b []byte, err error) { } func (s *S) Marshal(dst []byte) (b []byte) { - b = append(dst, '[') + b = dst + b = append(b, '[') for i, ss := range *s { b = ss.Marshal(b) if i < len(*s)-1 { diff --git a/pkg/interfaces/go.mod b/pkg/interfaces/go.mod index 31068ca..59676fd 100644 --- a/pkg/interfaces/go.mod +++ b/pkg/interfaces/go.mod @@ -4,9 +4,34 @@ go 1.25.0 replace ( crypto.orly => ../crypto + database.orly => ../database encoders.orly => ../encoders interfaces.orly => ../interfaces next.orly.dev => ../../ protocol.orly => ../protocol utils.orly => ../utils ) + +require ( + database.orly v0.0.0-00010101000000-000000000000 + encoders.orly v0.0.0-00010101000000-000000000000 + next.orly.dev v0.0.0-00010101000000-000000000000 +) + +require ( + crypto.orly v0.0.0-00010101000000-000000000000 // indirect + github.com/adrg/xdg v0.5.3 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/fatih/color v1.18.0 // indirect + github.com/klauspost/cpuid/v2 v2.3.0 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/templexxx/cpu v0.0.1 // indirect + github.com/templexxx/xhex v0.0.0-20200614015412-aed53437177b // indirect + go-simpler.org/env v0.12.0 // indirect + golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b // indirect + golang.org/x/sys v0.35.0 // indirect + lol.mleku.dev v1.0.2 // indirect + lukechampine.com/frand v1.5.1 // indirect + utils.orly v0.0.0-00010101000000-000000000000 // indirect +) diff --git a/pkg/interfaces/go.sum b/pkg/interfaces/go.sum new file mode 100644 index 0000000..d032de7 --- /dev/null +++ b/pkg/interfaces/go.sum @@ -0,0 +1,33 @@ +github.com/adrg/xdg v0.5.3 h1:xRnxJXne7+oWDatRhR1JLnvuccuIeCoBu2rtuLqQB78= +github.com/adrg/xdg v0.5.3/go.mod h1:nlTsY+NNiCBGCK2tpm09vRqfVzrc2fLmXGpBLF0zlTQ= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= +github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= +github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= +github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/templexxx/cpu v0.0.1 h1:hY4WdLOgKdc8y13EYklu9OUTXik80BkxHoWvTO6MQQY= +github.com/templexxx/cpu v0.0.1/go.mod h1:w7Tb+7qgcAlIyX4NhLuDKt78AHA5SzPmq0Wj6HiEnnk= +github.com/templexxx/xhex v0.0.0-20200614015412-aed53437177b h1:XeDLE6c9mzHpdv3Wb1+pWBaWv/BlHK0ZYIu/KaL6eHg= +github.com/templexxx/xhex v0.0.0-20200614015412-aed53437177b/go.mod h1:7rwmCH0wC2fQvNEvPZ3sKXukhyCTyiaZ5VTZMQYpZKQ= +go-simpler.org/env v0.12.0 h1:kt/lBts0J1kjWJAnB740goNdvwNxt5emhYngL0Fzufs= +go-simpler.org/env v0.12.0/go.mod h1:cc/5Md9JCUM7LVLtN0HYjPTDcI3Q8TDaPlNTAlDU+WI= +golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b h1:DXr+pvt3nC887026GRP39Ej11UATqWDmWuS99x26cD0= +golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b/go.mod h1:4QTo5u+SEIbbKW1RacMZq1YEfOBqeXa19JeshGi+zc4= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= +golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +lol.mleku.dev v1.0.2 h1:bSV1hHnkmt1hq+9nSvRwN6wgcI7itbM3XRZ4dMB438c= +lol.mleku.dev v1.0.2/go.mod h1:DQ0WnmkntA9dPLCXgvtIgYt5G0HSqx3wSTLolHgWeLA= +lukechampine.com/frand v1.5.1 h1:fg0eRtdmGFIxhP5zQJzM1lFDbD6CUfu/f+7WgAZd5/w= +lukechampine.com/frand v1.5.1/go.mod h1:4VstaWc2plN4Mjr10chUD46RAVGWhpkZ5Nja8+Azp0Q= diff --git a/pkg/interfaces/publisher/publisher.go b/pkg/interfaces/publisher/publisher.go new file mode 100644 index 0000000..1d38507 --- /dev/null +++ b/pkg/interfaces/publisher/publisher.go @@ -0,0 +1,14 @@ +package publisher + +import ( + "encoders.orly/event" + "interfaces.orly/typer" +) + +type I interface { + typer.T + Deliver(ev *event.E) + Receive(msg typer.T) +} + +type Publishers []I diff --git a/pkg/interfaces/typer/typer.go b/pkg/interfaces/typer/typer.go new file mode 100644 index 0000000..b531923 --- /dev/null +++ b/pkg/interfaces/typer/typer.go @@ -0,0 +1,10 @@ +// Package typer is an interface for server to use to identify their type simply for +// aggregating multiple self-registered server such that the top level can recognise the +// type of a message and match it to the type of handler. +package typer + +type T interface { + // Type returns a type identifier string to allow multiple self-registering publisher.I to + // be used with an abstraction to allow multiple APIs to publish. + Type() string +} diff --git a/pkg/protocol/publish/publisher.go b/pkg/protocol/publish/publisher.go new file mode 100644 index 0000000..9170f4a --- /dev/null +++ b/pkg/protocol/publish/publisher.go @@ -0,0 +1,38 @@ +package publish + +import ( + "encoders.orly/event" + "interfaces.orly/publisher" + "interfaces.orly/typer" +) + +// S is the control structure for the subscription management scheme. +type S struct { + publisher.Publishers +} + +// New creates a new publish.S. +func New(p ...publisher.I) (s *S) { + s = &S{Publishers: p} + return +} + +var _ publisher.I = &S{} + +func (s *S) Type() string { return "publish" } + +func (s *S) Deliver(ev *event.E) { + for _, p := range s.Publishers { + p.Deliver(ev) + } +} + +func (s *S) Receive(msg typer.T) { + t := msg.Type() + for _, p := range s.Publishers { + if p.Type() == t { + p.Receive(msg) + return + } + } +} diff --git a/pkg/utils/atomic/.codecov.yml b/pkg/utils/atomic/.codecov.yml new file mode 100644 index 0000000..571116c --- /dev/null +++ b/pkg/utils/atomic/.codecov.yml @@ -0,0 +1,19 @@ +coverage: + range: 80..100 + round: down + precision: 2 + + status: + project: # measuring the overall project coverage + default: # context, you can create multiple ones with custom titles + enabled: yes # must be yes|true to enable this status + target: 100 # specify the target coverage for each commit status + # option: "auto" (must increase from parent commit or pull request base) + # option: "X%" a static target percentage to hit + if_not_found: success # if parent is not found report status as success, error, or failure + if_ci_failed: error # if ci fails report status as success, error, or failure + +# Also update COVER_IGNORE_PKGS in the Makefile. +ignore: + - /internal/gen-atomicint/ + - /internal/gen-valuewrapper/ diff --git a/pkg/utils/atomic/CHANGELOG.md b/pkg/utils/atomic/CHANGELOG.md new file mode 100644 index 0000000..71db542 --- /dev/null +++ b/pkg/utils/atomic/CHANGELOG.md @@ -0,0 +1,130 @@ +# Changelog +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## Unreleased +- No changes yet. + +## [1.11.0] - 2023-05-02 +### Fixed +- Fix `Swap` and `CompareAndSwap` for `Value` wrappers without initialization. + +### Added +- Add `String` method to `atomic.Pointer[T]` type allowing users to safely print +underlying values of pointers. + +[1.11.0]: https://github.com/uber-go/atomic/compare/v1.10.0...v1.11.0 + +## [1.10.0] - 2022-08-11 +### Added +- Add `atomic.Float32` type for atomic operations on `float32`. +- Add `CompareAndSwap` and `Swap` methods to `atomic.String`, `atomic.Error`, + and `atomic.Value`. +- Add generic `atomic.Pointer[T]` type for atomic operations on pointers of any + type. This is present only for Go 1.18 or higher, and is a drop-in for + replacement for the standard library's `sync/atomic.Pointer` type. + +### Changed +- Deprecate `CAS` methods on all types in favor of corresponding + `CompareAndSwap` methods. + +Thanks to @eNV25 and @icpd for their contributions to this release. + +[1.10.0]: https://github.com/uber-go/atomic/compare/v1.9.0...v1.10.0 + +## [1.9.0] - 2021-07-15 +### Added +- Add `Float64.Swap` to match int atomic operations. +- Add `atomic.Time` type for atomic operations on `time.Time` values. + +[1.9.0]: https://github.com/uber-go/atomic/compare/v1.8.0...v1.9.0 + +## [1.8.0] - 2021-06-09 +### Added +- Add `atomic.Uintptr` type for atomic operations on `uintptr` values. +- Add `atomic.UnsafePointer` type for atomic operations on `unsafe.Pointer` values. + +[1.8.0]: https://github.com/uber-go/atomic/compare/v1.7.0...v1.8.0 + +## [1.7.0] - 2020-09-14 +### Added +- Support JSON serialization and deserialization of primitive atomic types. +- Support Text marshalling and unmarshalling for string atomics. + +### Changed +- Disallow incorrect comparison of atomic values in a non-atomic way. + +### Removed +- Remove dependency on `golang.org/x/{lint, tools}`. + +[1.7.0]: https://github.com/uber-go/atomic/compare/v1.6.0...v1.7.0 + +## [1.6.0] - 2020-02-24 +### Changed +- Drop library dependency on `golang.org/x/{lint, tools}`. + +[1.6.0]: https://github.com/uber-go/atomic/compare/v1.5.1...v1.6.0 + +## [1.5.1] - 2019-11-19 +- Fix bug where `Bool.CAS` and `Bool.Toggle` do work correctly together + causing `CAS` to fail even though the old value matches. + +[1.5.1]: https://github.com/uber-go/atomic/compare/v1.5.0...v1.5.1 + +## [1.5.0] - 2019-10-29 +### Changed +- With Go modules, only the `go.uber.org/atomic` import path is supported now. + If you need to use the old import path, please add a `replace` directive to + your `go.mod`. + +[1.5.0]: https://github.com/uber-go/atomic/compare/v1.4.0...v1.5.0 + +## [1.4.0] - 2019-05-01 +### Added + - Add `atomic.Error` type for atomic operations on `error` values. + +[1.4.0]: https://github.com/uber-go/atomic/compare/v1.3.2...v1.4.0 + +## [1.3.2] - 2018-05-02 +### Added +- Add `atomic.Duration` type for atomic operations on `time.Duration` values. + +[1.3.2]: https://github.com/uber-go/atomic/compare/v1.3.1...v1.3.2 + +## [1.3.1] - 2017-11-14 +### Fixed +- Revert optimization for `atomic.String.Store("")` which caused data races. + +[1.3.1]: https://github.com/uber-go/atomic/compare/v1.3.0...v1.3.1 + +## [1.3.0] - 2017-11-13 +### Added +- Add `atomic.Bool.CAS` for compare-and-swap semantics on bools. + +### Changed +- Optimize `atomic.String.Store("")` by avoiding an allocation. + +[1.3.0]: https://github.com/uber-go/atomic/compare/v1.2.0...v1.3.0 + +## [1.2.0] - 2017-04-12 +### Added +- Shadow `atomic.Value` from `sync/atomic`. + +[1.2.0]: https://github.com/uber-go/atomic/compare/v1.1.0...v1.2.0 + +## [1.1.0] - 2017-03-10 +### Added +- Add atomic `Float64` type. + +### Changed +- Support new `go.uber.org/atomic` import path. + +[1.1.0]: https://github.com/uber-go/atomic/compare/v1.0.0...v1.1.0 + +## [1.0.0] - 2016-07-18 + +- Initial release. + +[1.0.0]: https://github.com/uber-go/atomic/releases/tag/v1.0.0 diff --git a/pkg/utils/atomic/LICENSE b/pkg/utils/atomic/LICENSE new file mode 100644 index 0000000..8765c9f --- /dev/null +++ b/pkg/utils/atomic/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2016 Uber Technologies, Inc. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/pkg/utils/atomic/Makefile b/pkg/utils/atomic/Makefile new file mode 100644 index 0000000..53432ab --- /dev/null +++ b/pkg/utils/atomic/Makefile @@ -0,0 +1,79 @@ +# Directory to place `go install`ed binaries into. +export GOBIN ?= $(shell pwd)/bin + +GOLINT = $(GOBIN)/golint +GEN_ATOMICINT = $(GOBIN)/gen-atomicint +GEN_ATOMICWRAPPER = $(GOBIN)/gen-atomicwrapper +STATICCHECK = $(GOBIN)/staticcheck + +GO_FILES ?= $(shell find . '(' -path .git -o -path vendor ')' -prune -o -name '*.go' -print) + +# Also update ignore section in .codecov.yml. +COVER_IGNORE_PKGS = \ + github.com/p9ds/atomic/internal/gen-atomicint \ + github.com/p9ds/atomic/internal/gen-atomicwrapper + +.PHONY: build +build: + go build ./... + +.PHONY: test +test: + go test -race ./... + +.PHONY: gofmt +gofmt: + $(eval FMT_LOG := $(shell mktemp -t gofmt.XXXXX)) + gofmt -e -s -l $(GO_FILES) > $(FMT_LOG) || true + @[ ! -s "$(FMT_LOG)" ] || (echo "gofmt failed:" && cat $(FMT_LOG) && false) + +$(GOLINT): + cd tools && go install golang.org/x/lint/golint + +$(STATICCHECK): + cd tools && go install honnef.co/go/tools/cmd/staticcheck + +$(GEN_ATOMICWRAPPER): $(wildcard ./internal/gen-atomicwrapper/*) + go build -o $@ ./internal/gen-atomicwrapper + +$(GEN_ATOMICINT): $(wildcard ./internal/gen-atomicint/*) + go build -o $@ ./internal/gen-atomicint + +.PHONY: golint +golint: $(GOLINT) + $(GOLINT) ./... + +.PHONY: staticcheck +staticcheck: $(STATICCHECK) + $(STATICCHECK) ./... + +.PHONY: lint +lint: gofmt golint staticcheck generatenodirty + +# comma separated list of packages to consider for code coverage. +COVER_PKG = $(shell \ + go list -find ./... | \ + grep -v $(foreach pkg,$(COVER_IGNORE_PKGS),-e "^$(pkg)$$") | \ + paste -sd, -) + +.PHONY: cover +cover: + go test -coverprofile=cover.out -coverpkg $(COVER_PKG) -v ./... + go tool cover -html=cover.out -o cover.html + +.PHONY: generate +generate: $(GEN_ATOMICINT) $(GEN_ATOMICWRAPPER) + go generate ./... + +.PHONY: generatenodirty +generatenodirty: + @[ -z "$$(git status --porcelain)" ] || ( \ + echo "Working tree is dirty. Commit your changes first."; \ + git status; \ + exit 1 ) + @make generate + @status=$$(git status --porcelain); \ + [ -z "$$status" ] || ( \ + echo "Working tree is dirty after `make generate`:"; \ + echo "$$status"; \ + echo "Please ensure that the generated code is up-to-date." ) diff --git a/pkg/utils/atomic/README.md b/pkg/utils/atomic/README.md new file mode 100644 index 0000000..3eed44a --- /dev/null +++ b/pkg/utils/atomic/README.md @@ -0,0 +1,33 @@ +# atomic + +Simple wrappers for primitive types to enforce atomic access. + +## Installation + +```shell +$ go get -u github.com/mleku/nodl/pkg/atomic@latest +``` + +## Usage + +The standard library's `sync/atomic` is powerful, but it's easy to forget which +variables must be accessed atomically. `github.com/mleku/nodl/pkg/atomic` preserves all the +functionality of the standard library, but wraps the primitive types to +provide a safer, more convenient API. + +```go +var atom atomic.Uint32 +atom.Store(42) +atom.Sub(2) +atom.CompareAndSwap(40, 11) +``` + +See the [documentation][doc] for a complete API specification. + +## Development Status + +Stable. + +--- + +Released under the [MIT License](LICENSE.txt). \ No newline at end of file diff --git a/pkg/utils/atomic/assert_test.go b/pkg/utils/atomic/assert_test.go new file mode 100644 index 0000000..47cfbf2 --- /dev/null +++ b/pkg/utils/atomic/assert_test.go @@ -0,0 +1,45 @@ +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import ( + "encoding/json" + "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +// Marks the test as failed if the error cannot be cast into the provided type +// with errors.As. +// +// assertErrorAsType(t, err, new(ErrFoo)) +func assertErrorAsType(t *testing.T, err error, typ interface{}, msgAndArgs ...interface{}) bool { + t.Helper() + + return assert.True(t, errors.As(err, typ), msgAndArgs...) +} + +func assertErrorJSONUnmarshalType(t *testing.T, err error, msgAndArgs ...interface{}) bool { + t.Helper() + + return assertErrorAsType(t, err, new(*json.UnmarshalTypeError), msgAndArgs...) +} diff --git a/pkg/utils/atomic/bool.go b/pkg/utils/atomic/bool.go new file mode 100644 index 0000000..f0a2ddd --- /dev/null +++ b/pkg/utils/atomic/bool.go @@ -0,0 +1,88 @@ +// @generated Code generated by gen-atomicwrapper. + +// Copyright (c) 2020-2023 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import ( + "encoding/json" +) + +// Bool is an atomic type-safe wrapper for bool values. +type Bool struct { + _ nocmp // disallow non-atomic comparison + + v Uint32 +} + +var _zeroBool bool + +// NewBool creates a new Bool. +func NewBool(val bool) *Bool { + x := &Bool{} + if val != _zeroBool { + x.Store(val) + } + return x +} + +// Load atomically loads the wrapped bool. +func (x *Bool) Load() bool { + return truthy(x.v.Load()) +} + +// Store atomically stores the passed bool. +func (x *Bool) Store(val bool) { + x.v.Store(boolToInt(val)) +} + +// CAS is an atomic compare-and-swap for bool values. +// +// Deprecated: Use CompareAndSwap. +func (x *Bool) CAS(old, new bool) (swapped bool) { + return x.CompareAndSwap(old, new) +} + +// CompareAndSwap is an atomic compare-and-swap for bool values. +func (x *Bool) CompareAndSwap(old, new bool) (swapped bool) { + return x.v.CompareAndSwap(boolToInt(old), boolToInt(new)) +} + +// Swap atomically stores the given bool and returns the old +// value. +func (x *Bool) Swap(val bool) (old bool) { + return truthy(x.v.Swap(boolToInt(val))) +} + +// MarshalJSON encodes the wrapped bool into JSON. +func (x *Bool) MarshalJSON() ([]byte, error) { + return json.Marshal(x.Load()) +} + +// UnmarshalJSON decodes a bool from JSON. +func (x *Bool) UnmarshalJSON(b []byte) error { + var v bool + if err := json.Unmarshal(b, &v); err != nil { + return err + } + x.Store(v) + return nil +} diff --git a/pkg/utils/atomic/bool_ext.go b/pkg/utils/atomic/bool_ext.go new file mode 100644 index 0000000..a2e60e9 --- /dev/null +++ b/pkg/utils/atomic/bool_ext.go @@ -0,0 +1,53 @@ +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import ( + "strconv" +) + +//go:generate bin/gen-atomicwrapper -name=Bool -type=bool -wrapped=Uint32 -pack=boolToInt -unpack=truthy -cas -swap -json -file=bool.go + +func truthy(n uint32) bool { + return n == 1 +} + +func boolToInt(b bool) uint32 { + if b { + return 1 + } + return 0 +} + +// Toggle atomically negates the Boolean and returns the previous value. +func (b *Bool) Toggle() (old bool) { + for { + old := b.Load() + if b.CAS(old, !old) { + return old + } + } +} + +// String encodes the wrapped value as a string. +func (b *Bool) String() string { + return strconv.FormatBool(b.Load()) +} diff --git a/pkg/utils/atomic/bool_test.go b/pkg/utils/atomic/bool_test.go new file mode 100644 index 0000000..6753ebd --- /dev/null +++ b/pkg/utils/atomic/bool_test.go @@ -0,0 +1,150 @@ +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBool(t *testing.T) { + atom := NewBool(false) + require.False(t, atom.Toggle(), "Expected Toggle to return previous value.") + require.True(t, atom.Toggle(), "Expected Toggle to return previous value.") + require.False(t, atom.Toggle(), "Expected Toggle to return previous value.") + require.True(t, atom.Load(), "Unexpected state after swap.") + + require.True(t, atom.CAS(true, true), "CAS should swap when old matches") + require.True(t, atom.Load(), "CAS should have no effect") + require.True(t, atom.CAS(true, false), "CAS should swap when old matches") + require.False(t, atom.Load(), "CAS should have modified the value") + require.False(t, atom.CAS(true, false), "CAS should fail on old mismatch") + require.False(t, atom.Load(), "CAS should not have modified the value") + + atom.Store(false) + require.False(t, atom.Load(), "Unexpected state after store.") + + prev := atom.Swap(false) + require.False(t, prev, "Expected Swap to return previous value.") + + prev = atom.Swap(true) + require.False(t, prev, "Expected Swap to return previous value.") + + t.Run("JSON/Marshal", func(t *testing.T) { + atom.Store(true) + bytes, err := json.Marshal(atom) + require.NoError(t, err, "json.Marshal errored unexpectedly.") + require.Equal(t, []byte("true"), bytes, "json.Marshal encoded the wrong bytes.") + }) + + t.Run("JSON/Unmarshal", func(t *testing.T) { + err := json.Unmarshal([]byte("false"), &atom) + require.NoError(t, err, "json.Unmarshal errored unexpectedly.") + require.False(t, atom.Load(), "json.Unmarshal didn't set the correct value.") + }) + + t.Run("JSON/Unmarshal/Error", func(t *testing.T) { + err := json.Unmarshal([]byte("42"), &atom) + require.Error(t, err, "json.Unmarshal didn't error as expected.") + assertErrorJSONUnmarshalType(t, err, + "json.Unmarshal failed with unexpected error %v, want UnmarshalTypeError.", err) + }) + + t.Run("String", func(t *testing.T) { + t.Run("true", func(t *testing.T) { + assert.Equal(t, "true", NewBool(true).String(), + "String() returned an unexpected value.") + }) + + t.Run("false", func(t *testing.T) { + var b Bool + assert.Equal(t, "false", b.String(), + "String() returned an unexpected value.") + }) + }) +} + +func TestBool_InitializeDefaults(t *testing.T) { + tests := []struct { + msg string + newBool func() *Bool + }{ + { + msg: "Uninitialized", + newBool: func() *Bool { + var b Bool + return &b + }, + }, + { + msg: "NewBool with default", + newBool: func() *Bool { + return NewBool(false) + }, + }, + { + msg: "Bool swapped with default", + newBool: func() *Bool { + b := NewBool(true) + b.Swap(false) + return b + }, + }, + { + msg: "Bool CAS'd with default", + newBool: func() *Bool { + b := NewBool(true) + b.CompareAndSwap(true, false) + return b + }, + }, + } + + for _, tt := range tests { + t.Run(tt.msg, func(t *testing.T) { + t.Run("Marshal", func(t *testing.T) { + b := tt.newBool() + marshalled, err := b.MarshalJSON() + require.NoError(t, err) + assert.Equal(t, "false", string(marshalled)) + }) + + t.Run("String", func(t *testing.T) { + b := tt.newBool() + assert.Equal(t, "false", b.String()) + }) + + t.Run("CompareAndSwap", func(t *testing.T) { + b := tt.newBool() + require.True(t, b.CompareAndSwap(false, true)) + assert.Equal(t, true, b.Load()) + }) + + t.Run("Swap", func(t *testing.T) { + b := tt.newBool() + assert.Equal(t, false, b.Swap(true)) + }) + }) + } +} diff --git a/pkg/utils/atomic/bytes.go b/pkg/utils/atomic/bytes.go new file mode 100644 index 0000000..c36d318 --- /dev/null +++ b/pkg/utils/atomic/bytes.go @@ -0,0 +1,59 @@ +// @generated Code generated by gen-atomicwrapper. + +// Copyright (c) 2020-2025 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +// Bytes is an atomic type-safe wrapper for []byte values. +type Bytes struct { + _ nocmp // disallow non-atomic comparison + + v Value +} + +var _zeroBytes []byte + +// NewBytes creates a new Bytes. +func NewBytes(val []byte) *Bytes { + x := &Bytes{} + if val != nil { + x.Store(val) + } + return x +} + +// Load atomically loads the wrapped []byte. +func (x *Bytes) Load() (b []byte) { + if x.v.Load() == nil { + return + } + vb := x.v.Load().([]byte) + b = make([]byte, len(vb)) + copy(b, vb) + return +} + +// Store atomically stores the passed []byte. +func (x *Bytes) Store(val []byte) { + b := make([]byte, len(val)) + copy(b, val) + x.v.Store(b) +} diff --git a/pkg/utils/atomic/bytes_ext.go b/pkg/utils/atomic/bytes_ext.go new file mode 100644 index 0000000..a4ca879 --- /dev/null +++ b/pkg/utils/atomic/bytes_ext.go @@ -0,0 +1,56 @@ +// Copyright (c) 2020-2025 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import ( + "encoding/base64" + "encoding/json" +) + +// MarshalJSON encodes the wrapped []byte as a base64 string. +// +// This makes it encodable as JSON. +func (b *Bytes) MarshalJSON() ([]byte, error) { + data := b.Load() + if data == nil { + return []byte("null"), nil + } + encoded := base64.StdEncoding.EncodeToString(data) + return json.Marshal(encoded) +} + +// UnmarshalJSON decodes a base64 string and replaces the wrapped []byte with it. +// +// This makes it decodable from JSON. +func (b *Bytes) UnmarshalJSON(text []byte) error { + var encoded string + if err := json.Unmarshal(text, &encoded); err != nil { + return err + } + + decoded, err := base64.StdEncoding.DecodeString(encoded) + if err != nil { + return err + } + + b.Store(decoded) + return nil +} diff --git a/pkg/utils/atomic/bytes_test.go b/pkg/utils/atomic/bytes_test.go new file mode 100644 index 0000000..e4e5be1 --- /dev/null +++ b/pkg/utils/atomic/bytes_test.go @@ -0,0 +1,252 @@ +// Copyright (c) 2020-2025 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import ( + "encoding/json" + "runtime" + "sync" + "testing" + "utils.orly" + + "github.com/stretchr/testify/require" +) + +func TestBytesNoInitialValue(t *testing.T) { + atom := NewBytes([]byte{}) + require.Equal(t, []byte{}, atom.Load(), "Initial value should be empty") +} + +func TestBytes(t *testing.T) { + atom := NewBytes([]byte{}) + require.Equal( + t, []byte{}, atom.Load(), + "Expected Load to return initialized empty value", + ) + + emptyBytes := []byte{} + atom = NewBytes(emptyBytes) + require.Equal( + t, emptyBytes, atom.Load(), + "Expected Load to return initialized empty value", + ) + + testBytes := []byte("test data") + atom = NewBytes(testBytes) + loadedBytes := atom.Load() + require.Equal( + t, testBytes, loadedBytes, "Expected Load to return initialized value", + ) + + // Verify that the returned value is a copy + loadedBytes[0] = 'X' + require.NotEqual( + t, loadedBytes, atom.Load(), "Load should return a copy of the data", + ) + + // Store and verify + newBytes := []byte("new data") + atom.Store(newBytes) + require.Equal(t, newBytes, atom.Load(), "Unexpected value after Store") + + // Modify original data and verify it doesn't affect stored value + newBytes[0] = 'X' + require.NotEqual(t, newBytes, atom.Load(), "Store should copy the data") + + t.Run( + "JSON/Marshal", func(t *testing.T) { + jsonBytes := []byte("json data") + atom.Store(jsonBytes) + bytes, err := json.Marshal(atom) + require.NoError(t, err, "json.Marshal errored unexpectedly.") + require.Equal( + t, []byte(`"anNvbiBkYXRh"`), bytes, + "json.Marshal should encode as base64", + ) + }, + ) + + t.Run( + "JSON/Unmarshal", func(t *testing.T) { + err := json.Unmarshal( + []byte(`"dGVzdCBkYXRh"`), &atom, + ) // "test data" in base64 + require.NoError(t, err, "json.Unmarshal errored unexpectedly.") + require.Equal( + t, []byte("test data"), atom.Load(), + "json.Unmarshal didn't set the correct value.", + ) + }, + ) + + t.Run( + "JSON/Unmarshal/Error", func(t *testing.T) { + err := json.Unmarshal([]byte("42"), &atom) + require.Error(t, err, "json.Unmarshal didn't error as expected.") + }, + ) +} + +func TestBytesConcurrentAccess(t *testing.T) { + const ( + parallelism = 4 + iterations = 1000 + ) + + atom := NewBytes([]byte("initial")) + + var wg sync.WaitGroup + wg.Add(parallelism) + + // Start multiple goroutines that read and write concurrently + for i := 0; i < parallelism; i++ { + go func(id int) { + defer wg.Done() + + // Each goroutine writes a different value + myData := []byte{byte(id)} + + for j := 0; j < iterations; j++ { + // Store our data + atom.Store(myData) + + // Load the data (which might be from another goroutine) + loaded := atom.Load() + + // Verify the loaded data is valid (either our data or another goroutine's data) + require.LessOrEqual( + t, len(loaded), parallelism, + "Loaded data length should not exceed parallelism", + ) + + // If it's our data, verify it's correct + if len(loaded) == 1 && loaded[0] == byte(id) { + require.Equal(t, myData, loaded, "Data corruption detected") + } + } + }(i) + } + + wg.Wait() +} + +func TestBytesDataIntegrity(t *testing.T) { + // Test that large byte slices maintain integrity under concurrent access + const ( + parallelism = 4 + dataSize = 1024 // 1KB + iterations = 100 + ) + + // Create test data sets, each with a unique pattern + testData := make([][]byte, parallelism) + for i := 0; i < parallelism; i++ { + testData[i] = make([]byte, dataSize) + for j := 0; j < dataSize; j++ { + testData[i][j] = byte((i + j) % 256) + } + } + + atom := NewBytes(nil) + var wg sync.WaitGroup + wg.Add(parallelism) + + for i := 0; i < parallelism; i++ { + go func(id int) { + defer wg.Done() + myData := testData[id] + + for j := 0; j < iterations; j++ { + atom.Store(myData) + loaded := atom.Load() + + // Verify the loaded data is one of our test data sets + for k := 0; k < parallelism; k++ { + if utils.FastEqual(loaded, testData[k]) { + // Found a match, data is intact + break + } + if k == parallelism-1 { + // No match found, data corruption + t.Errorf("Data corruption detected: loaded data doesn't match any test set") + } + } + } + }(i) + } + + wg.Wait() +} + +func TestBytesStress(t *testing.T) { + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(4)) + + atom := NewBytes([]byte("initial")) + var wg sync.WaitGroup + + // We'll run 8 goroutines concurrently + workers := 8 + iterations := 1000 + wg.Add(workers) + + start := make(chan struct{}) + + for i := 0; i < workers; i++ { + go func(id int) { + defer wg.Done() + + // Wait for the start signal + <-start + + // Each worker gets its own data + myData := []byte{byte(id)} + + for j := 0; j < iterations; j++ { + // Alternate between reads and writes + if j%2 == 0 { + atom.Store(myData) + } else { + _ = atom.Load() + } + } + }(i) + } + + // Start all goroutines simultaneously + close(start) + wg.Wait() +} + +func BenchmarkBytesParallel(b *testing.B) { + atom := NewBytes([]byte("benchmark")) + + b.RunParallel( + func(pb *testing.PB) { + // Each goroutine gets its own data to prevent false sharing + myData := []byte("goroutine data") + + for pb.Next() { + atom.Store(myData) + _ = atom.Load() + } + }, + ) +} diff --git a/pkg/utils/atomic/doc.go b/pkg/utils/atomic/doc.go new file mode 100644 index 0000000..ae7390e --- /dev/null +++ b/pkg/utils/atomic/doc.go @@ -0,0 +1,23 @@ +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +// Package atomic provides simple wrappers around numerics to enforce atomic +// access. +package atomic diff --git a/pkg/utils/atomic/duration.go b/pkg/utils/atomic/duration.go new file mode 100644 index 0000000..7c23868 --- /dev/null +++ b/pkg/utils/atomic/duration.go @@ -0,0 +1,89 @@ +// @generated Code generated by gen-atomicwrapper. + +// Copyright (c) 2020-2023 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import ( + "encoding/json" + "time" +) + +// Duration is an atomic type-safe wrapper for time.Duration values. +type Duration struct { + _ nocmp // disallow non-atomic comparison + + v Int64 +} + +var _zeroDuration time.Duration + +// NewDuration creates a new Duration. +func NewDuration(val time.Duration) *Duration { + x := &Duration{} + if val != _zeroDuration { + x.Store(val) + } + return x +} + +// Load atomically loads the wrapped time.Duration. +func (x *Duration) Load() time.Duration { + return time.Duration(x.v.Load()) +} + +// Store atomically stores the passed time.Duration. +func (x *Duration) Store(val time.Duration) { + x.v.Store(int64(val)) +} + +// CAS is an atomic compare-and-swap for time.Duration values. +// +// Deprecated: Use CompareAndSwap. +func (x *Duration) CAS(old, new time.Duration) (swapped bool) { + return x.CompareAndSwap(old, new) +} + +// CompareAndSwap is an atomic compare-and-swap for time.Duration values. +func (x *Duration) CompareAndSwap(old, new time.Duration) (swapped bool) { + return x.v.CompareAndSwap(int64(old), int64(new)) +} + +// Swap atomically stores the given time.Duration and returns the old +// value. +func (x *Duration) Swap(val time.Duration) (old time.Duration) { + return time.Duration(x.v.Swap(int64(val))) +} + +// MarshalJSON encodes the wrapped time.Duration into JSON. +func (x *Duration) MarshalJSON() ([]byte, error) { + return json.Marshal(x.Load()) +} + +// UnmarshalJSON decodes a time.Duration from JSON. +func (x *Duration) UnmarshalJSON(b []byte) error { + var v time.Duration + if err := json.Unmarshal(b, &v); err != nil { + return err + } + x.Store(v) + return nil +} diff --git a/pkg/utils/atomic/duration_ext.go b/pkg/utils/atomic/duration_ext.go new file mode 100644 index 0000000..62a45b3 --- /dev/null +++ b/pkg/utils/atomic/duration_ext.go @@ -0,0 +1,40 @@ +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import "time" + +//go:generate bin/gen-atomicwrapper -name=Duration -type=time.Duration -wrapped=Int64 -pack=int64 -unpack=time.Duration -cas -swap -json -imports time -file=duration.go + +// Add atomically adds to the wrapped time.Duration and returns the new value. +func (x *Duration) Add(delta time.Duration) time.Duration { + return time.Duration(x.v.Add(int64(delta))) +} + +// Sub atomically subtracts from the wrapped time.Duration and returns the new value. +func (x *Duration) Sub(delta time.Duration) time.Duration { + return time.Duration(x.v.Sub(int64(delta))) +} + +// String encodes the wrapped value as a string. +func (x *Duration) String() string { + return x.Load().String() +} diff --git a/pkg/utils/atomic/duration_test.go b/pkg/utils/atomic/duration_test.go new file mode 100644 index 0000000..f5779fe --- /dev/null +++ b/pkg/utils/atomic/duration_test.go @@ -0,0 +1,73 @@ +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import ( + "encoding/json" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDuration(t *testing.T) { + atom := NewDuration(5 * time.Minute) + + require.Equal(t, 5*time.Minute, atom.Load(), "Load didn't work.") + require.Equal(t, 6*time.Minute, atom.Add(time.Minute), "Add didn't work.") + require.Equal(t, 4*time.Minute, atom.Sub(2*time.Minute), "Sub didn't work.") + + require.True(t, atom.CAS(4*time.Minute, time.Minute), "CAS didn't report a swap.") + require.Equal(t, time.Minute, atom.Load(), "CAS didn't set the correct value.") + + require.Equal(t, time.Minute, atom.Swap(2*time.Minute), "Swap didn't return the old value.") + require.Equal(t, 2*time.Minute, atom.Load(), "Swap didn't set the correct value.") + + atom.Store(10 * time.Minute) + require.Equal(t, 10*time.Minute, atom.Load(), "Store didn't set the correct value.") + + t.Run("JSON/Marshal", func(t *testing.T) { + atom.Store(time.Second) + bytes, err := json.Marshal(atom) + require.NoError(t, err, "json.Marshal errored unexpectedly.") + require.Equal(t, []byte("1000000000"), bytes, "json.Marshal encoded the wrong bytes.") + }) + + t.Run("JSON/Unmarshal", func(t *testing.T) { + err := json.Unmarshal([]byte("1000000000"), &atom) + require.NoError(t, err, "json.Unmarshal errored unexpectedly.") + require.Equal(t, time.Second, atom.Load(), + "json.Unmarshal didn't set the correct value.") + }) + + t.Run("JSON/Unmarshal/Error", func(t *testing.T) { + err := json.Unmarshal([]byte("\"1000000000\""), &atom) + require.Error(t, err, "json.Unmarshal didn't error as expected.") + assertErrorJSONUnmarshalType(t, err, + "json.Unmarshal failed with unexpected error %v, want UnmarshalTypeError.", err) + }) + + t.Run("String", func(t *testing.T) { + assert.Equal(t, "42s", NewDuration(42*time.Second).String(), + "String() returned an unexpected value.") + }) +} diff --git a/pkg/utils/atomic/error.go b/pkg/utils/atomic/error.go new file mode 100644 index 0000000..b7e3f12 --- /dev/null +++ b/pkg/utils/atomic/error.go @@ -0,0 +1,72 @@ +// @generated Code generated by gen-atomicwrapper. + +// Copyright (c) 2020-2023 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +// Error is an atomic type-safe wrapper for error values. +type Error struct { + _ nocmp // disallow non-atomic comparison + + v Value +} + +var _zeroError error + +// NewError creates a new Error. +func NewError(val error) *Error { + x := &Error{} + if val != _zeroError { + x.Store(val) + } + return x +} + +// Load atomically loads the wrapped error. +func (x *Error) Load() error { + return unpackError(x.v.Load()) +} + +// Store atomically stores the passed error. +func (x *Error) Store(val error) { + x.v.Store(packError(val)) +} + +// CompareAndSwap is an atomic compare-and-swap for error values. +func (x *Error) CompareAndSwap(old, new error) (swapped bool) { + if x.v.CompareAndSwap(packError(old), packError(new)) { + return true + } + + if old == _zeroError { + // If the old value is the empty value, then it's possible the + // underlying Value hasn't been set and is nil, so retry with nil. + return x.v.CompareAndSwap(nil, packError(new)) + } + + return false +} + +// Swap atomically stores the given error and returns the old +// value. +func (x *Error) Swap(val error) (old error) { + return unpackError(x.v.Swap(packError(val))) +} diff --git a/pkg/utils/atomic/error_ext.go b/pkg/utils/atomic/error_ext.go new file mode 100644 index 0000000..d31fb63 --- /dev/null +++ b/pkg/utils/atomic/error_ext.go @@ -0,0 +1,39 @@ +// Copyright (c) 2020-2022 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +// atomic.Value panics on nil inputs, or if the underlying type changes. +// Stabilize by always storing a custom struct that we control. + +//go:generate bin/gen-atomicwrapper -name=Error -type=error -wrapped=Value -pack=packError -unpack=unpackError -compareandswap -swap -file=error.go + +type packedError struct{ Value error } + +func packError(v error) interface{} { + return packedError{v} +} + +func unpackError(v interface{}) error { + if err, ok := v.(packedError); ok { + return err.Value + } + return nil +} diff --git a/pkg/utils/atomic/error_test.go b/pkg/utils/atomic/error_test.go new file mode 100644 index 0000000..1f02e6d --- /dev/null +++ b/pkg/utils/atomic/error_test.go @@ -0,0 +1,136 @@ +// Copyright (c) 2016 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestErrorByValue(t *testing.T) { + err := &Error{} + require.Nil(t, err.Load(), "Initial value shall be nil") +} + +func TestNewErrorWithNilArgument(t *testing.T) { + err := NewError(nil) + require.Nil(t, err.Load(), "Initial value shall be nil") +} + +func TestErrorCanStoreNil(t *testing.T) { + err := NewError(errors.New("hello")) + err.Store(nil) + require.Nil(t, err.Load(), "Stored value shall be nil") +} + +func TestNewErrorWithError(t *testing.T) { + err1 := errors.New("hello1") + err2 := errors.New("hello2") + + atom := NewError(err1) + require.Equal(t, err1, atom.Load(), "Expected Load to return initialized value") + + atom.Store(err2) + require.Equal(t, err2, atom.Load(), "Expected Load to return overridden value") +} + +func TestErrorSwap(t *testing.T) { + err1 := errors.New("hello1") + err2 := errors.New("hello2") + + atom := NewError(err1) + require.Equal(t, err1, atom.Load(), "Expected Load to return initialized value") + + old := atom.Swap(err2) + require.Equal(t, err2, atom.Load(), "Expected Load to return overridden value") + require.Equal(t, err1, old, "Expected old to be initial value") +} + +func TestErrorCompareAndSwap(t *testing.T) { + err1 := errors.New("hello1") + err2 := errors.New("hello2") + + atom := NewError(err1) + require.Equal(t, err1, atom.Load(), "Expected Load to return initialized value") + + swapped := atom.CompareAndSwap(err2, err2) + require.False(t, swapped, "Expected swapped to be false") + require.Equal(t, err1, atom.Load(), "Expected Load to return initial value") + + swapped = atom.CompareAndSwap(err1, err2) + require.True(t, swapped, "Expected swapped to be true") + require.Equal(t, err2, atom.Load(), "Expected Load to return overridden value") +} + +func TestError_InitializeDefaults(t *testing.T) { + tests := []struct { + msg string + newError func() *Error + }{ + { + msg: "Uninitialized", + newError: func() *Error { + var e Error + return &e + }, + }, + { + msg: "NewError with default", + newError: func() *Error { + return NewError(nil) + }, + }, + { + msg: "Error swapped with default", + newError: func() *Error { + e := NewError(assert.AnError) + e.Swap(nil) + return e + }, + }, + { + msg: "Error CAS'd with default", + newError: func() *Error { + e := NewError(assert.AnError) + e.CompareAndSwap(assert.AnError, nil) + return e + }, + }, + } + + for _, tt := range tests { + t.Run(tt.msg, func(t *testing.T) { + t.Run("CompareAndSwap", func(t *testing.T) { + e := tt.newError() + require.True(t, e.CompareAndSwap(nil, assert.AnError)) + assert.Equal(t, assert.AnError, e.Load()) + }) + + t.Run("Swap", func(t *testing.T) { + e := tt.newError() + assert.Equal(t, nil, e.Swap(assert.AnError)) + }) + }) + } +} diff --git a/pkg/utils/atomic/example_test.go b/pkg/utils/atomic/example_test.go new file mode 100644 index 0000000..4d6ad9e --- /dev/null +++ b/pkg/utils/atomic/example_test.go @@ -0,0 +1,42 @@ +// Copyright (c) 2016 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic_test + +import ( + "fmt" + "utils.orly/atomic" +) + +func Example() { + // Uint32 is a thin wrapper around the primitive uint32 type. + var atom atomic.Uint32 + + // The wrapper ensures that all operations are atomic. + atom.Store(42) + fmt.Println(atom.Inc()) + fmt.Println(atom.CompareAndSwap(43, 0)) + fmt.Println(atom.Load()) + + // Output: + // 43 + // true + // 0 +} diff --git a/pkg/utils/atomic/float32.go b/pkg/utils/atomic/float32.go new file mode 100644 index 0000000..62c3633 --- /dev/null +++ b/pkg/utils/atomic/float32.go @@ -0,0 +1,77 @@ +// @generated Code generated by gen-atomicwrapper. + +// Copyright (c) 2020-2023 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import ( + "encoding/json" + "math" +) + +// Float32 is an atomic type-safe wrapper for float32 values. +type Float32 struct { + _ nocmp // disallow non-atomic comparison + + v Uint32 +} + +var _zeroFloat32 float32 + +// NewFloat32 creates a new Float32. +func NewFloat32(val float32) *Float32 { + x := &Float32{} + if val != _zeroFloat32 { + x.Store(val) + } + return x +} + +// Load atomically loads the wrapped float32. +func (x *Float32) Load() float32 { + return math.Float32frombits(x.v.Load()) +} + +// Store atomically stores the passed float32. +func (x *Float32) Store(val float32) { + x.v.Store(math.Float32bits(val)) +} + +// Swap atomically stores the given float32 and returns the old +// value. +func (x *Float32) Swap(val float32) (old float32) { + return math.Float32frombits(x.v.Swap(math.Float32bits(val))) +} + +// MarshalJSON encodes the wrapped float32 into JSON. +func (x *Float32) MarshalJSON() ([]byte, error) { + return json.Marshal(x.Load()) +} + +// UnmarshalJSON decodes a float32 from JSON. +func (x *Float32) UnmarshalJSON(b []byte) error { + var v float32 + if err := json.Unmarshal(b, &v); err != nil { + return err + } + x.Store(v) + return nil +} diff --git a/pkg/utils/atomic/float32_ext.go b/pkg/utils/atomic/float32_ext.go new file mode 100644 index 0000000..b0cd8d9 --- /dev/null +++ b/pkg/utils/atomic/float32_ext.go @@ -0,0 +1,76 @@ +// Copyright (c) 2020-2022 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import ( + "math" + "strconv" +) + +//go:generate bin/gen-atomicwrapper -name=Float32 -type=float32 -wrapped=Uint32 -pack=math.Float32bits -unpack=math.Float32frombits -swap -json -imports math -file=float32.go + +// Add atomically adds to the wrapped float32 and returns the new value. +func (f *Float32) Add(delta float32) float32 { + for { + old := f.Load() + new := old + delta + if f.CAS(old, new) { + return new + } + } +} + +// Sub atomically subtracts from the wrapped float32 and returns the new value. +func (f *Float32) Sub(delta float32) float32 { + return f.Add(-delta) +} + +// CAS is an atomic compare-and-swap for float32 values. +// +// Deprecated: Use CompareAndSwap +func (f *Float32) CAS(old, new float32) (swapped bool) { + return f.CompareAndSwap(old, new) +} + +// CompareAndSwap is an atomic compare-and-swap for float32 values. +// +// Note: CompareAndSwap handles NaN incorrectly. NaN != NaN using Go's inbuilt operators +// but CompareAndSwap allows a stored NaN to compare equal to a passed in NaN. +// This avoids typical CompareAndSwap loops from blocking forever, e.g., +// +// for { +// old := atom.Load() +// new = f(old) +// if atom.CompareAndSwap(old, new) { +// break +// } +// } +// +// If CompareAndSwap did not match NaN to match, then the above would loop forever. +func (f *Float32) CompareAndSwap(old, new float32) (swapped bool) { + return f.v.CompareAndSwap(math.Float32bits(old), math.Float32bits(new)) +} + +// String encodes the wrapped value as a string. +func (f *Float32) String() string { + // 'g' is the behavior for floats with %v. + return strconv.FormatFloat(float64(f.Load()), 'g', -1, 32) +} diff --git a/pkg/utils/atomic/float32_test.go b/pkg/utils/atomic/float32_test.go new file mode 100644 index 0000000..5b7fd51 --- /dev/null +++ b/pkg/utils/atomic/float32_test.go @@ -0,0 +1,73 @@ +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestFloat32(t *testing.T) { + atom := NewFloat32(4.2) + + require.Equal(t, float32(4.2), atom.Load(), "Load didn't work.") + + require.True(t, atom.CAS(4.2, 0.5), "CAS didn't report a swap.") + require.Equal(t, float32(0.5), atom.Load(), "CAS didn't set the correct value.") + require.False(t, atom.CAS(0.0, 1.5), "CAS reported a swap.") + + atom.Store(42.0) + require.Equal(t, float32(42.0), atom.Load(), "Store didn't set the correct value.") + require.Equal(t, float32(42.5), atom.Add(0.5), "Add didn't work.") + require.Equal(t, float32(42.0), atom.Sub(0.5), "Sub didn't work.") + + require.Equal(t, float32(42.0), atom.Swap(45.0), "Swap didn't return the old value.") + require.Equal(t, float32(45.0), atom.Load(), "Swap didn't set the correct value.") + + t.Run("JSON/Marshal", func(t *testing.T) { + atom.Store(42.5) + bytes, err := json.Marshal(atom) + require.NoError(t, err, "json.Marshal errored unexpectedly.") + require.Equal(t, []byte("42.5"), bytes, "json.Marshal encoded the wrong bytes.") + }) + + t.Run("JSON/Unmarshal", func(t *testing.T) { + err := json.Unmarshal([]byte("40.5"), &atom) + require.NoError(t, err, "json.Unmarshal errored unexpectedly.") + require.Equal(t, float32(40.5), atom.Load(), + "json.Unmarshal didn't set the correct value.") + }) + + t.Run("JSON/Unmarshal/Error", func(t *testing.T) { + err := json.Unmarshal([]byte("\"40.5\""), &atom) + require.Error(t, err, "json.Unmarshal didn't error as expected.") + assertErrorJSONUnmarshalType(t, err, + "json.Unmarshal failed with unexpected error %v, want UnmarshalTypeError.", err) + }) + + t.Run("String", func(t *testing.T) { + assert.Equal(t, "42.5", NewFloat32(42.5).String(), + "String() returned an unexpected value.") + }) +} diff --git a/pkg/utils/atomic/float64.go b/pkg/utils/atomic/float64.go new file mode 100644 index 0000000..5bc11ca --- /dev/null +++ b/pkg/utils/atomic/float64.go @@ -0,0 +1,77 @@ +// @generated Code generated by gen-atomicwrapper. + +// Copyright (c) 2020-2023 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import ( + "encoding/json" + "math" +) + +// Float64 is an atomic type-safe wrapper for float64 values. +type Float64 struct { + _ nocmp // disallow non-atomic comparison + + v Uint64 +} + +var _zeroFloat64 float64 + +// NewFloat64 creates a new Float64. +func NewFloat64(val float64) *Float64 { + x := &Float64{} + if val != _zeroFloat64 { + x.Store(val) + } + return x +} + +// Load atomically loads the wrapped float64. +func (x *Float64) Load() float64 { + return math.Float64frombits(x.v.Load()) +} + +// Store atomically stores the passed float64. +func (x *Float64) Store(val float64) { + x.v.Store(math.Float64bits(val)) +} + +// Swap atomically stores the given float64 and returns the old +// value. +func (x *Float64) Swap(val float64) (old float64) { + return math.Float64frombits(x.v.Swap(math.Float64bits(val))) +} + +// MarshalJSON encodes the wrapped float64 into JSON. +func (x *Float64) MarshalJSON() ([]byte, error) { + return json.Marshal(x.Load()) +} + +// UnmarshalJSON decodes a float64 from JSON. +func (x *Float64) UnmarshalJSON(b []byte) error { + var v float64 + if err := json.Unmarshal(b, &v); err != nil { + return err + } + x.Store(v) + return nil +} diff --git a/pkg/utils/atomic/float64_ext.go b/pkg/utils/atomic/float64_ext.go new file mode 100644 index 0000000..48c52b0 --- /dev/null +++ b/pkg/utils/atomic/float64_ext.go @@ -0,0 +1,76 @@ +// Copyright (c) 2020-2022 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import ( + "math" + "strconv" +) + +//go:generate bin/gen-atomicwrapper -name=Float64 -type=float64 -wrapped=Uint64 -pack=math.Float64bits -unpack=math.Float64frombits -swap -json -imports math -file=float64.go + +// Add atomically adds to the wrapped float64 and returns the new value. +func (f *Float64) Add(delta float64) float64 { + for { + old := f.Load() + new := old + delta + if f.CAS(old, new) { + return new + } + } +} + +// Sub atomically subtracts from the wrapped float64 and returns the new value. +func (f *Float64) Sub(delta float64) float64 { + return f.Add(-delta) +} + +// CAS is an atomic compare-and-swap for float64 values. +// +// Deprecated: Use CompareAndSwap +func (f *Float64) CAS(old, new float64) (swapped bool) { + return f.CompareAndSwap(old, new) +} + +// CompareAndSwap is an atomic compare-and-swap for float64 values. +// +// Note: CompareAndSwap handles NaN incorrectly. NaN != NaN using Go's inbuilt operators +// but CompareAndSwap allows a stored NaN to compare equal to a passed in NaN. +// This avoids typical CompareAndSwap loops from blocking forever, e.g., +// +// for { +// old := atom.Load() +// new = f(old) +// if atom.CompareAndSwap(old, new) { +// break +// } +// } +// +// If CompareAndSwap did not match NaN to match, then the above would loop forever. +func (f *Float64) CompareAndSwap(old, new float64) (swapped bool) { + return f.v.CompareAndSwap(math.Float64bits(old), math.Float64bits(new)) +} + +// String encodes the wrapped value as a string. +func (f *Float64) String() string { + // 'g' is the behavior for floats with %v. + return strconv.FormatFloat(f.Load(), 'g', -1, 64) +} diff --git a/pkg/utils/atomic/float64_test.go b/pkg/utils/atomic/float64_test.go new file mode 100644 index 0000000..32fbc58 --- /dev/null +++ b/pkg/utils/atomic/float64_test.go @@ -0,0 +1,73 @@ +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestFloat64(t *testing.T) { + atom := NewFloat64(4.2) + + require.Equal(t, float64(4.2), atom.Load(), "Load didn't work.") + + require.True(t, atom.CAS(4.2, 0.5), "CAS didn't report a swap.") + require.Equal(t, float64(0.5), atom.Load(), "CAS didn't set the correct value.") + require.False(t, atom.CAS(0.0, 1.5), "CAS reported a swap.") + + atom.Store(42.0) + require.Equal(t, float64(42.0), atom.Load(), "Store didn't set the correct value.") + require.Equal(t, float64(42.5), atom.Add(0.5), "Add didn't work.") + require.Equal(t, float64(42.0), atom.Sub(0.5), "Sub didn't work.") + + require.Equal(t, float64(42.0), atom.Swap(45.0), "Swap didn't return the old value.") + require.Equal(t, float64(45.0), atom.Load(), "Swap didn't set the correct value.") + + t.Run("JSON/Marshal", func(t *testing.T) { + atom.Store(42.5) + bytes, err := json.Marshal(atom) + require.NoError(t, err, "json.Marshal errored unexpectedly.") + require.Equal(t, []byte("42.5"), bytes, "json.Marshal encoded the wrong bytes.") + }) + + t.Run("JSON/Unmarshal", func(t *testing.T) { + err := json.Unmarshal([]byte("40.5"), &atom) + require.NoError(t, err, "json.Unmarshal errored unexpectedly.") + require.Equal(t, float64(40.5), atom.Load(), + "json.Unmarshal didn't set the correct value.") + }) + + t.Run("JSON/Unmarshal/Error", func(t *testing.T) { + err := json.Unmarshal([]byte("\"40.5\""), &atom) + require.Error(t, err, "json.Unmarshal didn't error as expected.") + assertErrorJSONUnmarshalType(t, err, + "json.Unmarshal failed with unexpected error %v, want UnmarshalTypeError.", err) + }) + + t.Run("String", func(t *testing.T) { + assert.Equal(t, "42.5", NewFloat64(42.5).String(), + "String() returned an unexpected value.") + }) +} diff --git a/pkg/utils/atomic/gen.go b/pkg/utils/atomic/gen.go new file mode 100644 index 0000000..1e9ef4f --- /dev/null +++ b/pkg/utils/atomic/gen.go @@ -0,0 +1,27 @@ +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +//go:generate bin/gen-atomicint -name=Int32 -wrapped=int32 -file=int32.go +//go:generate bin/gen-atomicint -name=Int64 -wrapped=int64 -file=int64.go +//go:generate bin/gen-atomicint -name=Uint32 -wrapped=uint32 -unsigned -file=uint32.go +//go:generate bin/gen-atomicint -name=Uint64 -wrapped=uint64 -unsigned -file=uint64.go +//go:generate bin/gen-atomicint -name=Uintptr -wrapped=uintptr -unsigned -file=uintptr.go diff --git a/pkg/utils/atomic/int32.go b/pkg/utils/atomic/int32.go new file mode 100644 index 0000000..5320eac --- /dev/null +++ b/pkg/utils/atomic/int32.go @@ -0,0 +1,109 @@ +// @generated Code generated by gen-atomicint. + +// Copyright (c) 2020-2023 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import ( + "encoding/json" + "strconv" + "sync/atomic" +) + +// Int32 is an atomic wrapper around int32. +type Int32 struct { + _ nocmp // disallow non-atomic comparison + + v int32 +} + +// NewInt32 creates a new Int32. +func NewInt32(val int32) *Int32 { + return &Int32{v: val} +} + +// Load atomically loads the wrapped value. +func (i *Int32) Load() int32 { + return atomic.LoadInt32(&i.v) +} + +// Add atomically adds to the wrapped int32 and returns the new value. +func (i *Int32) Add(delta int32) int32 { + return atomic.AddInt32(&i.v, delta) +} + +// Sub atomically subtracts from the wrapped int32 and returns the new value. +func (i *Int32) Sub(delta int32) int32 { + return atomic.AddInt32(&i.v, -delta) +} + +// Inc atomically increments the wrapped int32 and returns the new value. +func (i *Int32) Inc() int32 { + return i.Add(1) +} + +// Dec atomically decrements the wrapped int32 and returns the new value. +func (i *Int32) Dec() int32 { + return i.Sub(1) +} + +// CAS is an atomic compare-and-swap. +// +// Deprecated: Use CompareAndSwap. +func (i *Int32) CAS(old, new int32) (swapped bool) { + return i.CompareAndSwap(old, new) +} + +// CompareAndSwap is an atomic compare-and-swap. +func (i *Int32) CompareAndSwap(old, new int32) (swapped bool) { + return atomic.CompareAndSwapInt32(&i.v, old, new) +} + +// Store atomically stores the passed value. +func (i *Int32) Store(val int32) { + atomic.StoreInt32(&i.v, val) +} + +// Swap atomically swaps the wrapped int32 and returns the old value. +func (i *Int32) Swap(val int32) (old int32) { + return atomic.SwapInt32(&i.v, val) +} + +// MarshalJSON encodes the wrapped int32 into JSON. +func (i *Int32) MarshalJSON() ([]byte, error) { + return json.Marshal(i.Load()) +} + +// UnmarshalJSON decodes JSON into the wrapped int32. +func (i *Int32) UnmarshalJSON(b []byte) error { + var v int32 + if err := json.Unmarshal(b, &v); err != nil { + return err + } + i.Store(v) + return nil +} + +// String encodes the wrapped value as a string. +func (i *Int32) String() string { + v := i.Load() + return strconv.FormatInt(int64(v), 10) +} diff --git a/pkg/utils/atomic/int32_test.go b/pkg/utils/atomic/int32_test.go new file mode 100644 index 0000000..9992251 --- /dev/null +++ b/pkg/utils/atomic/int32_test.go @@ -0,0 +1,82 @@ +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import ( + "encoding/json" + "math" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestInt32(t *testing.T) { + atom := NewInt32(42) + + require.Equal(t, int32(42), atom.Load(), "Load didn't work.") + require.Equal(t, int32(46), atom.Add(4), "Add didn't work.") + require.Equal(t, int32(44), atom.Sub(2), "Sub didn't work.") + require.Equal(t, int32(45), atom.Inc(), "Inc didn't work.") + require.Equal(t, int32(44), atom.Dec(), "Dec didn't work.") + + require.True(t, atom.CAS(44, 0), "CAS didn't report a swap.") + require.Equal(t, int32(0), atom.Load(), "CAS didn't set the correct value.") + + require.Equal(t, int32(0), atom.Swap(1), "Swap didn't return the old value.") + require.Equal(t, int32(1), atom.Load(), "Swap didn't set the correct value.") + + atom.Store(42) + require.Equal(t, int32(42), atom.Load(), "Store didn't set the correct value.") + + t.Run("JSON/Marshal", func(t *testing.T) { + bytes, err := json.Marshal(atom) + require.NoError(t, err, "json.Marshal errored unexpectedly.") + require.Equal(t, []byte("42"), bytes, "json.Marshal encoded the wrong bytes.") + }) + + t.Run("JSON/Unmarshal", func(t *testing.T) { + err := json.Unmarshal([]byte("40"), &atom) + require.NoError(t, err, "json.Unmarshal errored unexpectedly.") + require.Equal(t, int32(40), atom.Load(), "json.Unmarshal didn't set the correct value.") + }) + + t.Run("JSON/Unmarshal/Error", func(t *testing.T) { + err := json.Unmarshal([]byte(`"40"`), &atom) + require.Error(t, err, "json.Unmarshal didn't error as expected.") + assertErrorJSONUnmarshalType(t, err, + "json.Unmarshal failed with unexpected error %v, want UnmarshalTypeError.", err) + }) + + t.Run("String", func(t *testing.T) { + t.Run("positive", func(t *testing.T) { + atom := NewInt32(math.MaxInt32) + assert.Equal(t, "2147483647", atom.String(), + "String() returned an unexpected value.") + }) + + t.Run("negative", func(t *testing.T) { + atom := NewInt32(math.MinInt32) + assert.Equal(t, "-2147483648", atom.String(), + "String() returned an unexpected value.") + }) + }) +} diff --git a/pkg/utils/atomic/int64.go b/pkg/utils/atomic/int64.go new file mode 100644 index 0000000..460821d --- /dev/null +++ b/pkg/utils/atomic/int64.go @@ -0,0 +1,109 @@ +// @generated Code generated by gen-atomicint. + +// Copyright (c) 2020-2023 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import ( + "encoding/json" + "strconv" + "sync/atomic" +) + +// Int64 is an atomic wrapper around int64. +type Int64 struct { + _ nocmp // disallow non-atomic comparison + + v int64 +} + +// NewInt64 creates a new Int64. +func NewInt64(val int64) *Int64 { + return &Int64{v: val} +} + +// Load atomically loads the wrapped value. +func (i *Int64) Load() int64 { + return atomic.LoadInt64(&i.v) +} + +// Add atomically adds to the wrapped int64 and returns the new value. +func (i *Int64) Add(delta int64) int64 { + return atomic.AddInt64(&i.v, delta) +} + +// Sub atomically subtracts from the wrapped int64 and returns the new value. +func (i *Int64) Sub(delta int64) int64 { + return atomic.AddInt64(&i.v, -delta) +} + +// Inc atomically increments the wrapped int64 and returns the new value. +func (i *Int64) Inc() int64 { + return i.Add(1) +} + +// Dec atomically decrements the wrapped int64 and returns the new value. +func (i *Int64) Dec() int64 { + return i.Sub(1) +} + +// CAS is an atomic compare-and-swap. +// +// Deprecated: Use CompareAndSwap. +func (i *Int64) CAS(old, new int64) (swapped bool) { + return i.CompareAndSwap(old, new) +} + +// CompareAndSwap is an atomic compare-and-swap. +func (i *Int64) CompareAndSwap(old, new int64) (swapped bool) { + return atomic.CompareAndSwapInt64(&i.v, old, new) +} + +// Store atomically stores the passed value. +func (i *Int64) Store(val int64) { + atomic.StoreInt64(&i.v, val) +} + +// Swap atomically swaps the wrapped int64 and returns the old value. +func (i *Int64) Swap(val int64) (old int64) { + return atomic.SwapInt64(&i.v, val) +} + +// MarshalJSON encodes the wrapped int64 into JSON. +func (i *Int64) MarshalJSON() ([]byte, error) { + return json.Marshal(i.Load()) +} + +// UnmarshalJSON decodes JSON into the wrapped int64. +func (i *Int64) UnmarshalJSON(b []byte) error { + var v int64 + if err := json.Unmarshal(b, &v); err != nil { + return err + } + i.Store(v) + return nil +} + +// String encodes the wrapped value as a string. +func (i *Int64) String() string { + v := i.Load() + return strconv.FormatInt(int64(v), 10) +} diff --git a/pkg/utils/atomic/int64_test.go b/pkg/utils/atomic/int64_test.go new file mode 100644 index 0000000..ed5a104 --- /dev/null +++ b/pkg/utils/atomic/int64_test.go @@ -0,0 +1,82 @@ +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import ( + "encoding/json" + "math" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestInt64(t *testing.T) { + atom := NewInt64(42) + + require.Equal(t, int64(42), atom.Load(), "Load didn't work.") + require.Equal(t, int64(46), atom.Add(4), "Add didn't work.") + require.Equal(t, int64(44), atom.Sub(2), "Sub didn't work.") + require.Equal(t, int64(45), atom.Inc(), "Inc didn't work.") + require.Equal(t, int64(44), atom.Dec(), "Dec didn't work.") + + require.True(t, atom.CAS(44, 0), "CAS didn't report a swap.") + require.Equal(t, int64(0), atom.Load(), "CAS didn't set the correct value.") + + require.Equal(t, int64(0), atom.Swap(1), "Swap didn't return the old value.") + require.Equal(t, int64(1), atom.Load(), "Swap didn't set the correct value.") + + atom.Store(42) + require.Equal(t, int64(42), atom.Load(), "Store didn't set the correct value.") + + t.Run("JSON/Marshal", func(t *testing.T) { + bytes, err := json.Marshal(atom) + require.NoError(t, err, "json.Marshal errored unexpectedly.") + require.Equal(t, []byte("42"), bytes, "json.Marshal encoded the wrong bytes.") + }) + + t.Run("JSON/Unmarshal", func(t *testing.T) { + err := json.Unmarshal([]byte("40"), &atom) + require.NoError(t, err, "json.Unmarshal errored unexpectedly.") + require.Equal(t, int64(40), atom.Load(), "json.Unmarshal didn't set the correct value.") + }) + + t.Run("JSON/Unmarshal/Error", func(t *testing.T) { + err := json.Unmarshal([]byte(`"40"`), &atom) + require.Error(t, err, "json.Unmarshal didn't error as expected.") + assertErrorJSONUnmarshalType(t, err, + "json.Unmarshal failed with unexpected error %v, want UnmarshalTypeError.", err) + }) + + t.Run("String", func(t *testing.T) { + t.Run("positive", func(t *testing.T) { + atom := NewInt64(math.MaxInt64) + assert.Equal(t, "9223372036854775807", atom.String(), + "String() returned an unexpected value.") + }) + + t.Run("negative", func(t *testing.T) { + atom := NewInt64(math.MinInt64) + assert.Equal(t, "-9223372036854775808", atom.String(), + "String() returned an unexpected value.") + }) + }) +} diff --git a/pkg/utils/atomic/internal/gen-atomicint/main.go b/pkg/utils/atomic/internal/gen-atomicint/main.go new file mode 100644 index 0000000..719fe9c --- /dev/null +++ b/pkg/utils/atomic/internal/gen-atomicint/main.go @@ -0,0 +1,116 @@ +// Copyright (c) 2020-2022 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +// gen-atomicint generates an atomic wrapper around an integer type. +// +// gen-atomicint -name Int32 -wrapped int32 -file out.go +// +// The generated wrapper will use the functions in the sync/atomic package +// named after the generated type. +package main + +import ( + "bytes" + "embed" + "errors" + "flag" + "fmt" + "go/format" + "io" + "log" + "os" + "text/template" + "time" +) + +func main() { + log.SetFlags(0) + if err := run(os.Args[1:]); err != nil { + log.Fatalf("%+v", err) + } +} + +func run(args []string) error { + var opts struct { + Name string + Wrapped string + File string + Unsigned bool + } + + flag := flag.NewFlagSet("gen-atomicint", flag.ContinueOnError) + + flag.StringVar(&opts.Name, "name", "", "name of the generated type (e.g. Int32)") + flag.StringVar(&opts.Wrapped, "wrapped", "", "name of the wrapped type (e.g. int32)") + flag.StringVar(&opts.File, "file", "", "output file path (default: stdout)") + flag.BoolVar(&opts.Unsigned, "unsigned", false, "whether the type is unsigned") + + if err := flag.Parse(args); err != nil { + return err + } + + if len(opts.Name) == 0 || len(opts.Wrapped) == 0 { + return errors.New("flags -name and -wrapped are required") + } + + var w io.Writer = os.Stdout + if file := opts.File; len(file) > 0 { + f, err := os.Create(file) + if err != nil { + return fmt.Errorf("create %q: %v", file, err) + } + defer f.Close() + + w = f + } + + data := struct { + Name string + Wrapped string + Unsigned bool + ToYear int + }{ + Name: opts.Name, + Wrapped: opts.Wrapped, + Unsigned: opts.Unsigned, + ToYear: time.Now().Year(), + } + + var buff bytes.Buffer + if err := _tmpl.ExecuteTemplate(&buff, "wrapper.tmpl", data); err != nil { + return fmt.Errorf("render template: %v", err) + } + + bs, err := format.Source(buff.Bytes()) + if err != nil { + return fmt.Errorf("reformat source: %v", err) + } + + io.WriteString(w, "// @generated Code generated by gen-atomicint.\n\n") + _, err = w.Write(bs) + return err +} + +var ( + //go:embed *.tmpl + _tmplFS embed.FS + + _tmpl = template.Must(template.New("atomicint").ParseFS(_tmplFS, "*.tmpl")) +) diff --git a/pkg/utils/atomic/internal/gen-atomicint/wrapper.tmpl b/pkg/utils/atomic/internal/gen-atomicint/wrapper.tmpl new file mode 100644 index 0000000..502fadc --- /dev/null +++ b/pkg/utils/atomic/internal/gen-atomicint/wrapper.tmpl @@ -0,0 +1,117 @@ +// Copyright (c) 2020-{{.ToYear}} Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import ( + "encoding/json" + "strconv" + "sync/atomic" +) + +// {{ .Name }} is an atomic wrapper around {{ .Wrapped }}. +type {{ .Name }} struct { + _ nocmp // disallow non-atomic comparison + + v {{ .Wrapped }} +} + +// New{{ .Name }} creates a new {{ .Name }}. +func New{{ .Name }}(val {{ .Wrapped }}) *{{ .Name }} { + return &{{ .Name }}{v: val} +} + +// Load atomically loads the wrapped value. +func (i *{{ .Name }}) Load() {{ .Wrapped }} { + return atomic.Load{{ .Name }}(&i.v) +} + +// Add atomically adds to the wrapped {{ .Wrapped }} and returns the new value. +func (i *{{ .Name }}) Add(delta {{ .Wrapped }}) {{ .Wrapped }} { + return atomic.Add{{ .Name }}(&i.v, delta) +} + +// Sub atomically subtracts from the wrapped {{ .Wrapped }} and returns the new value. +func (i *{{ .Name }}) Sub(delta {{ .Wrapped }}) {{ .Wrapped }} { + return atomic.Add{{ .Name }}(&i.v, + {{- if .Unsigned -}} + ^(delta - 1) + {{- else -}} + -delta + {{- end -}} + ) +} + +// Inc atomically increments the wrapped {{ .Wrapped }} and returns the new value. +func (i *{{ .Name }}) Inc() {{ .Wrapped }} { + return i.Add(1) +} + +// Dec atomically decrements the wrapped {{ .Wrapped }} and returns the new value. +func (i *{{ .Name }}) Dec() {{ .Wrapped }} { + return i.Sub(1) +} + +// CAS is an atomic compare-and-swap. +// +// Deprecated: Use CompareAndSwap. +func (i *{{ .Name }}) CAS(old, new {{ .Wrapped }}) (swapped bool) { + return i.CompareAndSwap(old, new) +} + +// CompareAndSwap is an atomic compare-and-swap. +func (i *{{ .Name }}) CompareAndSwap(old, new {{ .Wrapped }}) (swapped bool) { + return atomic.CompareAndSwap{{ .Name }}(&i.v, old, new) +} + +// Store atomically stores the passed value. +func (i *{{ .Name }}) Store(val {{ .Wrapped }}) { + atomic.Store{{ .Name }}(&i.v, val) +} + +// Swap atomically swaps the wrapped {{ .Wrapped }} and returns the old value. +func (i *{{ .Name }}) Swap(val {{ .Wrapped }}) (old {{ .Wrapped }}) { + return atomic.Swap{{ .Name }}(&i.v, val) +} + +// MarshalJSON encodes the wrapped {{ .Wrapped }} into JSON. +func (i *{{ .Name }}) MarshalJSON() (by, er) { + return json.Marshal(i.Load()) +} + +// UnmarshalJSON decodes JSON into the wrapped {{ .Wrapped }}. +func (i *{{ .Name }}) UnmarshalJSON(b by) er { + var v {{ .Wrapped }} + if err := json.Unmarshal(b, &v); err != nil { + return err + } + i.Store(v) + return nil +} + +// String encodes the wrapped value as a string. +func (i *{{ .Name }}) String() string { + v := i.Load() + {{ if .Unsigned -}} + return strconv.FormatUint(uint64(v), 10) + {{- else -}} + return strconv.FormatInt(int64(v), 10) + {{- end }} +} diff --git a/pkg/utils/atomic/internal/gen-atomicwrapper/main.go b/pkg/utils/atomic/internal/gen-atomicwrapper/main.go new file mode 100644 index 0000000..26683cd --- /dev/null +++ b/pkg/utils/atomic/internal/gen-atomicwrapper/main.go @@ -0,0 +1,203 @@ +// Copyright (c) 2020-2022 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +// gen-atomicwrapper generates wrapper types around other atomic types. +// +// It supports plugging in functions which convert the value inside the atomic +// type to the user-facing value. For example, +// +// Given, atomic.Value and the functions, +// +// func packString(string) enveloper{} +// func unpackString(enveloper{}) string +// +// We can run the following command: +// +// gen-atomicwrapper -name String -wrapped Value \ +// -type string -pack fromString -unpack tostring +// +// This wil generate approximately, +// +// type String struct{ v Value } +// +// func (s *String) Load() string { +// return unpackString(v.Load()) +// } +// +// func (s *String) Store(s string) { +// return s.v.Store(packString(s)) +// } +// +// The packing/unpacking logic allows the stored value to be different from +// the user-facing value. +package main + +import ( + "bytes" + "embed" + "errors" + "flag" + "fmt" + "go/format" + "io" + "log" + "os" + "sort" + "strings" + "text/template" + "time" +) + +func main() { + log.SetFlags(0) + if err := run(os.Args[1:]); err != nil { + log.Fatalf("%+v", err) + } +} + +type stringList []string + +func (sl *stringList) String() string { + return strings.Join(*sl, ",") +} + +func (sl *stringList) Set(s string) error { + for _, i := range strings.Split(s, ",") { + *sl = append(*sl, strings.TrimSpace(i)) + } + return nil +} + +func run(args []string) error { + var opts struct { + Name string + Wrapped string + Type string + + Imports stringList + Pack, Unpack string + + CAS bool + CompareAndSwap bool + Swap bool + JSON bool + + File string + ToYear int + } + + opts.ToYear = time.Now().Year() + + fl := flag.NewFlagSet("gen-atomicwrapper", flag.ContinueOnError) + + // Required flags + fl.StringVar(&opts.Name, "name", "", + "name of the generated type (e.g. Duration)") + fl.StringVar(&opts.Wrapped, "wrapped", "", + "name of the wrapped atomic (e.g. Int64)") + fl.StringVar(&opts.Type, "type", "", + "name of the type exposed by the atomic (e.g. time.Duration)") + + // Optional flags + fl.Var(&opts.Imports, "imports", + "comma separated list of imports to add") + fl.StringVar(&opts.Pack, "pack", "", + "function to transform values with before storage") + fl.StringVar(&opts.Unpack, "unpack", "", + "function to reverse packing on loading") + fl.StringVar(&opts.File, "file", "", + "output file path (default: stdout)") + + // Switches for individual methods. Underlying atomics must support + // these. + fl.BoolVar(&opts.CAS, "cas", false, + "generate a deprecated `CAS(old, new) bool` method; requires -pack") + fl.BoolVar(&opts.CompareAndSwap, "compareandswap", false, + "generate a `CompareAndSwap(old, new) bool` method; requires -pack") + fl.BoolVar(&opts.Swap, "swap", false, + "generate a `Swap(new) old` method; requires -pack and -unpack") + fl.BoolVar(&opts.JSON, "json", false, + "generate `Marshal/UnmarshJSON` methods") + + if err := fl.Parse(args); err != nil { + return err + } + + if len(opts.Name) == 0 || + len(opts.Wrapped) == 0 || + len(opts.Type) == 0 || + len(opts.Pack) == 0 || + len(opts.Unpack) == 0 { + return errors.New("flags -name, -wrapped, -pack, -unpack and -type are required") + } + + if opts.CAS { + opts.CompareAndSwap = true + } + + var w io.Writer = os.Stdout + if file := opts.File; len(file) > 0 { + f, err := os.Create(file) + if err != nil { + return fmt.Errorf("create %q: %v", file, err) + } + defer f.Close() + + w = f + } + + // Import encoding/json if needed. + if opts.JSON { + found := false + for _, imp := range opts.Imports { + if imp == "encoding/json" { + found = true + break + } + } + + if !found { + opts.Imports = append(opts.Imports, "encoding/json") + } + } + + sort.Strings(opts.Imports) + + var buff bytes.Buffer + if err := _tmpl.ExecuteTemplate(&buff, "wrapper.tmpl", opts); err != nil { + return fmt.Errorf("render template: %v", err) + } + + bs, err := format.Source(buff.Bytes()) + if err != nil { + return fmt.Errorf("reformat source: %v", err) + } + + io.WriteString(w, "// @generated Code generated by gen-atomicwrapper.\n\n") + _, err = w.Write(bs) + return err +} + +var ( + //go:embed *.tmpl + _tmplFS embed.FS + + _tmpl = template.Must(template.New("atomicwrapper").ParseFS(_tmplFS, "*.tmpl")) +) diff --git a/pkg/utils/atomic/internal/gen-atomicwrapper/wrapper.tmpl b/pkg/utils/atomic/internal/gen-atomicwrapper/wrapper.tmpl new file mode 100644 index 0000000..6ed6a9e --- /dev/null +++ b/pkg/utils/atomic/internal/gen-atomicwrapper/wrapper.tmpl @@ -0,0 +1,120 @@ +// Copyright (c) 2020-{{.ToYear}} Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +{{ with .Imports }} +import ( + {{ range . -}} + {{ printf "%q" . }} + {{ end }} +) +{{ end }} + +// {{ .Name }} is an atomic type-safe wrapper for {{ .Type }} values. +type {{ .Name }} struct{ + _ nocmp // disallow non-atomic comparison + + v {{ .Wrapped }} +} + +var _zero{{ .Name }} {{ .Type }} + + +// New{{ .Name }} creates a new {{ .Name }}. +func New{{ .Name }}(val {{ .Type }}) *{{ .Name }} { + x := &{{ .Name }}{} + if val != _zero{{ .Name }} { + x.Store(val) + } + return x +} + +// Load atomically loads the wrapped {{ .Type }}. +func (x *{{ .Name }}) Load() {{ .Type }} { + {{ if .Unpack -}} + return {{ .Unpack }}(x.v.Load()) + {{- else -}} + if v := x.v.Load(); v != nil { + return v.({{ .Type }}) + } + return _zero{{ .Name }} + {{- end }} +} + +// Store atomically stores the passed {{ .Type }}. +func (x *{{ .Name }}) Store(val {{ .Type }}) { + x.v.Store({{ .Pack }}(val)) +} + +{{ if .CAS -}} + // CAS is an atomic compare-and-swap for {{ .Type }} values. + // + // Deprecated: Use CompareAndSwap. + func (x *{{ .Name }}) CAS(old, new {{ .Type }}) (swapped bool) { + return x.CompareAndSwap(old, new) + } +{{- end }} + +{{ if .CompareAndSwap -}} + // CompareAndSwap is an atomic compare-and-swap for {{ .Type }} values. + func (x *{{ .Name }}) CompareAndSwap(old, new {{ .Type }}) (swapped bool) { + {{ if eq .Wrapped "Value" -}} + if x.v.CompareAndSwap({{ .Pack }}(old), {{ .Pack }}(new)) { + return true + } + + if old == _zero{{ .Name }} { + // If the old value is the empty value, then it's possible the + // underlying Value hasn't been set and is nil, so retry with nil. + return x.v.CompareAndSwap(nil, {{ .Pack }}(new)) + } + + return false + {{- else -}} + return x.v.CompareAndSwap({{ .Pack }}(old), {{ .Pack }}(new)) + {{- end }} + } +{{- end }} + +{{ if .Swap -}} + // Swap atomically stores the given {{ .Type }} and returns the old + // value. + func (x *{{ .Name }}) Swap(val {{ .Type }}) (old {{ .Type }}) { + return {{ .Unpack }}(x.v.Swap({{ .Pack }}(val))) + } +{{- end }} + +{{ if .JSON -}} + // MarshalJSON encodes the wrapped {{ .Type }} into JSON. + func (x *{{ .Name }}) MarshalJSON() (by, er) { + return json.Marshal(x.Load()) + } + + // UnmarshalJSON decodes a {{ .Type }} from JSON. + func (x *{{ .Name }}) UnmarshalJSON(b by) er { + var v {{ .Type }} + if err := json.Unmarshal(b, &v); err != nil { + return err + } + x.Store(v) + return nil + } +{{- end }} diff --git a/pkg/utils/atomic/nocmp.go b/pkg/utils/atomic/nocmp.go new file mode 100644 index 0000000..54b7417 --- /dev/null +++ b/pkg/utils/atomic/nocmp.go @@ -0,0 +1,35 @@ +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +// nocmp is an uncomparable struct. Embed this inside another struct to make +// it uncomparable. +// +// type Foo struct { +// nocmp +// // ... +// } +// +// This DOES NOT: +// +// - Disallow shallow copies of structs +// - Disallow comparison of pointers to uncomparable structs +type nocmp [0]func() diff --git a/pkg/utils/atomic/nocmp_test.go b/pkg/utils/atomic/nocmp_test.go new file mode 100644 index 0000000..8719421 --- /dev/null +++ b/pkg/utils/atomic/nocmp_test.go @@ -0,0 +1,164 @@ +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import ( + "bytes" + "os" + "os/exec" + "path/filepath" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNocmpComparability(t *testing.T) { + tests := []struct { + desc string + give interface{} + comparable bool + }{ + { + desc: "nocmp struct", + give: nocmp{}, + }, + { + desc: "struct with nocmp embedded", + give: struct{ nocmp }{}, + }, + { + desc: "pointer to struct with nocmp embedded", + give: &struct{ nocmp }{}, + comparable: true, + }, + + // All exported types must be uncomparable. + {desc: "Bool", give: Bool{}}, + {desc: "Duration", give: Duration{}}, + {desc: "Error", give: Error{}}, + {desc: "Float64", give: Float64{}}, + {desc: "Int32", give: Int32{}}, + {desc: "Int64", give: Int64{}}, + {desc: "String", give: String{}}, + {desc: "Uint32", give: Uint32{}}, + {desc: "Uint64", give: Uint64{}}, + {desc: "Value", give: Value{}}, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + typ := reflect.TypeOf(tt.give) + assert.Equalf(t, tt.comparable, typ.Comparable(), + "type %v comparablity mismatch", typ) + }) + } +} + +// nocmp must not add to the size of a struct in-memory. +func TestNocmpSize(t *testing.T) { + type x struct{ _ int } + + before := reflect.TypeOf(x{}).Size() + + type y struct { + _ nocmp + _ x + } + + after := reflect.TypeOf(y{}).Size() + + assert.Equal(t, before, after, + "expected nocmp to have no effect on struct size") +} + +// This test will fail to compile if we disallow copying of nocmp. +// +// We need to allow this so that users can do, +// +// var x atomic.Int32 +// x = atomic.NewInt32(1) +func TestNocmpCopy(t *testing.T) { + type foo struct{ _ nocmp } + + t.Run("struct copy", func(t *testing.T) { + a := foo{} + b := a + _ = b // unused + }) + + t.Run("pointer copy", func(t *testing.T) { + a := &foo{} + b := *a + _ = b // unused + }) +} + +// Fake go.mod with no dependencies. +const _exampleGoMod = `module example.com/nocmp` + +const _badFile = `package atomic + +import "fmt" + +type Int64 struct { + nocmp + + v int64 +} + +func shouldNotCompile() { + var x, y Int64 + fmt.Println(x == y) +} +` + +func TestNocmpIntegration(t *testing.T) { + tempdir := t.TempDir() + + nocmp, err := os.ReadFile("nocmp.go") + require.NoError(t, err, "unable to read nocmp.go") + + require.NoError(t, + os.WriteFile(filepath.Join(tempdir, "go.mod"), []byte(_exampleGoMod), 0o644), + "unable to write go.mod") + + require.NoError(t, + os.WriteFile(filepath.Join(tempdir, "nocmp.go"), nocmp, 0o644), + "unable to write nocmp.go") + + require.NoError(t, + os.WriteFile(filepath.Join(tempdir, "bad.go"), []byte(_badFile), 0o644), + "unable to write bad.go") + + var stderr bytes.Buffer + cmd := exec.Command("go", "build") + cmd.Dir = tempdir + // Create a minimal build environment with only HOME set so that "go + // build" has somewhere to put the cache and other Go files in. + cmd.Env = []string{"HOME=" + filepath.Join(tempdir, "home")} + cmd.Stderr = &stderr + require.Error(t, cmd.Run(), "bad.go must not compile") + + assert.Contains(t, stderr.String(), + "struct containing nocmp cannot be compared") +} diff --git a/pkg/utils/atomic/pointer_test.go b/pkg/utils/atomic/pointer_test.go new file mode 100644 index 0000000..837bd45 --- /dev/null +++ b/pkg/utils/atomic/pointer_test.go @@ -0,0 +1,100 @@ +// Copyright (c) 2022 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +//go:build go1.18 +// +build go1.18 + +package atomic + +// +// import ( +// "fmt" +// "testing" +// +// "github.com/stretchr/testify/require" +// ) +// +// func TestPointer(t *testing.T) { +// type foo struct{ v int } +// +// i := foo{42} +// j := foo{0} +// k := foo{1} +// +// tests := []struct { +// desc string +// newAtomic func() *Pointer[foo] +// initial *foo +// }{ +// { +// desc: "New", +// newAtomic: func() *Pointer[foo] { +// return NewPointer(&i) +// }, +// initial: &i, +// }, +// { +// desc: "New/nil", +// newAtomic: func() *Pointer[foo] { +// return NewPointer[foo](nil) +// }, +// initial: nil, +// }, +// { +// desc: "zero value", +// newAtomic: func() *Pointer[foo] { +// var p Pointer[foo] +// return &p +// }, +// initial: nil, +// }, +// } +// +// for _, tt := range tests { +// t.Run(tt.desc, func(t *testing.T) { +// t.Run("Load", func(t *testing.T) { +// atom := tt.newAtomic() +// require.Equal(t, tt.initial, atom.Load(), "Load should report nil.") +// }) +// +// t.Run("Swap", func(t *testing.T) { +// atom := tt.newAtomic() +// require.Equal(t, tt.initial, atom.Swap(&k), "Swap didn't return the old value.") +// require.Equal(t, &k, atom.Load(), "Swap didn't set the correct value.") +// }) +// +// t.Run("CAS", func(t *testing.T) { +// atom := tt.newAtomic() +// require.True(t, atom.CompareAndSwap(tt.initial, &j), "CAS didn't report a swap.") +// require.Equal(t, &j, atom.Load(), "CAS didn't set the correct value.") +// }) +// +// t.Run("Store", func(t *testing.T) { +// atom := tt.newAtomic() +// atom.Store(&i) +// require.Equal(t, &i, atom.Load(), "Store didn't set the correct value.") +// }) +// t.Run("String", func(t *testing.T) { +// atom := tt.newAtomic() +// require.Equal(t, fmt.Sprint(tt.initial), atom.String(), "String did not return the correct value.") +// }) +// }) +// } +// } diff --git a/pkg/utils/atomic/stress_test.go b/pkg/utils/atomic/stress_test.go new file mode 100644 index 0000000..0ac7ac5 --- /dev/null +++ b/pkg/utils/atomic/stress_test.go @@ -0,0 +1,289 @@ +// Copyright (c) 2016 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import ( + "errors" + "math" + "runtime" + "sync" + "sync/atomic" + "testing" + "time" +) + +const ( + _parallelism = 4 + _iterations = 1000 +) + +var _stressTests = map[string]func() func(){ + "i32/std": stressStdInt32, + "i32": stressInt32, + "i64/std": stressStdInt64, + "i64": stressInt64, + "u32/std": stressStdUint32, + "u32": stressUint32, + "u64/std": stressStdUint64, + "u64": stressUint64, + "f64": stressFloat64, + "bool": stressBool, + "string": stressString, + "duration": stressDuration, + "error": stressError, + "time": stressTime, +} + +func TestStress(t *testing.T) { + for name, ff := range _stressTests { + t.Run(name, func(t *testing.T) { + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(_parallelism)) + + start := make(chan struct{}) + var wg sync.WaitGroup + wg.Add(_parallelism) + f := ff() + for i := 0; i < _parallelism; i++ { + go func() { + defer wg.Done() + <-start + for j := 0; j < _iterations; j++ { + f() + } + }() + } + close(start) + wg.Wait() + }) + } +} + +func BenchmarkStress(b *testing.B) { + for name, ff := range _stressTests { + b.Run(name, func(b *testing.B) { + f := ff() + + b.Run("serial", func(b *testing.B) { + for i := 0; i < b.N; i++ { + f() + } + }) + + b.Run("parallel", func(b *testing.B) { + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + f() + } + }) + }) + }) + } +} + +func stressStdInt32() func() { + var atom int32 + return func() { + atomic.LoadInt32(&atom) + atomic.AddInt32(&atom, 1) + atomic.AddInt32(&atom, -2) + atomic.AddInt32(&atom, 1) + atomic.AddInt32(&atom, -1) + atomic.CompareAndSwapInt32(&atom, 1, 0) + atomic.SwapInt32(&atom, 5) + atomic.StoreInt32(&atom, 1) + } +} + +func stressInt32() func() { + var atom Int32 + return func() { + atom.Load() + atom.Add(1) + atom.Sub(2) + atom.Inc() + atom.Dec() + atom.CAS(1, 0) + atom.Swap(5) + atom.Store(1) + } +} + +func stressStdInt64() func() { + var atom int64 + return func() { + atomic.LoadInt64(&atom) + atomic.AddInt64(&atom, 1) + atomic.AddInt64(&atom, -2) + atomic.AddInt64(&atom, 1) + atomic.AddInt64(&atom, -1) + atomic.CompareAndSwapInt64(&atom, 1, 0) + atomic.SwapInt64(&atom, 5) + atomic.StoreInt64(&atom, 1) + } +} + +func stressInt64() func() { + var atom Int64 + return func() { + atom.Load() + atom.Add(1) + atom.Sub(2) + atom.Inc() + atom.Dec() + atom.CAS(1, 0) + atom.Swap(5) + atom.Store(1) + } +} + +func stressStdUint32() func() { + var atom uint32 + return func() { + atomic.LoadUint32(&atom) + atomic.AddUint32(&atom, 1) + // Adding `MaxUint32` is the same as subtracting 1 + atomic.AddUint32(&atom, math.MaxUint32-1) + atomic.AddUint32(&atom, 1) + atomic.AddUint32(&atom, math.MaxUint32) + atomic.CompareAndSwapUint32(&atom, 1, 0) + atomic.SwapUint32(&atom, 5) + atomic.StoreUint32(&atom, 1) + } +} + +func stressUint32() func() { + var atom Uint32 + return func() { + atom.Load() + atom.Add(1) + atom.Sub(2) + atom.Inc() + atom.Dec() + atom.CAS(1, 0) + atom.Swap(5) + atom.Store(1) + } +} + +func stressStdUint64() func() { + var atom uint64 + return func() { + atomic.LoadUint64(&atom) + atomic.AddUint64(&atom, 1) + // Adding `MaxUint64` is the same as subtracting 1 + atomic.AddUint64(&atom, math.MaxUint64-1) + atomic.AddUint64(&atom, 1) + atomic.AddUint64(&atom, math.MaxUint64) + atomic.CompareAndSwapUint64(&atom, 1, 0) + atomic.SwapUint64(&atom, 5) + atomic.StoreUint64(&atom, 1) + } +} + +func stressUint64() func() { + var atom Uint64 + return func() { + atom.Load() + atom.Add(1) + atom.Sub(2) + atom.Inc() + atom.Dec() + atom.CAS(1, 0) + atom.Swap(5) + atom.Store(1) + } +} + +func stressFloat64() func() { + var atom Float64 + return func() { + atom.Load() + atom.CAS(1.0, 0.1) + atom.Add(1.1) + atom.Sub(0.2) + atom.Store(1.0) + } +} + +func stressBool() func() { + var atom Bool + return func() { + atom.Load() + atom.Store(false) + atom.Swap(true) + atom.CAS(true, false) + atom.CAS(true, false) + atom.Load() + atom.Toggle() + atom.Toggle() + } +} + +func stressString() func() { + var atom String + return func() { + atom.Load() + atom.Store("abc") + atom.Load() + atom.Store("def") + atom.Load() + atom.Store("") + } +} + +func stressDuration() func() { + var atom = NewDuration(0) + return func() { + atom.Load() + atom.Add(1) + atom.Sub(2) + atom.CAS(1, 0) + atom.Swap(5) + atom.Store(1) + } +} + +func stressError() func() { + var atom = NewError(nil) + var err1 = errors.New("err1") + var err2 = errors.New("err2") + return func() { + atom.Load() + atom.Store(err1) + atom.Load() + atom.Store(err2) + atom.Load() + atom.Store(nil) + } +} + +func stressTime() func() { + var atom = NewTime(time.Date(2021, 6, 17, 9, 0, 0, 0, time.UTC)) + var dayAgo = time.Date(2021, 6, 16, 9, 0, 0, 0, time.UTC) + var weekAgo = time.Date(2021, 6, 10, 9, 0, 0, 0, time.UTC) + return func() { + atom.Load() + atom.Store(dayAgo) + atom.Load() + atom.Store(weekAgo) + atom.Store(time.Time{}) + } +} diff --git a/pkg/utils/atomic/string.go b/pkg/utils/atomic/string.go new file mode 100644 index 0000000..061466c --- /dev/null +++ b/pkg/utils/atomic/string.go @@ -0,0 +1,72 @@ +// @generated Code generated by gen-atomicwrapper. + +// Copyright (c) 2020-2023 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +// String is an atomic type-safe wrapper for string values. +type String struct { + _ nocmp // disallow non-atomic comparison + + v Value +} + +var _zeroString string + +// NewString creates a new String. +func NewString(val string) *String { + x := &String{} + if val != _zeroString { + x.Store(val) + } + return x +} + +// Load atomically loads the wrapped string. +func (x *String) Load() string { + return unpackString(x.v.Load()) +} + +// Store atomically stores the passed string. +func (x *String) Store(val string) { + x.v.Store(packString(val)) +} + +// CompareAndSwap is an atomic compare-and-swap for string values. +func (x *String) CompareAndSwap(old, new string) (swapped bool) { + if x.v.CompareAndSwap(packString(old), packString(new)) { + return true + } + + if old == _zeroString { + // If the old value is the empty value, then it's possible the + // underlying Value hasn't been set and is nil, so retry with nil. + return x.v.CompareAndSwap(nil, packString(new)) + } + + return false +} + +// Swap atomically stores the given string and returns the old +// value. +func (x *String) Swap(val string) (old string) { + return unpackString(x.v.Swap(packString(val))) +} diff --git a/pkg/utils/atomic/string_ext.go b/pkg/utils/atomic/string_ext.go new file mode 100644 index 0000000..019109c --- /dev/null +++ b/pkg/utils/atomic/string_ext.go @@ -0,0 +1,54 @@ +// Copyright (c) 2020-2023 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +//go:generate bin/gen-atomicwrapper -name=String -type=string -wrapped Value -pack packString -unpack unpackString -compareandswap -swap -file=string.go + +func packString(s string) interface{} { + return s +} + +func unpackString(v interface{}) string { + if s, ok := v.(string); ok { + return s + } + return "" +} + +// String returns the wrapped value. +func (s *String) String() string { + return s.Load() +} + +// MarshalText encodes the wrapped string into a textual form. +// +// This makes it encodable as JSON, YAML, XML, and more. +func (s *String) MarshalText() ([]byte, error) { + return []byte(s.Load()), nil +} + +// UnmarshalText decodes text and replaces the wrapped string with it. +// +// This makes it decodable from JSON, YAML, XML, and more. +func (s *String) UnmarshalText(b []byte) error { + s.Store(string(b)) + return nil +} diff --git a/pkg/utils/atomic/string_test.go b/pkg/utils/atomic/string_test.go new file mode 100644 index 0000000..6163113 --- /dev/null +++ b/pkg/utils/atomic/string_test.go @@ -0,0 +1,170 @@ +// Copyright (c) 2016-2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import ( + "encoding/json" + "encoding/xml" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestStringNoInitialValue(t *testing.T) { + atom := &String{} + require.Equal(t, "", atom.Load(), "Initial value should be blank string") +} + +func TestString(t *testing.T) { + atom := NewString("") + require.Equal(t, "", atom.Load(), "Expected Load to return initialized value") + + atom.Store("abc") + require.Equal(t, "abc", atom.Load(), "Unexpected value after Store") + + atom = NewString("bcd") + require.Equal(t, "bcd", atom.Load(), "Expected Load to return initialized value") + + t.Run("JSON/Marshal", func(t *testing.T) { + bytes, err := json.Marshal(atom) + require.NoError(t, err, "json.Marshal errored unexpectedly.") + require.Equal(t, []byte(`"bcd"`), bytes, "json.Marshal encoded the wrong bytes.") + }) + + t.Run("JSON/Unmarshal", func(t *testing.T) { + err := json.Unmarshal([]byte(`"abc"`), &atom) + require.NoError(t, err, "json.Unmarshal errored unexpectedly.") + require.Equal(t, "abc", atom.Load(), "json.Unmarshal didn't set the correct value.") + }) + + t.Run("JSON/Unmarshal/Error", func(t *testing.T) { + err := json.Unmarshal([]byte("42"), &atom) + require.Error(t, err, "json.Unmarshal didn't error as expected.") + assertErrorJSONUnmarshalType(t, err, + "json.Unmarshal failed with unexpected error %v, want UnmarshalTypeError.", err) + }) + + atom = NewString("foo") + + t.Run("XML/Marshal", func(t *testing.T) { + bytes, err := xml.Marshal(atom) + require.NoError(t, err, "xml.Marshal errored unexpectedly.") + require.Equal(t, []byte("foo"), bytes, + "xml.Marshal encoded the wrong bytes.") + }) + + t.Run("XML/Unmarshal", func(t *testing.T) { + err := xml.Unmarshal([]byte("bar"), &atom) + require.NoError(t, err, "xml.Unmarshal errored unexpectedly.") + require.Equal(t, "bar", atom.Load(), "xml.Unmarshal didn't set the correct value.") + }) + + t.Run("String", func(t *testing.T) { + atom := NewString("foo") + assert.Equal(t, "foo", atom.String(), + "String() returned an unexpected value.") + }) + + t.Run("CompareAndSwap", func(t *testing.T) { + atom := NewString("foo") + + swapped := atom.CompareAndSwap("bar", "bar") + require.False(t, swapped, "swapped isn't false") + require.Equal(t, atom.Load(), "foo", "Load returned wrong value") + + swapped = atom.CompareAndSwap("foo", "bar") + require.True(t, swapped, "swapped isn't true") + require.Equal(t, atom.Load(), "bar", "Load returned wrong value") + }) + + t.Run("Swap", func(t *testing.T) { + atom := NewString("foo") + + old := atom.Swap("bar") + require.Equal(t, old, "foo", "Swap returned wrong value") + require.Equal(t, atom.Load(), "bar", "Load returned wrong value") + }) +} + +func TestString_InitializeDefault(t *testing.T) { + tests := []struct { + msg string + newStr func() *String + }{ + { + msg: "Uninitialized", + newStr: func() *String { + var s String + return &s + }, + }, + { + msg: "NewString with default", + newStr: func() *String { + return NewString("") + }, + }, + { + msg: "String swapped with default", + newStr: func() *String { + s := NewString("initial") + s.Swap("") + return s + }, + }, + { + msg: "String CAS'd with default", + newStr: func() *String { + s := NewString("initial") + s.CompareAndSwap("initial", "") + return s + }, + }, + } + + for _, tt := range tests { + t.Run(tt.msg, func(t *testing.T) { + t.Run("MarshalText", func(t *testing.T) { + str := tt.newStr() + text, err := str.MarshalText() + require.NoError(t, err) + assert.Equal(t, "", string(text), "") + }) + + t.Run("String", func(t *testing.T) { + str := tt.newStr() + assert.Equal(t, "", str.String()) + }) + + t.Run("CompareAndSwap", func(t *testing.T) { + str := tt.newStr() + require.True(t, str.CompareAndSwap("", "new")) + assert.Equal(t, "new", str.Load()) + }) + + t.Run("Swap", func(t *testing.T) { + str := tt.newStr() + assert.Equal(t, "", str.Swap("new")) + }) + }) + } +} diff --git a/pkg/utils/atomic/time.go b/pkg/utils/atomic/time.go new file mode 100644 index 0000000..cc2a230 --- /dev/null +++ b/pkg/utils/atomic/time.go @@ -0,0 +1,55 @@ +// @generated Code generated by gen-atomicwrapper. + +// Copyright (c) 2020-2023 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import ( + "time" +) + +// Time is an atomic type-safe wrapper for time.Time values. +type Time struct { + _ nocmp // disallow non-atomic comparison + + v Value +} + +var _zeroTime time.Time + +// NewTime creates a new Time. +func NewTime(val time.Time) *Time { + x := &Time{} + if val != _zeroTime { + x.Store(val) + } + return x +} + +// Load atomically loads the wrapped time.Time. +func (x *Time) Load() time.Time { + return unpackTime(x.v.Load()) +} + +// Store atomically stores the passed time.Time. +func (x *Time) Store(val time.Time) { + x.v.Store(packTime(val)) +} diff --git a/pkg/utils/atomic/time_ext.go b/pkg/utils/atomic/time_ext.go new file mode 100644 index 0000000..1e3dc97 --- /dev/null +++ b/pkg/utils/atomic/time_ext.go @@ -0,0 +1,36 @@ +// Copyright (c) 2021 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import "time" + +//go:generate bin/gen-atomicwrapper -name=Time -type=time.Time -wrapped=Value -pack=packTime -unpack=unpackTime -imports time -file=time.go + +func packTime(t time.Time) interface{} { + return t +} + +func unpackTime(v interface{}) time.Time { + if t, ok := v.(time.Time); ok { + return t + } + return time.Time{} +} diff --git a/pkg/utils/atomic/time_test.go b/pkg/utils/atomic/time_test.go new file mode 100644 index 0000000..83ac022 --- /dev/null +++ b/pkg/utils/atomic/time_test.go @@ -0,0 +1,86 @@ +// Copyright (c) 2021 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTime(t *testing.T) { + start := time.Date(2021, 6, 17, 9, 10, 0, 0, time.UTC) + atom := NewTime(start) + + require.Equal(t, start, atom.Load(), "Load didn't work") + require.Equal(t, time.Time{}, NewTime(time.Time{}).Load(), "Default time value is wrong") +} + +func TestTimeLocation(t *testing.T) { + // Check TZ data hasn't been lost from load/store. + ny, err := time.LoadLocation("America/New_York") + require.NoError(t, err, "Failed to load location") + nyTime := NewTime(time.Date(2021, 1, 1, 0, 0, 0, 0, ny)) + + var atom Time + atom.Store(nyTime.Load()) + + assert.Equal(t, ny, atom.Load().Location(), "Location information is wrong") +} + +func TestLargeTime(t *testing.T) { + // Check "large/small" time that are beyond int64 ns + // representation (< year 1678 or > year 2262) can be + // correctly load/store'd. + t.Parallel() + + t.Run("future", func(t *testing.T) { + future := time.Date(2262, 12, 31, 0, 0, 0, 0, time.UTC) + atom := NewTime(future) + dayAfterFuture := atom.Load().AddDate(0, 1, 0) + + atom.Store(dayAfterFuture) + assert.Equal(t, 2263, atom.Load().Year()) + }) + + t.Run("past", func(t *testing.T) { + past := time.Date(1678, 1, 1, 0, 0, 0, 0, time.UTC) + atom := NewTime(past) + dayBeforePast := atom.Load().AddDate(0, -1, 0) + + atom.Store(dayBeforePast) + assert.Equal(t, 1677, atom.Load().Year()) + }) +} + +func TestMonotonic(t *testing.T) { + before := NewTime(time.Now()) + time.Sleep(15 * time.Millisecond) + after := NewTime(time.Now()) + + // try loading/storing before and test monotonic clock value hasn't been lost + bt := before.Load() + before.Store(bt) + d := after.Load().Sub(before.Load()) + assert.True(t, 15 <= d.Milliseconds()) +} diff --git a/pkg/utils/atomic/tools/tools.go b/pkg/utils/atomic/tools/tools.go new file mode 100644 index 0000000..6c8e7e8 --- /dev/null +++ b/pkg/utils/atomic/tools/tools.go @@ -0,0 +1,30 @@ +// Copyright (c) 2019 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +//go:build tools +// +build tools + +package tools + +import ( + // Tools used during development. + _ "golang.org/x/lint/golint" + _ "honnef.co/go/tools/cmd/staticcheck" +) diff --git a/pkg/utils/atomic/uint32.go b/pkg/utils/atomic/uint32.go new file mode 100644 index 0000000..4adc294 --- /dev/null +++ b/pkg/utils/atomic/uint32.go @@ -0,0 +1,109 @@ +// @generated Code generated by gen-atomicint. + +// Copyright (c) 2020-2023 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import ( + "encoding/json" + "strconv" + "sync/atomic" +) + +// Uint32 is an atomic wrapper around uint32. +type Uint32 struct { + _ nocmp // disallow non-atomic comparison + + v uint32 +} + +// NewUint32 creates a new Uint32. +func NewUint32(val uint32) *Uint32 { + return &Uint32{v: val} +} + +// Load atomically loads the wrapped value. +func (i *Uint32) Load() uint32 { + return atomic.LoadUint32(&i.v) +} + +// Add atomically adds to the wrapped uint32 and returns the new value. +func (i *Uint32) Add(delta uint32) uint32 { + return atomic.AddUint32(&i.v, delta) +} + +// Sub atomically subtracts from the wrapped uint32 and returns the new value. +func (i *Uint32) Sub(delta uint32) uint32 { + return atomic.AddUint32(&i.v, ^(delta - 1)) +} + +// Inc atomically increments the wrapped uint32 and returns the new value. +func (i *Uint32) Inc() uint32 { + return i.Add(1) +} + +// Dec atomically decrements the wrapped uint32 and returns the new value. +func (i *Uint32) Dec() uint32 { + return i.Sub(1) +} + +// CAS is an atomic compare-and-swap. +// +// Deprecated: Use CompareAndSwap. +func (i *Uint32) CAS(old, new uint32) (swapped bool) { + return i.CompareAndSwap(old, new) +} + +// CompareAndSwap is an atomic compare-and-swap. +func (i *Uint32) CompareAndSwap(old, new uint32) (swapped bool) { + return atomic.CompareAndSwapUint32(&i.v, old, new) +} + +// Store atomically stores the passed value. +func (i *Uint32) Store(val uint32) { + atomic.StoreUint32(&i.v, val) +} + +// Swap atomically swaps the wrapped uint32 and returns the old value. +func (i *Uint32) Swap(val uint32) (old uint32) { + return atomic.SwapUint32(&i.v, val) +} + +// MarshalJSON encodes the wrapped uint32 into JSON. +func (i *Uint32) MarshalJSON() ([]byte, error) { + return json.Marshal(i.Load()) +} + +// UnmarshalJSON decodes JSON into the wrapped uint32. +func (i *Uint32) UnmarshalJSON(b []byte) error { + var v uint32 + if err := json.Unmarshal(b, &v); err != nil { + return err + } + i.Store(v) + return nil +} + +// String encodes the wrapped value as a string. +func (i *Uint32) String() string { + v := i.Load() + return strconv.FormatUint(uint64(v), 10) +} diff --git a/pkg/utils/atomic/uint32_test.go b/pkg/utils/atomic/uint32_test.go new file mode 100644 index 0000000..8bfcda2 --- /dev/null +++ b/pkg/utils/atomic/uint32_test.go @@ -0,0 +1,77 @@ +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import ( + "encoding/json" + "math" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestUint32(t *testing.T) { + atom := NewUint32(42) + + require.Equal(t, uint32(42), atom.Load(), "Load didn't work.") + require.Equal(t, uint32(46), atom.Add(4), "Add didn't work.") + require.Equal(t, uint32(44), atom.Sub(2), "Sub didn't work.") + require.Equal(t, uint32(45), atom.Inc(), "Inc didn't work.") + require.Equal(t, uint32(44), atom.Dec(), "Dec didn't work.") + + require.True(t, atom.CAS(44, 0), "CAS didn't report a swap.") + require.Equal(t, uint32(0), atom.Load(), "CAS didn't set the correct value.") + + require.Equal(t, uint32(0), atom.Swap(1), "Swap didn't return the old value.") + require.Equal(t, uint32(1), atom.Load(), "Swap didn't set the correct value.") + + atom.Store(42) + require.Equal(t, uint32(42), atom.Load(), "Store didn't set the correct value.") + + t.Run("JSON/Marshal", func(t *testing.T) { + bytes, err := json.Marshal(atom) + require.NoError(t, err, "json.Marshal errored unexpectedly.") + require.Equal(t, []byte("42"), bytes, "json.Marshal encoded the wrong bytes.") + }) + + t.Run("JSON/Unmarshal", func(t *testing.T) { + err := json.Unmarshal([]byte("40"), &atom) + require.NoError(t, err, "json.Unmarshal errored unexpectedly.") + require.Equal(t, uint32(40), atom.Load(), + "json.Unmarshal didn't set the correct value.") + }) + + t.Run("JSON/Unmarshal/Error", func(t *testing.T) { + err := json.Unmarshal([]byte(`"40"`), &atom) + require.Error(t, err, "json.Unmarshal didn't error as expected.") + assertErrorJSONUnmarshalType(t, err, + "json.Unmarshal failed with unexpected error %v, want UnmarshalTypeError.", err) + }) + + t.Run("String", func(t *testing.T) { + // Use an integer with the signed bit set. If we're converting + // incorrectly, we'll get a negative value here. + atom := NewUint32(math.MaxUint32) + assert.Equal(t, "4294967295", atom.String(), + "String() returned an unexpected value.") + }) +} diff --git a/pkg/utils/atomic/uint64.go b/pkg/utils/atomic/uint64.go new file mode 100644 index 0000000..0e2eddb --- /dev/null +++ b/pkg/utils/atomic/uint64.go @@ -0,0 +1,109 @@ +// @generated Code generated by gen-atomicint. + +// Copyright (c) 2020-2023 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import ( + "encoding/json" + "strconv" + "sync/atomic" +) + +// Uint64 is an atomic wrapper around uint64. +type Uint64 struct { + _ nocmp // disallow non-atomic comparison + + v uint64 +} + +// NewUint64 creates a new Uint64. +func NewUint64(val uint64) *Uint64 { + return &Uint64{v: val} +} + +// Load atomically loads the wrapped value. +func (i *Uint64) Load() uint64 { + return atomic.LoadUint64(&i.v) +} + +// Add atomically adds to the wrapped uint64 and returns the new value. +func (i *Uint64) Add(delta uint64) uint64 { + return atomic.AddUint64(&i.v, delta) +} + +// Sub atomically subtracts from the wrapped uint64 and returns the new value. +func (i *Uint64) Sub(delta uint64) uint64 { + return atomic.AddUint64(&i.v, ^(delta - 1)) +} + +// Inc atomically increments the wrapped uint64 and returns the new value. +func (i *Uint64) Inc() uint64 { + return i.Add(1) +} + +// Dec atomically decrements the wrapped uint64 and returns the new value. +func (i *Uint64) Dec() uint64 { + return i.Sub(1) +} + +// CAS is an atomic compare-and-swap. +// +// Deprecated: Use CompareAndSwap. +func (i *Uint64) CAS(old, new uint64) (swapped bool) { + return i.CompareAndSwap(old, new) +} + +// CompareAndSwap is an atomic compare-and-swap. +func (i *Uint64) CompareAndSwap(old, new uint64) (swapped bool) { + return atomic.CompareAndSwapUint64(&i.v, old, new) +} + +// Store atomically stores the passed value. +func (i *Uint64) Store(val uint64) { + atomic.StoreUint64(&i.v, val) +} + +// Swap atomically swaps the wrapped uint64 and returns the old value. +func (i *Uint64) Swap(val uint64) (old uint64) { + return atomic.SwapUint64(&i.v, val) +} + +// MarshalJSON encodes the wrapped uint64 into JSON. +func (i *Uint64) MarshalJSON() ([]byte, error) { + return json.Marshal(i.Load()) +} + +// UnmarshalJSON decodes JSON into the wrapped uint64. +func (i *Uint64) UnmarshalJSON(b []byte) error { + var v uint64 + if err := json.Unmarshal(b, &v); err != nil { + return err + } + i.Store(v) + return nil +} + +// String encodes the wrapped value as a string. +func (i *Uint64) String() string { + v := i.Load() + return strconv.FormatUint(uint64(v), 10) +} diff --git a/pkg/utils/atomic/uint64_test.go b/pkg/utils/atomic/uint64_test.go new file mode 100644 index 0000000..1141e5a --- /dev/null +++ b/pkg/utils/atomic/uint64_test.go @@ -0,0 +1,77 @@ +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import ( + "encoding/json" + "math" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestUint64(t *testing.T) { + atom := NewUint64(42) + + require.Equal(t, uint64(42), atom.Load(), "Load didn't work.") + require.Equal(t, uint64(46), atom.Add(4), "Add didn't work.") + require.Equal(t, uint64(44), atom.Sub(2), "Sub didn't work.") + require.Equal(t, uint64(45), atom.Inc(), "Inc didn't work.") + require.Equal(t, uint64(44), atom.Dec(), "Dec didn't work.") + + require.True(t, atom.CAS(44, 0), "CAS didn't report a swap.") + require.Equal(t, uint64(0), atom.Load(), "CAS didn't set the correct value.") + + require.Equal(t, uint64(0), atom.Swap(1), "Swap didn't return the old value.") + require.Equal(t, uint64(1), atom.Load(), "Swap didn't set the correct value.") + + atom.Store(42) + require.Equal(t, uint64(42), atom.Load(), "Store didn't set the correct value.") + + t.Run("JSON/Marshal", func(t *testing.T) { + bytes, err := json.Marshal(atom) + require.NoError(t, err, "json.Marshal errored unexpectedly.") + require.Equal(t, []byte("42"), bytes, "json.Marshal encoded the wrong bytes.") + }) + + t.Run("JSON/Unmarshal", func(t *testing.T) { + err := json.Unmarshal([]byte("40"), &atom) + require.NoError(t, err, "json.Unmarshal errored unexpectedly.") + require.Equal(t, uint64(40), atom.Load(), + "json.Unmarshal didn't set the correct value.") + }) + + t.Run("JSON/Unmarshal/Error", func(t *testing.T) { + err := json.Unmarshal([]byte(`"40"`), &atom) + require.Error(t, err, "json.Unmarshal didn't error as expected.") + assertErrorJSONUnmarshalType(t, err, + "json.Unmarshal failed with unexpected error %v, want UnmarshalTypeError.", err) + }) + + t.Run("String", func(t *testing.T) { + // Use an integer with the signed bit set. If we're converting + // incorrectly, we'll get a negative value here. + atom := NewUint64(math.MaxUint64) + assert.Equal(t, "18446744073709551615", atom.String(), + "String() returned an unexpected value.") + }) +} diff --git a/pkg/utils/atomic/uintptr.go b/pkg/utils/atomic/uintptr.go new file mode 100644 index 0000000..7d5b000 --- /dev/null +++ b/pkg/utils/atomic/uintptr.go @@ -0,0 +1,109 @@ +// @generated Code generated by gen-atomicint. + +// Copyright (c) 2020-2023 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import ( + "encoding/json" + "strconv" + "sync/atomic" +) + +// Uintptr is an atomic wrapper around uintptr. +type Uintptr struct { + _ nocmp // disallow non-atomic comparison + + v uintptr +} + +// NewUintptr creates a new Uintptr. +func NewUintptr(val uintptr) *Uintptr { + return &Uintptr{v: val} +} + +// Load atomically loads the wrapped value. +func (i *Uintptr) Load() uintptr { + return atomic.LoadUintptr(&i.v) +} + +// Add atomically adds to the wrapped uintptr and returns the new value. +func (i *Uintptr) Add(delta uintptr) uintptr { + return atomic.AddUintptr(&i.v, delta) +} + +// Sub atomically subtracts from the wrapped uintptr and returns the new value. +func (i *Uintptr) Sub(delta uintptr) uintptr { + return atomic.AddUintptr(&i.v, ^(delta - 1)) +} + +// Inc atomically increments the wrapped uintptr and returns the new value. +func (i *Uintptr) Inc() uintptr { + return i.Add(1) +} + +// Dec atomically decrements the wrapped uintptr and returns the new value. +func (i *Uintptr) Dec() uintptr { + return i.Sub(1) +} + +// CAS is an atomic compare-and-swap. +// +// Deprecated: Use CompareAndSwap. +func (i *Uintptr) CAS(old, new uintptr) (swapped bool) { + return i.CompareAndSwap(old, new) +} + +// CompareAndSwap is an atomic compare-and-swap. +func (i *Uintptr) CompareAndSwap(old, new uintptr) (swapped bool) { + return atomic.CompareAndSwapUintptr(&i.v, old, new) +} + +// Store atomically stores the passed value. +func (i *Uintptr) Store(val uintptr) { + atomic.StoreUintptr(&i.v, val) +} + +// Swap atomically swaps the wrapped uintptr and returns the old value. +func (i *Uintptr) Swap(val uintptr) (old uintptr) { + return atomic.SwapUintptr(&i.v, val) +} + +// MarshalJSON encodes the wrapped uintptr into JSON. +func (i *Uintptr) MarshalJSON() ([]byte, error) { + return json.Marshal(i.Load()) +} + +// UnmarshalJSON decodes JSON into the wrapped uintptr. +func (i *Uintptr) UnmarshalJSON(b []byte) error { + var v uintptr + if err := json.Unmarshal(b, &v); err != nil { + return err + } + i.Store(v) + return nil +} + +// String encodes the wrapped value as a string. +func (i *Uintptr) String() string { + v := i.Load() + return strconv.FormatUint(uint64(v), 10) +} diff --git a/pkg/utils/atomic/uintptr_test.go b/pkg/utils/atomic/uintptr_test.go new file mode 100644 index 0000000..7d8ac39 --- /dev/null +++ b/pkg/utils/atomic/uintptr_test.go @@ -0,0 +1,80 @@ +// Copyright (c) 2021 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import ( + "encoding/json" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestUintptr(t *testing.T) { + atom := NewUintptr(42) + + require.Equal(t, uintptr(42), atom.Load(), "Load didn't work.") + require.Equal(t, uintptr(46), atom.Add(4), "Add didn't work.") + require.Equal(t, uintptr(44), atom.Sub(2), "Sub didn't work.") + require.Equal(t, uintptr(45), atom.Inc(), "Inc didn't work.") + require.Equal(t, uintptr(44), atom.Dec(), "Dec didn't work.") + + require.True(t, atom.CAS(44, 0), "CAS didn't report a swap.") + require.Equal(t, uintptr(0), atom.Load(), "CAS didn't set the correct value.") + + require.Equal(t, uintptr(0), atom.Swap(1), "Swap didn't return the old value.") + require.Equal(t, uintptr(1), atom.Load(), "Swap didn't set the correct value.") + + atom.Store(42) + require.Equal(t, uintptr(42), atom.Load(), "Store didn't set the correct value.") + + t.Run("JSON/Marshal", func(t *testing.T) { + bytes, err := json.Marshal(atom) + require.NoError(t, err, "json.Marshal errored unexpectedly.") + require.Equal(t, []byte("42"), bytes, "json.Marshal encoded the wrong bytes.") + }) + + t.Run("JSON/Unmarshal", func(t *testing.T) { + err := json.Unmarshal([]byte("40"), &atom) + require.NoError(t, err, "json.Unmarshal errored unexpectedly.") + require.Equal(t, uintptr(40), atom.Load(), + "json.Unmarshal didn't set the correct value.") + }) + + t.Run("JSON/Unmarshal/Error", func(t *testing.T) { + err := json.Unmarshal([]byte(`"40"`), &atom) + require.Error(t, err, "json.Unmarshal didn't error as expected.") + assertErrorJSONUnmarshalType(t, err, + "json.Unmarshal failed with unexpected error %v, want UnmarshalTypeError.", err) + }) + + t.Run("String", func(t *testing.T) { + // Use an integer with the signed bit set. If we're converting + // incorrectly, we'll get a negative value here. + // Use an int variable, as constants cause compile-time overflows. + negative := -1 + atom := NewUintptr(uintptr(negative)) + want := fmt.Sprint(uintptr(negative)) + assert.Equal(t, want, atom.String(), + "String() returned an unexpected value.") + }) +} diff --git a/pkg/utils/atomic/unsafe_pointer.go b/pkg/utils/atomic/unsafe_pointer.go new file mode 100644 index 0000000..34868ba --- /dev/null +++ b/pkg/utils/atomic/unsafe_pointer.go @@ -0,0 +1,65 @@ +// Copyright (c) 2021-2022 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import ( + "sync/atomic" + "unsafe" +) + +// UnsafePointer is an atomic wrapper around unsafe.Pointer. +type UnsafePointer struct { + _ nocmp // disallow non-atomic comparison + + v unsafe.Pointer +} + +// NewUnsafePointer creates a new UnsafePointer. +func NewUnsafePointer(val unsafe.Pointer) *UnsafePointer { + return &UnsafePointer{v: val} +} + +// Load atomically loads the wrapped value. +func (p *UnsafePointer) Load() unsafe.Pointer { + return atomic.LoadPointer(&p.v) +} + +// Store atomically stores the passed value. +func (p *UnsafePointer) Store(val unsafe.Pointer) { + atomic.StorePointer(&p.v, val) +} + +// Swap atomically swaps the wrapped unsafe.Pointer and returns the old value. +func (p *UnsafePointer) Swap(val unsafe.Pointer) (old unsafe.Pointer) { + return atomic.SwapPointer(&p.v, val) +} + +// CAS is an atomic compare-and-swap. +// +// Deprecated: Use CompareAndSwap +func (p *UnsafePointer) CAS(old, new unsafe.Pointer) (swapped bool) { + return p.CompareAndSwap(old, new) +} + +// CompareAndSwap is an atomic compare-and-swap. +func (p *UnsafePointer) CompareAndSwap(old, new unsafe.Pointer) (swapped bool) { + return atomic.CompareAndSwapPointer(&p.v, old, new) +} diff --git a/pkg/utils/atomic/unsafe_pointer_test.go b/pkg/utils/atomic/unsafe_pointer_test.go new file mode 100644 index 0000000..f0193df --- /dev/null +++ b/pkg/utils/atomic/unsafe_pointer_test.go @@ -0,0 +1,83 @@ +// Copyright (c) 2021 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import ( + "testing" + "unsafe" + + "github.com/stretchr/testify/require" +) + +func TestUnsafePointer(t *testing.T) { + i := int64(42) + j := int64(0) + k := int64(1) + + tests := []struct { + desc string + newAtomic func() *UnsafePointer + initial unsafe.Pointer + }{ + { + desc: "non-empty", + newAtomic: func() *UnsafePointer { + return NewUnsafePointer(unsafe.Pointer(&i)) + }, + initial: unsafe.Pointer(&i), + }, + { + desc: "nil", + newAtomic: func() *UnsafePointer { + var p UnsafePointer + return &p + }, + initial: unsafe.Pointer(nil), + }, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + t.Run("Load", func(t *testing.T) { + atom := tt.newAtomic() + require.Equal(t, tt.initial, atom.Load(), "Load should report nil.") + }) + + t.Run("Swap", func(t *testing.T) { + atom := tt.newAtomic() + require.Equal(t, tt.initial, atom.Swap(unsafe.Pointer(&k)), "Swap didn't return the old value.") + require.Equal(t, unsafe.Pointer(&k), atom.Load(), "Swap didn't set the correct value.") + }) + + t.Run("CAS", func(t *testing.T) { + atom := tt.newAtomic() + require.True(t, atom.CAS(tt.initial, unsafe.Pointer(&j)), "CAS didn't report a swap.") + require.Equal(t, unsafe.Pointer(&j), atom.Load(), "CAS didn't set the correct value.") + }) + + t.Run("Store", func(t *testing.T) { + atom := tt.newAtomic() + atom.Store(unsafe.Pointer(&i)) + require.Equal(t, unsafe.Pointer(&i), atom.Load(), "Store didn't set the correct value.") + }) + }) + } +} diff --git a/pkg/utils/atomic/value.go b/pkg/utils/atomic/value.go new file mode 100644 index 0000000..52caedb --- /dev/null +++ b/pkg/utils/atomic/value.go @@ -0,0 +1,31 @@ +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import "sync/atomic" + +// Value shadows the type of the same name from sync/atomic +// https://godoc.org/sync/atomic#Value +type Value struct { + _ nocmp // disallow non-atomic comparison + + atomic.Value +} diff --git a/pkg/utils/atomic/value_test.go b/pkg/utils/atomic/value_test.go new file mode 100644 index 0000000..bb9f301 --- /dev/null +++ b/pkg/utils/atomic/value_test.go @@ -0,0 +1,40 @@ +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package atomic + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestValue(t *testing.T) { + var v Value + assert.Nil(t, v.Load(), "initial Value is not nil") + + v.Store(42) + assert.Equal(t, 42, v.Load()) + + v.Store(84) + assert.Equal(t, 84, v.Load()) + + assert.Panics(t, func() { v.Store("foo") }) +} diff --git a/pkg/utils/go.mod b/pkg/utils/go.mod index 5ddcafc..19b06e2 100644 --- a/pkg/utils/go.mod +++ b/pkg/utils/go.mod @@ -4,16 +4,28 @@ go 1.25.0 require ( encoders.orly v0.0.0-00010101000000-000000000000 + github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0 + github.com/stretchr/testify v1.11.1 + go.uber.org/atomic v1.11.0 + golang.org/x/lint v0.0.0-20241112194109-818c5a804067 + honnef.co/go/tools v0.6.1 lol.mleku.dev v1.0.2 ) require ( + github.com/BurntSushi/toml v1.4.1-0.20240526193622-a339e1f7089c // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/fatih/color v1.18.0 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b // indirect + golang.org/x/exp/typeparams v0.0.0-20231108232855-2478ac86f678 // indirect + golang.org/x/mod v0.27.0 // indirect + golang.org/x/sync v0.16.0 // indirect golang.org/x/sys v0.35.0 // indirect + golang.org/x/tools v0.36.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) replace ( diff --git a/pkg/utils/go.sum b/pkg/utils/go.sum index b1dd4bf..661bcb1 100644 --- a/pkg/utils/go.sum +++ b/pkg/utils/go.sum @@ -1,16 +1,57 @@ +github.com/BurntSushi/toml v1.4.1-0.20240526193622-a339e1f7089c h1:pxW6RcqyfI9/kWtOwnv/G+AzdKuy2ZrqINhenH4HyNs= +github.com/BurntSushi/toml v1.4.1-0.20240526193622-a339e1f7089c/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0 h1:iQTw/8FWTuc7uiaSepXwyf3o52HaUYcV+Tu66S3F5GA= +github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0/go.mod h1:1NbS8ALrpOvjt0rHPNLyCIeMtbizbir8U//inJ+zuB8= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b h1:DXr+pvt3nC887026GRP39Ej11UATqWDmWuS99x26cD0= golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b/go.mod h1:4QTo5u+SEIbbKW1RacMZq1YEfOBqeXa19JeshGi+zc4= +golang.org/x/exp/typeparams v0.0.0-20231108232855-2478ac86f678 h1:1P7xPZEwZMoBoz0Yze5Nx2/4pxj6nw9ZqHWXqP0iRgQ= +golang.org/x/exp/typeparams v0.0.0-20231108232855-2478ac86f678/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk= +golang.org/x/lint v0.0.0-20241112194109-818c5a804067 h1:adDmSQyFTCiv19j015EGKJBoaa7ElV0Q1Wovb/4G7NA= +golang.org/x/lint v0.0.0-20241112194109-818c5a804067/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= +golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= +golang.org/x/mod v0.27.0 h1:kb+q2PyFnEADO2IEF935ehFUXlWiNjJWtRNgBLSfbxQ= +golang.org/x/mod v0.27.0/go.mod h1:rWI627Fq0DEoudcK+MBkNkCe0EetEaDSwJJkCcjpazc= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= +golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/tools v0.36.0 h1:kWS0uv/zsvHEle1LbV5LE8QujrxB3wfQyxHfhOk0Qkg= +golang.org/x/tools v0.36.0/go.mod h1:WBDiHKJK8YgLHlcQPYQzNCkUxUypCaa5ZegCVutKm+s= +golang.org/x/tools/go/expect v0.1.1-deprecated h1:jpBZDwmgPhXsKZC6WhL20P4b/wmnpsEAGHaNy0n/rJM= +golang.org/x/tools/go/expect v0.1.1-deprecated/go.mod h1:eihoPOH+FgIqa3FpoTwguz/bVUSGBlGQU67vpBeOrBY= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +honnef.co/go/tools v0.6.1 h1:R094WgE8K4JirYjBaOpz/AvTyUu/3wbmAoskKN/pxTI= +honnef.co/go/tools v0.6.1/go.mod h1:3puzxxljPCe8RGJX7BIy1plGbxEOZni5mR2aXe3/uk4= lol.mleku.dev v1.0.2 h1:bSV1hHnkmt1hq+9nSvRwN6wgcI7itbM3XRZ4dMB438c= lol.mleku.dev v1.0.2/go.mod h1:DQ0WnmkntA9dPLCXgvtIgYt5G0HSqx3wSTLolHgWeLA= lukechampine.com/frand v1.5.1 h1:fg0eRtdmGFIxhP5zQJzM1lFDbD6CUfu/f+7WgAZd5/w= diff --git a/pkg/utils/interrupt/README.md b/pkg/utils/interrupt/README.md new file mode 100644 index 0000000..40f997f --- /dev/null +++ b/pkg/utils/interrupt/README.md @@ -0,0 +1,2 @@ +# interrupt +Handle shutdowns cleanly and enable hot reload diff --git a/pkg/utils/interrupt/main.go b/pkg/utils/interrupt/main.go new file mode 100644 index 0000000..69ce322 --- /dev/null +++ b/pkg/utils/interrupt/main.go @@ -0,0 +1,153 @@ +// Package interrupt is a library for providing handling for Ctrl-C/Interrupt +// handling and triggering callbacks for such things as closing files, flushing +// buffers, and other elements of graceful shutdowns. +package interrupt + +import ( + "fmt" + "os" + "os/signal" + "runtime" + + "lol.mleku.dev/log" + "utils.orly/atomic" + "utils.orly/qu" +) + +// HandlerWithSource is an interrupt handling closure and the source location +// that it was sent from. +type HandlerWithSource struct { + Source string + Fn func() +} + +var ( + // RestartRequested is set true after restart is requested. + RestartRequested bool // = true + requested atomic.Bool + + // ch is used to receive SIGINT (Ctrl+C) signals. + ch chan os.Signal + + // signals is the list of signals that cause the interrupt + signals = []os.Signal{os.Interrupt} + + // ShutdownRequestChan is a channel that can receive shutdown requests + ShutdownRequestChan = qu.T() + + // addHandlerChan is used to add an interrupt handler to the list of + // handlers to be invoked on SIGINT (Ctrl+C) signals. + addHandlerChan = make(chan HandlerWithSource) + + // HandlersDone is closed after all interrupt handlers run the first time an + // interrupt is signaled. + HandlersDone = make(qu.C) + + interruptCallbacks []func() + interruptCallbackSources []string +) + +// Listener listens for interrupt signals, registers interrupt callbacks, and +// responds to custom shutdown signals as required +func Listener() { + invokeCallbacks := func() { + // run handlers in LIFO order. + for i := range interruptCallbacks { + idx := len(interruptCallbacks) - 1 - i + log.T.F( + "running callback %d from %s", idx, + interruptCallbackSources[idx], + ) + interruptCallbacks[idx]() + } + log.D.Ln("interrupt handlers finished") + HandlersDone.Q() + if RestartRequested { + Restart() + } else { + os.Exit(0) + } + } +out: + for { + select { + case _ = <-ch: + fmt.Fprintf(os.Stderr, "\r") + requested.Store(true) + invokeCallbacks() + break out + + case <-ShutdownRequestChan.Wait(): + log.W.Ln("received shutdown request - shutting down...") + requested.Store(true) + invokeCallbacks() + break out + + case handler := <-addHandlerChan: + interruptCallbacks = append(interruptCallbacks, handler.Fn) + interruptCallbackSources = append( + interruptCallbackSources, + handler.Source, + ) + + case <-HandlersDone.Wait(): + break out + } + } +} + +// AddHandler adds a handler to call when a SIGINT (Ctrl+C) is received. +func AddHandler(handler func()) { + // Create the channel and start the main interrupt handler which invokes all + // other callbacks and exits if not already done. + _, loc, line, _ := runtime.Caller(1) + msg := fmt.Sprintf("%s:%d", loc, line) + if ch == nil { + ch = make(chan os.Signal) + signal.Notify(ch, signals...) + go Listener() + } + addHandlerChan <- HandlerWithSource{ + msg, handler, + } +} + +// Request programmatically requests a shutdown +func Request() { + _, f, l, _ := runtime.Caller(1) + log.D.Ln("interrupt requested", f, l, requested.Load()) + if requested.Load() { + log.D.Ln("requested again") + return + } + requested.Store(true) + ShutdownRequestChan.Q() + var ok bool + select { + case _, ok = <-ShutdownRequestChan: + default: + } + if ok { + close(ShutdownRequestChan) + } +} + +// GoroutineDump returns a string with the current goroutine dump in order to +// show what's going on in case of timeout. +func GoroutineDump() string { + buf := make([]byte, 1<<18) + n := runtime.Stack(buf, true) + return string(buf[:n]) +} + +// RequestRestart sets the reset flag and requests a restart +func RequestRestart() { + RestartRequested = true + log.D.Ln("requesting restart") + Request() +} + +// Requested returns true if an interrupt has been requested +func Requested() bool { + return requested.Load() +} diff --git a/pkg/utils/interrupt/restart.go b/pkg/utils/interrupt/restart.go new file mode 100644 index 0000000..41143b1 --- /dev/null +++ b/pkg/utils/interrupt/restart.go @@ -0,0 +1,26 @@ +//go:build linux + +package interrupt + +import ( + "lol.mleku.dev/log" + "os" + "syscall" + + "github.com/kardianos/osext" +) + +// Restart uses syscall.Exec to restart the process. macOS and Windows are not +// implemented, currently. +func Restart() { + log.D.Ln("restarting") + file, e := osext.Executable() + if e != nil { + log.E.Ln(e) + return + } + e = syscall.Exec(file, os.Args, os.Environ()) + if e != nil { + log.F.Ln(e) + } +} diff --git a/pkg/utils/interrupt/restart_darwin.go b/pkg/utils/interrupt/restart_darwin.go new file mode 100644 index 0000000..e740315 --- /dev/null +++ b/pkg/utils/interrupt/restart_darwin.go @@ -0,0 +1,20 @@ +package interrupt + +func Restart() { + // TODO: test this thing actually works! + // log.D.Ln("doing windows restart") + // // procAttr := new(os.ProcAttr) + // // procAttr.Files = []*os.File{os.Stdin, os.Stdout, os.Stderr} + // // os.StartProcess(os.Args[0], os.Args[1:], procAttr) + // var s []string + // // s = []string{"cmd.exe", "/C", "start"} + // s = append(s, os.Args[0]) + // // s = append(s, "--delaystart") + // s = append(s, os.Args[1:]...) + // cmd := exec.Command(s[0], s[1:]...) + // log.D.Ln("windows restart done") + // if err := cmd.Start(); log.Fail(err) { + // } + // // select{} + // os.Exit(0) +} diff --git a/pkg/utils/interrupt/restart_windows.go b/pkg/utils/interrupt/restart_windows.go new file mode 100644 index 0000000..e740315 --- /dev/null +++ b/pkg/utils/interrupt/restart_windows.go @@ -0,0 +1,20 @@ +package interrupt + +func Restart() { + // TODO: test this thing actually works! + // log.D.Ln("doing windows restart") + // // procAttr := new(os.ProcAttr) + // // procAttr.Files = []*os.File{os.Stdin, os.Stdout, os.Stderr} + // // os.StartProcess(os.Args[0], os.Args[1:], procAttr) + // var s []string + // // s = []string{"cmd.exe", "/C", "start"} + // s = append(s, os.Args[0]) + // // s = append(s, "--delaystart") + // s = append(s, os.Args[1:]...) + // cmd := exec.Command(s[0], s[1:]...) + // log.D.Ln("windows restart done") + // if err := cmd.Start(); log.Fail(err) { + // } + // // select{} + // os.Exit(0) +} diff --git a/pkg/utils/interrupt/sigterm.go b/pkg/utils/interrupt/sigterm.go new file mode 100644 index 0000000..d8287de --- /dev/null +++ b/pkg/utils/interrupt/sigterm.go @@ -0,0 +1,12 @@ +//go:build darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris + +package interrupt + +import ( + "os" + "syscall" +) + +func init() { + signals = []os.Signal{os.Interrupt, syscall.SIGTERM} +} diff --git a/pkg/utils/normalize/normalize.go b/pkg/utils/normalize/normalize.go index a9e4cce..d51a7a5 100644 --- a/pkg/utils/normalize/normalize.go +++ b/pkg/utils/normalize/normalize.go @@ -4,6 +4,7 @@ package normalize import ( "bytes" + "errors" "fmt" "net/url" @@ -105,6 +106,15 @@ func Msg(prefix Reason, format string, params ...any) []byte { return []byte(fmt.Sprintf(prefix.S()+": "+format, params...)) } +// MsgString constructs a properly formatted message with a machine-readable prefix +// for OK and CLOSED envelopes. +func MsgString(prefix Reason, format string, params ...any) string { + if len(prefix) < 1 { + prefix = Error + } + return fmt.Sprintf(prefix.S()+": "+format, params...) +} + // Reason is the machine-readable prefix before the colon in an OK or CLOSED // envelope message. Below are the most common kinds that are mentioned in // NIP-01. @@ -141,3 +151,12 @@ func (r Reason) F(format string, params ...any) []byte { r, format, params..., ) } + +// Errorf allows creation of a full Reason text with a printf style as an error. +func (r Reason) Errorf(format string, params ...any) (err error) { + return errors.New( + MsgString( + r, format, params..., + ), + ) +} diff --git a/pkg/utils/qu/README.adoc b/pkg/utils/qu/README.adoc new file mode 100644 index 0000000..0ad2e8a --- /dev/null +++ b/pkg/utils/qu/README.adoc @@ -0,0 +1,60 @@ += qu + +===== observable signal channels + +simple channels that act as breakers or momentary one-shot triggers. + +can enable logging to get detailed information on channel state, and channels do +not panic if closed channels are attempted to be closed or signalled with. + +provides a neat function based syntax for usage. + +wait function does require use of the `<-` receive operator prefix to be used in +a select statement. + +== usage + +=== creating channels: + +==== unbuffered + +---- +newSigChan := qu.T() +---- + +==== buffered + +---- +newBufferedSigChan := qu.Ts(5) +---- + +==== closing + +---- +newSigChan.Q() +---- + +==== signalling + +---- +newBufferedSigChan.Signal() +---- + +==== logging features + +---- +numberOpenUnbufferedChannels := GetOpenUnbufferedChanCount() + +numberOpenBufferedChannels := GetOpenBufferedChanCount() +---- + +print a list of closed and open channels known by qu: + +---- +PrintChanState() +---- + +== garbage collection + +this library automatically cleans up closed channels once a minute to free +resources that have become unused. \ No newline at end of file diff --git a/pkg/utils/qu/qu.go b/pkg/utils/qu/qu.go new file mode 100644 index 0000000..d849bdc --- /dev/null +++ b/pkg/utils/qu/qu.go @@ -0,0 +1,245 @@ +// Package qu is a library for making handling signal (chan struct{}) channels +// simpler, as well as monitoring the state of the signal channels in an +// application. +package qu + +import ( + "fmt" + "strings" + "sync" + "time" + + "go.uber.org/atomic" + "lol.mleku.dev" + "lol.mleku.dev/log" +) + +// C is your basic empty struct signal channel +type C chan struct{} + +var ( + createdList []string + createdChannels []C + createdChannelBufferCounts []int + mx sync.Mutex + logEnabled = atomic.NewBool(false) +) + +// SetLogging switches on and off the channel logging +func SetLogging(on bool) { + logEnabled.Store(on) +} + +func l(a ...interface{}) { + if logEnabled.Load() { + log.D.Ln(a...) + } +} + +func lc(cl func() string) { + if logEnabled.Load() { + log.D.Ln(cl()) + } +} + +// T creates an unbuffered chan struct{} for trigger and quit signalling +// (momentary and breaker switches) +func T() C { + mx.Lock() + defer mx.Unlock() + msg := fmt.Sprintf("chan from %s", lol.GetLoc(1)) + l("created", msg) + createdList = append(createdList, msg) + o := make(C) + createdChannels = append(createdChannels, o) + createdChannelBufferCounts = append(createdChannelBufferCounts, 0) + return o +} + +// Ts creates a buffered chan struct{} which is specifically intended for +// signalling without blocking, generally one is the size of buffer to be used, +// though there might be conceivable cases where the channel should accept more +// signals without blocking the caller +func Ts(n int) C { + mx.Lock() + defer mx.Unlock() + msg := fmt.Sprintf("buffered chan (%d) from %s", n, lol.GetLoc(1)) + l("created", msg) + createdList = append(createdList, msg) + o := make(C, n) + createdChannels = append(createdChannels, o) + createdChannelBufferCounts = append(createdChannelBufferCounts, n) + return o +} + +// Q closes the channel, which makes it emit a nil every time it is selected. +func (c C) Q() { + open := !testChanIsClosed(c) + lc( + func() (o string) { + lo := getLocForChan(c) + mx.Lock() + defer mx.Unlock() + if open { + return "closing chan from " + lo + "\n" + strings.Repeat( + " ", + 48, + ) + "from" + lol.GetLoc(1) + } else { + return "from" + lol.GetLoc(1) + "\n" + strings.Repeat(" ", 48) + + "channel " + lo + " was already closed" + } + }, + ) + if open { + close(c) + } +} + +// Signal sends struct{}{} on the channel which functions as a momentary switch, +// useful in pairs for stop/start +func (c C) Signal() { + lc(func() (o string) { return "signalling " + getLocForChan(c) }) + if !testChanIsClosed(c) { + c <- struct{}{} + } +} + +// Wait should be placed with a `<-` in a select case in addition to the channel +// variable name +func (c C) Wait() <-chan struct{} { + lc( + func() (o string) { + return fmt.Sprint( + "waiting on "+getLocForChan(c)+"at", + lol.GetLoc(1), + ) + }, + ) + return c +} + +// IsClosed exposes a test to see if the channel is closed +func (c C) IsClosed() bool { + return testChanIsClosed(c) +} + +// testChanIsClosed allows you to see whether the channel has been closed so you +// can avoid a panic by trying to close or signal on it +func testChanIsClosed(ch C) (o bool) { + if ch == nil { + return true + } + select { + case <-ch: + o = true + default: + } + return +} + +// getLocForChan finds which record connects to the channel in question +func getLocForChan(c C) (s string) { + s = "not found" + mx.Lock() + for i := range createdList { + if i >= len(createdChannels) { + break + } + if createdChannels[i] == c { + s = createdList[i] + } + } + mx.Unlock() + return +} + +// once a minute clean up the channel cache to remove closed channels no longer +// in use +func init() { + go func() { + for { + <-time.After(time.Minute) + l("cleaning up closed channels") + var c []C + var ll []string + mx.Lock() + for i := range createdChannels { + if i >= len(createdList) { + break + } + if testChanIsClosed(createdChannels[i]) { + } else { + c = append(c, createdChannels[i]) + ll = append(ll, createdList[i]) + } + } + createdChannels = c + createdList = ll + mx.Unlock() + } + }() +} + +// PrintChanState creates an output showing the current state of the channels +// being monitored This is a function for use by the programmer while debugging +func PrintChanState() { + mx.Lock() + for i := range createdChannels { + if i >= len(createdList) { + break + } + if testChanIsClosed(createdChannels[i]) { + log.T.Ln(">>> closed", createdList[i]) + } else { + log.T.Ln("<<< open", createdList[i]) + } + } + mx.Unlock() +} + +// GetOpenUnbufferedChanCount returns the number of qu channels that are still +// open +func GetOpenUnbufferedChanCount() (o int) { + mx.Lock() + var c int + for i := range createdChannels { + if i >= len(createdChannels) { + break + } + // skip buffered channels + if createdChannelBufferCounts[i] > 0 { + continue + } + if testChanIsClosed(createdChannels[i]) { + c++ + } else { + o++ + } + } + mx.Unlock() + return +} + +// GetOpenBufferedChanCount returns the number of qu channels that are still +// open +func GetOpenBufferedChanCount() (o int) { + mx.Lock() + var c int + for i := range createdChannels { + if i >= len(createdChannels) { + break + } + // skip unbuffered channels + if createdChannelBufferCounts[i] < 1 { + continue + } + if testChanIsClosed(createdChannels[i]) { + c++ + } else { + o++ + } + } + mx.Unlock() + return +} diff --git a/pprof.go b/pprof.go index 3e8d6b0..e68e8ad 100644 --- a/pprof.go +++ b/pprof.go @@ -2,18 +2,19 @@ package main import ( "github.com/pkg/profile" + "utils.orly/interrupt" ) func startProfiler(mode string) { switch mode { case "cpu": prof := profile.Start(profile.CPUProfile) - defer prof.Stop() + interrupt.AddHandler(prof.Stop) case "memory": prof := profile.Start(profile.MemProfile) - defer prof.Stop() + interrupt.AddHandler(prof.Stop) case "allocation": prof := profile.Start(profile.MemProfileAllocs) - defer prof.Stop() + interrupt.AddHandler(prof.Stop) } }