data: replace Tree with TreeNodeIterator

The TreeNodeIterator decodes nodes while iterating over a tree blob.
This should reduce peak memory usage as now only the serialized tree
blob and a single node have to alive at the same time. Using the
iterator has implications for the error handling however. Now it is
necessary that all loops that iterate through a tree check for errors
before using the node returned by the iterator.

The other change is that it is no longer possible to iterate over a tree
multiple times. Instead it must be loaded a second time. This only
affects the tree rewriting code.
This commit is contained in:
Michael Eischer 2025-11-22 19:55:50 +01:00
parent 1e183509d4
commit 350f29d921
23 changed files with 394 additions and 233 deletions

View file

@ -278,7 +278,7 @@ func copyTree(ctx context.Context, srcRepo restic.Repository, dstRepo restic.Rep
visited := visitedTrees.Has(handle)
visitedTrees.Insert(handle)
return visited
}, func(treeID restic.ID, err error, tree *data.Tree) error {
}, func(treeID restic.ID, err error, nodes data.TreeNodeIterator) error {
if err != nil {
return fmt.Errorf("LoadTree(%v) returned error %v", treeID.Str(), err)
}
@ -286,10 +286,13 @@ func copyTree(ctx context.Context, srcRepo restic.Repository, dstRepo restic.Rep
// copy raw tree bytes to avoid problems if the serialization changes
enqueue(restic.BlobHandle{ID: treeID, Type: restic.TreeBlob})
for _, entry := range tree.Nodes {
for item := range nodes {
if item.Error != nil {
return item.Error
}
// Recursion into directories is handled by StreamTrees
// Copy the blobs for this file.
for _, blobID := range entry.Content {
for _, blobID := range item.Node.Content {
enqueue(restic.BlobHandle{Type: restic.DataBlob, ID: blobID})
}
}

View file

@ -184,11 +184,14 @@ func (c *Comparer) printDir(ctx context.Context, mode string, stats *DiffStat, b
return err
}
for _, node := range tree.Nodes {
for item := range tree {
if item.Error != nil {
return item.Error
}
if ctx.Err() != nil {
return ctx.Err()
}
node := item.Node
name := path.Join(prefix, node.Name)
if node.Type == data.NodeTypeDir {
name += "/"
@ -215,11 +218,15 @@ func (c *Comparer) collectDir(ctx context.Context, blobs restic.AssociatedBlobSe
return err
}
for _, node := range tree.Nodes {
for item := range tree {
if item.Error != nil {
return item.Error
}
if ctx.Err() != nil {
return ctx.Err()
}
node := item.Node
addBlobs(blobs, node)
if node.Type == data.NodeTypeDir {
@ -233,16 +240,24 @@ func (c *Comparer) collectDir(ctx context.Context, blobs restic.AssociatedBlobSe
return ctx.Err()
}
func uniqueNodeNames(tree1, tree2 *data.Tree) (tree1Nodes, tree2Nodes map[string]*data.Node, uniqueNames []string) {
func uniqueNodeNames(tree1, tree2 data.TreeNodeIterator) (tree1Nodes, tree2Nodes map[string]*data.Node, uniqueNames []string, err error) {
names := make(map[string]struct{})
tree1Nodes = make(map[string]*data.Node)
for _, node := range tree1.Nodes {
for item := range tree1 {
if item.Error != nil {
return nil, nil, nil, item.Error
}
node := item.Node
tree1Nodes[node.Name] = node
names[node.Name] = struct{}{}
}
tree2Nodes = make(map[string]*data.Node)
for _, node := range tree2.Nodes {
for item := range tree2 {
if item.Error != nil {
return nil, nil, nil, item.Error
}
node := item.Node
tree2Nodes[node.Name] = node
names[node.Name] = struct{}{}
}
@ -253,7 +268,7 @@ func uniqueNodeNames(tree1, tree2 *data.Tree) (tree1Nodes, tree2Nodes map[string
}
sort.Strings(uniqueNames)
return tree1Nodes, tree2Nodes, uniqueNames
return tree1Nodes, tree2Nodes, uniqueNames, nil
}
func (c *Comparer) diffTree(ctx context.Context, stats *DiffStatsContainer, prefix string, id1, id2 restic.ID) error {
@ -268,7 +283,10 @@ func (c *Comparer) diffTree(ctx context.Context, stats *DiffStatsContainer, pref
return err
}
tree1Nodes, tree2Nodes, names := uniqueNodeNames(tree1, tree2)
tree1Nodes, tree2Nodes, names, err := uniqueNodeNames(tree1, tree2)
if err != nil {
return err
}
for _, name := range names {
if ctx.Err() != nil {

View file

@ -80,7 +80,7 @@ func splitPath(p string) []string {
return append(s, f)
}
func printFromTree(ctx context.Context, tree *data.Tree, repo restic.BlobLoader, prefix string, pathComponents []string, d *dump.Dumper, canWriteArchiveFunc func() error) error {
func printFromTree(ctx context.Context, tree data.TreeNodeIterator, repo restic.BlobLoader, prefix string, pathComponents []string, d *dump.Dumper, canWriteArchiveFunc func() error) error {
// If we print / we need to assume that there are multiple nodes at that
// level in the tree.
if pathComponents[0] == "" {
@ -92,11 +92,14 @@ func printFromTree(ctx context.Context, tree *data.Tree, repo restic.BlobLoader,
item := filepath.Join(prefix, pathComponents[0])
l := len(pathComponents)
for _, node := range tree.Nodes {
for it := range tree {
if it.Error != nil {
return it.Error
}
if ctx.Err() != nil {
return ctx.Err()
}
node := it.Node
// If dumping something in the highest level it will just take the
// first item it finds and dump that according to the switch case below.
if node.Name == pathComponents[0] {

View file

@ -97,7 +97,11 @@ func runRecover(ctx context.Context, gopts global.Options, term ui.Terminal) err
continue
}
for _, node := range tree.Nodes {
for item := range tree {
if item.Error != nil {
return item.Error
}
node := item.Node
if node.Type == data.NodeTypeDir && node.Subtree != nil {
trees[*node.Subtree] = true
}

View file

@ -2,6 +2,7 @@ package main
import (
"context"
"slices"
"github.com/restic/restic/internal/data"
"github.com/restic/restic/internal/errors"
@ -130,7 +131,7 @@ func runRepairSnapshots(ctx context.Context, gopts global.Options, opts RepairOp
node.Size = newSize
return node
},
RewriteFailedTree: func(_ restic.ID, path string, _ error) (*data.Tree, error) {
RewriteFailedTree: func(_ restic.ID, path string, _ error) (data.TreeNodeIterator, error) {
if path == "/" {
printer.P(" dir %q: not readable", path)
// remove snapshots with invalid root node
@ -138,7 +139,7 @@ func runRepairSnapshots(ctx context.Context, gopts global.Options, opts RepairOp
}
// If a subtree fails to load, remove it
printer.P(" dir %q: replaced with empty directory", path)
return &data.Tree{}, nil
return slices.Values([]data.NodeOrError{}), nil
},
AllowUnstableSerialization: true,
})

View file

@ -281,7 +281,7 @@ func (arch *Archiver) nodeFromFileInfo(snPath, filename string, meta ToNoder, ig
// loadSubtree tries to load the subtree referenced by node. In case of an error, nil is returned.
// If there is no node to load, then nil is returned without an error.
func (arch *Archiver) loadSubtree(ctx context.Context, node *data.Node) (*data.Tree, error) {
func (arch *Archiver) loadSubtree(ctx context.Context, node *data.Node) (data.TreeNodeIterator, error) {
if node == nil || node.Type != data.NodeTypeDir || node.Subtree == nil {
return nil, nil
}
@ -307,7 +307,7 @@ func (arch *Archiver) wrapLoadTreeError(id restic.ID, err error) error {
// saveDir stores a directory in the repo and returns the node. snPath is the
// path within the current snapshot.
func (arch *Archiver) saveDir(ctx context.Context, snPath string, dir string, meta fs.File, previous *data.Tree, complete fileCompleteFunc) (d futureNode, err error) {
func (arch *Archiver) saveDir(ctx context.Context, snPath string, dir string, meta fs.File, previous data.TreeNodeIterator, complete fileCompleteFunc) (d futureNode, err error) {
debug.Log("%v %v", snPath, dir)
treeNode, names, err := arch.dirToNodeAndEntries(snPath, dir, meta)
@ -317,6 +317,9 @@ func (arch *Archiver) saveDir(ctx context.Context, snPath string, dir string, me
nodes := make([]futureNode, 0, len(names))
finder := data.NewTreeFinder(previous)
defer finder.Close()
for _, name := range names {
// test if context has been cancelled
if ctx.Err() != nil {
@ -325,7 +328,11 @@ func (arch *Archiver) saveDir(ctx context.Context, snPath string, dir string, me
}
pathname := arch.FS.Join(dir, name)
oldNode := previous.Find(name)
oldNode, err := finder.Find(name)
err = arch.error(pathname, err)
if err != nil {
return futureNode{}, err
}
snItem := join(snPath, name)
fn, excluded, err := arch.save(ctx, snItem, pathname, oldNode)
@ -645,7 +652,7 @@ func join(elem ...string) string {
// saveTree stores a Tree in the repo, returned is the tree. snPath is the path
// within the current snapshot.
func (arch *Archiver) saveTree(ctx context.Context, snPath string, atree *tree, previous *data.Tree, complete fileCompleteFunc) (futureNode, int, error) {
func (arch *Archiver) saveTree(ctx context.Context, snPath string, atree *tree, previous data.TreeNodeIterator, complete fileCompleteFunc) (futureNode, int, error) {
var node *data.Node
if snPath != "/" {
@ -663,10 +670,13 @@ func (arch *Archiver) saveTree(ctx context.Context, snPath string, atree *tree,
node = &data.Node{}
}
debug.Log("%v (%v nodes), parent %v", snPath, len(atree.Nodes), previous)
debug.Log("%v (%v nodes)", snPath, len(atree.Nodes))
nodeNames := atree.NodeNames()
nodes := make([]futureNode, 0, len(nodeNames))
finder := data.NewTreeFinder(previous)
defer finder.Close()
// iterate over the nodes of atree in lexicographic (=deterministic) order
for _, name := range nodeNames {
subatree := atree.Nodes[name]
@ -678,7 +688,13 @@ func (arch *Archiver) saveTree(ctx context.Context, snPath string, atree *tree,
// this is a leaf node
if subatree.Leaf() {
fn, excluded, err := arch.save(ctx, join(snPath, name), subatree.Path, previous.Find(name))
pathname := join(snPath, name)
oldNode, err := finder.Find(name)
err = arch.error(pathname, err)
if err != nil {
return futureNode{}, 0, err
}
fn, excluded, err := arch.save(ctx, pathname, subatree.Path, oldNode)
if err != nil {
err = arch.error(subatree.Path, err)
@ -698,7 +714,11 @@ func (arch *Archiver) saveTree(ctx context.Context, snPath string, atree *tree,
snItem := join(snPath, name) + "/"
start := time.Now()
oldNode := previous.Find(name)
oldNode, err := finder.Find(name)
err = arch.error(snItem, err)
if err != nil {
return futureNode{}, 0, err
}
oldSubtree, err := arch.loadSubtree(ctx, oldNode)
if err != nil {
err = arch.error(join(snPath, name), err)
@ -801,7 +821,7 @@ type SnapshotOptions struct {
}
// loadParentTree loads a tree referenced by snapshot id. If id is null, nil is returned.
func (arch *Archiver) loadParentTree(ctx context.Context, sn *data.Snapshot) *data.Tree {
func (arch *Archiver) loadParentTree(ctx context.Context, sn *data.Snapshot) data.TreeNodeIterator {
if sn == nil {
return nil
}

View file

@ -877,11 +877,7 @@ func TestArchiverSaveDir(t *testing.T) {
}
node.Name = targetNodeName
tree := &data.Tree{Nodes: []*data.Node{node}}
treeID, err = data.SaveTree(ctx, uploader, tree)
if err != nil {
t.Fatal(err)
}
treeID = data.TestSaveNodes(t, ctx, uploader, []*data.Node{node})
arch.stopWorkers()
return wg.Wait()
})
@ -2256,19 +2252,15 @@ func snapshot(t testing.TB, repo archiverRepo, fs fs.FS, parent *data.Snapshot,
ParentSnapshot: parent,
}
snapshot, _, _, err := arch.Snapshot(ctx, []string{filename}, sopts)
if err != nil {
t.Fatal(err)
}
rtest.OK(t, err)
tree, err := data.LoadTree(ctx, repo, *snapshot.Tree)
if err != nil {
t.Fatal(err)
}
rtest.OK(t, err)
node := tree.Find(filename)
if node == nil {
t.Fatalf("unable to find node for testfile in snapshot")
}
finder := data.NewTreeFinder(tree)
defer finder.Close()
node, err := finder.Find(filename)
rtest.OK(t, err)
rtest.Assert(t, node != nil, "unable to find node for testfile in snapshot")
return snapshot, node
}

View file

@ -15,6 +15,7 @@ import (
"github.com/restic/restic/internal/debug"
"github.com/restic/restic/internal/fs"
"github.com/restic/restic/internal/restic"
rtest "github.com/restic/restic/internal/test"
)
// TestSnapshot creates a new snapshot of path.
@ -265,19 +266,14 @@ func TestEnsureTree(ctx context.Context, t testing.TB, prefix string, repo resti
t.Helper()
tree, err := data.LoadTree(ctx, repo, treeID)
if err != nil {
t.Fatal(err)
return
}
rtest.OK(t, err)
var nodeNames []string
for _, node := range tree.Nodes {
nodeNames = append(nodeNames, node.Name)
}
debug.Log("%v (%v) %v", prefix, treeID.Str(), nodeNames)
checked := make(map[string]struct{})
for _, node := range tree.Nodes {
for item := range tree {
rtest.OK(t, item.Error)
node := item.Node
nodeNames = append(nodeNames, node.Name)
nodePrefix := path.Join(prefix, node.Name)
entry, ok := dir[node.Name]
@ -316,6 +312,7 @@ func TestEnsureTree(ctx context.Context, t testing.TB, prefix string, repo resti
}
}
}
debug.Log("%v (%v) %v", prefix, treeID.Str(), nodeNames)
for name := range dir {
_, ok := checked[name]

View file

@ -162,14 +162,14 @@ func (c *Checker) Structure(ctx context.Context, p *progress.Counter, errChan ch
c.blobRefs.M.Insert(h)
c.blobRefs.Unlock()
return blobReferenced
}, func(treeID restic.ID, err error, tree *data.Tree) error {
}, func(treeID restic.ID, err error, nodes data.TreeNodeIterator) error {
debug.Log("check tree %v (err %v)", treeID, err)
var errs []error
if err != nil {
errs = append(errs, err)
} else {
errs = c.checkTree(treeID, tree)
errs = c.checkTree(treeID, nodes)
}
if len(errs) == 0 {
return nil
@ -193,10 +193,15 @@ func (c *Checker) Structure(ctx context.Context, p *progress.Counter, errChan ch
}
}
func (c *Checker) checkTree(id restic.ID, tree *data.Tree) (errs []error) {
func (c *Checker) checkTree(id restic.ID, tree data.TreeNodeIterator) (errs []error) {
debug.Log("checking tree %v", id)
for _, node := range tree.Nodes {
for item := range tree {
if item.Error != nil {
errs = append(errs, &Error{TreeID: id, Err: errors.Errorf("failed to decode tree %v: %w", id, item.Error)})
break
}
node := item.Node
switch node.Type {
case data.NodeTypeFile:
if node.Content == nil {

View file

@ -522,15 +522,12 @@ func TestCheckerBlobTypeConfusion(t *testing.T) {
Size: 42,
Content: restic.IDs{restic.TestParseID("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef")},
}
damagedTree := &data.Tree{
Nodes: []*data.Node{damagedNode},
}
damagedNodes := []*data.Node{damagedNode}
var id restic.ID
test.OK(t, repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaverWithAsync) error {
var err error
id, err = data.SaveTree(ctx, uploader, damagedTree)
return err
id = data.TestSaveNodes(t, ctx, uploader, damagedNodes)
return nil
}))
buf, err := repo.LoadBlob(ctx, restic.TreeBlob, id, nil)
@ -556,15 +553,12 @@ func TestCheckerBlobTypeConfusion(t *testing.T) {
Subtree: &id,
}
rootTree := &data.Tree{
Nodes: []*data.Node{malNode, dirNode},
}
rootNodes := []*data.Node{malNode, dirNode}
var rootID restic.ID
test.OK(t, repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaverWithAsync) error {
var err error
rootID, err = data.SaveTree(ctx, uploader, rootTree)
return err
rootID = data.TestSaveNodes(t, ctx, uploader, rootNodes)
return nil
}))
snapshot, err := data.NewSnapshot([]string{"/damaged"}, []string{"test"}, "foo", time.Now())

View file

@ -22,16 +22,19 @@ func FindUsedBlobs(ctx context.Context, repo restic.Loader, treeIDs restic.IDs,
blobs.Insert(h)
lock.Unlock()
return blobReferenced
}, func(_ restic.ID, err error, tree *Tree) error {
}, func(_ restic.ID, err error, nodes TreeNodeIterator) error {
if err != nil {
return err
}
for _, node := range tree.Nodes {
for item := range nodes {
if item.Error != nil {
return item.Error
}
lock.Lock()
switch node.Type {
switch item.Node.Type {
case NodeTypeFile:
for _, blob := range node.Content {
for _, blob := range item.Node.Content {
blobs.Insert(restic.BlobHandle{ID: blob, Type: restic.DataBlob})
}
}

View file

@ -5,12 +5,15 @@ import (
"fmt"
"io"
"math/rand"
"slices"
"strings"
"testing"
"time"
"github.com/restic/chunker"
"github.com/restic/restic/internal/restic"
"github.com/restic/restic/internal/test"
rtest "github.com/restic/restic/internal/test"
)
// fakeFile returns a reader which yields deterministic pseudo-random data.
@ -72,9 +75,8 @@ func (fs *fakeFileSystem) saveTree(ctx context.Context, uploader restic.BlobSave
rnd := rand.NewSource(seed)
numNodes := int(rnd.Int63() % maxNodes)
var tree Tree
var nodes []*Node
for i := 0; i < numNodes; i++ {
// randomly select the type of the node, either tree (p = 1/4) or file (p = 3/4).
if depth > 1 && rnd.Int63()%4 == 0 {
treeSeed := rnd.Int63() % maxSeed
@ -87,7 +89,7 @@ func (fs *fakeFileSystem) saveTree(ctx context.Context, uploader restic.BlobSave
Subtree: &id,
}
tree.Nodes = append(tree.Nodes, node)
nodes = append(nodes, node)
continue
}
@ -102,14 +104,24 @@ func (fs *fakeFileSystem) saveTree(ctx context.Context, uploader restic.BlobSave
}
node.Content = fs.saveFile(ctx, uploader, fakeFile(fileSeed, fileSize))
tree.Nodes = append(tree.Nodes, node)
nodes = append(nodes, node)
}
tree.Sort()
id, err := SaveTree(ctx, uploader, &tree)
if err != nil {
fs.t.Fatalf("SaveTree returned error: %v", err)
return TestSaveNodes(fs.t, ctx, uploader, nodes)
}
//nolint:revive // as this is a test helper, t should go first
func TestSaveNodes(t testing.TB, ctx context.Context, uploader restic.BlobSaver, nodes []*Node) restic.ID {
slices.SortFunc(nodes, func(a, b *Node) int {
return strings.Compare(a.Name, b.Name)
})
treeWriter := NewTreeWriter(uploader)
for _, node := range nodes {
err := treeWriter.AddNode(node)
rtest.OK(t, err)
}
id, err := treeWriter.Finalize(ctx)
rtest.OK(t, err)
return id
}

View file

@ -5,124 +5,164 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"iter"
"path"
"sort"
"strings"
"github.com/restic/restic/internal/errors"
"github.com/restic/restic/internal/restic"
"github.com/restic/restic/internal/debug"
)
// For documentation purposes only:
// // Tree is an ordered list of nodes.
// type Tree struct {
// Nodes []*Node `json:"nodes"`
// }
var ErrTreeNotOrdered = errors.New("nodes are not ordered or duplicate")
// Tree is an ordered list of nodes.
type Tree struct {
Nodes []*Node `json:"nodes"`
type treeIterator struct {
dec json.Decoder
started bool
}
// NewTree creates a new tree object with the given initial capacity.
func NewTree(capacity int) *Tree {
return &Tree{
Nodes: make([]*Node, 0, capacity),
}
type NodeOrError struct {
Node *Node
Error error
}
func (t *Tree) String() string {
return fmt.Sprintf("Tree<%d nodes>", len(t.Nodes))
}
type TreeNodeIterator = iter.Seq[NodeOrError]
// Equals returns true if t and other have exactly the same nodes.
func (t *Tree) Equals(other *Tree) bool {
if len(t.Nodes) != len(other.Nodes) {
debug.Log("tree.Equals(): trees have different number of nodes")
return false
func NewTreeNodeIterator(rd io.Reader) (TreeNodeIterator, error) {
t := &treeIterator{
dec: *json.NewDecoder(rd),
}
for i := 0; i < len(t.Nodes); i++ {
if !t.Nodes[i].Equals(*other.Nodes[i]) {
debug.Log("tree.Equals(): node %d is different:", i)
debug.Log(" %#v", t.Nodes[i])
debug.Log(" %#v", other.Nodes[i])
return false
err := t.init()
if err != nil {
return nil, err
}
return func(yield func(NodeOrError) bool) {
if t.started {
panic("tree iterator is single use only")
}
}
return true
t.started = true
for {
n, err := t.next()
if err != nil && errors.Is(err, io.EOF) {
return
}
if !yield(NodeOrError{Node: n, Error: err}) {
return
}
// errors are final
if err != nil {
return
}
}
}, nil
}
// Insert adds a new node at the correct place in the tree.
func (t *Tree) Insert(node *Node) error {
pos, found := t.find(node.Name)
if found != nil {
return errors.Errorf("node %q already present", node.Name)
func (t *treeIterator) init() error {
// `{"nodes":[` `]}`
if err := t.assertToken(json.Delim('{')); err != nil {
return err
}
if err := t.assertToken("nodes"); err != nil {
return err
}
if err := t.assertToken(json.Delim('[')); err != nil {
return err
}
// https://github.com/golang/go/wiki/SliceTricks
t.Nodes = append(t.Nodes, nil)
copy(t.Nodes[pos+1:], t.Nodes[pos:])
t.Nodes[pos] = node
return nil
}
func (t *Tree) find(name string) (int, *Node) {
pos := sort.Search(len(t.Nodes), func(i int) bool {
return t.Nodes[i].Name >= name
})
if pos < len(t.Nodes) && t.Nodes[pos].Name == name {
return pos, t.Nodes[pos]
}
return pos, nil
}
// Find returns a node with the given name, or nil if none could be found.
func (t *Tree) Find(name string) *Node {
if t == nil {
return nil
}
_, node := t.find(name)
return node
}
// Sort sorts the nodes by name.
func (t *Tree) Sort() {
list := Nodes(t.Nodes)
sort.Sort(list)
t.Nodes = list
}
// Subtrees returns a slice of all subtree IDs of the tree.
func (t *Tree) Subtrees() (trees restic.IDs) {
for _, node := range t.Nodes {
if node.Type == NodeTypeDir && node.Subtree != nil {
trees = append(trees, *node.Subtree)
func (t *treeIterator) next() (*Node, error) {
if t.dec.More() {
var n Node
err := t.dec.Decode(&n)
if err != nil {
return nil, err
}
return &n, nil
}
return trees
if err := t.assertToken(json.Delim(']')); err != nil {
return nil, err
}
if err := t.assertToken(json.Delim('}')); err != nil {
return nil, err
}
return nil, io.EOF
}
// LoadTree loads a tree from the repository.
func LoadTree(ctx context.Context, r restic.BlobLoader, id restic.ID) (*Tree, error) {
debug.Log("load tree %v", id)
func (t *treeIterator) assertToken(token json.Token) error {
to, err := t.dec.Token()
if err != nil {
return err
}
if to != token {
return errors.Errorf("error decoding tree: expected %v, got %v", token, to)
}
return nil
}
buf, err := r.LoadBlob(ctx, restic.TreeBlob, id, nil)
func LoadTree(ctx context.Context, loader restic.BlobLoader, content restic.ID) (TreeNodeIterator, error) {
rd, err := loader.LoadBlob(ctx, restic.TreeBlob, content, nil)
if err != nil {
return nil, err
}
return NewTreeNodeIterator(bytes.NewReader(rd))
}
t := &Tree{}
err = json.Unmarshal(buf, t)
if err != nil {
return nil, err
type TreeFinder struct {
next func() (NodeOrError, bool)
stop func()
current *Node
}
func NewTreeFinder(tree TreeNodeIterator) *TreeFinder {
if tree == nil {
return &TreeFinder{stop: func() {}}
}
next, stop := iter.Pull(tree)
return &TreeFinder{next: next, stop: stop}
}
// Find finds the node with the given name. If the node is not found, it returns nil.
// If Find was called before, the new name must be strictly greater than the last name.
func (t *TreeFinder) Find(name string) (*Node, error) {
if t.next == nil {
return nil, nil
}
// loop until `t.current.Name` is >= name
for t.current == nil || t.current.Name < name {
current, ok := t.next()
if current.Error != nil {
return nil, current.Error
}
if !ok {
return nil, nil
}
t.current = current.Node
}
return t, nil
if t.current.Name == name {
// forget the current node to free memory as early as possible
current := t.current
t.current = nil
return current, nil
}
// we have already passed the name
return nil, nil
}
func (t *TreeFinder) Close() {
t.stop()
}
type TreeWriter struct {
@ -148,10 +188,13 @@ func (t *TreeWriter) Finalize(ctx context.Context) (restic.ID, error) {
return id, err
}
func SaveTree(ctx context.Context, saver restic.BlobSaver, t *Tree) (restic.ID, error) {
func SaveTree(ctx context.Context, saver restic.BlobSaver, nodes TreeNodeIterator) (restic.ID, error) {
treeWriter := NewTreeWriter(saver)
for _, node := range t.Nodes {
err := treeWriter.AddNode(node)
for item := range nodes {
if item.Error != nil {
return restic.ID{}, item.Error
}
err := treeWriter.AddNode(item.Node)
if err != nil {
return restic.ID{}, err
}
@ -214,7 +257,12 @@ func FindTreeDirectory(ctx context.Context, repo restic.BlobLoader, id *restic.I
if err != nil {
return nil, fmt.Errorf("path %s: %w", subfolder, err)
}
node := tree.Find(name)
finder := NewTreeFinder(tree)
node, err := finder.Find(name)
finder.Close()
if err != nil {
return nil, fmt.Errorf("path %s: %w", subfolder, err)
}
if node == nil {
return nil, fmt.Errorf("path %s: not found", subfolder)
}

View file

@ -23,12 +23,36 @@ type trackedID struct {
rootIdx int
}
// subtreesCollector wraps a TreeNodeIterator and returns a new iterator that collects the subtrees.
func subtreesCollector(tree TreeNodeIterator) (TreeNodeIterator, func() restic.IDs) {
subtrees := restic.IDs{}
isComplete := false
return func(yield func(NodeOrError) bool) {
for item := range tree {
if !yield(item) {
return
}
// be defensive and check for nil subtree as this code is also used by the checker
if item.Node != nil && item.Node.Type == NodeTypeDir && item.Node.Subtree != nil {
subtrees = append(subtrees, *item.Node.Subtree)
}
}
isComplete = true
}, func() restic.IDs {
if !isComplete {
panic("tree was not read completely")
}
return subtrees
}
}
// loadTreeWorker loads trees from repo and sends them to out.
func loadTreeWorker(
ctx context.Context,
repo restic.Loader,
in <-chan trackedID,
process func(id restic.ID, error error, tree *Tree) error,
process func(id restic.ID, error error, nodes TreeNodeIterator) error,
out chan<- trackedTreeItem,
) error {
@ -39,14 +63,21 @@ func loadTreeWorker(
}
debug.Log("load tree %v (%v) returned err: %v", tree, treeID, err)
// wrap iterator to collect subtrees while `process` iterates over `tree`
var collectSubtrees func() restic.IDs
if tree != nil {
tree, collectSubtrees = subtreesCollector(tree)
}
err = process(treeID.ID, err, tree)
if err != nil {
return err
}
// assume that the number of subtrees is within reasonable limits, such that the memory usage is not a problem
var subtrees restic.IDs
if tree != nil {
subtrees = tree.Subtrees()
if collectSubtrees != nil {
subtrees = collectSubtrees()
}
job := trackedTreeItem{ID: treeID.ID, Subtrees: subtrees, rootIdx: treeID.rootIdx}
@ -159,7 +190,7 @@ func StreamTrees(
trees restic.IDs,
p *progress.Counter,
skip func(tree restic.ID) bool,
process func(id restic.ID, error error, tree *Tree) error,
process func(id restic.ID, error error, nodes TreeNodeIterator) error,
) error {
loaderChan := make(chan trackedID)
hugeTreeChan := make(chan trackedID, 10)

View file

@ -6,6 +6,7 @@ import (
"errors"
"os"
"path/filepath"
"slices"
"strconv"
"testing"
@ -105,37 +106,45 @@ func TestNodeComparison(t *testing.T) {
func TestEmptyLoadTree(t *testing.T) {
repo := repository.TestRepository(t)
tree := data.NewTree(0)
nodes := []*data.Node{}
var id restic.ID
rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaverWithAsync) error {
var err error
// save tree
id, err = data.SaveTree(ctx, uploader, tree)
return err
id = data.TestSaveNodes(t, ctx, uploader, nodes)
return nil
}))
// load tree again
tree2, err := data.LoadTree(context.TODO(), repo, id)
it, err := data.LoadTree(context.TODO(), repo, id)
rtest.OK(t, err)
nodes2 := []*data.Node{}
for item := range it {
rtest.OK(t, item.Error)
nodes2 = append(nodes2, item.Node)
}
rtest.Assert(t, tree.Equals(tree2),
"trees are not equal: want %v, got %v",
tree, tree2)
rtest.Assert(t, slices.Equal(nodes, nodes2),
"tree nodes are not equal: want %v, got %v",
nodes, nodes2)
}
// Basic type for comparing the serialization of the tree
type Tree struct {
Nodes []*data.Node `json:"nodes"`
}
func TestTreeEqualSerialization(t *testing.T) {
files := []string{"node.go", "tree.go", "tree_test.go"}
for i := 1; i <= len(files); i++ {
tree := data.NewTree(i)
tree := Tree{Nodes: make([]*data.Node, 0, i)}
builder := data.NewTreeJSONBuilder()
for _, fn := range files[:i] {
node := nodeForFile(t, fn)
rtest.OK(t, tree.Insert(node))
tree.Nodes = append(tree.Nodes, node)
rtest.OK(t, builder.AddNode(node))
rtest.Assert(t, tree.Insert(node) != nil, "no error on duplicate node")
rtest.Assert(t, builder.AddNode(node) != nil, "no error on duplicate node")
rtest.Assert(t, errors.Is(builder.AddNode(node), data.ErrTreeNotOrdered), "wrong error returned")
}
@ -144,11 +153,11 @@ func TestTreeEqualSerialization(t *testing.T) {
treeBytes = append(treeBytes, '\n')
rtest.OK(t, err)
stiBytes, err := builder.Finalize()
buf, err := builder.Finalize()
rtest.OK(t, err)
// compare serialization of an individual node and the SaveTreeIterator
rtest.Equals(t, treeBytes, stiBytes)
rtest.Equals(t, treeBytes, buf)
}
}
@ -165,11 +174,12 @@ func BenchmarkBuildTree(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
t := data.NewTree(size)
t := data.NewTreeJSONBuilder()
for i := range nodes {
_ = t.Insert(&nodes[i])
rtest.OK(b, t.AddNode(&nodes[i]))
}
_, err := t.Finalize()
rtest.OK(b, err)
}
}
@ -186,8 +196,11 @@ func testLoadTree(t *testing.T, version uint) {
repo, _, _ := repository.TestRepositoryWithVersion(t, version)
sn := archiver.TestSnapshot(t, repo, rtest.BenchArchiveDirectory, nil)
_, err := data.LoadTree(context.TODO(), repo, *sn.Tree)
nodes, err := data.LoadTree(context.TODO(), repo, *sn.Tree)
rtest.OK(t, err)
for item := range nodes {
rtest.OK(t, item.Error)
}
}
func BenchmarkLoadTree(t *testing.B) {

View file

@ -30,7 +30,7 @@ func New(format string, repo restic.Loader, w io.Writer) *Dumper {
}
}
func (d *Dumper) DumpTree(ctx context.Context, tree *data.Tree, rootPath string) error {
func (d *Dumper) DumpTree(ctx context.Context, tree data.TreeNodeIterator, rootPath string) error {
wg, ctx := errgroup.WithContext(ctx)
// ch is buffered to deal with variable download/write speeds.
@ -52,10 +52,14 @@ func (d *Dumper) DumpTree(ctx context.Context, tree *data.Tree, rootPath string)
return wg.Wait()
}
func sendTrees(ctx context.Context, repo restic.BlobLoader, tree *data.Tree, rootPath string, ch chan *data.Node) error {
func sendTrees(ctx context.Context, repo restic.BlobLoader, nodes data.TreeNodeIterator, rootPath string, ch chan *data.Node) error {
defer close(ch)
for _, node := range tree.Nodes {
for item := range nodes {
if item.Error != nil {
return item.Error
}
node := item.Node
node.Path = path.Join(rootPath, node.Name)
if err := sendNodes(ctx, repo, node, ch); err != nil {
return err

View file

@ -7,6 +7,7 @@ import (
"errors"
"os"
"path/filepath"
"slices"
"sync"
"syscall"
@ -65,13 +66,13 @@ func unwrapCtxCanceled(err error) error {
// replaceSpecialNodes replaces nodes with name "." and "/" by their contents.
// Otherwise, the node is returned.
func replaceSpecialNodes(ctx context.Context, repo restic.BlobLoader, node *data.Node) ([]*data.Node, error) {
func replaceSpecialNodes(ctx context.Context, repo restic.BlobLoader, node *data.Node) (data.TreeNodeIterator, error) {
if node.Type != data.NodeTypeDir || node.Subtree == nil {
return []*data.Node{node}, nil
return slices.Values([]data.NodeOrError{{Node: node}}), nil
}
if node.Name != "." && node.Name != "/" {
return []*data.Node{node}, nil
return slices.Values([]data.NodeOrError{{Node: node}}), nil
}
tree, err := data.LoadTree(ctx, repo, *node.Subtree)
@ -79,7 +80,7 @@ func replaceSpecialNodes(ctx context.Context, repo restic.BlobLoader, node *data
return nil, unwrapCtxCanceled(err)
}
return tree.Nodes, nil
return tree, nil
}
func newDirFromSnapshot(root *Root, forget forgetFn, inode uint64, snapshot *data.Snapshot) (*dir, error) {
@ -115,18 +116,25 @@ func (d *dir) open(ctx context.Context) error {
return unwrapCtxCanceled(err)
}
items := make(map[string]*data.Node)
for _, n := range tree.Nodes {
for item := range tree {
if item.Error != nil {
return unwrapCtxCanceled(item.Error)
}
if ctx.Err() != nil {
return ctx.Err()
}
n := item.Node
nodes, err := replaceSpecialNodes(ctx, d.root.repo, n)
if err != nil {
debug.Log(" replaceSpecialNodes(%v) failed: %v", n, err)
return err
}
for _, node := range nodes {
items[cleanupNodeName(node.Name)] = node
for item := range nodes {
if item.Error != nil {
return unwrapCtxCanceled(item.Error)
}
items[cleanupNodeName(item.Node.Name)] = item.Node
}
}
d.items = items

View file

@ -59,7 +59,7 @@ func loadFirstSnapshot(t testing.TB, repo restic.ListerLoaderUnpacked) *data.Sna
return sn
}
func loadTree(t testing.TB, repo restic.Loader, id restic.ID) *data.Tree {
func loadTree(t testing.TB, repo restic.Loader, id restic.ID) data.TreeNodeIterator {
tree, err := data.LoadTree(context.TODO(), repo, id)
rtest.OK(t, err)
return tree
@ -79,8 +79,9 @@ func TestFuseFile(t *testing.T) {
tree := loadTree(t, repo, *sn.Tree)
var content restic.IDs
for _, node := range tree.Nodes {
content = append(content, node.Content...)
for item := range tree {
rtest.OK(t, item.Error)
content = append(content, item.Node.Content...)
}
t.Logf("tree loaded, content: %v", content)

View file

@ -163,15 +163,18 @@ func (res *Restorer) traverseTreeInner(ctx context.Context, target, location str
}
if res.opts.Delete {
filenames = make([]string, 0, len(tree.Nodes))
filenames = make([]string, 0)
}
for i, node := range tree.Nodes {
for item := range tree {
if item.Error != nil {
debug.Log("error iterating tree %v: %v", treeID, item.Error)
return nil, hasRestored, res.sanitizeError(location, item.Error)
}
node := item.Node
if ctx.Err() != nil {
return nil, hasRestored, ctx.Err()
}
// allow GC of tree node
tree.Nodes[i] = nil
if res.opts.Delete {
// just track all files included in the tree node to simplify the control flow.
// tracking too many files does not matter except for a slightly elevated memory usage

View file

@ -79,7 +79,7 @@ func saveDir(t testing.TB, repo restic.BlobSaver, nodes map[string]Node, inode u
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
tree := &data.Tree{}
tree := make([]*data.Node, 0, len(nodes))
for name, n := range nodes {
inode++
switch node := n.(type) {
@ -107,7 +107,7 @@ func saveDir(t testing.TB, repo restic.BlobSaver, nodes map[string]Node, inode u
if mode == 0 {
mode = 0644
}
err := tree.Insert(&data.Node{
tree = append(tree, &data.Node{
Type: data.NodeTypeFile,
Mode: mode,
ModTime: node.ModTime,
@ -120,9 +120,8 @@ func saveDir(t testing.TB, repo restic.BlobSaver, nodes map[string]Node, inode u
Links: lc,
GenericAttributes: getGenericAttributes(node.attributes, false),
})
rtest.OK(t, err)
case Symlink:
err := tree.Insert(&data.Node{
tree = append(tree, &data.Node{
Type: data.NodeTypeSymlink,
Mode: os.ModeSymlink | 0o777,
ModTime: node.ModTime,
@ -133,7 +132,6 @@ func saveDir(t testing.TB, repo restic.BlobSaver, nodes map[string]Node, inode u
Inode: inode,
Links: 1,
})
rtest.OK(t, err)
case Dir:
id := saveDir(t, repo, node.Nodes, inode, getGenericAttributes)
@ -142,7 +140,7 @@ func saveDir(t testing.TB, repo restic.BlobSaver, nodes map[string]Node, inode u
mode = 0755
}
err := tree.Insert(&data.Node{
tree = append(tree, &data.Node{
Type: data.NodeTypeDir,
Mode: mode,
ModTime: node.ModTime,
@ -152,19 +150,12 @@ func saveDir(t testing.TB, repo restic.BlobSaver, nodes map[string]Node, inode u
Subtree: &id,
GenericAttributes: getGenericAttributes(node.attributes, false),
})
rtest.OK(t, err)
default:
t.Fatalf("unknown node type %T", node)
}
}
tree.Sort()
id, err := data.SaveTree(ctx, repo, tree)
if err != nil {
t.Fatal(err)
}
return id
return data.TestSaveNodes(t, ctx, repo, tree)
}
func saveSnapshot(t testing.TB, repo restic.Repository, snapshot Snapshot, getGenericAttributes func(attr *FileAttributes, isDir bool) (genericAttributes map[data.GenericAttributeType]json.RawMessage)) (*data.Snapshot, restic.ID) {

View file

@ -11,7 +11,7 @@ import (
)
type NodeRewriteFunc func(node *data.Node, path string) *data.Node
type FailedTreeRewriteFunc func(nodeID restic.ID, path string, err error) (*data.Tree, error)
type FailedTreeRewriteFunc func(nodeID restic.ID, path string, err error) (data.TreeNodeIterator, error)
type QueryRewrittenSizeFunc func() SnapshotSize
type SnapshotSize struct {
@ -52,7 +52,7 @@ func NewTreeRewriter(opts RewriteOpts) *TreeRewriter {
}
if rw.opts.RewriteFailedTree == nil {
// fail with error by default
rw.opts.RewriteFailedTree = func(_ restic.ID, _ string, err error) (*data.Tree, error) {
rw.opts.RewriteFailedTree = func(_ restic.ID, _ string, err error) (data.TreeNodeIterator, error) {
return nil, err
}
}
@ -117,15 +117,26 @@ func (t *TreeRewriter) RewriteTree(ctx context.Context, loader restic.BlobLoader
if nodeID != testID {
return restic.ID{}, fmt.Errorf("cannot encode tree at %q without losing information", nodepath)
}
// reload the tree to get a new iterator
curTree, err = data.LoadTree(ctx, loader, nodeID)
if err != nil {
// shouldn't fail as the first load was successful
return restic.ID{}, fmt.Errorf("failed to reload tree %v: %w", nodeID, err)
}
}
debug.Log("filterTree: %s, nodeId: %s\n", nodepath, nodeID.Str())
tb := data.NewTreeWriter(saver)
for _, node := range curTree.Nodes {
for item := range curTree {
if ctx.Err() != nil {
return restic.ID{}, ctx.Err()
}
if item.Error != nil {
return restic.ID{}, err
}
node := item.Node
path := path.Join(nodepath, node.Name)
node = t.opts.RewriteNode(node, path)

View file

@ -2,6 +2,7 @@ package walker
import (
"context"
"slices"
"testing"
"github.com/pkg/errors"
@ -405,16 +406,15 @@ func TestRewriterTreeLoadError(t *testing.T) {
t.Fatal("missing error on unloadable tree")
}
replacementTree := &data.Tree{Nodes: []*data.Node{{Name: "replacement", Type: data.NodeTypeFile, Size: 42}}}
replacementID, err := data.SaveTree(ctx, tm, replacementTree)
test.OK(t, err)
replacementNode := &data.Node{Name: "replacement", Type: data.NodeTypeFile, Size: 42}
replacementID := data.TestSaveNodes(t, ctx, tm, []*data.Node{replacementNode})
rewriter = NewTreeRewriter(RewriteOpts{
RewriteFailedTree: func(nodeID restic.ID, path string, err error) (*data.Tree, error) {
RewriteFailedTree: func(nodeID restic.ID, path string, err error) (data.TreeNodeIterator, error) {
if nodeID != id || path != "/" {
t.Fail()
}
return replacementTree, nil
return slices.Values([]data.NodeOrError{{Node: replacementNode}}), nil
},
})
newRoot, err := rewriter.RewriteTree(ctx, tm, tm, "/", id)

View file

@ -3,7 +3,6 @@ package walker
import (
"context"
"path"
"sort"
"github.com/pkg/errors"
@ -52,15 +51,15 @@ func Walk(ctx context.Context, repo restic.BlobLoader, root restic.ID, visitor W
// walk recursively traverses the tree, ignoring subtrees when the ID of the
// subtree is in ignoreTrees. If err is nil and ignore is true, the subtree ID
// will be added to ignoreTrees by walk.
func walk(ctx context.Context, repo restic.BlobLoader, prefix string, parentTreeID restic.ID, tree *data.Tree, visitor WalkVisitor) (err error) {
sort.Slice(tree.Nodes, func(i, j int) bool {
return tree.Nodes[i].Name < tree.Nodes[j].Name
})
for _, node := range tree.Nodes {
func walk(ctx context.Context, repo restic.BlobLoader, prefix string, parentTreeID restic.ID, tree data.TreeNodeIterator, visitor WalkVisitor) (err error) {
for item := range tree {
if item.Error != nil {
return item.Error
}
if ctx.Err() != nil {
return ctx.Err()
}
node := item.Node
p := path.Join(prefix, node.Name)