fix handleevents not prompting auth for event publish with auth-required
Some checks failed
Go / build (push) Has been cancelled
Go / release (push) Has been cancelled

This commit is contained in:
2025-11-18 20:26:36 +00:00
parent d5c0e3abfc
commit d4fb6cbf49
15 changed files with 302 additions and 275 deletions

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"strings"
"github.com/neo4j/neo4j-go-driver/v5/neo4j"
"next.orly.dev/pkg/database/indexes/types"
"next.orly.dev/pkg/encoders/event"
"next.orly.dev/pkg/encoders/filter"
@@ -113,7 +114,7 @@ func (n *N) buildCypherQuery(f *filter.F, includeDeleteEvents bool) (string, map
// Tag filters - this is where Neo4j's graph capabilities shine
// We can efficiently traverse tag relationships
tagIndex := 0
for tagType, tagValues := range *f.Tags {
for _, tagValues := range *f.Tags {
if len(tagValues.T) > 0 {
tagVarName := fmt.Sprintf("t%d", tagIndex)
tagTypeParam := fmt.Sprintf("tagType_%d", tagIndex)
@@ -122,14 +123,17 @@ func (n *N) buildCypherQuery(f *filter.F, includeDeleteEvents bool) (string, map
// Add tag relationship to MATCH clause
matchClause += fmt.Sprintf(" OPTIONAL MATCH (e)-[:TAGGED_WITH]->(%s:Tag)", tagVarName)
// Convert tag values to strings
tagValueStrings := make([]string, len(tagValues.T))
for i, tv := range tagValues.T {
// The first element is the tag type (e.g., "e", "p", etc.)
tagType := string(tagValues.T[0])
// Convert remaining tag values to strings (skip first element which is the type)
tagValueStrings := make([]string, len(tagValues.T)-1)
for i, tv := range tagValues.T[1:] {
tagValueStrings[i] = string(tv)
}
// Add WHERE conditions for this tag
params[tagTypeParam] = string(tagType)
params[tagTypeParam] = tagType
params[tagValuesParam] = tagValueStrings
whereClauses = append(whereClauses,
fmt.Sprintf("(%s.type = $%s AND %s.value IN $%s)",
@@ -179,40 +183,42 @@ RETURN e.id AS id,
// parseEventsFromResult converts Neo4j query results to Nostr events
func (n *N) parseEventsFromResult(result any) ([]*event.E, error) {
// Type assert to Neo4j result
neo4jResult, ok := result.(interface {
events := make([]*event.E, 0)
ctx := context.Background()
// Type assert to the interface we actually use
resultIter, ok := result.(interface {
Next(context.Context) bool
Record() *interface{}
Record() *neo4j.Record
Err() error
})
if !ok {
return nil, fmt.Errorf("invalid result type")
}
events := make([]*event.E, 0)
ctx := context.Background()
// Iterate through result records
for neo4jResult.Next(ctx) {
record := neo4jResult.Record()
for resultIter.Next(ctx) {
record := resultIter.Record()
if record == nil {
continue
}
// Extract fields from record
recordMap, ok := (*record).(map[string]any)
if !ok {
continue
}
// Parse event fields
idStr, _ := recordMap["id"].(string)
kind, _ := recordMap["kind"].(int64)
createdAt, _ := recordMap["created_at"].(int64)
content, _ := recordMap["content"].(string)
sigStr, _ := recordMap["sig"].(string)
pubkeyStr, _ := recordMap["pubkey"].(string)
tagsStr, _ := recordMap["tags"].(string)
idRaw, _ := record.Get("id")
kindRaw, _ := record.Get("kind")
createdAtRaw, _ := record.Get("created_at")
contentRaw, _ := record.Get("content")
sigRaw, _ := record.Get("sig")
pubkeyRaw, _ := record.Get("pubkey")
tagsRaw, _ := record.Get("tags")
idStr, _ := idRaw.(string)
kind, _ := kindRaw.(int64)
createdAt, _ := createdAtRaw.(int64)
content, _ := contentRaw.(string)
sigStr, _ := sigRaw.(string)
pubkeyStr, _ := pubkeyRaw.(string)
tagsStr, _ := tagsRaw.(string)
// Decode hex strings
id, err := hex.Dec(idStr)
@@ -250,7 +256,7 @@ func (n *N) parseEventsFromResult(result any) ([]*event.E, error) {
events = append(events, e)
}
if err := neo4jResult.Err(); err != nil {
if err := resultIter.Err(); err != nil {
return nil, fmt.Errorf("error iterating results: %w", err)
}
@@ -323,27 +329,31 @@ func (n *N) QueryForSerials(c context.Context, f *filter.F) (
serials = make([]*types.Uint40, 0)
ctx := context.Background()
neo4jResult, ok := result.(interface {
resultIter, ok := result.(interface {
Next(context.Context) bool
Record() *interface{}
Record() *neo4j.Record
Err() error
})
if !ok {
return nil, fmt.Errorf("invalid result type")
}
for neo4jResult.Next(ctx) {
record := neo4jResult.Record()
for resultIter.Next(ctx) {
record := resultIter.Record()
if record == nil {
continue
}
recordMap, ok := (*record).(map[string]any)
serialRaw, found := record.Get("serial")
if !found {
continue
}
serialVal, ok := serialRaw.(int64)
if !ok {
continue
}
serialVal, _ := recordMap["serial"].(int64)
serial := types.Uint40{}
serial.Set(uint64(serialVal))
serials = append(serials, &serial)
@@ -386,30 +396,30 @@ func (n *N) QueryForIds(c context.Context, f *filter.F) (
idPkTs = make([]*store.IdPkTs, 0)
ctx := context.Background()
neo4jResult, ok := result.(interface {
resultIter, ok := result.(interface {
Next(context.Context) bool
Record() *interface{}
Record() *neo4j.Record
Err() error
})
if !ok {
return nil, fmt.Errorf("invalid result type")
}
for neo4jResult.Next(ctx) {
record := neo4jResult.Record()
for resultIter.Next(ctx) {
record := resultIter.Record()
if record == nil {
continue
}
recordMap, ok := (*record).(map[string]any)
if !ok {
continue
}
idRaw, _ := record.Get("id")
pubkeyRaw, _ := record.Get("pubkey")
createdAtRaw, _ := record.Get("created_at")
serialRaw, _ := record.Get("serial")
idStr, _ := recordMap["id"].(string)
pubkeyStr, _ := recordMap["pubkey"].(string)
createdAt, _ := recordMap["created_at"].(int64)
serialVal, _ := recordMap["serial"].(int64)
idStr, _ := idRaw.(string)
pubkeyStr, _ := pubkeyRaw.(string)
createdAt, _ := createdAtRaw.(int64)
serialVal, _ := serialRaw.(int64)
id, err := hex.Dec(idStr)
if err != nil {
@@ -456,22 +466,24 @@ func (n *N) CountEvents(c context.Context, f *filter.F) (
// Parse count from result
ctx := context.Background()
neo4jResult, ok := result.(interface {
resultIter, ok := result.(interface {
Next(context.Context) bool
Record() *interface{}
Record() *neo4j.Record
Err() error
})
if !ok {
return 0, false, fmt.Errorf("invalid result type")
}
if neo4jResult.Next(ctx) {
record := neo4jResult.Record()
if resultIter.Next(ctx) {
record := resultIter.Record()
if record != nil {
recordMap, ok := (*record).(map[string]any)
if ok {
countVal, _ := recordMap["count"].(int64)
count = int(countVal)
countRaw, found := record.Get("count")
if found {
countVal, ok := countRaw.(int64)
if ok {
count = int(countVal)
}
}
}
}