update filters to correctly handle unknown fields in filter per nip-01

This commit is contained in:
2025-12-04 09:29:41 +00:00
parent 511bba3a12
commit 40bf8ac057
3 changed files with 518 additions and 1 deletions

View File

@@ -37,6 +37,12 @@ type F struct {
Until *timestamp.T `json:"until,omitempty"`
Search []byte `json:"search,omitempty"`
Limit *uint `json:"limit,omitempty"`
// Extra holds unknown JSON fields for relay extensions (e.g., _graph).
// Per NIP-01, unknown fields should be ignored by relays that don't support them,
// but relays implementing extensions can access them here.
// Keys are stored without quotes, values are raw JSON bytes.
Extra map[string][]byte `json:"-"`
}
// New creates a new, reasonably initialized filter that will be ready for most uses without
@@ -510,8 +516,19 @@ func (f *F) Unmarshal(b []byte) (r []byte, err error) {
// log.I.Ln("betweenKV")
}
default:
// Per NIP-01, unknown filter fields should be ignored by relays that
// don't support them. Store them in Extra for extensions to use.
var val []byte
if val, r, err = skipJSONValue(r); err != nil {
goto invalid
}
if f.Extra == nil {
f.Extra = make(map[string][]byte)
}
// Store the raw JSON value keyed by the field name
f.Extra[string(key)] = val
state = betweenKV
}
key = key[:0]
case betweenKV:
if len(r) == 0 {

161
encoders/filter/skip.go Normal file
View File

@@ -0,0 +1,161 @@
package filter
import (
"lol.mleku.dev/errorf"
)
// skipJSONValue skips over an arbitrary JSON value and returns the raw bytes and remainder.
// It handles: objects {}, arrays [], strings "", numbers, true, false, null.
// The input `b` should start at the first character of the value (after the colon in "key":value).
func skipJSONValue(b []byte) (val []byte, r []byte, err error) {
if len(b) == 0 {
err = errorf.E("empty input")
return
}
start := 0
end := 0
switch b[0] {
case '{':
// Object - find matching closing brace
end, err = findMatchingBrace(b, '{', '}')
case '[':
// Array - find matching closing bracket
end, err = findMatchingBrace(b, '[', ']')
case '"':
// String - find closing quote (handling escapes)
end, err = findClosingQuote(b)
case 't':
// true
if len(b) >= 4 && string(b[:4]) == "true" {
end = 4
} else {
err = errorf.E("invalid JSON value starting with 't'")
}
case 'f':
// false
if len(b) >= 5 && string(b[:5]) == "false" {
end = 5
} else {
err = errorf.E("invalid JSON value starting with 'f'")
}
case 'n':
// null
if len(b) >= 4 && string(b[:4]) == "null" {
end = 4
} else {
err = errorf.E("invalid JSON value starting with 'n'")
}
case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
// Number - scan until we hit a non-number character
end = scanNumber(b)
default:
err = errorf.E("invalid JSON value starting with '%c'", b[0])
}
if err != nil {
return
}
val = b[start:end]
r = b[end:]
return
}
// findMatchingBrace finds the index after the closing brace/bracket that matches the opening one.
// It handles nested structures and strings.
func findMatchingBrace(b []byte, open, close byte) (end int, err error) {
if len(b) == 0 || b[0] != open {
err = errorf.E("expected '%c'", open)
return
}
depth := 0
inString := false
escaped := false
for i := 0; i < len(b); i++ {
c := b[i]
if escaped {
escaped = false
continue
}
if c == '\\' && inString {
escaped = true
continue
}
if c == '"' {
inString = !inString
continue
}
if inString {
continue
}
if c == open {
depth++
} else if c == close {
depth--
if depth == 0 {
end = i + 1
return
}
}
}
err = errorf.E("unmatched '%c'", open)
return
}
// findClosingQuote finds the index after the closing quote of a JSON string.
// Handles escape sequences.
func findClosingQuote(b []byte) (end int, err error) {
if len(b) == 0 || b[0] != '"' {
err = errorf.E("expected '\"'")
return
}
escaped := false
for i := 1; i < len(b); i++ {
c := b[i]
if escaped {
escaped = false
continue
}
if c == '\\' {
escaped = true
continue
}
if c == '"' {
end = i + 1
return
}
}
err = errorf.E("unclosed string")
return
}
// scanNumber scans a JSON number and returns the index after it.
// Handles integers, decimals, and scientific notation.
func scanNumber(b []byte) (end int) {
for i := 0; i < len(b); i++ {
c := b[i]
// Number characters: digits, minus, plus, dot, e, E
if (c >= '0' && c <= '9') || c == '-' || c == '+' || c == '.' || c == 'e' || c == 'E' {
continue
}
end = i
return
}
end = len(b)
return
}

View File

@@ -0,0 +1,339 @@
package filter
import (
"testing"
)
func TestSkipJSONValue(t *testing.T) {
tests := []struct {
name string
input string
wantVal string
wantRem string
wantErr bool
}{
// Objects
{
name: "empty object",
input: `{}`,
wantVal: `{}`,
wantRem: "",
},
{
name: "simple object",
input: `{"foo":"bar"}`,
wantVal: `{"foo":"bar"}`,
wantRem: "",
},
{
name: "nested object",
input: `{"a":{"b":"c"}},"next"`,
wantVal: `{"a":{"b":"c"}}`,
wantRem: `,"next"`,
},
{
name: "object with array",
input: `{"items":[1,2,3]}rest`,
wantVal: `{"items":[1,2,3]}`,
wantRem: `rest`,
},
// Arrays
{
name: "empty array",
input: `[]`,
wantVal: `[]`,
wantRem: "",
},
{
name: "number array",
input: `[1,2,3]`,
wantVal: `[1,2,3]`,
wantRem: "",
},
{
name: "string array",
input: `["a","b","c"],rest`,
wantVal: `["a","b","c"]`,
wantRem: `,rest`,
},
{
name: "nested array",
input: `[[1,2],[3,4]]more`,
wantVal: `[[1,2],[3,4]]`,
wantRem: `more`,
},
// Strings
{
name: "simple string",
input: `"hello"`,
wantVal: `"hello"`,
wantRem: "",
},
{
name: "string with escapes",
input: `"hello \"world\""rest`,
wantVal: `"hello \"world\""`,
wantRem: `rest`,
},
{
name: "string with backslash",
input: `"path\\to\\file",next`,
wantVal: `"path\\to\\file"`,
wantRem: `,next`,
},
// Numbers
{
name: "integer",
input: `123`,
wantVal: `123`,
wantRem: "",
},
{
name: "negative integer",
input: `-456,next`,
wantVal: `-456`,
wantRem: `,next`,
},
{
name: "decimal",
input: `3.14159}`,
wantVal: `3.14159`,
wantRem: `}`,
},
{
name: "scientific notation",
input: `1.23e-4,next`,
wantVal: `1.23e-4`,
wantRem: `,next`,
},
// Booleans
{
name: "true",
input: `true`,
wantVal: `true`,
wantRem: "",
},
{
name: "false",
input: `false,next`,
wantVal: `false`,
wantRem: `,next`,
},
// Null
{
name: "null",
input: `null`,
wantVal: `null`,
wantRem: "",
},
{
name: "null with remainder",
input: `null}`,
wantVal: `null`,
wantRem: `}`,
},
// Complex nested structures
{
name: "graph query object",
input: `{"method":"follows","seed":"abc123","depth":2,"inbound_refs":[{"kinds":[7],"from_depth":1}]},rest`,
wantVal: `{"method":"follows","seed":"abc123","depth":2,"inbound_refs":[{"kinds":[7],"from_depth":1}]}`,
wantRem: `,rest`,
},
// Error cases
{
name: "empty input",
input: ``,
wantErr: true,
},
{
name: "unclosed object",
input: `{"foo":"bar"`,
wantErr: true,
},
{
name: "unclosed array",
input: `[1,2,3`,
wantErr: true,
},
{
name: "unclosed string",
input: `"hello`,
wantErr: true,
},
{
name: "invalid true",
input: `tru`,
wantErr: true,
},
{
name: "invalid false",
input: `fals`,
wantErr: true,
},
{
name: "invalid null",
input: `nul`,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
val, rem, err := skipJSONValue([]byte(tt.input))
if tt.wantErr {
if err == nil {
t.Errorf("expected error, got nil")
}
return
}
if err != nil {
t.Errorf("unexpected error: %v", err)
return
}
if string(val) != tt.wantVal {
t.Errorf("val = %q, want %q", string(val), tt.wantVal)
}
if string(rem) != tt.wantRem {
t.Errorf("rem = %q, want %q", string(rem), tt.wantRem)
}
})
}
}
func TestUnmarshalWithUnknownFields(t *testing.T) {
tests := []struct {
name string
input string
wantKinds []int
wantExtra map[string]string
wantErr bool
}{
{
name: "simple filter with _graph extension",
input: `{"kinds":[1,7],"_graph":{"method":"follows","seed":"abc123","depth":2}}`,
wantKinds: []int{1, 7},
wantExtra: map[string]string{
"_graph": `{"method":"follows","seed":"abc123","depth":2}`,
},
},
{
name: "filter with unknown string field",
input: `{"kinds":[1],"_custom":"value"}`,
wantKinds: []int{1},
wantExtra: map[string]string{
"_custom": `"value"`,
},
},
{
name: "filter with multiple unknown fields",
input: `{"kinds":[1],"_foo":123,"_bar":["a","b"]}`,
wantKinds: []int{1},
wantExtra: map[string]string{
"_foo": `123`,
"_bar": `["a","b"]`,
},
},
{
name: "filter with complex _graph extension",
input: `{"kinds":[0],"_graph":{"method":"follows","seed":"abc","depth":2,"inbound_refs":[{"kinds":[7],"from_depth":1}]}}`,
wantKinds: []int{0},
wantExtra: map[string]string{
"_graph": `{"method":"follows","seed":"abc","depth":2,"inbound_refs":[{"kinds":[7],"from_depth":1}]}`,
},
},
{
name: "unknown field before known fields",
input: `{"_unknown":true,"kinds":[3]}`,
wantKinds: []int{3},
wantExtra: map[string]string{
"_unknown": `true`,
},
},
{
name: "unknown field with null value",
input: `{"kinds":[1],"_nullable":null}`,
wantKinds: []int{1},
wantExtra: map[string]string{
"_nullable": `null`,
},
},
{
name: "standard filter without unknown fields",
input: `{"kinds":[1,7]}`,
wantKinds: []int{1, 7},
wantExtra: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
f := &F{}
_, err := f.Unmarshal([]byte(tt.input))
if tt.wantErr {
if err == nil {
t.Errorf("expected error, got nil")
}
return
}
if err != nil {
t.Errorf("unexpected error: %v", err)
return
}
// Check kinds
if f.Kinds != nil {
if f.Kinds.Len() != len(tt.wantKinds) {
t.Errorf("kinds len = %d, want %d", f.Kinds.Len(), len(tt.wantKinds))
} else {
for i, k := range f.Kinds.K {
if int(k.K) != tt.wantKinds[i] {
t.Errorf("kinds[%d] = %d, want %d", i, k.K, tt.wantKinds[i])
}
}
}
} else if len(tt.wantKinds) > 0 {
t.Errorf("kinds is nil, want %v", tt.wantKinds)
}
// Check extra fields
if tt.wantExtra == nil {
if f.Extra != nil && len(f.Extra) > 0 {
t.Errorf("extra = %v, want nil", f.Extra)
}
} else {
if f.Extra == nil {
t.Errorf("extra is nil, want %v", tt.wantExtra)
return
}
for key, wantVal := range tt.wantExtra {
gotVal, ok := f.Extra[key]
if !ok {
t.Errorf("extra[%q] not found", key)
continue
}
if string(gotVal) != wantVal {
t.Errorf("extra[%q] = %q, want %q", key, string(gotVal), wantVal)
}
}
for key := range f.Extra {
if _, ok := tt.wantExtra[key]; !ok {
t.Errorf("unexpected extra key %q", key)
}
}
}
})
}
}