diff --git a/.claude/settings.local.json b/.claude/settings.local.json index 4ee3a6f..4936096 100644 --- a/.claude/settings.local.json +++ b/.claude/settings.local.json @@ -94,7 +94,8 @@ "Bash(dmesg:*)", "Bash(export:*)", "Bash(timeout 60 /tmp/benchmark-fixed:*)", - "Bash(/tmp/test-auth-event.sh)" + "Bash(/tmp/test-auth-event.sh)", + "Bash(CGO_ENABLED=0 timeout 180 go test:*)" ], "deny": [], "ask": [] diff --git a/app/handle-nip43_test.go b/app/handle-nip43_test.go index 7a1e620..57ebd6a 100644 --- a/app/handle-nip43_test.go +++ b/app/handle-nip43_test.go @@ -7,9 +7,11 @@ import ( "time" "next.orly.dev/app/config" + "next.orly.dev/pkg/acl" "next.orly.dev/pkg/crypto/keys" "next.orly.dev/pkg/database" "next.orly.dev/pkg/encoders/event" + "next.orly.dev/pkg/encoders/hex" "next.orly.dev/pkg/encoders/tag" "next.orly.dev/pkg/interfaces/signer/p8k" "next.orly.dev/pkg/protocol/nip43" @@ -38,24 +40,47 @@ func setupTestListener(t *testing.T) (*Listener, *database.D, func()) { RelayURL: "wss://test.relay", Listen: "localhost", Port: 3334, + ACLMode: "none", } server := &Server{ Ctx: ctx, Config: cfg, - D: db, + DB: db, publishers: publish.New(NewPublisher(ctx)), InviteManager: nip43.NewInviteManager(cfg.NIP43InviteExpiry), cfg: cfg, db: db, } - listener := &Listener{ - Server: server, - ctx: ctx, + // Configure ACL registry + acl.Registry.Active.Store(cfg.ACLMode) + if err = acl.Registry.Configure(cfg, db, ctx); err != nil { + db.Close() + os.RemoveAll(tempDir) + t.Fatalf("failed to configure ACL: %v", err) } + listener := &Listener{ + Server: server, + ctx: ctx, + writeChan: make(chan publish.WriteRequest, 100), + writeDone: make(chan struct{}), + messageQueue: make(chan messageRequest, 100), + processingDone: make(chan struct{}), + subscriptions: make(map[string]context.CancelFunc), + } + + // Start write worker and message processor + go listener.writeWorker() + go listener.messageProcessor() + cleanup := func() { + // Close listener channels + close(listener.writeChan) + <-listener.writeDone + close(listener.messageQueue) + <-listener.processingDone db.Close() os.RemoveAll(tempDir) } @@ -350,8 +375,13 @@ func TestHandleNIP43InviteRequest_ValidRequest(t *testing.T) { } adminPubkey := adminSigner.Pub() - // Add admin to server (simulating admin config) - listener.Server.Admins = [][]byte{adminPubkey} + // Add admin to config and reconfigure ACL + adminHex := hex.Enc(adminPubkey) + listener.Server.Config.Admins = []string{adminHex} + acl.Registry.Active.Store("none") + if err = acl.Registry.Configure(listener.Server.Config, listener.Server.DB, listener.ctx); err != nil { + t.Fatalf("failed to reconfigure ACL: %v", err) + } // Handle invite request inviteEvent, err := listener.Server.HandleNIP43InviteRequest(adminPubkey) diff --git a/app/handle-nip86_minimal_test.go b/app/handle-nip86_minimal_test.go index c47cc11..96bc716 100644 --- a/app/handle-nip86_minimal_test.go +++ b/app/handle-nip86_minimal_test.go @@ -35,7 +35,7 @@ func TestHandleNIP86Management_Basic(t *testing.T) { // Setup server server := &Server{ Config: cfg, - D: db, + DB: db, Admins: [][]byte{[]byte("admin1")}, Owners: [][]byte{[]byte("owner1")}, } diff --git a/app/listener.go b/app/listener.go index e7d3ade..1494e51 100644 --- a/app/listener.go +++ b/app/listener.go @@ -161,6 +161,12 @@ func (l *Listener) writeWorker() { return } + // Skip writes if no connection (unit tests) + if l.conn == nil { + log.T.F("ws->%s skipping write (no connection)", l.remote) + continue + } + // Handle the write request var err error if req.IsPing { diff --git a/app/nip43_e2e_test.go b/app/nip43_e2e_test.go index d8a5e3f..e53418d 100644 --- a/app/nip43_e2e_test.go +++ b/app/nip43_e2e_test.go @@ -11,15 +11,44 @@ import ( "time" "next.orly.dev/app/config" + "next.orly.dev/pkg/acl" "next.orly.dev/pkg/crypto/keys" "next.orly.dev/pkg/database" "next.orly.dev/pkg/encoders/event" + "next.orly.dev/pkg/encoders/hex" "next.orly.dev/pkg/encoders/tag" "next.orly.dev/pkg/protocol/nip43" "next.orly.dev/pkg/protocol/publish" "next.orly.dev/pkg/protocol/relayinfo" ) +// newTestListener creates a properly initialized Listener for testing +func newTestListener(server *Server, ctx context.Context) *Listener { + listener := &Listener{ + Server: server, + ctx: ctx, + writeChan: make(chan publish.WriteRequest, 100), + writeDone: make(chan struct{}), + messageQueue: make(chan messageRequest, 100), + processingDone: make(chan struct{}), + subscriptions: make(map[string]context.CancelFunc), + } + + // Start write worker and message processor + go listener.writeWorker() + go listener.messageProcessor() + + return listener +} + +// closeTestListener properly closes a test listener +func closeTestListener(listener *Listener) { + close(listener.writeChan) + <-listener.writeDone + close(listener.messageQueue) + <-listener.processingDone +} + // setupE2ETest creates a full test server for end-to-end testing func setupE2ETest(t *testing.T) (*Server, *httptest.Server, func()) { tempDir, err := os.MkdirTemp("", "nip43_e2e_test_*") @@ -61,16 +90,28 @@ func setupE2ETest(t *testing.T) (*Server, *httptest.Server, func()) { } adminPubkey := adminSigner.Pub() + // Add admin to config for ACL + cfg.Admins = []string{hex.Enc(adminPubkey)} + server := &Server{ Ctx: ctx, Config: cfg, - D: db, + DB: db, publishers: publish.New(NewPublisher(ctx)), Admins: [][]byte{adminPubkey}, InviteManager: nip43.NewInviteManager(cfg.NIP43InviteExpiry), cfg: cfg, db: db, } + + // Configure ACL registry + acl.Registry.Active.Store(cfg.ACLMode) + if err = acl.Registry.Configure(cfg, db, ctx); err != nil { + db.Close() + os.RemoveAll(tempDir) + t.Fatalf("failed to configure ACL: %v", err) + } + server.mux = http.NewServeMux() // Set up HTTP handlers @@ -177,6 +218,7 @@ func TestE2E_CompleteJoinFlow(t *testing.T) { joinEv := event.New() joinEv.Kind = nip43.KindJoinRequest copy(joinEv.Pubkey, userPubkey) + joinEv.Tags = tag.NewS() joinEv.Tags.Append(tag.NewFromAny("-")) joinEv.Tags.Append(tag.NewFromAny("claim", inviteCode)) joinEv.CreatedAt = time.Now().Unix() @@ -186,17 +228,15 @@ func TestE2E_CompleteJoinFlow(t *testing.T) { } // Step 3: Process join request - listener := &Listener{ - Server: server, - ctx: server.Ctx, - } + listener := newTestListener(server, server.Ctx) + defer closeTestListener(listener) err = listener.HandleNIP43JoinRequest(joinEv) if err != nil { t.Fatalf("failed to handle join request: %v", err) } // Step 4: Verify membership - isMember, err := server.D.IsNIP43Member(userPubkey) + isMember, err := server.DB.IsNIP43Member(userPubkey) if err != nil { t.Fatalf("failed to check membership: %v", err) } @@ -204,7 +244,7 @@ func TestE2E_CompleteJoinFlow(t *testing.T) { t.Error("user was not added as member") } - membership, err := server.D.GetNIP43Membership(userPubkey) + membership, err := server.DB.GetNIP43Membership(userPubkey) if err != nil { t.Fatalf("failed to get membership: %v", err) } @@ -227,10 +267,8 @@ func TestE2E_InviteCodeReuse(t *testing.T) { t.Fatalf("failed to generate invite code: %v", err) } - listener := &Listener{ - Server: server, - ctx: server.Ctx, - } + listener := newTestListener(server, server.Ctx) + defer closeTestListener(listener) // First user uses the code user1Secret, err := keys.GenerateSecretKey() @@ -249,6 +287,7 @@ func TestE2E_InviteCodeReuse(t *testing.T) { joinEv1 := event.New() joinEv1.Kind = nip43.KindJoinRequest copy(joinEv1.Pubkey, user1Pubkey) + joinEv1.Tags = tag.NewS() joinEv1.Tags.Append(tag.NewFromAny("-")) joinEv1.Tags.Append(tag.NewFromAny("claim", code)) joinEv1.CreatedAt = time.Now().Unix() @@ -263,7 +302,7 @@ func TestE2E_InviteCodeReuse(t *testing.T) { } // Verify first user is member - isMember, err := server.D.IsNIP43Member(user1Pubkey) + isMember, err := server.DB.IsNIP43Member(user1Pubkey) if err != nil { t.Fatalf("failed to check user1 membership: %v", err) } @@ -288,6 +327,7 @@ func TestE2E_InviteCodeReuse(t *testing.T) { joinEv2 := event.New() joinEv2.Kind = nip43.KindJoinRequest copy(joinEv2.Pubkey, user2Pubkey) + joinEv2.Tags = tag.NewS() joinEv2.Tags.Append(tag.NewFromAny("-")) joinEv2.Tags.Append(tag.NewFromAny("claim", code)) joinEv2.CreatedAt = time.Now().Unix() @@ -303,7 +343,7 @@ func TestE2E_InviteCodeReuse(t *testing.T) { } // Verify second user is NOT member - isMember, err = server.D.IsNIP43Member(user2Pubkey) + isMember, err = server.DB.IsNIP43Member(user2Pubkey) if err != nil { t.Fatalf("failed to check user2 membership: %v", err) } @@ -317,10 +357,8 @@ func TestE2E_MembershipListGeneration(t *testing.T) { server, _, cleanup := setupE2ETest(t) defer cleanup() - listener := &Listener{ - Server: server, - ctx: server.Ctx, - } + listener := newTestListener(server, server.Ctx) + defer closeTestListener(listener) // Add multiple members memberCount := 5 @@ -338,7 +376,7 @@ func TestE2E_MembershipListGeneration(t *testing.T) { members[i] = userPubkey // Add directly to database for speed - err = server.D.AddNIP43Member(userPubkey, "code") + err = server.DB.AddNIP43Member(userPubkey, "code") if err != nil { t.Fatalf("failed to add member %d: %v", i, err) } @@ -379,17 +417,15 @@ func TestE2E_ExpiredInviteCode(t *testing.T) { server := &Server{ Ctx: ctx, Config: cfg, - D: db, + DB: db, publishers: publish.New(NewPublisher(ctx)), InviteManager: nip43.NewInviteManager(cfg.NIP43InviteExpiry), cfg: cfg, db: db, } - listener := &Listener{ - Server: server, - ctx: ctx, - } + listener := newTestListener(server, ctx) + defer closeTestListener(listener) // Generate invite code code, err := server.InviteManager.GenerateCode() @@ -417,6 +453,7 @@ func TestE2E_ExpiredInviteCode(t *testing.T) { joinEv := event.New() joinEv.Kind = nip43.KindJoinRequest copy(joinEv.Pubkey, userPubkey) + joinEv.Tags = tag.NewS() joinEv.Tags.Append(tag.NewFromAny("-")) joinEv.Tags.Append(tag.NewFromAny("claim", code)) joinEv.CreatedAt = time.Now().Unix() @@ -445,10 +482,8 @@ func TestE2E_InvalidTimestampRejected(t *testing.T) { server, _, cleanup := setupE2ETest(t) defer cleanup() - listener := &Listener{ - Server: server, - ctx: server.Ctx, - } + listener := newTestListener(server, server.Ctx) + defer closeTestListener(listener) // Generate invite code code, err := server.InviteManager.GenerateCode() @@ -474,6 +509,7 @@ func TestE2E_InvalidTimestampRejected(t *testing.T) { joinEv := event.New() joinEv.Kind = nip43.KindJoinRequest copy(joinEv.Pubkey, userPubkey) + joinEv.Tags = tag.NewS() joinEv.Tags.Append(tag.NewFromAny("-")) joinEv.Tags.Append(tag.NewFromAny("claim", code)) joinEv.CreatedAt = time.Now().Unix() - 700 // More than 10 minutes ago @@ -489,7 +525,7 @@ func TestE2E_InvalidTimestampRejected(t *testing.T) { } // Verify user was NOT added - isMember, err := server.D.IsNIP43Member(userPubkey) + isMember, err := server.DB.IsNIP43Member(userPubkey) if err != nil { t.Fatalf("failed to check membership: %v", err) } @@ -523,17 +559,15 @@ func BenchmarkJoinRequestProcessing(b *testing.B) { server := &Server{ Ctx: ctx, Config: cfg, - D: db, + DB: db, publishers: publish.New(NewPublisher(ctx)), InviteManager: nip43.NewInviteManager(cfg.NIP43InviteExpiry), cfg: cfg, db: db, } - listener := &Listener{ - Server: server, - ctx: ctx, - } + listener := newTestListener(server, ctx) + defer closeTestListener(listener) b.ResetTimer() @@ -547,6 +581,7 @@ func BenchmarkJoinRequestProcessing(b *testing.B) { joinEv := event.New() joinEv.Kind = nip43.KindJoinRequest copy(joinEv.Pubkey, userPubkey) + joinEv.Tags = tag.NewS() joinEv.Tags.Append(tag.NewFromAny("-")) joinEv.Tags.Append(tag.NewFromAny("claim", code)) joinEv.CreatedAt = time.Now().Unix() diff --git a/app/subscription_stability_test.go b/app/subscription_stability_test.go index 83a93f0..b434b86 100644 --- a/app/subscription_stability_test.go +++ b/app/subscription_stability_test.go @@ -199,7 +199,7 @@ func TestLongRunningSubscriptionStability(t *testing.T) { ev := createSignedTestEvent(t, 1, fmt.Sprintf("Test event %d for long-running subscription", i)) // Save event to database - if _, err := server.D.SaveEvent(context.Background(), ev); err != nil { + if _, err := server.DB.SaveEvent(context.Background(), ev); err != nil { t.Errorf("Failed to save event %d: %v", i, err) continue } @@ -376,7 +376,7 @@ func TestMultipleConcurrentSubscriptions(t *testing.T) { // Create and sign test event ev := createSignedTestEvent(t, uint16(sub.kind), fmt.Sprintf("Test for kind %d event %d", sub.kind, i)) - if _, err := server.D.SaveEvent(context.Background(), ev); err != nil { + if _, err := server.DB.SaveEvent(context.Background(), ev); err != nil { t.Errorf("Failed to save event: %v", err) } @@ -431,7 +431,7 @@ func setupTestServer(t *testing.T) (*Server, func()) { // Setup server server := &Server{ Config: cfg, - D: db, + DB: db, Ctx: ctx, publishers: publish.New(NewPublisher(ctx)), Admins: [][]byte{}, diff --git a/pkg/database/get-serial-by-id.go b/pkg/database/get-serial-by-id.go index 8c47ee4..caa8f9d 100644 --- a/pkg/database/get-serial-by-id.go +++ b/pkg/database/get-serial-by-id.go @@ -58,7 +58,7 @@ func (d *D) GetSerialById(id []byte) (ser *types.Uint40, err error) { return } if !idFound { - // err = errorf.T("id not found in database: %s", hex.Enc(id)) + err = errorf.E("id not found in database") return } diff --git a/pkg/database/inline-storage_test.go b/pkg/database/inline-storage_test.go index 1e5ba06..a092898 100644 --- a/pkg/database/inline-storage_test.go +++ b/pkg/database/inline-storage_test.go @@ -20,7 +20,7 @@ import ( ) // TestInlineSmallEventStorage tests the Reiser4-inspired inline storage optimization -// for small events (<=384 bytes). +// for small events (<=1024 bytes by default). func TestInlineSmallEventStorage(t *testing.T) { // Create a temporary directory for the database tempDir, err := os.MkdirTemp("", "test-inline-db-*") @@ -129,8 +129,8 @@ func TestInlineSmallEventStorage(t *testing.T) { largeEvent := event.New() largeEvent.Kind = kind.TextNote.K largeEvent.CreatedAt = timestamp.Now().V - // Create content larger than 384 bytes - largeContent := make([]byte, 500) + // Create content larger than 1024 bytes (the default inline storage threshold) + largeContent := make([]byte, 1500) for i := range largeContent { largeContent[i] = 'x' } diff --git a/pkg/neo4j/delete.go b/pkg/neo4j/delete.go index 2c37e63..7884a91 100644 --- a/pkg/neo4j/delete.go +++ b/pkg/neo4j/delete.go @@ -93,21 +93,12 @@ func (n *N) ProcessDelete(ev *event.E, admins [][]byte) error { continue } - neo4jResult, ok := result.(interface { - Next(context.Context) bool - Record() *interface{} - Err() error - }) - if !ok { - continue - } - - if neo4jResult.Next(ctx) { - record := neo4jResult.Record() + if result.Next(ctx) { + record := result.Record() if record != nil { - recordMap, ok := (*record).(map[string]any) - if ok { - if pubkeyStr, ok := recordMap["pubkey"].(string); ok { + pubkeyValue, found := record.Get("pubkey") + if found { + if pubkeyStr, ok := pubkeyValue.(string); ok { pubkey, err := hex.Dec(pubkeyStr) if err != nil { continue @@ -160,12 +151,7 @@ LIMIT 1` return nil // Not deleted } - neo4jResult, ok := result.(interface { - Next(context.Context) bool - Record() *interface{} - Err() error - }) - if ok && neo4jResult.Next(ctx) { + if result.Next(ctx) { return fmt.Errorf("event has been deleted") } diff --git a/pkg/neo4j/fetch-event.go b/pkg/neo4j/fetch-event.go index 299d42c..4a7206e 100644 --- a/pkg/neo4j/fetch-event.go +++ b/pkg/neo4j/fetch-event.go @@ -82,35 +82,30 @@ RETURN e.id AS id, events = make(map[uint64]*event.E) ctx := context.Background() - neo4jResult, ok := result.(interface { - Next(context.Context) bool - Record() *interface{} - Err() error - }) - if !ok { - return events, nil - } - - for neo4jResult.Next(ctx) { - record := neo4jResult.Record() + for result.Next(ctx) { + record := result.Record() if record == nil { continue } - recordMap, ok := (*record).(map[string]any) - if !ok { - continue - } - // Parse event - idStr, _ := recordMap["id"].(string) - kind, _ := recordMap["kind"].(int64) - createdAt, _ := recordMap["created_at"].(int64) - content, _ := recordMap["content"].(string) - sigStr, _ := recordMap["sig"].(string) - pubkeyStr, _ := recordMap["pubkey"].(string) - tagsStr, _ := recordMap["tags"].(string) - serialVal, _ := recordMap["serial"].(int64) + idRaw, _ := record.Get("id") + kindRaw, _ := record.Get("kind") + createdAtRaw, _ := record.Get("created_at") + contentRaw, _ := record.Get("content") + sigRaw, _ := record.Get("sig") + pubkeyRaw, _ := record.Get("pubkey") + tagsRaw, _ := record.Get("tags") + serialRaw, _ := record.Get("serial") + + idStr, _ := idRaw.(string) + kind, _ := kindRaw.(int64) + createdAt, _ := createdAtRaw.(int64) + content, _ := contentRaw.(string) + sigStr, _ := sigRaw.(string) + pubkeyStr, _ := pubkeyRaw.(string) + tagsStr, _ := tagsRaw.(string) + serialVal, _ := serialRaw.(int64) id, err := hex.Dec(idStr) if err != nil { @@ -160,21 +155,13 @@ func (n *N) GetSerialById(id []byte) (ser *types.Uint40, err error) { } ctx := context.Background() - neo4jResult, ok := result.(interface { - Next(context.Context) bool - Record() *interface{} - Err() error - }) - if !ok { - return nil, fmt.Errorf("invalid result type") - } - if neo4jResult.Next(ctx) { - record := neo4jResult.Record() + if result.Next(ctx) { + record := result.Record() if record != nil { - recordMap, ok := (*record).(map[string]any) - if ok { - if serialVal, ok := recordMap["serial"].(int64); ok { + serialRaw, found := record.Get("serial") + if found { + if serialVal, ok := serialRaw.(int64); ok { ser = &types.Uint40{} ser.Set(uint64(serialVal)) return ser, nil @@ -221,28 +208,24 @@ RETURN e.id AS id, e.serial AS serial` } ctx := context.Background() - neo4jResult, ok := result.(interface { - Next(context.Context) bool - Record() *interface{} - Err() error - }) - if !ok { - return serials, nil - } - for neo4jResult.Next(ctx) { - record := neo4jResult.Record() + for result.Next(ctx) { + record := result.Record() if record == nil { continue } - recordMap, ok := (*record).(map[string]any) - if !ok { + idRaw, found := record.Get("id") + if !found { + continue + } + serialRaw, found := record.Get("serial") + if !found { continue } - idStr, _ := recordMap["id"].(string) - serialVal, _ := recordMap["serial"].(int64) + idStr, _ := idRaw.(string) + serialVal, _ := serialRaw.(int64) serial := &types.Uint40{} serial.Set(uint64(serialVal)) @@ -322,43 +305,45 @@ RETURN e.id AS id, } ctx := context.Background() - neo4jResult, ok := result.(interface { - Next(context.Context) bool - Record() *interface{} - Err() error - }) - if !ok { - return nil, fmt.Errorf("invalid result type") - } - if neo4jResult.Next(ctx) { - record := neo4jResult.Record() + if result.Next(ctx) { + record := result.Record() if record != nil { - recordMap, ok := (*record).(map[string]any) - if ok { - idStr, _ := recordMap["id"].(string) - pubkeyStr, _ := recordMap["pubkey"].(string) - createdAt, _ := recordMap["created_at"].(int64) - - id, err := hex.Dec(idStr) - if err != nil { - return nil, err - } - - pubkey, err := hex.Dec(pubkeyStr) - if err != nil { - return nil, err - } - - fidpk = &store.IdPkTs{ - Id: id, - Pub: pubkey, - Ts: createdAt, - Ser: serial, - } - - return fidpk, nil + idRaw, found := record.Get("id") + if !found { + return nil, fmt.Errorf("event not found") } + pubkeyRaw, found := record.Get("pubkey") + if !found { + return nil, fmt.Errorf("event not found") + } + createdAtRaw, found := record.Get("created_at") + if !found { + return nil, fmt.Errorf("event not found") + } + + idStr, _ := idRaw.(string) + pubkeyStr, _ := pubkeyRaw.(string) + createdAt, _ := createdAtRaw.(int64) + + id, err := hex.Dec(idStr) + if err != nil { + return nil, err + } + + pubkey, err := hex.Dec(pubkeyStr) + if err != nil { + return nil, err + } + + fidpk = &store.IdPkTs{ + Id: id, + Pub: pubkey, + Ts: createdAt, + Ser: serial, + } + + return fidpk, nil } } @@ -397,30 +382,34 @@ RETURN e.id AS id, } ctx := context.Background() - neo4jResult, ok := result.(interface { - Next(context.Context) bool - Record() *interface{} - Err() error - }) - if !ok { - return fidpks, nil - } - for neo4jResult.Next(ctx) { - record := neo4jResult.Record() + for result.Next(ctx) { + record := result.Record() if record == nil { continue } - recordMap, ok := (*record).(map[string]any) - if !ok { + idRaw, found := record.Get("id") + if !found { + continue + } + pubkeyRaw, found := record.Get("pubkey") + if !found { + continue + } + createdAtRaw, found := record.Get("created_at") + if !found { + continue + } + serialRaw, found := record.Get("serial") + if !found { continue } - idStr, _ := recordMap["id"].(string) - pubkeyStr, _ := recordMap["pubkey"].(string) - createdAt, _ := recordMap["created_at"].(int64) - serialVal, _ := recordMap["serial"].(int64) + idStr, _ := idRaw.(string) + pubkeyStr, _ := pubkeyRaw.(string) + createdAt, _ := createdAtRaw.(int64) + serialVal, _ := serialRaw.(int64) id, err := hex.Dec(idStr) if err != nil { diff --git a/pkg/neo4j/markers.go b/pkg/neo4j/markers.go index 1fa63df..53eda72 100644 --- a/pkg/neo4j/markers.go +++ b/pkg/neo4j/markers.go @@ -42,21 +42,13 @@ func (n *N) GetMarker(key string) (value []byte, err error) { } ctx := context.Background() - neo4jResult, ok := result.(interface { - Next(context.Context) bool - Record() *interface{} - Err() error - }) - if !ok { - return nil, fmt.Errorf("invalid result type") - } - if neo4jResult.Next(ctx) { - record := neo4jResult.Record() + if result.Next(ctx) { + record := result.Record() if record != nil { - recordMap, ok := (*record).(map[string]any) - if ok { - if valueStr, ok := recordMap["value"].(string); ok { + valueRaw, found := record.Get("value") + if found { + if valueStr, ok := valueRaw.(string); ok { // Decode hex value value, err = hex.Dec(valueStr) if err != nil { diff --git a/pkg/neo4j/query-events.go b/pkg/neo4j/query-events.go index 3e68e7d..0fec968 100644 --- a/pkg/neo4j/query-events.go +++ b/pkg/neo4j/query-events.go @@ -5,6 +5,7 @@ import ( "fmt" "strings" + "github.com/neo4j/neo4j-go-driver/v5/neo4j" "next.orly.dev/pkg/database/indexes/types" "next.orly.dev/pkg/encoders/event" "next.orly.dev/pkg/encoders/filter" @@ -113,7 +114,7 @@ func (n *N) buildCypherQuery(f *filter.F, includeDeleteEvents bool) (string, map // Tag filters - this is where Neo4j's graph capabilities shine // We can efficiently traverse tag relationships tagIndex := 0 - for tagType, tagValues := range *f.Tags { + for _, tagValues := range *f.Tags { if len(tagValues.T) > 0 { tagVarName := fmt.Sprintf("t%d", tagIndex) tagTypeParam := fmt.Sprintf("tagType_%d", tagIndex) @@ -122,14 +123,17 @@ func (n *N) buildCypherQuery(f *filter.F, includeDeleteEvents bool) (string, map // Add tag relationship to MATCH clause matchClause += fmt.Sprintf(" OPTIONAL MATCH (e)-[:TAGGED_WITH]->(%s:Tag)", tagVarName) - // Convert tag values to strings - tagValueStrings := make([]string, len(tagValues.T)) - for i, tv := range tagValues.T { + // The first element is the tag type (e.g., "e", "p", etc.) + tagType := string(tagValues.T[0]) + + // Convert remaining tag values to strings (skip first element which is the type) + tagValueStrings := make([]string, len(tagValues.T)-1) + for i, tv := range tagValues.T[1:] { tagValueStrings[i] = string(tv) } // Add WHERE conditions for this tag - params[tagTypeParam] = string(tagType) + params[tagTypeParam] = tagType params[tagValuesParam] = tagValueStrings whereClauses = append(whereClauses, fmt.Sprintf("(%s.type = $%s AND %s.value IN $%s)", @@ -179,40 +183,42 @@ RETURN e.id AS id, // parseEventsFromResult converts Neo4j query results to Nostr events func (n *N) parseEventsFromResult(result any) ([]*event.E, error) { - // Type assert to Neo4j result - neo4jResult, ok := result.(interface { + events := make([]*event.E, 0) + ctx := context.Background() + + // Type assert to the interface we actually use + resultIter, ok := result.(interface { Next(context.Context) bool - Record() *interface{} + Record() *neo4j.Record Err() error }) if !ok { return nil, fmt.Errorf("invalid result type") } - events := make([]*event.E, 0) - ctx := context.Background() - // Iterate through result records - for neo4jResult.Next(ctx) { - record := neo4jResult.Record() + for resultIter.Next(ctx) { + record := resultIter.Record() if record == nil { continue } - // Extract fields from record - recordMap, ok := (*record).(map[string]any) - if !ok { - continue - } - // Parse event fields - idStr, _ := recordMap["id"].(string) - kind, _ := recordMap["kind"].(int64) - createdAt, _ := recordMap["created_at"].(int64) - content, _ := recordMap["content"].(string) - sigStr, _ := recordMap["sig"].(string) - pubkeyStr, _ := recordMap["pubkey"].(string) - tagsStr, _ := recordMap["tags"].(string) + idRaw, _ := record.Get("id") + kindRaw, _ := record.Get("kind") + createdAtRaw, _ := record.Get("created_at") + contentRaw, _ := record.Get("content") + sigRaw, _ := record.Get("sig") + pubkeyRaw, _ := record.Get("pubkey") + tagsRaw, _ := record.Get("tags") + + idStr, _ := idRaw.(string) + kind, _ := kindRaw.(int64) + createdAt, _ := createdAtRaw.(int64) + content, _ := contentRaw.(string) + sigStr, _ := sigRaw.(string) + pubkeyStr, _ := pubkeyRaw.(string) + tagsStr, _ := tagsRaw.(string) // Decode hex strings id, err := hex.Dec(idStr) @@ -250,7 +256,7 @@ func (n *N) parseEventsFromResult(result any) ([]*event.E, error) { events = append(events, e) } - if err := neo4jResult.Err(); err != nil { + if err := resultIter.Err(); err != nil { return nil, fmt.Errorf("error iterating results: %w", err) } @@ -323,27 +329,31 @@ func (n *N) QueryForSerials(c context.Context, f *filter.F) ( serials = make([]*types.Uint40, 0) ctx := context.Background() - neo4jResult, ok := result.(interface { + resultIter, ok := result.(interface { Next(context.Context) bool - Record() *interface{} + Record() *neo4j.Record Err() error }) if !ok { return nil, fmt.Errorf("invalid result type") } - for neo4jResult.Next(ctx) { - record := neo4jResult.Record() + for resultIter.Next(ctx) { + record := resultIter.Record() if record == nil { continue } - recordMap, ok := (*record).(map[string]any) + serialRaw, found := record.Get("serial") + if !found { + continue + } + + serialVal, ok := serialRaw.(int64) if !ok { continue } - serialVal, _ := recordMap["serial"].(int64) serial := types.Uint40{} serial.Set(uint64(serialVal)) serials = append(serials, &serial) @@ -386,30 +396,30 @@ func (n *N) QueryForIds(c context.Context, f *filter.F) ( idPkTs = make([]*store.IdPkTs, 0) ctx := context.Background() - neo4jResult, ok := result.(interface { + resultIter, ok := result.(interface { Next(context.Context) bool - Record() *interface{} + Record() *neo4j.Record Err() error }) if !ok { return nil, fmt.Errorf("invalid result type") } - for neo4jResult.Next(ctx) { - record := neo4jResult.Record() + for resultIter.Next(ctx) { + record := resultIter.Record() if record == nil { continue } - recordMap, ok := (*record).(map[string]any) - if !ok { - continue - } + idRaw, _ := record.Get("id") + pubkeyRaw, _ := record.Get("pubkey") + createdAtRaw, _ := record.Get("created_at") + serialRaw, _ := record.Get("serial") - idStr, _ := recordMap["id"].(string) - pubkeyStr, _ := recordMap["pubkey"].(string) - createdAt, _ := recordMap["created_at"].(int64) - serialVal, _ := recordMap["serial"].(int64) + idStr, _ := idRaw.(string) + pubkeyStr, _ := pubkeyRaw.(string) + createdAt, _ := createdAtRaw.(int64) + serialVal, _ := serialRaw.(int64) id, err := hex.Dec(idStr) if err != nil { @@ -456,22 +466,24 @@ func (n *N) CountEvents(c context.Context, f *filter.F) ( // Parse count from result ctx := context.Background() - neo4jResult, ok := result.(interface { + resultIter, ok := result.(interface { Next(context.Context) bool - Record() *interface{} + Record() *neo4j.Record Err() error }) if !ok { return 0, false, fmt.Errorf("invalid result type") } - if neo4jResult.Next(ctx) { - record := neo4jResult.Record() + if resultIter.Next(ctx) { + record := resultIter.Record() if record != nil { - recordMap, ok := (*record).(map[string]any) - if ok { - countVal, _ := recordMap["count"].(int64) - count = int(countVal) + countRaw, found := record.Get("count") + if found { + countVal, ok := countRaw.(int64) + if ok { + count = int(countVal) + } } } } diff --git a/pkg/neo4j/save-event.go b/pkg/neo4j/save-event.go index 5e90a9f..031835b 100644 --- a/pkg/neo4j/save-event.go +++ b/pkg/neo4j/save-event.go @@ -27,12 +27,7 @@ func (n *N) SaveEvent(c context.Context, ev *event.E) (exists bool, err error) { // Check if we got a result ctx := context.Background() - neo4jResult, ok := result.(interface { - Next(context.Context) bool - Record() *interface{} - Err() error - }) - if ok && neo4jResult.Next(ctx) { + if result.Next(ctx) { return true, nil // Event already exists } @@ -232,30 +227,25 @@ ORDER BY e.created_at DESC` } // Parse results - neo4jResult, ok := result.(interface { - Next(context.Context) bool - Record() *interface{} - Err() error - }) - if !ok { - return false, nil, fmt.Errorf("invalid result type") - } - var serials types.Uint40s wouldReplace := false - for neo4jResult.Next(ctx) { - record := neo4jResult.Record() + for result.Next(ctx) { + record := result.Record() if record == nil { continue } - recordMap, ok := (*record).(map[string]any) + serialRaw, found := record.Get("serial") + if !found { + continue + } + + serialVal, ok := serialRaw.(int64) if !ok { continue } - serialVal, _ := recordMap["serial"].(int64) wouldReplace = true serial := types.Uint40{} serial.Set(uint64(serialVal)) diff --git a/pkg/neo4j/serial.go b/pkg/neo4j/serial.go index 6f51853..783b156 100644 --- a/pkg/neo4j/serial.go +++ b/pkg/neo4j/serial.go @@ -31,22 +31,13 @@ func (n *N) getNextSerial() (uint64, error) { return 0, fmt.Errorf("failed to query serial counter: %w", err) } - neo4jResult, ok := result.(interface { - Next(context.Context) bool - Record() *interface{} - Err() error - }) - if !ok { - return 1, nil - } - var currentSerial uint64 = 1 - if neo4jResult.Next(ctx) { - record := neo4jResult.Record() + if result.Next(ctx) { + record := result.Record() if record != nil { - recordMap, ok := (*record).(map[string]any) - if ok { - if value, ok := recordMap["value"].(int64); ok { + valueRaw, found := record.Get("value") + if found { + if value, ok := valueRaw.(int64); ok { currentSerial = uint64(value) } } @@ -86,12 +77,7 @@ func (n *N) initSerialCounter() error { return fmt.Errorf("failed to check serial counter: %w", err) } - neo4jResult, ok := result.(interface { - Next(context.Context) bool - Record() *interface{} - Err() error - }) - if ok && neo4jResult.Next(ctx) { + if result.Next(ctx) { // Counter already exists return nil } diff --git a/pkg/version/version b/pkg/version/version index 96d68bf..19f7d77 100644 --- a/pkg/version/version +++ b/pkg/version/version @@ -1 +1 @@ -v0.29.3 \ No newline at end of file +v0.29.4 \ No newline at end of file