Add export functionality and fix privilege checks

- Added `Export` method in `database/database.go` to export events to an io.Writer
- Implemented detailed logic for exporting all or specific pubkeys' events
- Removed placeholder `Export` function with TODO comment from `database/database.go`
- Updated error handling in `handleReq.go` and `publisher.go` by using `err != nil` instead of `chk.E(err)`
- Added more detailed logging in privilege check conditions in both `publisher.go` and `handleReq.go`
- Introduced new imports such as `"fmt"` in `connection.go` for improved error message formatting
- Created a new file `export.go` under the `database` package with complete implementation of export functionality
This commit is contained in:
2025-07-22 11:34:57 +01:00
parent 53d649c64e
commit 651791aec1
12 changed files with 217 additions and 52 deletions

View File

@@ -5,7 +5,6 @@ import (
"orly.dev/pkg/encoders/event"
"orly.dev/pkg/encoders/filter"
"orly.dev/pkg/encoders/hex"
"orly.dev/pkg/encoders/kind"
"orly.dev/pkg/encoders/kinds"
"orly.dev/pkg/encoders/tag"
"orly.dev/pkg/protocol/ws"
@@ -17,12 +16,12 @@ import (
)
func (s *Server) SpiderFetch(
k *kind.T, noFetch bool, pubkeys ...[]byte,
k *kinds.T, noFetch, noExtract bool, pubkeys ...[]byte,
) (pks [][]byte, err error) {
// first search the local database
pkList := tag.New(pubkeys...)
f := &filter.F{
Kinds: kinds.New(k),
Kinds: k,
Authors: pkList,
}
var evs event.S
@@ -30,23 +29,37 @@ func (s *Server) SpiderFetch(
// none were found, so we need to scan the spiders
err = nil
}
var kindsList string
for i, kk := range k.K {
if i > 0 {
kindsList += ","
}
kindsList += kk.Name()
}
log.I.F("%d events found of type %s", len(evs), kindsList)
// for _, ev := range evs {
// o += fmt.Sprintf("%s\n\n", ev.Marshal(nil))
// }
// log.I.F("%s", o)
if len(evs) < len(pubkeys) && !noFetch {
// we need to search the spider seeds.
// Break up pubkeys into batches of 512
for i := 0; i < len(pubkeys); i += 512 {
end := i + 512
// Break up pubkeys into batches of 128
for i := 0; i < len(pubkeys); i += 128 {
end := i + 128
if end > len(pubkeys) {
end = len(pubkeys)
}
batchPubkeys := pubkeys[i:end]
log.I.F(
"processing batch %d to %d of %d for kind %s",
i, end, len(pubkeys), k.Name(),
i, end, len(pubkeys), kindsList,
)
batchPkList := tag.New(batchPubkeys...)
lim := uint(batchPkList.Len())
batchFilter := &filter.F{
Kinds: kinds.New(k),
Kinds: k,
Authors: batchPkList,
Limit: &lim,
}
var mx sync.Mutex
@@ -76,6 +89,16 @@ func (s *Server) SpiderFetch(
return
}
mx.Lock()
// save the events to the database
for _, ev := range evss {
log.I.F("saving event:\n%s", ev.Marshal(nil))
if _, _, err = s.Storage().SaveEvent(
s.Ctx, ev,
); chk.E(err) {
err = nil
continue
}
}
for _, ev := range evss {
evs = append(evs, ev)
}
@@ -84,13 +107,6 @@ func (s *Server) SpiderFetch(
}
wg.Wait()
}
// save the events to the database
for _, ev := range evs {
if _, _, err = s.Storage().SaveEvent(s.Ctx, ev); chk.E(err) {
err = nil
continue
}
}
}
// deduplicate and take the newest
var tmp event.S
@@ -108,7 +124,10 @@ func (s *Server) SpiderFetch(
tmp = append(tmp, evm[0])
}
evs = tmp
// we have all we're going to get now
// we have all we're going to get now, extract the p tags
if noExtract {
return
}
pkMap := make(map[string]struct{})
for _, ev := range evs {
t := ev.Tags.GetAll(tag.New("p"))
@@ -118,7 +137,7 @@ func (s *Server) SpiderFetch(
continue
}
pk := make([]byte, schnorr.PubKeyBytesLen)
if _, err = hex.DecBytes(pk, pkh); chk.E(err) {
if _, err = hex.DecBytes(pk, pkh); err != nil {
err = nil
continue
}

View File

@@ -6,6 +6,7 @@ import (
"orly.dev/pkg/encoders/bech32encoding"
"orly.dev/pkg/encoders/hex"
"orly.dev/pkg/encoders/kind"
"orly.dev/pkg/encoders/kinds"
"orly.dev/pkg/utils/chk"
"orly.dev/pkg/utils/log"
)
@@ -52,21 +53,22 @@ func (s *Server) Spider(noFetch ...bool) (err error) {
log.I.F("getting ownersFollowed")
var ownersFollowed [][]byte
if ownersFollowed, err = s.SpiderFetch(
kind.FollowList, dontFetch, ownersPubkeys...,
kinds.New(kind.FollowList), dontFetch, false, ownersPubkeys...,
); chk.E(err) {
return
}
// log.I.S(ownersFollowed)
log.I.F("getting followedFollows")
var followedFollows [][]byte
if followedFollows, err = s.SpiderFetch(
kind.FollowList, dontFetch, ownersFollowed...,
kinds.New(kind.FollowList), dontFetch, false, ownersFollowed...,
); chk.E(err) {
return
}
log.I.F("getting ownersMuted")
var ownersMuted [][]byte
if ownersMuted, err = s.SpiderFetch(
kind.MuteList, dontFetch, ownersPubkeys...,
kinds.New(kind.MuteList), dontFetch, false, ownersPubkeys...,
); chk.E(err) {
return
}
@@ -74,22 +76,17 @@ func (s *Server) Spider(noFetch ...bool) (err error) {
// list
filteredFollows := make([][]byte, 0, len(followedFollows))
for _, follow := range followedFollows {
found := false
for _, owner := range ownersFollowed {
if bytes.Equal(follow, owner) {
found = true
break
}
}
for _, owner := range ownersMuted {
if bytes.Equal(follow, owner) {
found = true
break
}
}
if !found {
filteredFollows = append(filteredFollows, follow)
}
filteredFollows = append(filteredFollows, follow)
}
followedFollows = filteredFollows
own := "owner"
@@ -115,7 +112,7 @@ func (s *Server) Spider(noFetch ...bool) (err error) {
len(followedFollows), folfol,
len(ownersMuted), mut,
)
// add the owners
// add the owners to the ownersFollowed
ownersFollowed = append(ownersFollowed, ownersPubkeys...)
s.SetOwnersPubkeys(ownersPubkeys)
s.SetOwnersFollowed(ownersFollowed)
@@ -125,9 +122,12 @@ func (s *Server) Spider(noFetch ...bool) (err error) {
if !dontFetch {
go func() {
everyone := append(ownersFollowed, followedFollows...)
s.SpiderFetch(kind.ProfileMetadata, false, everyone...)
s.SpiderFetch(kind.RelayListMetadata, false, everyone...)
s.SpiderFetch(kind.DMRelaysList, false, everyone...)
s.SpiderFetch(
kinds.New(
kind.ProfileMetadata, kind.RelayListMetadata,
kind.DMRelaysList,
), false, true, everyone...,
)
}()
}
}()

View File

@@ -80,11 +80,6 @@ func (d *D) Import(r io.Reader) {
panic("implement me")
}
func (d *D) Export(c context.T, w io.Writer, pubkeys ...[]byte) {
// TODO implement me
panic("implement me")
}
func (d *D) SetLogLevel(level string) {
d.Logger.SetLogLevel(lol.GetLogLevel(level))
}

96
pkg/database/export.go Normal file
View File

@@ -0,0 +1,96 @@
package database
import (
"bytes"
"github.com/dgraph-io/badger/v4"
"io"
"orly.dev/pkg/database/indexes"
"orly.dev/pkg/database/indexes/types"
"orly.dev/pkg/encoders/codecbuf"
"orly.dev/pkg/encoders/event"
"orly.dev/pkg/utils/chk"
"orly.dev/pkg/utils/context"
"orly.dev/pkg/utils/units"
)
// Export the complete database of stored events to an io.Writer in line structured minified
// JSON.
func (d *D) Export(c context.T, w io.Writer, pubkeys ...[]byte) {
var err error
if len(pubkeys) == 0 {
if err = d.View(
func(txn *badger.Txn) (err error) {
buf := codecbuf.Get()
defer codecbuf.Put(buf)
if err = indexes.EventEnc(nil).MarshalWrite(buf); chk.E(err) {
return
}
it := txn.NewIterator(badger.IteratorOptions{Prefix: buf.Bytes()})
evB := make([]byte, 0, units.Mb)
defer it.Close()
for it.Rewind(); it.Valid(); it.Next() {
item := it.Item()
if evB, err = item.ValueCopy(evB); chk.E(err) {
continue
}
evBuf := bytes.NewBuffer(evB)
ev := event.New()
if err = ev.UnmarshalBinary(evBuf); chk.E(err) {
continue
}
if _, err = evBuf.WriteTo(w); chk.E(err) {
continue
}
if _, err = w.Write([]byte{'\n'}); chk.E(err) {
continue
}
}
return
},
); err != nil {
return
}
} else {
for _, pubkey := range pubkeys {
if err = d.View(
func(txn *badger.Txn) (err error) {
pkBuf := codecbuf.Get()
defer codecbuf.Put(pkBuf)
ph := &types.PubHash{}
if err = ph.FromPubkey(pubkey); chk.E(err) {
return
}
if err = indexes.PubkeyEnc(
ph, nil, nil,
).MarshalWrite(pkBuf); chk.E(err) {
return
}
it := txn.NewIterator(badger.IteratorOptions{Prefix: pkBuf.Bytes()})
evB := make([]byte, 0, units.Mb)
defer it.Close()
for it.Rewind(); it.Valid(); it.Next() {
item := it.Item()
if evB, err = item.ValueCopy(evB); chk.E(err) {
continue
}
evBuf := bytes.NewBuffer(evB)
ev := event.New()
if err = ev.UnmarshalBinary(evBuf); chk.E(err) {
continue
}
if _, err = evBuf.WriteTo(w); chk.E(err) {
continue
}
if _, err = w.Write([]byte{'\n'}); chk.E(err) {
continue
}
}
return
},
); err != nil {
return
}
}
}
return
}

View File

@@ -32,7 +32,7 @@ func (d *D) QueryEvents(c context.T, f *filter.F) (evs event.S, err error) {
}
// fetch the events
var ev *event.E
if ev, err = d.FetchEventBySerial(ser); chk.E(err) {
if ev, err = d.FetchEventBySerial(ser); err != nil {
continue
}
evs = append(evs, ev)

View File

@@ -145,7 +145,7 @@ func (en *Result) Unmarshal(b []byte) (r []byte, err error) {
return
}
en.Event = event.New()
if r, err = en.Event.Unmarshal(r); chk.E(err) {
if r, err = en.Event.Unmarshal(r); err != nil {
return
}
if r, err = envelopes.SkipToTheEnd(r); chk.E(err) {
@@ -158,7 +158,7 @@ func (en *Result) Unmarshal(b []byte) (r []byte, err error) {
// envelope into it.
func ParseResult(b []byte) (t *Result, rem []byte, err error) {
t = NewResult()
if rem, err = t.Unmarshal(b); chk.T(err) {
if rem, err = t.Unmarshal(b); err != nil {
return
}
return

View File

@@ -2,6 +2,7 @@ package event
import (
"bytes"
"fmt"
"github.com/minio/sha256-simd"
"io"
"orly.dev/pkg/crypto/ec/schnorr"
@@ -300,7 +301,7 @@ AfterClose:
}
return
invalid:
err = errorf.E(
err = fmt.Errorf(
"invalid key,\n'%s'\n'%s'\n'%s'", string(b), string(b[:len(r)]),
string(r),
)

File diff suppressed because one or more lines are too long

View File

@@ -78,8 +78,7 @@ func (a *A) HandleReq(
continue
}
}
if events, err = sto.QueryEvents(c, f); chk.E(err) {
log.E.F("eventstore: %v", err)
if events, err = sto.QueryEvents(c, f); err != nil {
if errors.Is(err, badger.ErrDBClosed) {
return
}
@@ -91,8 +90,9 @@ func (a *A) HandleReq(
for _, ev := range events {
if !auth.CheckPrivilege(a.Listener.AuthedPubkey(), ev) {
log.W.F(
"not privileged %0x ev pubkey %0x",
a.Listener.AuthedPubkey(), ev.Pubkey,
"not privileged %0x ev pubkey %0x kind %s privileged: %v",
a.Listener.AuthedPubkey(), ev.Pubkey, ev.Kind.Name(),
ev.Kind.IsPrivileged(),
)
continue
}

View File

@@ -132,17 +132,22 @@ func (p *S) Deliver(ev *event.E) {
p.Mx.Lock()
defer p.Mx.Unlock()
for w, subs := range p.Map {
log.I.F("%v %s", subs, w.RealRemote())
// log.I.F("%v %s", subs, w.RealRemote())
for id, subscriber := range subs {
log.T.F(
"subscriber %s\n%s", w.RealRemote(),
subscriber.Marshal(nil),
)
// log.T.F(
// "subscriber %s\n%s", w.RealRemote(),
// subscriber.Marshal(nil),
// )
if !subscriber.Match(ev) {
continue
}
if p.Server.AuthRequired() {
if !auth.CheckPrivilege(w.AuthedPubkey(), ev) {
log.W.F(
"not privileged %0x ev pubkey %0x kind %s privileged: %v",
w.AuthedPubkey(), ev.Pubkey, ev.Kind.Name(),
ev.Kind.IsPrivileged(),
)
continue
}
var res *eventenvelope.Result

View File

@@ -234,9 +234,10 @@ func (r *Client) ConnectWithTLS(ctx context.T, tlsConfig *tls.Config) error {
// general message reader loop
go func() {
buf := new(bytes.Buffer)
var err error
for {
buf.Reset()
if err := conn.ReadMessage(r.connectionContext, buf); chk.T(err) {
if err = conn.ReadMessage(r.connectionContext, buf); err != nil {
r.ConnectionError = err
r.Close()
break
@@ -270,10 +271,12 @@ func (r *Client) ConnectWithTLS(ctx context.T, tlsConfig *tls.Config) error {
}
r.challenge = env.Challenge
case eventenvelope.L:
// log.I.F("message: %s", message)
env := eventenvelope.NewResult()
if env, message, err = eventenvelope.ParseResult(message); chk.E(err) {
if env, message, err = eventenvelope.ParseResult(message); err != nil {
continue
}
// log.I.F("%s", env.Event.Marshal(nil))
if len(env.Subscription.T) == 0 {
continue
}
@@ -497,12 +500,13 @@ func (r *Client) PrepareSubscription(
return sub
}
// QuerySync is only used in tests. The realy query method is synchronous now
// QuerySync is only used in tests. The relay query method is synchronous now
// anyway (it ensures sort order is respected).
func (r *Client) QuerySync(
ctx context.T, f *filter.F,
opts ...SubscriptionOption,
) ([]*event.E, error) {
// log.T.F("QuerySync:\n%s", f.Marshal(nil))
sub, err := r.Subscribe(ctx, filters.New(f), opts...)
if err != nil {
return nil, err

View File

@@ -4,6 +4,7 @@ import (
"bytes"
"compress/flate"
"crypto/tls"
"fmt"
"github.com/gobwas/httphead"
"github.com/gobwas/ws"
"github.com/gobwas/ws/wsflate"
@@ -170,7 +171,7 @@ func (cn *Connection) ReadMessage(c context.T, buf io.Writer) (err error) {
h, err := cn.reader.NextFrame()
if err != nil {
cn.conn.Close()
return errorf.E(
return fmt.Errorf(
"%s failed to advance frame: %s",
cn.conn.RemoteAddr(),
err.Error(),