diff --git a/encoders/filter/filter.go b/encoders/filter/filter.go index bb517a9..f77ff65 100644 --- a/encoders/filter/filter.go +++ b/encoders/filter/filter.go @@ -37,6 +37,12 @@ type F struct { Until *timestamp.T `json:"until,omitempty"` Search []byte `json:"search,omitempty"` Limit *uint `json:"limit,omitempty"` + + // Extra holds unknown JSON fields for relay extensions (e.g., _graph). + // Per NIP-01, unknown fields should be ignored by relays that don't support them, + // but relays implementing extensions can access them here. + // Keys are stored without quotes, values are raw JSON bytes. + Extra map[string][]byte `json:"-"` } // New creates a new, reasonably initialized filter that will be ready for most uses without @@ -510,7 +516,18 @@ func (f *F) Unmarshal(b []byte) (r []byte, err error) { // log.I.Ln("betweenKV") } default: - goto invalid + // Per NIP-01, unknown filter fields should be ignored by relays that + // don't support them. Store them in Extra for extensions to use. + var val []byte + if val, r, err = skipJSONValue(r); err != nil { + goto invalid + } + if f.Extra == nil { + f.Extra = make(map[string][]byte) + } + // Store the raw JSON value keyed by the field name + f.Extra[string(key)] = val + state = betweenKV } key = key[:0] case betweenKV: diff --git a/encoders/filter/skip.go b/encoders/filter/skip.go new file mode 100644 index 0000000..6354ad9 --- /dev/null +++ b/encoders/filter/skip.go @@ -0,0 +1,161 @@ +package filter + +import ( + "lol.mleku.dev/errorf" +) + +// skipJSONValue skips over an arbitrary JSON value and returns the raw bytes and remainder. +// It handles: objects {}, arrays [], strings "", numbers, true, false, null. +// The input `b` should start at the first character of the value (after the colon in "key":value). +func skipJSONValue(b []byte) (val []byte, r []byte, err error) { + if len(b) == 0 { + err = errorf.E("empty input") + return + } + + start := 0 + end := 0 + + switch b[0] { + case '{': + // Object - find matching closing brace + end, err = findMatchingBrace(b, '{', '}') + case '[': + // Array - find matching closing bracket + end, err = findMatchingBrace(b, '[', ']') + case '"': + // String - find closing quote (handling escapes) + end, err = findClosingQuote(b) + case 't': + // true + if len(b) >= 4 && string(b[:4]) == "true" { + end = 4 + } else { + err = errorf.E("invalid JSON value starting with 't'") + } + case 'f': + // false + if len(b) >= 5 && string(b[:5]) == "false" { + end = 5 + } else { + err = errorf.E("invalid JSON value starting with 'f'") + } + case 'n': + // null + if len(b) >= 4 && string(b[:4]) == "null" { + end = 4 + } else { + err = errorf.E("invalid JSON value starting with 'n'") + } + case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': + // Number - scan until we hit a non-number character + end = scanNumber(b) + default: + err = errorf.E("invalid JSON value starting with '%c'", b[0]) + } + + if err != nil { + return + } + + val = b[start:end] + r = b[end:] + return +} + +// findMatchingBrace finds the index after the closing brace/bracket that matches the opening one. +// It handles nested structures and strings. +func findMatchingBrace(b []byte, open, close byte) (end int, err error) { + if len(b) == 0 || b[0] != open { + err = errorf.E("expected '%c'", open) + return + } + + depth := 0 + inString := false + escaped := false + + for i := 0; i < len(b); i++ { + c := b[i] + + if escaped { + escaped = false + continue + } + + if c == '\\' && inString { + escaped = true + continue + } + + if c == '"' { + inString = !inString + continue + } + + if inString { + continue + } + + if c == open { + depth++ + } else if c == close { + depth-- + if depth == 0 { + end = i + 1 + return + } + } + } + + err = errorf.E("unmatched '%c'", open) + return +} + +// findClosingQuote finds the index after the closing quote of a JSON string. +// Handles escape sequences. +func findClosingQuote(b []byte) (end int, err error) { + if len(b) == 0 || b[0] != '"' { + err = errorf.E("expected '\"'") + return + } + + escaped := false + for i := 1; i < len(b); i++ { + c := b[i] + + if escaped { + escaped = false + continue + } + + if c == '\\' { + escaped = true + continue + } + + if c == '"' { + end = i + 1 + return + } + } + + err = errorf.E("unclosed string") + return +} + +// scanNumber scans a JSON number and returns the index after it. +// Handles integers, decimals, and scientific notation. +func scanNumber(b []byte) (end int) { + for i := 0; i < len(b); i++ { + c := b[i] + // Number characters: digits, minus, plus, dot, e, E + if (c >= '0' && c <= '9') || c == '-' || c == '+' || c == '.' || c == 'e' || c == 'E' { + continue + } + end = i + return + } + end = len(b) + return +} diff --git a/encoders/filter/skip_test.go b/encoders/filter/skip_test.go new file mode 100644 index 0000000..9d3bbdc --- /dev/null +++ b/encoders/filter/skip_test.go @@ -0,0 +1,339 @@ +package filter + +import ( + "testing" +) + +func TestSkipJSONValue(t *testing.T) { + tests := []struct { + name string + input string + wantVal string + wantRem string + wantErr bool + }{ + // Objects + { + name: "empty object", + input: `{}`, + wantVal: `{}`, + wantRem: "", + }, + { + name: "simple object", + input: `{"foo":"bar"}`, + wantVal: `{"foo":"bar"}`, + wantRem: "", + }, + { + name: "nested object", + input: `{"a":{"b":"c"}},"next"`, + wantVal: `{"a":{"b":"c"}}`, + wantRem: `,"next"`, + }, + { + name: "object with array", + input: `{"items":[1,2,3]}rest`, + wantVal: `{"items":[1,2,3]}`, + wantRem: `rest`, + }, + + // Arrays + { + name: "empty array", + input: `[]`, + wantVal: `[]`, + wantRem: "", + }, + { + name: "number array", + input: `[1,2,3]`, + wantVal: `[1,2,3]`, + wantRem: "", + }, + { + name: "string array", + input: `["a","b","c"],rest`, + wantVal: `["a","b","c"]`, + wantRem: `,rest`, + }, + { + name: "nested array", + input: `[[1,2],[3,4]]more`, + wantVal: `[[1,2],[3,4]]`, + wantRem: `more`, + }, + + // Strings + { + name: "simple string", + input: `"hello"`, + wantVal: `"hello"`, + wantRem: "", + }, + { + name: "string with escapes", + input: `"hello \"world\""rest`, + wantVal: `"hello \"world\""`, + wantRem: `rest`, + }, + { + name: "string with backslash", + input: `"path\\to\\file",next`, + wantVal: `"path\\to\\file"`, + wantRem: `,next`, + }, + + // Numbers + { + name: "integer", + input: `123`, + wantVal: `123`, + wantRem: "", + }, + { + name: "negative integer", + input: `-456,next`, + wantVal: `-456`, + wantRem: `,next`, + }, + { + name: "decimal", + input: `3.14159}`, + wantVal: `3.14159`, + wantRem: `}`, + }, + { + name: "scientific notation", + input: `1.23e-4,next`, + wantVal: `1.23e-4`, + wantRem: `,next`, + }, + + // Booleans + { + name: "true", + input: `true`, + wantVal: `true`, + wantRem: "", + }, + { + name: "false", + input: `false,next`, + wantVal: `false`, + wantRem: `,next`, + }, + + // Null + { + name: "null", + input: `null`, + wantVal: `null`, + wantRem: "", + }, + { + name: "null with remainder", + input: `null}`, + wantVal: `null`, + wantRem: `}`, + }, + + // Complex nested structures + { + name: "graph query object", + input: `{"method":"follows","seed":"abc123","depth":2,"inbound_refs":[{"kinds":[7],"from_depth":1}]},rest`, + wantVal: `{"method":"follows","seed":"abc123","depth":2,"inbound_refs":[{"kinds":[7],"from_depth":1}]}`, + wantRem: `,rest`, + }, + + // Error cases + { + name: "empty input", + input: ``, + wantErr: true, + }, + { + name: "unclosed object", + input: `{"foo":"bar"`, + wantErr: true, + }, + { + name: "unclosed array", + input: `[1,2,3`, + wantErr: true, + }, + { + name: "unclosed string", + input: `"hello`, + wantErr: true, + }, + { + name: "invalid true", + input: `tru`, + wantErr: true, + }, + { + name: "invalid false", + input: `fals`, + wantErr: true, + }, + { + name: "invalid null", + input: `nul`, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + val, rem, err := skipJSONValue([]byte(tt.input)) + + if tt.wantErr { + if err == nil { + t.Errorf("expected error, got nil") + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + if string(val) != tt.wantVal { + t.Errorf("val = %q, want %q", string(val), tt.wantVal) + } + + if string(rem) != tt.wantRem { + t.Errorf("rem = %q, want %q", string(rem), tt.wantRem) + } + }) + } +} + +func TestUnmarshalWithUnknownFields(t *testing.T) { + tests := []struct { + name string + input string + wantKinds []int + wantExtra map[string]string + wantErr bool + }{ + { + name: "simple filter with _graph extension", + input: `{"kinds":[1,7],"_graph":{"method":"follows","seed":"abc123","depth":2}}`, + wantKinds: []int{1, 7}, + wantExtra: map[string]string{ + "_graph": `{"method":"follows","seed":"abc123","depth":2}`, + }, + }, + { + name: "filter with unknown string field", + input: `{"kinds":[1],"_custom":"value"}`, + wantKinds: []int{1}, + wantExtra: map[string]string{ + "_custom": `"value"`, + }, + }, + { + name: "filter with multiple unknown fields", + input: `{"kinds":[1],"_foo":123,"_bar":["a","b"]}`, + wantKinds: []int{1}, + wantExtra: map[string]string{ + "_foo": `123`, + "_bar": `["a","b"]`, + }, + }, + { + name: "filter with complex _graph extension", + input: `{"kinds":[0],"_graph":{"method":"follows","seed":"abc","depth":2,"inbound_refs":[{"kinds":[7],"from_depth":1}]}}`, + wantKinds: []int{0}, + wantExtra: map[string]string{ + "_graph": `{"method":"follows","seed":"abc","depth":2,"inbound_refs":[{"kinds":[7],"from_depth":1}]}`, + }, + }, + { + name: "unknown field before known fields", + input: `{"_unknown":true,"kinds":[3]}`, + wantKinds: []int{3}, + wantExtra: map[string]string{ + "_unknown": `true`, + }, + }, + { + name: "unknown field with null value", + input: `{"kinds":[1],"_nullable":null}`, + wantKinds: []int{1}, + wantExtra: map[string]string{ + "_nullable": `null`, + }, + }, + { + name: "standard filter without unknown fields", + input: `{"kinds":[1,7]}`, + wantKinds: []int{1, 7}, + wantExtra: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f := &F{} + _, err := f.Unmarshal([]byte(tt.input)) + + if tt.wantErr { + if err == nil { + t.Errorf("expected error, got nil") + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + // Check kinds + if f.Kinds != nil { + if f.Kinds.Len() != len(tt.wantKinds) { + t.Errorf("kinds len = %d, want %d", f.Kinds.Len(), len(tt.wantKinds)) + } else { + for i, k := range f.Kinds.K { + if int(k.K) != tt.wantKinds[i] { + t.Errorf("kinds[%d] = %d, want %d", i, k.K, tt.wantKinds[i]) + } + } + } + } else if len(tt.wantKinds) > 0 { + t.Errorf("kinds is nil, want %v", tt.wantKinds) + } + + // Check extra fields + if tt.wantExtra == nil { + if f.Extra != nil && len(f.Extra) > 0 { + t.Errorf("extra = %v, want nil", f.Extra) + } + } else { + if f.Extra == nil { + t.Errorf("extra is nil, want %v", tt.wantExtra) + return + } + for key, wantVal := range tt.wantExtra { + gotVal, ok := f.Extra[key] + if !ok { + t.Errorf("extra[%q] not found", key) + continue + } + if string(gotVal) != wantVal { + t.Errorf("extra[%q] = %q, want %q", key, string(gotVal), wantVal) + } + } + for key := range f.Extra { + if _, ok := tt.wantExtra[key]; !ok { + t.Errorf("unexpected extra key %q", key) + } + } + } + }) + } +}