From 350f29d921cceff60bc0de06b9df0cec33182712 Mon Sep 17 00:00:00 2001 From: Michael Eischer Date: Sat, 22 Nov 2025 19:55:50 +0100 Subject: [PATCH] data: replace Tree with TreeNodeIterator The TreeNodeIterator decodes nodes while iterating over a tree blob. This should reduce peak memory usage as now only the serialized tree blob and a single node have to alive at the same time. Using the iterator has implications for the error handling however. Now it is necessary that all loops that iterate through a tree check for errors before using the node returned by the iterator. The other change is that it is no longer possible to iterate over a tree multiple times. Instead it must be loaded a second time. This only affects the tree rewriting code. --- cmd/restic/cmd_copy.go | 9 +- cmd/restic/cmd_diff.go | 34 +++-- cmd/restic/cmd_dump.go | 9 +- cmd/restic/cmd_recover.go | 6 +- cmd/restic/cmd_repair_snapshots.go | 5 +- internal/archiver/archiver.go | 36 +++-- internal/archiver/archiver_test.go | 24 ++-- internal/archiver/testing.go | 17 +-- internal/checker/checker.go | 13 +- internal/checker/checker_test.go | 18 +-- internal/data/find.go | 11 +- internal/data/testing.go | 28 ++-- internal/data/tree.go | 222 ++++++++++++++++++----------- internal/data/tree_stream.go | 39 ++++- internal/data/tree_test.go | 47 +++--- internal/dump/common.go | 10 +- internal/fuse/dir.go | 22 ++- internal/fuse/fuse_test.go | 7 +- internal/restorer/restorer.go | 11 +- internal/restorer/restorer_test.go | 19 +-- internal/walker/rewriter.go | 17 ++- internal/walker/rewriter_test.go | 10 +- internal/walker/walker.go | 13 +- 23 files changed, 394 insertions(+), 233 deletions(-) diff --git a/cmd/restic/cmd_copy.go b/cmd/restic/cmd_copy.go index 46e70f120..d17ded7c9 100644 --- a/cmd/restic/cmd_copy.go +++ b/cmd/restic/cmd_copy.go @@ -278,7 +278,7 @@ func copyTree(ctx context.Context, srcRepo restic.Repository, dstRepo restic.Rep visited := visitedTrees.Has(handle) visitedTrees.Insert(handle) return visited - }, func(treeID restic.ID, err error, tree *data.Tree) error { + }, func(treeID restic.ID, err error, nodes data.TreeNodeIterator) error { if err != nil { return fmt.Errorf("LoadTree(%v) returned error %v", treeID.Str(), err) } @@ -286,10 +286,13 @@ func copyTree(ctx context.Context, srcRepo restic.Repository, dstRepo restic.Rep // copy raw tree bytes to avoid problems if the serialization changes enqueue(restic.BlobHandle{ID: treeID, Type: restic.TreeBlob}) - for _, entry := range tree.Nodes { + for item := range nodes { + if item.Error != nil { + return item.Error + } // Recursion into directories is handled by StreamTrees // Copy the blobs for this file. - for _, blobID := range entry.Content { + for _, blobID := range item.Node.Content { enqueue(restic.BlobHandle{Type: restic.DataBlob, ID: blobID}) } } diff --git a/cmd/restic/cmd_diff.go b/cmd/restic/cmd_diff.go index 58138e9e5..3f8e16c3a 100644 --- a/cmd/restic/cmd_diff.go +++ b/cmd/restic/cmd_diff.go @@ -184,11 +184,14 @@ func (c *Comparer) printDir(ctx context.Context, mode string, stats *DiffStat, b return err } - for _, node := range tree.Nodes { + for item := range tree { + if item.Error != nil { + return item.Error + } if ctx.Err() != nil { return ctx.Err() } - + node := item.Node name := path.Join(prefix, node.Name) if node.Type == data.NodeTypeDir { name += "/" @@ -215,11 +218,15 @@ func (c *Comparer) collectDir(ctx context.Context, blobs restic.AssociatedBlobSe return err } - for _, node := range tree.Nodes { + for item := range tree { + if item.Error != nil { + return item.Error + } if ctx.Err() != nil { return ctx.Err() } + node := item.Node addBlobs(blobs, node) if node.Type == data.NodeTypeDir { @@ -233,16 +240,24 @@ func (c *Comparer) collectDir(ctx context.Context, blobs restic.AssociatedBlobSe return ctx.Err() } -func uniqueNodeNames(tree1, tree2 *data.Tree) (tree1Nodes, tree2Nodes map[string]*data.Node, uniqueNames []string) { +func uniqueNodeNames(tree1, tree2 data.TreeNodeIterator) (tree1Nodes, tree2Nodes map[string]*data.Node, uniqueNames []string, err error) { names := make(map[string]struct{}) tree1Nodes = make(map[string]*data.Node) - for _, node := range tree1.Nodes { + for item := range tree1 { + if item.Error != nil { + return nil, nil, nil, item.Error + } + node := item.Node tree1Nodes[node.Name] = node names[node.Name] = struct{}{} } tree2Nodes = make(map[string]*data.Node) - for _, node := range tree2.Nodes { + for item := range tree2 { + if item.Error != nil { + return nil, nil, nil, item.Error + } + node := item.Node tree2Nodes[node.Name] = node names[node.Name] = struct{}{} } @@ -253,7 +268,7 @@ func uniqueNodeNames(tree1, tree2 *data.Tree) (tree1Nodes, tree2Nodes map[string } sort.Strings(uniqueNames) - return tree1Nodes, tree2Nodes, uniqueNames + return tree1Nodes, tree2Nodes, uniqueNames, nil } func (c *Comparer) diffTree(ctx context.Context, stats *DiffStatsContainer, prefix string, id1, id2 restic.ID) error { @@ -268,7 +283,10 @@ func (c *Comparer) diffTree(ctx context.Context, stats *DiffStatsContainer, pref return err } - tree1Nodes, tree2Nodes, names := uniqueNodeNames(tree1, tree2) + tree1Nodes, tree2Nodes, names, err := uniqueNodeNames(tree1, tree2) + if err != nil { + return err + } for _, name := range names { if ctx.Err() != nil { diff --git a/cmd/restic/cmd_dump.go b/cmd/restic/cmd_dump.go index da45cc303..abea6e52e 100644 --- a/cmd/restic/cmd_dump.go +++ b/cmd/restic/cmd_dump.go @@ -80,7 +80,7 @@ func splitPath(p string) []string { return append(s, f) } -func printFromTree(ctx context.Context, tree *data.Tree, repo restic.BlobLoader, prefix string, pathComponents []string, d *dump.Dumper, canWriteArchiveFunc func() error) error { +func printFromTree(ctx context.Context, tree data.TreeNodeIterator, repo restic.BlobLoader, prefix string, pathComponents []string, d *dump.Dumper, canWriteArchiveFunc func() error) error { // If we print / we need to assume that there are multiple nodes at that // level in the tree. if pathComponents[0] == "" { @@ -92,11 +92,14 @@ func printFromTree(ctx context.Context, tree *data.Tree, repo restic.BlobLoader, item := filepath.Join(prefix, pathComponents[0]) l := len(pathComponents) - for _, node := range tree.Nodes { + for it := range tree { + if it.Error != nil { + return it.Error + } if ctx.Err() != nil { return ctx.Err() } - + node := it.Node // If dumping something in the highest level it will just take the // first item it finds and dump that according to the switch case below. if node.Name == pathComponents[0] { diff --git a/cmd/restic/cmd_recover.go b/cmd/restic/cmd_recover.go index 8cb5f3cc7..bbb71972f 100644 --- a/cmd/restic/cmd_recover.go +++ b/cmd/restic/cmd_recover.go @@ -97,7 +97,11 @@ func runRecover(ctx context.Context, gopts global.Options, term ui.Terminal) err continue } - for _, node := range tree.Nodes { + for item := range tree { + if item.Error != nil { + return item.Error + } + node := item.Node if node.Type == data.NodeTypeDir && node.Subtree != nil { trees[*node.Subtree] = true } diff --git a/cmd/restic/cmd_repair_snapshots.go b/cmd/restic/cmd_repair_snapshots.go index 226da3d44..104d3dff3 100644 --- a/cmd/restic/cmd_repair_snapshots.go +++ b/cmd/restic/cmd_repair_snapshots.go @@ -2,6 +2,7 @@ package main import ( "context" + "slices" "github.com/restic/restic/internal/data" "github.com/restic/restic/internal/errors" @@ -130,7 +131,7 @@ func runRepairSnapshots(ctx context.Context, gopts global.Options, opts RepairOp node.Size = newSize return node }, - RewriteFailedTree: func(_ restic.ID, path string, _ error) (*data.Tree, error) { + RewriteFailedTree: func(_ restic.ID, path string, _ error) (data.TreeNodeIterator, error) { if path == "/" { printer.P(" dir %q: not readable", path) // remove snapshots with invalid root node @@ -138,7 +139,7 @@ func runRepairSnapshots(ctx context.Context, gopts global.Options, opts RepairOp } // If a subtree fails to load, remove it printer.P(" dir %q: replaced with empty directory", path) - return &data.Tree{}, nil + return slices.Values([]data.NodeOrError{}), nil }, AllowUnstableSerialization: true, }) diff --git a/internal/archiver/archiver.go b/internal/archiver/archiver.go index 0ed37eb5b..996e79e6a 100644 --- a/internal/archiver/archiver.go +++ b/internal/archiver/archiver.go @@ -281,7 +281,7 @@ func (arch *Archiver) nodeFromFileInfo(snPath, filename string, meta ToNoder, ig // loadSubtree tries to load the subtree referenced by node. In case of an error, nil is returned. // If there is no node to load, then nil is returned without an error. -func (arch *Archiver) loadSubtree(ctx context.Context, node *data.Node) (*data.Tree, error) { +func (arch *Archiver) loadSubtree(ctx context.Context, node *data.Node) (data.TreeNodeIterator, error) { if node == nil || node.Type != data.NodeTypeDir || node.Subtree == nil { return nil, nil } @@ -307,7 +307,7 @@ func (arch *Archiver) wrapLoadTreeError(id restic.ID, err error) error { // saveDir stores a directory in the repo and returns the node. snPath is the // path within the current snapshot. -func (arch *Archiver) saveDir(ctx context.Context, snPath string, dir string, meta fs.File, previous *data.Tree, complete fileCompleteFunc) (d futureNode, err error) { +func (arch *Archiver) saveDir(ctx context.Context, snPath string, dir string, meta fs.File, previous data.TreeNodeIterator, complete fileCompleteFunc) (d futureNode, err error) { debug.Log("%v %v", snPath, dir) treeNode, names, err := arch.dirToNodeAndEntries(snPath, dir, meta) @@ -317,6 +317,9 @@ func (arch *Archiver) saveDir(ctx context.Context, snPath string, dir string, me nodes := make([]futureNode, 0, len(names)) + finder := data.NewTreeFinder(previous) + defer finder.Close() + for _, name := range names { // test if context has been cancelled if ctx.Err() != nil { @@ -325,7 +328,11 @@ func (arch *Archiver) saveDir(ctx context.Context, snPath string, dir string, me } pathname := arch.FS.Join(dir, name) - oldNode := previous.Find(name) + oldNode, err := finder.Find(name) + err = arch.error(pathname, err) + if err != nil { + return futureNode{}, err + } snItem := join(snPath, name) fn, excluded, err := arch.save(ctx, snItem, pathname, oldNode) @@ -645,7 +652,7 @@ func join(elem ...string) string { // saveTree stores a Tree in the repo, returned is the tree. snPath is the path // within the current snapshot. -func (arch *Archiver) saveTree(ctx context.Context, snPath string, atree *tree, previous *data.Tree, complete fileCompleteFunc) (futureNode, int, error) { +func (arch *Archiver) saveTree(ctx context.Context, snPath string, atree *tree, previous data.TreeNodeIterator, complete fileCompleteFunc) (futureNode, int, error) { var node *data.Node if snPath != "/" { @@ -663,10 +670,13 @@ func (arch *Archiver) saveTree(ctx context.Context, snPath string, atree *tree, node = &data.Node{} } - debug.Log("%v (%v nodes), parent %v", snPath, len(atree.Nodes), previous) + debug.Log("%v (%v nodes)", snPath, len(atree.Nodes)) nodeNames := atree.NodeNames() nodes := make([]futureNode, 0, len(nodeNames)) + finder := data.NewTreeFinder(previous) + defer finder.Close() + // iterate over the nodes of atree in lexicographic (=deterministic) order for _, name := range nodeNames { subatree := atree.Nodes[name] @@ -678,7 +688,13 @@ func (arch *Archiver) saveTree(ctx context.Context, snPath string, atree *tree, // this is a leaf node if subatree.Leaf() { - fn, excluded, err := arch.save(ctx, join(snPath, name), subatree.Path, previous.Find(name)) + pathname := join(snPath, name) + oldNode, err := finder.Find(name) + err = arch.error(pathname, err) + if err != nil { + return futureNode{}, 0, err + } + fn, excluded, err := arch.save(ctx, pathname, subatree.Path, oldNode) if err != nil { err = arch.error(subatree.Path, err) @@ -698,7 +714,11 @@ func (arch *Archiver) saveTree(ctx context.Context, snPath string, atree *tree, snItem := join(snPath, name) + "/" start := time.Now() - oldNode := previous.Find(name) + oldNode, err := finder.Find(name) + err = arch.error(snItem, err) + if err != nil { + return futureNode{}, 0, err + } oldSubtree, err := arch.loadSubtree(ctx, oldNode) if err != nil { err = arch.error(join(snPath, name), err) @@ -801,7 +821,7 @@ type SnapshotOptions struct { } // loadParentTree loads a tree referenced by snapshot id. If id is null, nil is returned. -func (arch *Archiver) loadParentTree(ctx context.Context, sn *data.Snapshot) *data.Tree { +func (arch *Archiver) loadParentTree(ctx context.Context, sn *data.Snapshot) data.TreeNodeIterator { if sn == nil { return nil } diff --git a/internal/archiver/archiver_test.go b/internal/archiver/archiver_test.go index a4012d585..b13734655 100644 --- a/internal/archiver/archiver_test.go +++ b/internal/archiver/archiver_test.go @@ -877,11 +877,7 @@ func TestArchiverSaveDir(t *testing.T) { } node.Name = targetNodeName - tree := &data.Tree{Nodes: []*data.Node{node}} - treeID, err = data.SaveTree(ctx, uploader, tree) - if err != nil { - t.Fatal(err) - } + treeID = data.TestSaveNodes(t, ctx, uploader, []*data.Node{node}) arch.stopWorkers() return wg.Wait() }) @@ -2256,19 +2252,15 @@ func snapshot(t testing.TB, repo archiverRepo, fs fs.FS, parent *data.Snapshot, ParentSnapshot: parent, } snapshot, _, _, err := arch.Snapshot(ctx, []string{filename}, sopts) - if err != nil { - t.Fatal(err) - } - + rtest.OK(t, err) tree, err := data.LoadTree(ctx, repo, *snapshot.Tree) - if err != nil { - t.Fatal(err) - } + rtest.OK(t, err) - node := tree.Find(filename) - if node == nil { - t.Fatalf("unable to find node for testfile in snapshot") - } + finder := data.NewTreeFinder(tree) + defer finder.Close() + node, err := finder.Find(filename) + rtest.OK(t, err) + rtest.Assert(t, node != nil, "unable to find node for testfile in snapshot") return snapshot, node } diff --git a/internal/archiver/testing.go b/internal/archiver/testing.go index 666a8c556..6f1195c29 100644 --- a/internal/archiver/testing.go +++ b/internal/archiver/testing.go @@ -15,6 +15,7 @@ import ( "github.com/restic/restic/internal/debug" "github.com/restic/restic/internal/fs" "github.com/restic/restic/internal/restic" + rtest "github.com/restic/restic/internal/test" ) // TestSnapshot creates a new snapshot of path. @@ -265,19 +266,14 @@ func TestEnsureTree(ctx context.Context, t testing.TB, prefix string, repo resti t.Helper() tree, err := data.LoadTree(ctx, repo, treeID) - if err != nil { - t.Fatal(err) - return - } + rtest.OK(t, err) var nodeNames []string - for _, node := range tree.Nodes { - nodeNames = append(nodeNames, node.Name) - } - debug.Log("%v (%v) %v", prefix, treeID.Str(), nodeNames) - checked := make(map[string]struct{}) - for _, node := range tree.Nodes { + for item := range tree { + rtest.OK(t, item.Error) + node := item.Node + nodeNames = append(nodeNames, node.Name) nodePrefix := path.Join(prefix, node.Name) entry, ok := dir[node.Name] @@ -316,6 +312,7 @@ func TestEnsureTree(ctx context.Context, t testing.TB, prefix string, repo resti } } } + debug.Log("%v (%v) %v", prefix, treeID.Str(), nodeNames) for name := range dir { _, ok := checked[name] diff --git a/internal/checker/checker.go b/internal/checker/checker.go index 164e6b053..0d3a908b8 100644 --- a/internal/checker/checker.go +++ b/internal/checker/checker.go @@ -162,14 +162,14 @@ func (c *Checker) Structure(ctx context.Context, p *progress.Counter, errChan ch c.blobRefs.M.Insert(h) c.blobRefs.Unlock() return blobReferenced - }, func(treeID restic.ID, err error, tree *data.Tree) error { + }, func(treeID restic.ID, err error, nodes data.TreeNodeIterator) error { debug.Log("check tree %v (err %v)", treeID, err) var errs []error if err != nil { errs = append(errs, err) } else { - errs = c.checkTree(treeID, tree) + errs = c.checkTree(treeID, nodes) } if len(errs) == 0 { return nil @@ -193,10 +193,15 @@ func (c *Checker) Structure(ctx context.Context, p *progress.Counter, errChan ch } } -func (c *Checker) checkTree(id restic.ID, tree *data.Tree) (errs []error) { +func (c *Checker) checkTree(id restic.ID, tree data.TreeNodeIterator) (errs []error) { debug.Log("checking tree %v", id) - for _, node := range tree.Nodes { + for item := range tree { + if item.Error != nil { + errs = append(errs, &Error{TreeID: id, Err: errors.Errorf("failed to decode tree %v: %w", id, item.Error)}) + break + } + node := item.Node switch node.Type { case data.NodeTypeFile: if node.Content == nil { diff --git a/internal/checker/checker_test.go b/internal/checker/checker_test.go index ea20b7302..106ffd6b3 100644 --- a/internal/checker/checker_test.go +++ b/internal/checker/checker_test.go @@ -522,15 +522,12 @@ func TestCheckerBlobTypeConfusion(t *testing.T) { Size: 42, Content: restic.IDs{restic.TestParseID("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef")}, } - damagedTree := &data.Tree{ - Nodes: []*data.Node{damagedNode}, - } + damagedNodes := []*data.Node{damagedNode} var id restic.ID test.OK(t, repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { - var err error - id, err = data.SaveTree(ctx, uploader, damagedTree) - return err + id = data.TestSaveNodes(t, ctx, uploader, damagedNodes) + return nil })) buf, err := repo.LoadBlob(ctx, restic.TreeBlob, id, nil) @@ -556,15 +553,12 @@ func TestCheckerBlobTypeConfusion(t *testing.T) { Subtree: &id, } - rootTree := &data.Tree{ - Nodes: []*data.Node{malNode, dirNode}, - } + rootNodes := []*data.Node{malNode, dirNode} var rootID restic.ID test.OK(t, repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { - var err error - rootID, err = data.SaveTree(ctx, uploader, rootTree) - return err + rootID = data.TestSaveNodes(t, ctx, uploader, rootNodes) + return nil })) snapshot, err := data.NewSnapshot([]string{"/damaged"}, []string{"test"}, "foo", time.Now()) diff --git a/internal/data/find.go b/internal/data/find.go index a009c6496..5fadcba93 100644 --- a/internal/data/find.go +++ b/internal/data/find.go @@ -22,16 +22,19 @@ func FindUsedBlobs(ctx context.Context, repo restic.Loader, treeIDs restic.IDs, blobs.Insert(h) lock.Unlock() return blobReferenced - }, func(_ restic.ID, err error, tree *Tree) error { + }, func(_ restic.ID, err error, nodes TreeNodeIterator) error { if err != nil { return err } - for _, node := range tree.Nodes { + for item := range nodes { + if item.Error != nil { + return item.Error + } lock.Lock() - switch node.Type { + switch item.Node.Type { case NodeTypeFile: - for _, blob := range node.Content { + for _, blob := range item.Node.Content { blobs.Insert(restic.BlobHandle{ID: blob, Type: restic.DataBlob}) } } diff --git a/internal/data/testing.go b/internal/data/testing.go index 524dd4eb7..52abc8864 100644 --- a/internal/data/testing.go +++ b/internal/data/testing.go @@ -5,12 +5,15 @@ import ( "fmt" "io" "math/rand" + "slices" + "strings" "testing" "time" "github.com/restic/chunker" "github.com/restic/restic/internal/restic" "github.com/restic/restic/internal/test" + rtest "github.com/restic/restic/internal/test" ) // fakeFile returns a reader which yields deterministic pseudo-random data. @@ -72,9 +75,8 @@ func (fs *fakeFileSystem) saveTree(ctx context.Context, uploader restic.BlobSave rnd := rand.NewSource(seed) numNodes := int(rnd.Int63() % maxNodes) - var tree Tree + var nodes []*Node for i := 0; i < numNodes; i++ { - // randomly select the type of the node, either tree (p = 1/4) or file (p = 3/4). if depth > 1 && rnd.Int63()%4 == 0 { treeSeed := rnd.Int63() % maxSeed @@ -87,7 +89,7 @@ func (fs *fakeFileSystem) saveTree(ctx context.Context, uploader restic.BlobSave Subtree: &id, } - tree.Nodes = append(tree.Nodes, node) + nodes = append(nodes, node) continue } @@ -102,14 +104,24 @@ func (fs *fakeFileSystem) saveTree(ctx context.Context, uploader restic.BlobSave } node.Content = fs.saveFile(ctx, uploader, fakeFile(fileSeed, fileSize)) - tree.Nodes = append(tree.Nodes, node) + nodes = append(nodes, node) } - tree.Sort() - id, err := SaveTree(ctx, uploader, &tree) - if err != nil { - fs.t.Fatalf("SaveTree returned error: %v", err) + return TestSaveNodes(fs.t, ctx, uploader, nodes) +} + +//nolint:revive // as this is a test helper, t should go first +func TestSaveNodes(t testing.TB, ctx context.Context, uploader restic.BlobSaver, nodes []*Node) restic.ID { + slices.SortFunc(nodes, func(a, b *Node) int { + return strings.Compare(a.Name, b.Name) + }) + treeWriter := NewTreeWriter(uploader) + for _, node := range nodes { + err := treeWriter.AddNode(node) + rtest.OK(t, err) } + id, err := treeWriter.Finalize(ctx) + rtest.OK(t, err) return id } diff --git a/internal/data/tree.go b/internal/data/tree.go index 763a1aa83..026ad9cdb 100644 --- a/internal/data/tree.go +++ b/internal/data/tree.go @@ -5,124 +5,164 @@ import ( "context" "encoding/json" "fmt" + + "io" + "iter" "path" - "sort" "strings" "github.com/restic/restic/internal/errors" "github.com/restic/restic/internal/restic" - - "github.com/restic/restic/internal/debug" ) +// For documentation purposes only: +// // Tree is an ordered list of nodes. +// type Tree struct { +// Nodes []*Node `json:"nodes"` +// } + var ErrTreeNotOrdered = errors.New("nodes are not ordered or duplicate") -// Tree is an ordered list of nodes. -type Tree struct { - Nodes []*Node `json:"nodes"` +type treeIterator struct { + dec json.Decoder + started bool } -// NewTree creates a new tree object with the given initial capacity. -func NewTree(capacity int) *Tree { - return &Tree{ - Nodes: make([]*Node, 0, capacity), - } +type NodeOrError struct { + Node *Node + Error error } -func (t *Tree) String() string { - return fmt.Sprintf("Tree<%d nodes>", len(t.Nodes)) -} +type TreeNodeIterator = iter.Seq[NodeOrError] -// Equals returns true if t and other have exactly the same nodes. -func (t *Tree) Equals(other *Tree) bool { - if len(t.Nodes) != len(other.Nodes) { - debug.Log("tree.Equals(): trees have different number of nodes") - return false +func NewTreeNodeIterator(rd io.Reader) (TreeNodeIterator, error) { + t := &treeIterator{ + dec: *json.NewDecoder(rd), } - for i := 0; i < len(t.Nodes); i++ { - if !t.Nodes[i].Equals(*other.Nodes[i]) { - debug.Log("tree.Equals(): node %d is different:", i) - debug.Log(" %#v", t.Nodes[i]) - debug.Log(" %#v", other.Nodes[i]) - return false + err := t.init() + if err != nil { + return nil, err + } + + return func(yield func(NodeOrError) bool) { + if t.started { + panic("tree iterator is single use only") } - } - - return true + t.started = true + for { + n, err := t.next() + if err != nil && errors.Is(err, io.EOF) { + return + } + if !yield(NodeOrError{Node: n, Error: err}) { + return + } + // errors are final + if err != nil { + return + } + } + }, nil } -// Insert adds a new node at the correct place in the tree. -func (t *Tree) Insert(node *Node) error { - pos, found := t.find(node.Name) - if found != nil { - return errors.Errorf("node %q already present", node.Name) +func (t *treeIterator) init() error { + // `{"nodes":[` `]}` + + if err := t.assertToken(json.Delim('{')); err != nil { + return err + } + if err := t.assertToken("nodes"); err != nil { + return err + } + if err := t.assertToken(json.Delim('[')); err != nil { + return err } - - // https://github.com/golang/go/wiki/SliceTricks - t.Nodes = append(t.Nodes, nil) - copy(t.Nodes[pos+1:], t.Nodes[pos:]) - t.Nodes[pos] = node - return nil } -func (t *Tree) find(name string) (int, *Node) { - pos := sort.Search(len(t.Nodes), func(i int) bool { - return t.Nodes[i].Name >= name - }) - - if pos < len(t.Nodes) && t.Nodes[pos].Name == name { - return pos, t.Nodes[pos] - } - - return pos, nil -} - -// Find returns a node with the given name, or nil if none could be found. -func (t *Tree) Find(name string) *Node { - if t == nil { - return nil - } - - _, node := t.find(name) - return node -} - -// Sort sorts the nodes by name. -func (t *Tree) Sort() { - list := Nodes(t.Nodes) - sort.Sort(list) - t.Nodes = list -} - -// Subtrees returns a slice of all subtree IDs of the tree. -func (t *Tree) Subtrees() (trees restic.IDs) { - for _, node := range t.Nodes { - if node.Type == NodeTypeDir && node.Subtree != nil { - trees = append(trees, *node.Subtree) +func (t *treeIterator) next() (*Node, error) { + if t.dec.More() { + var n Node + err := t.dec.Decode(&n) + if err != nil { + return nil, err } + return &n, nil } - return trees + if err := t.assertToken(json.Delim(']')); err != nil { + return nil, err + } + if err := t.assertToken(json.Delim('}')); err != nil { + return nil, err + } + return nil, io.EOF } -// LoadTree loads a tree from the repository. -func LoadTree(ctx context.Context, r restic.BlobLoader, id restic.ID) (*Tree, error) { - debug.Log("load tree %v", id) +func (t *treeIterator) assertToken(token json.Token) error { + to, err := t.dec.Token() + if err != nil { + return err + } + if to != token { + return errors.Errorf("error decoding tree: expected %v, got %v", token, to) + } + return nil +} - buf, err := r.LoadBlob(ctx, restic.TreeBlob, id, nil) +func LoadTree(ctx context.Context, loader restic.BlobLoader, content restic.ID) (TreeNodeIterator, error) { + rd, err := loader.LoadBlob(ctx, restic.TreeBlob, content, nil) if err != nil { return nil, err } + return NewTreeNodeIterator(bytes.NewReader(rd)) +} - t := &Tree{} - err = json.Unmarshal(buf, t) - if err != nil { - return nil, err +type TreeFinder struct { + next func() (NodeOrError, bool) + stop func() + current *Node +} + +func NewTreeFinder(tree TreeNodeIterator) *TreeFinder { + if tree == nil { + return &TreeFinder{stop: func() {}} + } + next, stop := iter.Pull(tree) + return &TreeFinder{next: next, stop: stop} +} + +// Find finds the node with the given name. If the node is not found, it returns nil. +// If Find was called before, the new name must be strictly greater than the last name. +func (t *TreeFinder) Find(name string) (*Node, error) { + if t.next == nil { + return nil, nil + } + // loop until `t.current.Name` is >= name + for t.current == nil || t.current.Name < name { + current, ok := t.next() + if current.Error != nil { + return nil, current.Error + } + if !ok { + return nil, nil + } + t.current = current.Node } - return t, nil + if t.current.Name == name { + // forget the current node to free memory as early as possible + current := t.current + t.current = nil + return current, nil + } + // we have already passed the name + return nil, nil +} + +func (t *TreeFinder) Close() { + t.stop() } type TreeWriter struct { @@ -148,10 +188,13 @@ func (t *TreeWriter) Finalize(ctx context.Context) (restic.ID, error) { return id, err } -func SaveTree(ctx context.Context, saver restic.BlobSaver, t *Tree) (restic.ID, error) { +func SaveTree(ctx context.Context, saver restic.BlobSaver, nodes TreeNodeIterator) (restic.ID, error) { treeWriter := NewTreeWriter(saver) - for _, node := range t.Nodes { - err := treeWriter.AddNode(node) + for item := range nodes { + if item.Error != nil { + return restic.ID{}, item.Error + } + err := treeWriter.AddNode(item.Node) if err != nil { return restic.ID{}, err } @@ -214,7 +257,12 @@ func FindTreeDirectory(ctx context.Context, repo restic.BlobLoader, id *restic.I if err != nil { return nil, fmt.Errorf("path %s: %w", subfolder, err) } - node := tree.Find(name) + finder := NewTreeFinder(tree) + node, err := finder.Find(name) + finder.Close() + if err != nil { + return nil, fmt.Errorf("path %s: %w", subfolder, err) + } if node == nil { return nil, fmt.Errorf("path %s: not found", subfolder) } diff --git a/internal/data/tree_stream.go b/internal/data/tree_stream.go index 042a55f7e..1f832a731 100644 --- a/internal/data/tree_stream.go +++ b/internal/data/tree_stream.go @@ -23,12 +23,36 @@ type trackedID struct { rootIdx int } +// subtreesCollector wraps a TreeNodeIterator and returns a new iterator that collects the subtrees. +func subtreesCollector(tree TreeNodeIterator) (TreeNodeIterator, func() restic.IDs) { + subtrees := restic.IDs{} + isComplete := false + + return func(yield func(NodeOrError) bool) { + for item := range tree { + if !yield(item) { + return + } + // be defensive and check for nil subtree as this code is also used by the checker + if item.Node != nil && item.Node.Type == NodeTypeDir && item.Node.Subtree != nil { + subtrees = append(subtrees, *item.Node.Subtree) + } + } + isComplete = true + }, func() restic.IDs { + if !isComplete { + panic("tree was not read completely") + } + return subtrees + } +} + // loadTreeWorker loads trees from repo and sends them to out. func loadTreeWorker( ctx context.Context, repo restic.Loader, in <-chan trackedID, - process func(id restic.ID, error error, tree *Tree) error, + process func(id restic.ID, error error, nodes TreeNodeIterator) error, out chan<- trackedTreeItem, ) error { @@ -39,14 +63,21 @@ func loadTreeWorker( } debug.Log("load tree %v (%v) returned err: %v", tree, treeID, err) + // wrap iterator to collect subtrees while `process` iterates over `tree` + var collectSubtrees func() restic.IDs + if tree != nil { + tree, collectSubtrees = subtreesCollector(tree) + } + err = process(treeID.ID, err, tree) if err != nil { return err } + // assume that the number of subtrees is within reasonable limits, such that the memory usage is not a problem var subtrees restic.IDs - if tree != nil { - subtrees = tree.Subtrees() + if collectSubtrees != nil { + subtrees = collectSubtrees() } job := trackedTreeItem{ID: treeID.ID, Subtrees: subtrees, rootIdx: treeID.rootIdx} @@ -159,7 +190,7 @@ func StreamTrees( trees restic.IDs, p *progress.Counter, skip func(tree restic.ID) bool, - process func(id restic.ID, error error, tree *Tree) error, + process func(id restic.ID, error error, nodes TreeNodeIterator) error, ) error { loaderChan := make(chan trackedID) hugeTreeChan := make(chan trackedID, 10) diff --git a/internal/data/tree_test.go b/internal/data/tree_test.go index 7433d6b2e..56752850e 100644 --- a/internal/data/tree_test.go +++ b/internal/data/tree_test.go @@ -6,6 +6,7 @@ import ( "errors" "os" "path/filepath" + "slices" "strconv" "testing" @@ -105,37 +106,45 @@ func TestNodeComparison(t *testing.T) { func TestEmptyLoadTree(t *testing.T) { repo := repository.TestRepository(t) - tree := data.NewTree(0) + nodes := []*data.Node{} var id restic.ID rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { - var err error // save tree - id, err = data.SaveTree(ctx, uploader, tree) - return err + id = data.TestSaveNodes(t, ctx, uploader, nodes) + return nil })) // load tree again - tree2, err := data.LoadTree(context.TODO(), repo, id) + it, err := data.LoadTree(context.TODO(), repo, id) rtest.OK(t, err) + nodes2 := []*data.Node{} + for item := range it { + rtest.OK(t, item.Error) + nodes2 = append(nodes2, item.Node) + } - rtest.Assert(t, tree.Equals(tree2), - "trees are not equal: want %v, got %v", - tree, tree2) + rtest.Assert(t, slices.Equal(nodes, nodes2), + "tree nodes are not equal: want %v, got %v", + nodes, nodes2) +} + +// Basic type for comparing the serialization of the tree +type Tree struct { + Nodes []*data.Node `json:"nodes"` } func TestTreeEqualSerialization(t *testing.T) { files := []string{"node.go", "tree.go", "tree_test.go"} for i := 1; i <= len(files); i++ { - tree := data.NewTree(i) + tree := Tree{Nodes: make([]*data.Node, 0, i)} builder := data.NewTreeJSONBuilder() for _, fn := range files[:i] { node := nodeForFile(t, fn) - rtest.OK(t, tree.Insert(node)) + tree.Nodes = append(tree.Nodes, node) rtest.OK(t, builder.AddNode(node)) - rtest.Assert(t, tree.Insert(node) != nil, "no error on duplicate node") rtest.Assert(t, builder.AddNode(node) != nil, "no error on duplicate node") rtest.Assert(t, errors.Is(builder.AddNode(node), data.ErrTreeNotOrdered), "wrong error returned") } @@ -144,11 +153,11 @@ func TestTreeEqualSerialization(t *testing.T) { treeBytes = append(treeBytes, '\n') rtest.OK(t, err) - stiBytes, err := builder.Finalize() + buf, err := builder.Finalize() rtest.OK(t, err) // compare serialization of an individual node and the SaveTreeIterator - rtest.Equals(t, treeBytes, stiBytes) + rtest.Equals(t, treeBytes, buf) } } @@ -165,11 +174,12 @@ func BenchmarkBuildTree(b *testing.B) { b.ReportAllocs() for i := 0; i < b.N; i++ { - t := data.NewTree(size) - + t := data.NewTreeJSONBuilder() for i := range nodes { - _ = t.Insert(&nodes[i]) + rtest.OK(b, t.AddNode(&nodes[i])) } + _, err := t.Finalize() + rtest.OK(b, err) } } @@ -186,8 +196,11 @@ func testLoadTree(t *testing.T, version uint) { repo, _, _ := repository.TestRepositoryWithVersion(t, version) sn := archiver.TestSnapshot(t, repo, rtest.BenchArchiveDirectory, nil) - _, err := data.LoadTree(context.TODO(), repo, *sn.Tree) + nodes, err := data.LoadTree(context.TODO(), repo, *sn.Tree) rtest.OK(t, err) + for item := range nodes { + rtest.OK(t, item.Error) + } } func BenchmarkLoadTree(t *testing.B) { diff --git a/internal/dump/common.go b/internal/dump/common.go index 1a6a9c2b2..2c0edf67a 100644 --- a/internal/dump/common.go +++ b/internal/dump/common.go @@ -30,7 +30,7 @@ func New(format string, repo restic.Loader, w io.Writer) *Dumper { } } -func (d *Dumper) DumpTree(ctx context.Context, tree *data.Tree, rootPath string) error { +func (d *Dumper) DumpTree(ctx context.Context, tree data.TreeNodeIterator, rootPath string) error { wg, ctx := errgroup.WithContext(ctx) // ch is buffered to deal with variable download/write speeds. @@ -52,10 +52,14 @@ func (d *Dumper) DumpTree(ctx context.Context, tree *data.Tree, rootPath string) return wg.Wait() } -func sendTrees(ctx context.Context, repo restic.BlobLoader, tree *data.Tree, rootPath string, ch chan *data.Node) error { +func sendTrees(ctx context.Context, repo restic.BlobLoader, nodes data.TreeNodeIterator, rootPath string, ch chan *data.Node) error { defer close(ch) - for _, node := range tree.Nodes { + for item := range nodes { + if item.Error != nil { + return item.Error + } + node := item.Node node.Path = path.Join(rootPath, node.Name) if err := sendNodes(ctx, repo, node, ch); err != nil { return err diff --git a/internal/fuse/dir.go b/internal/fuse/dir.go index 28f7ba9a7..df558ac1f 100644 --- a/internal/fuse/dir.go +++ b/internal/fuse/dir.go @@ -7,6 +7,7 @@ import ( "errors" "os" "path/filepath" + "slices" "sync" "syscall" @@ -65,13 +66,13 @@ func unwrapCtxCanceled(err error) error { // replaceSpecialNodes replaces nodes with name "." and "/" by their contents. // Otherwise, the node is returned. -func replaceSpecialNodes(ctx context.Context, repo restic.BlobLoader, node *data.Node) ([]*data.Node, error) { +func replaceSpecialNodes(ctx context.Context, repo restic.BlobLoader, node *data.Node) (data.TreeNodeIterator, error) { if node.Type != data.NodeTypeDir || node.Subtree == nil { - return []*data.Node{node}, nil + return slices.Values([]data.NodeOrError{{Node: node}}), nil } if node.Name != "." && node.Name != "/" { - return []*data.Node{node}, nil + return slices.Values([]data.NodeOrError{{Node: node}}), nil } tree, err := data.LoadTree(ctx, repo, *node.Subtree) @@ -79,7 +80,7 @@ func replaceSpecialNodes(ctx context.Context, repo restic.BlobLoader, node *data return nil, unwrapCtxCanceled(err) } - return tree.Nodes, nil + return tree, nil } func newDirFromSnapshot(root *Root, forget forgetFn, inode uint64, snapshot *data.Snapshot) (*dir, error) { @@ -115,18 +116,25 @@ func (d *dir) open(ctx context.Context) error { return unwrapCtxCanceled(err) } items := make(map[string]*data.Node) - for _, n := range tree.Nodes { + for item := range tree { + if item.Error != nil { + return unwrapCtxCanceled(item.Error) + } if ctx.Err() != nil { return ctx.Err() } + n := item.Node nodes, err := replaceSpecialNodes(ctx, d.root.repo, n) if err != nil { debug.Log(" replaceSpecialNodes(%v) failed: %v", n, err) return err } - for _, node := range nodes { - items[cleanupNodeName(node.Name)] = node + for item := range nodes { + if item.Error != nil { + return unwrapCtxCanceled(item.Error) + } + items[cleanupNodeName(item.Node.Name)] = item.Node } } d.items = items diff --git a/internal/fuse/fuse_test.go b/internal/fuse/fuse_test.go index 05e6c5340..c82252458 100644 --- a/internal/fuse/fuse_test.go +++ b/internal/fuse/fuse_test.go @@ -59,7 +59,7 @@ func loadFirstSnapshot(t testing.TB, repo restic.ListerLoaderUnpacked) *data.Sna return sn } -func loadTree(t testing.TB, repo restic.Loader, id restic.ID) *data.Tree { +func loadTree(t testing.TB, repo restic.Loader, id restic.ID) data.TreeNodeIterator { tree, err := data.LoadTree(context.TODO(), repo, id) rtest.OK(t, err) return tree @@ -79,8 +79,9 @@ func TestFuseFile(t *testing.T) { tree := loadTree(t, repo, *sn.Tree) var content restic.IDs - for _, node := range tree.Nodes { - content = append(content, node.Content...) + for item := range tree { + rtest.OK(t, item.Error) + content = append(content, item.Node.Content...) } t.Logf("tree loaded, content: %v", content) diff --git a/internal/restorer/restorer.go b/internal/restorer/restorer.go index 22ab196a5..8454591e4 100644 --- a/internal/restorer/restorer.go +++ b/internal/restorer/restorer.go @@ -163,15 +163,18 @@ func (res *Restorer) traverseTreeInner(ctx context.Context, target, location str } if res.opts.Delete { - filenames = make([]string, 0, len(tree.Nodes)) + filenames = make([]string, 0) } - for i, node := range tree.Nodes { + for item := range tree { + if item.Error != nil { + debug.Log("error iterating tree %v: %v", treeID, item.Error) + return nil, hasRestored, res.sanitizeError(location, item.Error) + } + node := item.Node if ctx.Err() != nil { return nil, hasRestored, ctx.Err() } - // allow GC of tree node - tree.Nodes[i] = nil if res.opts.Delete { // just track all files included in the tree node to simplify the control flow. // tracking too many files does not matter except for a slightly elevated memory usage diff --git a/internal/restorer/restorer_test.go b/internal/restorer/restorer_test.go index 29f15d343..337a99918 100644 --- a/internal/restorer/restorer_test.go +++ b/internal/restorer/restorer_test.go @@ -79,7 +79,7 @@ func saveDir(t testing.TB, repo restic.BlobSaver, nodes map[string]Node, inode u ctx, cancel := context.WithCancel(context.Background()) defer cancel() - tree := &data.Tree{} + tree := make([]*data.Node, 0, len(nodes)) for name, n := range nodes { inode++ switch node := n.(type) { @@ -107,7 +107,7 @@ func saveDir(t testing.TB, repo restic.BlobSaver, nodes map[string]Node, inode u if mode == 0 { mode = 0644 } - err := tree.Insert(&data.Node{ + tree = append(tree, &data.Node{ Type: data.NodeTypeFile, Mode: mode, ModTime: node.ModTime, @@ -120,9 +120,8 @@ func saveDir(t testing.TB, repo restic.BlobSaver, nodes map[string]Node, inode u Links: lc, GenericAttributes: getGenericAttributes(node.attributes, false), }) - rtest.OK(t, err) case Symlink: - err := tree.Insert(&data.Node{ + tree = append(tree, &data.Node{ Type: data.NodeTypeSymlink, Mode: os.ModeSymlink | 0o777, ModTime: node.ModTime, @@ -133,7 +132,6 @@ func saveDir(t testing.TB, repo restic.BlobSaver, nodes map[string]Node, inode u Inode: inode, Links: 1, }) - rtest.OK(t, err) case Dir: id := saveDir(t, repo, node.Nodes, inode, getGenericAttributes) @@ -142,7 +140,7 @@ func saveDir(t testing.TB, repo restic.BlobSaver, nodes map[string]Node, inode u mode = 0755 } - err := tree.Insert(&data.Node{ + tree = append(tree, &data.Node{ Type: data.NodeTypeDir, Mode: mode, ModTime: node.ModTime, @@ -152,19 +150,12 @@ func saveDir(t testing.TB, repo restic.BlobSaver, nodes map[string]Node, inode u Subtree: &id, GenericAttributes: getGenericAttributes(node.attributes, false), }) - rtest.OK(t, err) default: t.Fatalf("unknown node type %T", node) } } - tree.Sort() - id, err := data.SaveTree(ctx, repo, tree) - if err != nil { - t.Fatal(err) - } - - return id + return data.TestSaveNodes(t, ctx, repo, tree) } func saveSnapshot(t testing.TB, repo restic.Repository, snapshot Snapshot, getGenericAttributes func(attr *FileAttributes, isDir bool) (genericAttributes map[data.GenericAttributeType]json.RawMessage)) (*data.Snapshot, restic.ID) { diff --git a/internal/walker/rewriter.go b/internal/walker/rewriter.go index b3445e438..f53577e5a 100644 --- a/internal/walker/rewriter.go +++ b/internal/walker/rewriter.go @@ -11,7 +11,7 @@ import ( ) type NodeRewriteFunc func(node *data.Node, path string) *data.Node -type FailedTreeRewriteFunc func(nodeID restic.ID, path string, err error) (*data.Tree, error) +type FailedTreeRewriteFunc func(nodeID restic.ID, path string, err error) (data.TreeNodeIterator, error) type QueryRewrittenSizeFunc func() SnapshotSize type SnapshotSize struct { @@ -52,7 +52,7 @@ func NewTreeRewriter(opts RewriteOpts) *TreeRewriter { } if rw.opts.RewriteFailedTree == nil { // fail with error by default - rw.opts.RewriteFailedTree = func(_ restic.ID, _ string, err error) (*data.Tree, error) { + rw.opts.RewriteFailedTree = func(_ restic.ID, _ string, err error) (data.TreeNodeIterator, error) { return nil, err } } @@ -117,15 +117,26 @@ func (t *TreeRewriter) RewriteTree(ctx context.Context, loader restic.BlobLoader if nodeID != testID { return restic.ID{}, fmt.Errorf("cannot encode tree at %q without losing information", nodepath) } + + // reload the tree to get a new iterator + curTree, err = data.LoadTree(ctx, loader, nodeID) + if err != nil { + // shouldn't fail as the first load was successful + return restic.ID{}, fmt.Errorf("failed to reload tree %v: %w", nodeID, err) + } } debug.Log("filterTree: %s, nodeId: %s\n", nodepath, nodeID.Str()) tb := data.NewTreeWriter(saver) - for _, node := range curTree.Nodes { + for item := range curTree { if ctx.Err() != nil { return restic.ID{}, ctx.Err() } + if item.Error != nil { + return restic.ID{}, err + } + node := item.Node path := path.Join(nodepath, node.Name) node = t.opts.RewriteNode(node, path) diff --git a/internal/walker/rewriter_test.go b/internal/walker/rewriter_test.go index 9290a62d5..b26449a58 100644 --- a/internal/walker/rewriter_test.go +++ b/internal/walker/rewriter_test.go @@ -2,6 +2,7 @@ package walker import ( "context" + "slices" "testing" "github.com/pkg/errors" @@ -405,16 +406,15 @@ func TestRewriterTreeLoadError(t *testing.T) { t.Fatal("missing error on unloadable tree") } - replacementTree := &data.Tree{Nodes: []*data.Node{{Name: "replacement", Type: data.NodeTypeFile, Size: 42}}} - replacementID, err := data.SaveTree(ctx, tm, replacementTree) - test.OK(t, err) + replacementNode := &data.Node{Name: "replacement", Type: data.NodeTypeFile, Size: 42} + replacementID := data.TestSaveNodes(t, ctx, tm, []*data.Node{replacementNode}) rewriter = NewTreeRewriter(RewriteOpts{ - RewriteFailedTree: func(nodeID restic.ID, path string, err error) (*data.Tree, error) { + RewriteFailedTree: func(nodeID restic.ID, path string, err error) (data.TreeNodeIterator, error) { if nodeID != id || path != "/" { t.Fail() } - return replacementTree, nil + return slices.Values([]data.NodeOrError{{Node: replacementNode}}), nil }, }) newRoot, err := rewriter.RewriteTree(ctx, tm, tm, "/", id) diff --git a/internal/walker/walker.go b/internal/walker/walker.go index 67c4a9d03..8347c28c4 100644 --- a/internal/walker/walker.go +++ b/internal/walker/walker.go @@ -3,7 +3,6 @@ package walker import ( "context" "path" - "sort" "github.com/pkg/errors" @@ -52,15 +51,15 @@ func Walk(ctx context.Context, repo restic.BlobLoader, root restic.ID, visitor W // walk recursively traverses the tree, ignoring subtrees when the ID of the // subtree is in ignoreTrees. If err is nil and ignore is true, the subtree ID // will be added to ignoreTrees by walk. -func walk(ctx context.Context, repo restic.BlobLoader, prefix string, parentTreeID restic.ID, tree *data.Tree, visitor WalkVisitor) (err error) { - sort.Slice(tree.Nodes, func(i, j int) bool { - return tree.Nodes[i].Name < tree.Nodes[j].Name - }) - - for _, node := range tree.Nodes { +func walk(ctx context.Context, repo restic.BlobLoader, prefix string, parentTreeID restic.ID, tree data.TreeNodeIterator, visitor WalkVisitor) (err error) { + for item := range tree { + if item.Error != nil { + return item.Error + } if ctx.Err() != nil { return ctx.Err() } + node := item.Node p := path.Join(prefix, node.Name)