diff --git a/internal/data/tree.go b/internal/data/tree.go index eafb23adf..1bfcbf660 100644 --- a/internal/data/tree.go +++ b/internal/data/tree.go @@ -67,18 +67,37 @@ func NewTreeNodeIterator(rd io.Reader) (TreeNodeIterator, error) { } func (t *treeIterator) init() error { - // `{"nodes":[` `]}` + // A tree is expected to be encoded as a JSON object with a single key "nodes". + // However, for future-proofness, we allow unknown keys before and after the "nodes" key. + // The following is the expected format: + // `{"nodes":[...]}` if err := t.assertToken(json.Delim('{')); err != nil { return err } - if err := t.assertToken("nodes"); err != nil { - return err + // Skip unknown keys until we find "nodes" + for { + token, err := t.dec.Token() + if err != nil { + return err + } + key, ok := token.(string) + if !ok { + return errors.Errorf("error decoding tree: expected string key, got %v", token) + } + if key == "nodes" { + // Found "nodes", proceed to read the array + if err := t.assertToken(json.Delim('[')); err != nil { + return err + } + return nil + } + // Unknown key, decode its value into RawMessage and discard it + var raw json.RawMessage + if err := t.dec.Decode(&raw); err != nil { + return err + } } - if err := t.assertToken(json.Delim('[')); err != nil { - return err - } - return nil } func (t *treeIterator) next() (*Node, error) { @@ -94,10 +113,21 @@ func (t *treeIterator) next() (*Node, error) { if err := t.assertToken(json.Delim(']')); err != nil { return nil, err } - if err := t.assertToken(json.Delim('}')); err != nil { - return nil, err + // Skip unknown keys after the array until we find the closing brace + for { + token, err := t.dec.Token() + if err != nil { + return nil, err + } + if token == json.Delim('}') { + return nil, io.EOF + } + // We have an unknown key, decode its value into RawMessage and discard it + var raw json.RawMessage + if err := t.dec.Decode(&raw); err != nil { + return nil, err + } } - return nil, io.EOF } func (t *treeIterator) assertToken(token json.Token) error { diff --git a/internal/data/tree_test.go b/internal/data/tree_test.go index c6e2a517a..47fc4b9a0 100644 --- a/internal/data/tree_test.go +++ b/internal/data/tree_test.go @@ -225,6 +225,75 @@ func testLoadTree(t *testing.T, version uint) { } } +func TestTreeIteratorUnknownKeys(t *testing.T) { + tests := []struct { + name string + jsonData string + wantNodes []string + }{ + { + name: "unknown key before nodes", + jsonData: `{"extra": "value", "nodes": [{"name": "test1"}, {"name": "test2"}]}`, + wantNodes: []string{"test1", "test2"}, + }, + { + name: "unknown key after nodes", + jsonData: `{"nodes": [{"name": "test1"}, {"name": "test2"}], "extra": "value"}`, + wantNodes: []string{"test1", "test2"}, + }, + { + name: "multiple unknown keys before nodes", + jsonData: `{"key1": "value1", "key2": 42, "nodes": [{"name": "test1"}]}`, + wantNodes: []string{"test1"}, + }, + { + name: "multiple unknown keys after nodes", + jsonData: `{"nodes": [{"name": "test1"}], "key1": "value1", "key2": 42}`, + wantNodes: []string{"test1"}, + }, + { + name: "unknown keys before and after nodes", + jsonData: `{"before": "value", "nodes": [{"name": "test1"}], "after": "value"}`, + wantNodes: []string{"test1"}, + }, + { + name: "nested object as unknown value", + jsonData: `{"extra": {"nested": "value"}, "nodes": [{"name": "test1"}]}`, + wantNodes: []string{"test1"}, + }, + { + name: "nested array as unknown value", + jsonData: `{"extra": [1, 2, 3], "nodes": [{"name": "test1"}]}`, + wantNodes: []string{"test1"}, + }, + { + name: "complex nested structure as unknown value", + jsonData: `{"extra": {"obj": {"arr": [1, {"nested": true}]}}, "nodes": [{"name": "test1"}]}`, + wantNodes: []string{"test1"}, + }, + { + name: "empty nodes array with unknown keys", + jsonData: `{"extra": "value", "nodes": []}`, + wantNodes: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + it, err := data.NewTreeNodeIterator(strings.NewReader(tt.jsonData + "\n")) + rtest.OK(t, err) + + var gotNodes []string + for item := range it { + rtest.OK(t, item.Error) + gotNodes = append(gotNodes, item.Node.Name) + } + + rtest.Equals(t, tt.wantNodes, gotNodes, "nodes mismatch") + }) + } +} + func BenchmarkLoadTree(t *testing.B) { repository.BenchmarkAllVersions(t, benchmarkLoadTree) }