diff --git a/cmd/restic/cmd_copy.go b/cmd/restic/cmd_copy.go index f209015f0..d17ded7c9 100644 --- a/cmd/restic/cmd_copy.go +++ b/cmd/restic/cmd_copy.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "iter" + "sync" "time" "github.com/restic/restic/internal/data" @@ -14,7 +15,6 @@ import ( "github.com/restic/restic/internal/restic" "github.com/restic/restic/internal/ui" "github.com/restic/restic/internal/ui/progress" - "golang.org/x/sync/errgroup" "github.com/spf13/cobra" "github.com/spf13/pflag" @@ -257,19 +257,13 @@ func copyTreeBatched(ctx context.Context, srcRepo restic.Repository, dstRepo res func copyTree(ctx context.Context, srcRepo restic.Repository, dstRepo restic.Repository, visitedTrees restic.AssociatedBlobSet, rootTreeID restic.ID, printer progress.Printer, uploader restic.BlobSaverWithAsync) (uint64, error) { - wg, wgCtx := errgroup.WithContext(ctx) - - treeStream := data.StreamTrees(wgCtx, wg, srcRepo, restic.IDs{rootTreeID}, func(treeID restic.ID) bool { - handle := restic.BlobHandle{ID: treeID, Type: restic.TreeBlob} - visited := visitedTrees.Has(handle) - visitedTrees.Insert(handle) - return visited - }, nil) - copyBlobs := srcRepo.NewAssociatedBlobSet() packList := restic.NewIDSet() + var lock sync.Mutex enqueue := func(h restic.BlobHandle) { + lock.Lock() + defer lock.Unlock() if _, ok := dstRepo.LookupBlobSize(h.Type, h.ID); !ok { pb := srcRepo.LookupBlob(h.Type, h.ID) copyBlobs.Insert(h) @@ -279,26 +273,31 @@ func copyTree(ctx context.Context, srcRepo restic.Repository, dstRepo restic.Rep } } - wg.Go(func() error { - for tree := range treeStream { - if tree.Error != nil { - return fmt.Errorf("LoadTree(%v) returned error %v", tree.ID.Str(), tree.Error) + err := data.StreamTrees(ctx, srcRepo, restic.IDs{rootTreeID}, nil, func(treeID restic.ID) bool { + handle := restic.BlobHandle{ID: treeID, Type: restic.TreeBlob} + visited := visitedTrees.Has(handle) + visitedTrees.Insert(handle) + return visited + }, func(treeID restic.ID, err error, nodes data.TreeNodeIterator) error { + if err != nil { + return fmt.Errorf("LoadTree(%v) returned error %v", treeID.Str(), err) + } + + // copy raw tree bytes to avoid problems if the serialization changes + enqueue(restic.BlobHandle{ID: treeID, Type: restic.TreeBlob}) + + for item := range nodes { + if item.Error != nil { + return item.Error } - - // copy raw tree bytes to avoid problems if the serialization changes - enqueue(restic.BlobHandle{ID: tree.ID, Type: restic.TreeBlob}) - - for _, entry := range tree.Nodes { - // Recursion into directories is handled by StreamTrees - // Copy the blobs for this file. - for _, blobID := range entry.Content { - enqueue(restic.BlobHandle{Type: restic.DataBlob, ID: blobID}) - } + // Recursion into directories is handled by StreamTrees + // Copy the blobs for this file. + for _, blobID := range item.Node.Content { + enqueue(restic.BlobHandle{Type: restic.DataBlob, ID: blobID}) } } return nil }) - err := wg.Wait() if err != nil { return 0, err } diff --git a/cmd/restic/cmd_diff.go b/cmd/restic/cmd_diff.go index 58138e9e5..30b0878c0 100644 --- a/cmd/restic/cmd_diff.go +++ b/cmd/restic/cmd_diff.go @@ -5,7 +5,6 @@ import ( "encoding/json" "path" "reflect" - "sort" "github.com/restic/restic/internal/data" "github.com/restic/restic/internal/debug" @@ -184,11 +183,14 @@ func (c *Comparer) printDir(ctx context.Context, mode string, stats *DiffStat, b return err } - for _, node := range tree.Nodes { + for item := range tree { + if item.Error != nil { + return item.Error + } if ctx.Err() != nil { return ctx.Err() } - + node := item.Node name := path.Join(prefix, node.Name) if node.Type == data.NodeTypeDir { name += "/" @@ -215,11 +217,15 @@ func (c *Comparer) collectDir(ctx context.Context, blobs restic.AssociatedBlobSe return err } - for _, node := range tree.Nodes { + for item := range tree { + if item.Error != nil { + return item.Error + } if ctx.Err() != nil { return ctx.Err() } + node := item.Node addBlobs(blobs, node) if node.Type == data.NodeTypeDir { @@ -233,29 +239,6 @@ func (c *Comparer) collectDir(ctx context.Context, blobs restic.AssociatedBlobSe return ctx.Err() } -func uniqueNodeNames(tree1, tree2 *data.Tree) (tree1Nodes, tree2Nodes map[string]*data.Node, uniqueNames []string) { - names := make(map[string]struct{}) - tree1Nodes = make(map[string]*data.Node) - for _, node := range tree1.Nodes { - tree1Nodes[node.Name] = node - names[node.Name] = struct{}{} - } - - tree2Nodes = make(map[string]*data.Node) - for _, node := range tree2.Nodes { - tree2Nodes[node.Name] = node - names[node.Name] = struct{}{} - } - - uniqueNames = make([]string, 0, len(names)) - for name := range names { - uniqueNames = append(uniqueNames, name) - } - - sort.Strings(uniqueNames) - return tree1Nodes, tree2Nodes, uniqueNames -} - func (c *Comparer) diffTree(ctx context.Context, stats *DiffStatsContainer, prefix string, id1, id2 restic.ID) error { debug.Log("diffing %v to %v", id1, id2) tree1, err := data.LoadTree(ctx, c.repo, id1) @@ -268,21 +251,29 @@ func (c *Comparer) diffTree(ctx context.Context, stats *DiffStatsContainer, pref return err } - tree1Nodes, tree2Nodes, names := uniqueNodeNames(tree1, tree2) - - for _, name := range names { + for dt := range data.DualTreeIterator(tree1, tree2) { + if dt.Error != nil { + return dt.Error + } if ctx.Err() != nil { return ctx.Err() } - node1, t1 := tree1Nodes[name] - node2, t2 := tree2Nodes[name] + node1 := dt.Tree1 + node2 := dt.Tree2 + + var name string + if node1 != nil { + name = node1.Name + } else { + name = node2.Name + } addBlobs(stats.BlobsBefore, node1) addBlobs(stats.BlobsAfter, node2) switch { - case t1 && t2: + case node1 != nil && node2 != nil: name := path.Join(prefix, name) mod := "" @@ -328,7 +319,7 @@ func (c *Comparer) diffTree(ctx context.Context, stats *DiffStatsContainer, pref c.printError("error: %v", err) } } - case t1 && !t2: + case node1 != nil && node2 == nil: prefix := path.Join(prefix, name) if node1.Type == data.NodeTypeDir { prefix += "/" @@ -342,7 +333,7 @@ func (c *Comparer) diffTree(ctx context.Context, stats *DiffStatsContainer, pref c.printError("error: %v", err) } } - case !t1 && t2: + case node1 == nil && node2 != nil: prefix := path.Join(prefix, name) if node2.Type == data.NodeTypeDir { prefix += "/" diff --git a/cmd/restic/cmd_dump.go b/cmd/restic/cmd_dump.go index da45cc303..abea6e52e 100644 --- a/cmd/restic/cmd_dump.go +++ b/cmd/restic/cmd_dump.go @@ -80,7 +80,7 @@ func splitPath(p string) []string { return append(s, f) } -func printFromTree(ctx context.Context, tree *data.Tree, repo restic.BlobLoader, prefix string, pathComponents []string, d *dump.Dumper, canWriteArchiveFunc func() error) error { +func printFromTree(ctx context.Context, tree data.TreeNodeIterator, repo restic.BlobLoader, prefix string, pathComponents []string, d *dump.Dumper, canWriteArchiveFunc func() error) error { // If we print / we need to assume that there are multiple nodes at that // level in the tree. if pathComponents[0] == "" { @@ -92,11 +92,14 @@ func printFromTree(ctx context.Context, tree *data.Tree, repo restic.BlobLoader, item := filepath.Join(prefix, pathComponents[0]) l := len(pathComponents) - for _, node := range tree.Nodes { + for it := range tree { + if it.Error != nil { + return it.Error + } if ctx.Err() != nil { return ctx.Err() } - + node := it.Node // If dumping something in the highest level it will just take the // first item it finds and dump that according to the switch case below. if node.Name == pathComponents[0] { diff --git a/cmd/restic/cmd_recover.go b/cmd/restic/cmd_recover.go index fec4c44b5..bbb71972f 100644 --- a/cmd/restic/cmd_recover.go +++ b/cmd/restic/cmd_recover.go @@ -97,7 +97,11 @@ func runRecover(ctx context.Context, gopts global.Options, term ui.Terminal) err continue } - for _, node := range tree.Nodes { + for item := range tree { + if item.Error != nil { + return item.Error + } + node := item.Node if node.Type == data.NodeTypeDir && node.Subtree != nil { trees[*node.Subtree] = true } @@ -134,28 +138,28 @@ func runRecover(ctx context.Context, gopts global.Options, term ui.Terminal) err return ctx.Err() } - tree := data.NewTree(len(roots)) - for id := range roots { - var subtreeID = id - node := data.Node{ - Type: data.NodeTypeDir, - Name: id.Str(), - Mode: 0755, - Subtree: &subtreeID, - AccessTime: time.Now(), - ModTime: time.Now(), - ChangeTime: time.Now(), - } - err := tree.Insert(&node) - if err != nil { - return err - } - } - var treeID restic.ID err = repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { var err error - treeID, err = data.SaveTree(ctx, uploader, tree) + tw := data.NewTreeWriter(uploader) + for id := range roots { + var subtreeID = id + node := data.Node{ + Type: data.NodeTypeDir, + Name: id.Str(), + Mode: 0755, + Subtree: &subtreeID, + AccessTime: time.Now(), + ModTime: time.Now(), + ChangeTime: time.Now(), + } + err := tw.AddNode(&node) + if err != nil { + return err + } + } + + treeID, err = tw.Finalize(ctx) if err != nil { return errors.Fatalf("unable to save new tree to the repository: %v", err) } diff --git a/cmd/restic/cmd_repair_snapshots.go b/cmd/restic/cmd_repair_snapshots.go index 226da3d44..104d3dff3 100644 --- a/cmd/restic/cmd_repair_snapshots.go +++ b/cmd/restic/cmd_repair_snapshots.go @@ -2,6 +2,7 @@ package main import ( "context" + "slices" "github.com/restic/restic/internal/data" "github.com/restic/restic/internal/errors" @@ -130,7 +131,7 @@ func runRepairSnapshots(ctx context.Context, gopts global.Options, opts RepairOp node.Size = newSize return node }, - RewriteFailedTree: func(_ restic.ID, path string, _ error) (*data.Tree, error) { + RewriteFailedTree: func(_ restic.ID, path string, _ error) (data.TreeNodeIterator, error) { if path == "/" { printer.P(" dir %q: not readable", path) // remove snapshots with invalid root node @@ -138,7 +139,7 @@ func runRepairSnapshots(ctx context.Context, gopts global.Options, opts RepairOp } // If a subtree fails to load, remove it printer.P(" dir %q: replaced with empty directory", path) - return &data.Tree{}, nil + return slices.Values([]data.NodeOrError{}), nil }, AllowUnstableSerialization: true, }) diff --git a/internal/archiver/archiver.go b/internal/archiver/archiver.go index 0ed37eb5b..996e79e6a 100644 --- a/internal/archiver/archiver.go +++ b/internal/archiver/archiver.go @@ -281,7 +281,7 @@ func (arch *Archiver) nodeFromFileInfo(snPath, filename string, meta ToNoder, ig // loadSubtree tries to load the subtree referenced by node. In case of an error, nil is returned. // If there is no node to load, then nil is returned without an error. -func (arch *Archiver) loadSubtree(ctx context.Context, node *data.Node) (*data.Tree, error) { +func (arch *Archiver) loadSubtree(ctx context.Context, node *data.Node) (data.TreeNodeIterator, error) { if node == nil || node.Type != data.NodeTypeDir || node.Subtree == nil { return nil, nil } @@ -307,7 +307,7 @@ func (arch *Archiver) wrapLoadTreeError(id restic.ID, err error) error { // saveDir stores a directory in the repo and returns the node. snPath is the // path within the current snapshot. -func (arch *Archiver) saveDir(ctx context.Context, snPath string, dir string, meta fs.File, previous *data.Tree, complete fileCompleteFunc) (d futureNode, err error) { +func (arch *Archiver) saveDir(ctx context.Context, snPath string, dir string, meta fs.File, previous data.TreeNodeIterator, complete fileCompleteFunc) (d futureNode, err error) { debug.Log("%v %v", snPath, dir) treeNode, names, err := arch.dirToNodeAndEntries(snPath, dir, meta) @@ -317,6 +317,9 @@ func (arch *Archiver) saveDir(ctx context.Context, snPath string, dir string, me nodes := make([]futureNode, 0, len(names)) + finder := data.NewTreeFinder(previous) + defer finder.Close() + for _, name := range names { // test if context has been cancelled if ctx.Err() != nil { @@ -325,7 +328,11 @@ func (arch *Archiver) saveDir(ctx context.Context, snPath string, dir string, me } pathname := arch.FS.Join(dir, name) - oldNode := previous.Find(name) + oldNode, err := finder.Find(name) + err = arch.error(pathname, err) + if err != nil { + return futureNode{}, err + } snItem := join(snPath, name) fn, excluded, err := arch.save(ctx, snItem, pathname, oldNode) @@ -645,7 +652,7 @@ func join(elem ...string) string { // saveTree stores a Tree in the repo, returned is the tree. snPath is the path // within the current snapshot. -func (arch *Archiver) saveTree(ctx context.Context, snPath string, atree *tree, previous *data.Tree, complete fileCompleteFunc) (futureNode, int, error) { +func (arch *Archiver) saveTree(ctx context.Context, snPath string, atree *tree, previous data.TreeNodeIterator, complete fileCompleteFunc) (futureNode, int, error) { var node *data.Node if snPath != "/" { @@ -663,10 +670,13 @@ func (arch *Archiver) saveTree(ctx context.Context, snPath string, atree *tree, node = &data.Node{} } - debug.Log("%v (%v nodes), parent %v", snPath, len(atree.Nodes), previous) + debug.Log("%v (%v nodes)", snPath, len(atree.Nodes)) nodeNames := atree.NodeNames() nodes := make([]futureNode, 0, len(nodeNames)) + finder := data.NewTreeFinder(previous) + defer finder.Close() + // iterate over the nodes of atree in lexicographic (=deterministic) order for _, name := range nodeNames { subatree := atree.Nodes[name] @@ -678,7 +688,13 @@ func (arch *Archiver) saveTree(ctx context.Context, snPath string, atree *tree, // this is a leaf node if subatree.Leaf() { - fn, excluded, err := arch.save(ctx, join(snPath, name), subatree.Path, previous.Find(name)) + pathname := join(snPath, name) + oldNode, err := finder.Find(name) + err = arch.error(pathname, err) + if err != nil { + return futureNode{}, 0, err + } + fn, excluded, err := arch.save(ctx, pathname, subatree.Path, oldNode) if err != nil { err = arch.error(subatree.Path, err) @@ -698,7 +714,11 @@ func (arch *Archiver) saveTree(ctx context.Context, snPath string, atree *tree, snItem := join(snPath, name) + "/" start := time.Now() - oldNode := previous.Find(name) + oldNode, err := finder.Find(name) + err = arch.error(snItem, err) + if err != nil { + return futureNode{}, 0, err + } oldSubtree, err := arch.loadSubtree(ctx, oldNode) if err != nil { err = arch.error(join(snPath, name), err) @@ -801,7 +821,7 @@ type SnapshotOptions struct { } // loadParentTree loads a tree referenced by snapshot id. If id is null, nil is returned. -func (arch *Archiver) loadParentTree(ctx context.Context, sn *data.Snapshot) *data.Tree { +func (arch *Archiver) loadParentTree(ctx context.Context, sn *data.Snapshot) data.TreeNodeIterator { if sn == nil { return nil } diff --git a/internal/archiver/archiver_test.go b/internal/archiver/archiver_test.go index a4012d585..b13734655 100644 --- a/internal/archiver/archiver_test.go +++ b/internal/archiver/archiver_test.go @@ -877,11 +877,7 @@ func TestArchiverSaveDir(t *testing.T) { } node.Name = targetNodeName - tree := &data.Tree{Nodes: []*data.Node{node}} - treeID, err = data.SaveTree(ctx, uploader, tree) - if err != nil { - t.Fatal(err) - } + treeID = data.TestSaveNodes(t, ctx, uploader, []*data.Node{node}) arch.stopWorkers() return wg.Wait() }) @@ -2256,19 +2252,15 @@ func snapshot(t testing.TB, repo archiverRepo, fs fs.FS, parent *data.Snapshot, ParentSnapshot: parent, } snapshot, _, _, err := arch.Snapshot(ctx, []string{filename}, sopts) - if err != nil { - t.Fatal(err) - } - + rtest.OK(t, err) tree, err := data.LoadTree(ctx, repo, *snapshot.Tree) - if err != nil { - t.Fatal(err) - } + rtest.OK(t, err) - node := tree.Find(filename) - if node == nil { - t.Fatalf("unable to find node for testfile in snapshot") - } + finder := data.NewTreeFinder(tree) + defer finder.Close() + node, err := finder.Find(filename) + rtest.OK(t, err) + rtest.Assert(t, node != nil, "unable to find node for testfile in snapshot") return snapshot, node } diff --git a/internal/archiver/testing.go b/internal/archiver/testing.go index 666a8c556..6f1195c29 100644 --- a/internal/archiver/testing.go +++ b/internal/archiver/testing.go @@ -15,6 +15,7 @@ import ( "github.com/restic/restic/internal/debug" "github.com/restic/restic/internal/fs" "github.com/restic/restic/internal/restic" + rtest "github.com/restic/restic/internal/test" ) // TestSnapshot creates a new snapshot of path. @@ -265,19 +266,14 @@ func TestEnsureTree(ctx context.Context, t testing.TB, prefix string, repo resti t.Helper() tree, err := data.LoadTree(ctx, repo, treeID) - if err != nil { - t.Fatal(err) - return - } + rtest.OK(t, err) var nodeNames []string - for _, node := range tree.Nodes { - nodeNames = append(nodeNames, node.Name) - } - debug.Log("%v (%v) %v", prefix, treeID.Str(), nodeNames) - checked := make(map[string]struct{}) - for _, node := range tree.Nodes { + for item := range tree { + rtest.OK(t, item.Error) + node := item.Node + nodeNames = append(nodeNames, node.Name) nodePrefix := path.Join(prefix, node.Name) entry, ok := dir[node.Name] @@ -316,6 +312,7 @@ func TestEnsureTree(ctx context.Context, t testing.TB, prefix string, repo resti } } } + debug.Log("%v (%v) %v", prefix, treeID.Str(), nodeNames) for name := range dir { _, ok := checked[name] diff --git a/internal/backend/location/location_test.go b/internal/backend/location/location_test.go index fe550a586..ccb1db6ae 100644 --- a/internal/backend/location/location_test.go +++ b/internal/backend/location/location_test.go @@ -29,7 +29,7 @@ func TestParse(t *testing.T) { u, err := location.Parse(registry, path) test.OK(t, err) test.Equals(t, "local", u.Scheme) - test.Equals(t, &testConfig{loc: path}, u.Config) + test.Equals(t, any(&testConfig{loc: path}), u.Config) } func TestParseFallback(t *testing.T) { diff --git a/internal/backend/util/defaults_test.go b/internal/backend/util/defaults_test.go index b0efc336f..6cdd058f8 100644 --- a/internal/backend/util/defaults_test.go +++ b/internal/backend/util/defaults_test.go @@ -37,7 +37,7 @@ func TestDefaultLoad(t *testing.T) { return rd, nil }, func(ird io.Reader) error { - rtest.Equals(t, rd, ird) + rtest.Equals(t, io.Reader(rd), ird) return nil }) rtest.OK(t, err) diff --git a/internal/checker/checker.go b/internal/checker/checker.go index c985951fd..0d3a908b8 100644 --- a/internal/checker/checker.go +++ b/internal/checker/checker.go @@ -3,7 +3,6 @@ package checker import ( "context" "fmt" - "runtime" "sync" "github.com/restic/restic/internal/data" @@ -12,7 +11,6 @@ import ( "github.com/restic/restic/internal/repository" "github.com/restic/restic/internal/restic" "github.com/restic/restic/internal/ui/progress" - "golang.org/x/sync/errgroup" ) // Checker runs various checks on a repository. It is advisable to create an @@ -92,31 +90,6 @@ func (e *TreeError) Error() string { return fmt.Sprintf("tree %v: %v", e.ID, e.Errors) } -// checkTreeWorker checks the trees received and sends out errors to errChan. -func (c *Checker) checkTreeWorker(ctx context.Context, trees <-chan data.TreeItem, out chan<- error) { - for job := range trees { - debug.Log("check tree %v (tree %v, err %v)", job.ID, job.Tree, job.Error) - - var errs []error - if job.Error != nil { - errs = append(errs, job.Error) - } else { - errs = c.checkTree(job.ID, job.Tree) - } - - if len(errs) == 0 { - continue - } - treeError := &TreeError{ID: job.ID, Errors: errs} - select { - case <-ctx.Done(): - return - case out <- treeError: - debug.Log("tree %v: sent %d errors", treeError.ID, len(treeError.Errors)) - } - } -} - func loadSnapshotTreeIDs(ctx context.Context, lister restic.Lister, repo restic.LoaderUnpacked) (ids restic.IDs, errs []error) { err := data.ForAllSnapshots(ctx, lister, repo, nil, func(id restic.ID, sn *data.Snapshot, err error) error { if err != nil { @@ -171,6 +144,7 @@ func (c *Checker) Structure(ctx context.Context, p *progress.Counter, errChan ch p.SetMax(uint64(len(trees))) debug.Log("need to check %d trees from snapshots, %d errs returned", len(trees), len(errs)) + defer close(errChan) for _, err := range errs { select { case <-ctx.Done(): @@ -179,8 +153,7 @@ func (c *Checker) Structure(ctx context.Context, p *progress.Counter, errChan ch } } - wg, ctx := errgroup.WithContext(ctx) - treeStream := data.StreamTrees(ctx, wg, c.repo, trees, func(treeID restic.ID) bool { + err := data.StreamTrees(ctx, c.repo, trees, p, func(treeID restic.ID) bool { // blobRefs may be accessed in parallel by checkTree c.blobRefs.Lock() h := restic.BlobHandle{ID: treeID, Type: restic.TreeBlob} @@ -189,30 +162,46 @@ func (c *Checker) Structure(ctx context.Context, p *progress.Counter, errChan ch c.blobRefs.M.Insert(h) c.blobRefs.Unlock() return blobReferenced - }, p) + }, func(treeID restic.ID, err error, nodes data.TreeNodeIterator) error { + debug.Log("check tree %v (err %v)", treeID, err) - defer close(errChan) - // The checkTree worker only processes already decoded trees and is thus CPU-bound - workerCount := runtime.GOMAXPROCS(0) - for i := 0; i < workerCount; i++ { - wg.Go(func() error { - c.checkTreeWorker(ctx, treeStream, errChan) + var errs []error + if err != nil { + errs = append(errs, err) + } else { + errs = c.checkTree(treeID, nodes) + } + if len(errs) == 0 { return nil - }) - } + } - // the wait group should not return an error because no worker returns an + treeError := &TreeError{ID: treeID, Errors: errs} + select { + case <-ctx.Done(): + return nil + case errChan <- treeError: + debug.Log("tree %v: sent %d errors", treeError.ID, len(treeError.Errors)) + } + + return nil + }) + + // StreamTrees should not return an error because no worker returns an // error, so panic if that has changed somehow. - err := wg.Wait() if err != nil { panic(err) } } -func (c *Checker) checkTree(id restic.ID, tree *data.Tree) (errs []error) { +func (c *Checker) checkTree(id restic.ID, tree data.TreeNodeIterator) (errs []error) { debug.Log("checking tree %v", id) - for _, node := range tree.Nodes { + for item := range tree { + if item.Error != nil { + errs = append(errs, &Error{TreeID: id, Err: errors.Errorf("failed to decode tree %v: %w", id, item.Error)}) + break + } + node := item.Node switch node.Type { case data.NodeTypeFile: if node.Content == nil { diff --git a/internal/checker/checker_test.go b/internal/checker/checker_test.go index ea20b7302..106ffd6b3 100644 --- a/internal/checker/checker_test.go +++ b/internal/checker/checker_test.go @@ -522,15 +522,12 @@ func TestCheckerBlobTypeConfusion(t *testing.T) { Size: 42, Content: restic.IDs{restic.TestParseID("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef")}, } - damagedTree := &data.Tree{ - Nodes: []*data.Node{damagedNode}, - } + damagedNodes := []*data.Node{damagedNode} var id restic.ID test.OK(t, repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { - var err error - id, err = data.SaveTree(ctx, uploader, damagedTree) - return err + id = data.TestSaveNodes(t, ctx, uploader, damagedNodes) + return nil })) buf, err := repo.LoadBlob(ctx, restic.TreeBlob, id, nil) @@ -556,15 +553,12 @@ func TestCheckerBlobTypeConfusion(t *testing.T) { Subtree: &id, } - rootTree := &data.Tree{ - Nodes: []*data.Node{malNode, dirNode}, - } + rootNodes := []*data.Node{malNode, dirNode} var rootID restic.ID test.OK(t, repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { - var err error - rootID, err = data.SaveTree(ctx, uploader, rootTree) - return err + rootID = data.TestSaveNodes(t, ctx, uploader, rootNodes) + return nil })) snapshot, err := data.NewSnapshot([]string{"/damaged"}, []string{"test"}, "foo", time.Now()) diff --git a/internal/data/find.go b/internal/data/find.go index 14d64670e..5fadcba93 100644 --- a/internal/data/find.go +++ b/internal/data/find.go @@ -6,7 +6,6 @@ import ( "github.com/restic/restic/internal/restic" "github.com/restic/restic/internal/ui/progress" - "golang.org/x/sync/errgroup" ) // FindUsedBlobs traverses the tree ID and adds all seen blobs (trees and data @@ -14,8 +13,7 @@ import ( func FindUsedBlobs(ctx context.Context, repo restic.Loader, treeIDs restic.IDs, blobs restic.FindBlobSet, p *progress.Counter) error { var lock sync.Mutex - wg, ctx := errgroup.WithContext(ctx) - treeStream := StreamTrees(ctx, wg, repo, treeIDs, func(treeID restic.ID) bool { + return StreamTrees(ctx, repo, treeIDs, p, func(treeID restic.ID) bool { // locking is necessary the goroutine below concurrently adds data blobs lock.Lock() h := restic.BlobHandle{ID: treeID, Type: restic.TreeBlob} @@ -24,26 +22,24 @@ func FindUsedBlobs(ctx context.Context, repo restic.Loader, treeIDs restic.IDs, blobs.Insert(h) lock.Unlock() return blobReferenced - }, p) + }, func(_ restic.ID, err error, nodes TreeNodeIterator) error { + if err != nil { + return err + } - wg.Go(func() error { - for tree := range treeStream { - if tree.Error != nil { - return tree.Error + for item := range nodes { + if item.Error != nil { + return item.Error } - lock.Lock() - for _, node := range tree.Nodes { - switch node.Type { - case NodeTypeFile: - for _, blob := range node.Content { - blobs.Insert(restic.BlobHandle{ID: blob, Type: restic.DataBlob}) - } + switch item.Node.Type { + case NodeTypeFile: + for _, blob := range item.Node.Content { + blobs.Insert(restic.BlobHandle{ID: blob, Type: restic.DataBlob}) } } lock.Unlock() } return nil }) - return wg.Wait() } diff --git a/internal/data/testdata/used_blobs_snapshot0 b/internal/data/testdata/used_blobs_snapshot0 index cc789f043..85ac807f0 100644 --- a/internal/data/testdata/used_blobs_snapshot0 +++ b/internal/data/testdata/used_blobs_snapshot0 @@ -1,9 +1,9 @@ {"ID":"05bddd650a800f83f7c0d844cecb1e02f99ce962df5652a53842be50386078e1","Type":"data"} {"ID":"087040b12f129e89e4eab2b86aa14467404366a17a6082efb0d11fa7e2f9f58e","Type":"data"} -{"ID":"08a650e4d7575177ddeabf6a96896b76fa7e621aa3dd75e77293f22ce6c0c420","Type":"tree"} {"ID":"1e0f0e5799b9d711e07883050366c7eee6b7481c0d884694093149f6c4e9789a","Type":"data"} -{"ID":"435b9207cd489b41a7d119e0d75eab2a861e2b3c8d4d12ac51873ff76be0cf73","Type":"tree"} +{"ID":"3cb26562e6849003adffc5e1dcf9a58a9d151ea4203bd451f6359d7cc0328104","Type":"tree"} {"ID":"4719f8a039f5b745e16cf90e5b84c9255c290d500da716f7dd25909cdabb85b6","Type":"data"} +{"ID":"4bb52083c8a467921e8ed4139f7be3e282bad8d25d0056145eadd3962aed0127","Type":"tree"} {"ID":"4e352975938a29711c3003c498185972235af261a6cf8cf700a8a6ee4f914b05","Type":"data"} {"ID":"606772eacb7fe1a79267088dcadd13431914854faf1d39d47fe99a26b9fecdcb","Type":"data"} {"ID":"6b5fd3a9baf615489c82a99a71f9917bf9a2d82d5f640d7f47d175412c4b8d19","Type":"data"} @@ -14,10 +14,10 @@ {"ID":"a69c8621776ca8bb34c6c90e5ad811ddc8e2e5cfd6bb0cec5e75cca70e0b9ade","Type":"data"} {"ID":"b11f4dd9d2722b3325186f57cd13a71a3af7791118477f355b49d101104e4c22","Type":"data"} {"ID":"b1f2ae9d748035e5bd9a87f2579405166d150c6560d8919496f02855e1c36cf9","Type":"data"} +{"ID":"b326b56e1b4c5c3b80e449fc40abcada21b5bd7ff12ce02236a2d289b89dcea7","Type":"tree"} {"ID":"b5ba06039224566a09555abd089de7a693660154991295122fa72b0a3adc4150","Type":"data"} {"ID":"b7040572b44cbfea8b784ecf8679c3d75cefc1cd3d12ed783ca0d8e5d124a60f","Type":"data"} {"ID":"b9e634143719742fe77feed78b61f09573d59d2efa23d6d54afe6c159d220503","Type":"data"} {"ID":"ca896fc9ebf95fcffd7c768b07b92110b21e332a47fef7e382bf15363b0ece1a","Type":"data"} {"ID":"e6fe3512ea23a4ebf040d30958c669f7ffe724400f155a756467a9f3cafc27c5","Type":"data"} {"ID":"ed00928ce97ac5acd27c862d9097e606536e9063af1c47481257811f66260f3a","Type":"data"} -{"ID":"fb62dd9093c4958b019b90e591b2d36320ff381a24bdc9c5db3b8960ff94d174","Type":"tree"} diff --git a/internal/data/testdata/used_blobs_snapshot1 b/internal/data/testdata/used_blobs_snapshot1 index aa840294a..e7d66c7ef 100644 --- a/internal/data/testdata/used_blobs_snapshot1 +++ b/internal/data/testdata/used_blobs_snapshot1 @@ -1,6 +1,8 @@ {"ID":"05bddd650a800f83f7c0d844cecb1e02f99ce962df5652a53842be50386078e1","Type":"data"} {"ID":"18dcaa1a676823c909aafabbb909652591915eebdde4f9a65cee955157583494","Type":"data"} +{"ID":"428b3f50bcfdcb9a85b87f9401d8947b2c8a2f807c19c00491626e3ee890075e","Type":"tree"} {"ID":"4719f8a039f5b745e16cf90e5b84c9255c290d500da716f7dd25909cdabb85b6","Type":"data"} +{"ID":"55416727dd211e5f208b70fc9a3d60d34484626279717be87843a7535f997404","Type":"tree"} {"ID":"6824d08e63a598c02b364e25f195e64758494b5944f06c921ff30029e1e4e4bf","Type":"data"} {"ID":"72b6eb0fd0d87e00392f8b91efc1a4c3f7f5c0c76f861b38aea054bc9d43463b","Type":"data"} {"ID":"8192279e4b56e1644dcff715d5e08d875cd5713349139d36d142ed28364d8e00","Type":"data"} @@ -10,6 +12,4 @@ {"ID":"ca896fc9ebf95fcffd7c768b07b92110b21e332a47fef7e382bf15363b0ece1a","Type":"data"} {"ID":"cc4cab5b20a3a88995f8cdb8b0698d67a32dbc5b54487f03cb612c30a626af39","Type":"data"} {"ID":"e6fe3512ea23a4ebf040d30958c669f7ffe724400f155a756467a9f3cafc27c5","Type":"data"} -{"ID":"e9f3c4fe78e903cba60d310a9668c42232c8274b3f29b5ecebb6ff1aaeabd7e3","Type":"tree"} {"ID":"ed00928ce97ac5acd27c862d9097e606536e9063af1c47481257811f66260f3a","Type":"data"} -{"ID":"ff58f76c2313e68aa9aaaece855183855ac4ff682910404c2ae33dc999ebaca2","Type":"tree"} diff --git a/internal/data/testdata/used_blobs_snapshot2 b/internal/data/testdata/used_blobs_snapshot2 index 3ed193f53..029bcbc6e 100644 --- a/internal/data/testdata/used_blobs_snapshot2 +++ b/internal/data/testdata/used_blobs_snapshot2 @@ -1,24 +1,24 @@ {"ID":"05bddd650a800f83f7c0d844cecb1e02f99ce962df5652a53842be50386078e1","Type":"data"} {"ID":"087040b12f129e89e4eab2b86aa14467404366a17a6082efb0d11fa7e2f9f58e","Type":"data"} {"ID":"0b88f99abc5ac71c54b3e8263c52ecb7d8903462779afdb3c8176ec5c4bb04fb","Type":"data"} -{"ID":"0e1a817fca83f569d1733b11eba14b6c9b176e41bca3644eed8b29cb907d84d3","Type":"tree"} {"ID":"1e0f0e5799b9d711e07883050366c7eee6b7481c0d884694093149f6c4e9789a","Type":"data"} {"ID":"27917462f89cecae77a4c8fb65a094b9b75a917f13794c628b1640b17f4c4981","Type":"data"} +{"ID":"2b8ebd79732fcfec50ec94429cfd404d531c93defed78832a597d0fe8de64b96","Type":"tree"} {"ID":"32745e4b26a5883ecec272c9fbfe7f3c9835c9ab41c9a2baa4d06f319697a0bd","Type":"data"} {"ID":"4719f8a039f5b745e16cf90e5b84c9255c290d500da716f7dd25909cdabb85b6","Type":"data"} {"ID":"4e352975938a29711c3003c498185972235af261a6cf8cf700a8a6ee4f914b05","Type":"data"} {"ID":"6824d08e63a598c02b364e25f195e64758494b5944f06c921ff30029e1e4e4bf","Type":"data"} {"ID":"6b5fd3a9baf615489c82a99a71f9917bf9a2d82d5f640d7f47d175412c4b8d19","Type":"data"} +{"ID":"721d803612a2565f9be9581048c5d899c14a65129dabafbb0e43bba89684a63a","Type":"tree"} +{"ID":"9103a50221ecf6684c1e1652adacb4a85afb322d81a74ecd7477930ecf4774fc","Type":"tree"} {"ID":"95c97192efa810ccb1cee112238dca28673fbffce205d75ce8cc990a31005a51","Type":"data"} {"ID":"99dab094430d3c1be22c801a6ad7364d490a8d2ce3f9dfa3d2677431446925f4","Type":"data"} {"ID":"a4c97189465344038584e76c965dd59100eaed051db1fa5ba0e143897e2c87f1","Type":"data"} {"ID":"a69c8621776ca8bb34c6c90e5ad811ddc8e2e5cfd6bb0cec5e75cca70e0b9ade","Type":"data"} +{"ID":"ac08ce34ba4f8123618661bef2425f7028ffb9ac740578a3ee88684d2523fee8","Type":"tree"} +{"ID":"b326b56e1b4c5c3b80e449fc40abcada21b5bd7ff12ce02236a2d289b89dcea7","Type":"tree"} {"ID":"b6a7e8d2aa717e0a6bd68abab512c6b566074b5a6ca2edf4cd446edc5857d732","Type":"data"} -{"ID":"bad84ed273c5fbfb40aa839a171675b7f16f5e67f3eaf4448730caa0ee27297c","Type":"tree"} -{"ID":"bfc2fdb527b0c9f66bbb8d4ff1c44023cc2414efcc7f0831c10debab06bb4388","Type":"tree"} {"ID":"ca896fc9ebf95fcffd7c768b07b92110b21e332a47fef7e382bf15363b0ece1a","Type":"data"} -{"ID":"d1d3137eb08de6d8c5d9f44788c45a9fea9bb082e173bed29a0945b3347f2661","Type":"tree"} {"ID":"e6fe3512ea23a4ebf040d30958c669f7ffe724400f155a756467a9f3cafc27c5","Type":"data"} {"ID":"ed00928ce97ac5acd27c862d9097e606536e9063af1c47481257811f66260f3a","Type":"data"} {"ID":"f3cd67d9c14d2a81663d63522ab914e465b021a3b65e2f1ea6caf7478f2ec139","Type":"data"} -{"ID":"fb62dd9093c4958b019b90e591b2d36320ff381a24bdc9c5db3b8960ff94d174","Type":"tree"} diff --git a/internal/data/testing.go b/internal/data/testing.go index 8187833a6..19f7fa7b3 100644 --- a/internal/data/testing.go +++ b/internal/data/testing.go @@ -5,12 +5,14 @@ import ( "fmt" "io" "math/rand" + "slices" + "strings" "testing" "time" "github.com/restic/chunker" "github.com/restic/restic/internal/restic" - "github.com/restic/restic/internal/test" + rtest "github.com/restic/restic/internal/test" ) // fakeFile returns a reader which yields deterministic pseudo-random data. @@ -72,22 +74,21 @@ func (fs *fakeFileSystem) saveTree(ctx context.Context, uploader restic.BlobSave rnd := rand.NewSource(seed) numNodes := int(rnd.Int63() % maxNodes) - var tree Tree + var nodes []*Node for i := 0; i < numNodes; i++ { - // randomly select the type of the node, either tree (p = 1/4) or file (p = 3/4). if depth > 1 && rnd.Int63()%4 == 0 { treeSeed := rnd.Int63() % maxSeed id := fs.saveTree(ctx, uploader, treeSeed, depth-1) node := &Node{ - Name: fmt.Sprintf("dir-%v", treeSeed), + Name: fmt.Sprintf("dir-%v", i), Type: NodeTypeDir, Mode: 0755, Subtree: &id, } - tree.Nodes = append(tree.Nodes, node) + nodes = append(nodes, node) continue } @@ -95,22 +96,31 @@ func (fs *fakeFileSystem) saveTree(ctx context.Context, uploader restic.BlobSave fileSize := (maxFileSize / maxSeed) * fileSeed node := &Node{ - Name: fmt.Sprintf("file-%v", fileSeed), + Name: fmt.Sprintf("file-%v", i), Type: NodeTypeFile, Mode: 0644, Size: uint64(fileSize), } node.Content = fs.saveFile(ctx, uploader, fakeFile(fileSeed, fileSize)) - tree.Nodes = append(tree.Nodes, node) + nodes = append(nodes, node) } - tree.Sort() + return TestSaveNodes(fs.t, ctx, uploader, nodes) +} - id, err := SaveTree(ctx, uploader, &tree) - if err != nil { - fs.t.Fatalf("SaveTree returned error: %v", err) +//nolint:revive // as this is a test helper, t should go first +func TestSaveNodes(t testing.TB, ctx context.Context, uploader restic.BlobSaver, nodes []*Node) restic.ID { + slices.SortFunc(nodes, func(a, b *Node) int { + return strings.Compare(a.Name, b.Name) + }) + treeWriter := NewTreeWriter(uploader) + for _, node := range nodes { + err := treeWriter.AddNode(node) + rtest.OK(t, err) } + id, err := treeWriter.Finalize(ctx) + rtest.OK(t, err) return id } @@ -136,7 +146,7 @@ func TestCreateSnapshot(t testing.TB, repo restic.Repository, at time.Time, dept } var treeID restic.ID - test.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { + rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { treeID = fs.saveTree(ctx, uploader, seed, depth) return nil })) @@ -188,3 +198,49 @@ func TestLoadAllSnapshots(ctx context.Context, repo restic.ListerLoaderUnpacked, return snapshots, nil } + +// TestTreeMap returns the trees from the map on LoadTree. +type TestTreeMap map[restic.ID][]byte + +func (t TestTreeMap) LoadBlob(_ context.Context, tpe restic.BlobType, id restic.ID, _ []byte) ([]byte, error) { + if tpe != restic.TreeBlob { + return nil, fmt.Errorf("can only load trees") + } + tree, ok := t[id] + if !ok { + return nil, fmt.Errorf("tree not found") + } + return tree, nil +} + +func (t TestTreeMap) Connections() uint { + return 2 +} + +// TestWritableTreeMap also support saving +type TestWritableTreeMap struct { + TestTreeMap +} + +func (t TestWritableTreeMap) SaveBlob(_ context.Context, tpe restic.BlobType, buf []byte, id restic.ID, _ bool) (newID restic.ID, known bool, size int, err error) { + if tpe != restic.TreeBlob { + return restic.ID{}, false, 0, fmt.Errorf("can only save trees") + } + + if id.IsNull() { + id = restic.Hash(buf) + } + _, ok := t.TestTreeMap[id] + if ok { + return id, false, 0, nil + } + + t.TestTreeMap[id] = append([]byte{}, buf...) + return id, true, len(buf), nil +} + +func (t TestWritableTreeMap) Dump(test testing.TB) { + for k, v := range t.TestTreeMap { + test.Logf("%v: %v", k, string(v)) + } +} diff --git a/internal/data/tree.go b/internal/data/tree.go index 9031f3bf5..1bfcbf660 100644 --- a/internal/data/tree.go +++ b/internal/data/tree.go @@ -5,142 +5,237 @@ import ( "context" "encoding/json" "fmt" + + "io" + "iter" "path" - "sort" "strings" "github.com/restic/restic/internal/errors" "github.com/restic/restic/internal/restic" - - "github.com/restic/restic/internal/debug" ) -// Tree is an ordered list of nodes. -type Tree struct { - Nodes []*Node `json:"nodes"` +// For documentation purposes only: +// // Tree is an ordered list of nodes. +// type Tree struct { +// Nodes []*Node `json:"nodes"` +// } + +var ErrTreeNotOrdered = errors.New("nodes are not ordered or duplicate") + +type treeIterator struct { + dec json.Decoder + started bool } -// NewTree creates a new tree object with the given initial capacity. -func NewTree(capacity int) *Tree { - return &Tree{ - Nodes: make([]*Node, 0, capacity), - } +type NodeOrError struct { + Node *Node + Error error } -func (t *Tree) String() string { - return fmt.Sprintf("Tree<%d nodes>", len(t.Nodes)) -} +type TreeNodeIterator = iter.Seq[NodeOrError] -// Equals returns true if t and other have exactly the same nodes. -func (t *Tree) Equals(other *Tree) bool { - if len(t.Nodes) != len(other.Nodes) { - debug.Log("tree.Equals(): trees have different number of nodes") - return false +func NewTreeNodeIterator(rd io.Reader) (TreeNodeIterator, error) { + t := &treeIterator{ + dec: *json.NewDecoder(rd), } - for i := 0; i < len(t.Nodes); i++ { - if !t.Nodes[i].Equals(*other.Nodes[i]) { - debug.Log("tree.Equals(): node %d is different:", i) - debug.Log(" %#v", t.Nodes[i]) - debug.Log(" %#v", other.Nodes[i]) - return false + err := t.init() + if err != nil { + return nil, err + } + + return func(yield func(NodeOrError) bool) { + if t.started { + panic("tree iterator is single use only") + } + t.started = true + for { + n, err := t.next() + if err != nil && errors.Is(err, io.EOF) { + return + } + if !yield(NodeOrError{Node: n, Error: err}) { + return + } + // errors are final + if err != nil { + return + } + } + }, nil +} + +func (t *treeIterator) init() error { + // A tree is expected to be encoded as a JSON object with a single key "nodes". + // However, for future-proofness, we allow unknown keys before and after the "nodes" key. + // The following is the expected format: + // `{"nodes":[...]}` + + if err := t.assertToken(json.Delim('{')); err != nil { + return err + } + // Skip unknown keys until we find "nodes" + for { + token, err := t.dec.Token() + if err != nil { + return err + } + key, ok := token.(string) + if !ok { + return errors.Errorf("error decoding tree: expected string key, got %v", token) + } + if key == "nodes" { + // Found "nodes", proceed to read the array + if err := t.assertToken(json.Delim('[')); err != nil { + return err + } + return nil + } + // Unknown key, decode its value into RawMessage and discard it + var raw json.RawMessage + if err := t.dec.Decode(&raw); err != nil { + return err } } - - return true } -// Insert adds a new node at the correct place in the tree. -func (t *Tree) Insert(node *Node) error { - pos, found := t.find(node.Name) - if found != nil { - return errors.Errorf("node %q already present", node.Name) +func (t *treeIterator) next() (*Node, error) { + if t.dec.More() { + var n Node + err := t.dec.Decode(&n) + if err != nil { + return nil, err + } + return &n, nil } - // https://github.com/golang/go/wiki/SliceTricks - t.Nodes = append(t.Nodes, nil) - copy(t.Nodes[pos+1:], t.Nodes[pos:]) - t.Nodes[pos] = node + if err := t.assertToken(json.Delim(']')); err != nil { + return nil, err + } + // Skip unknown keys after the array until we find the closing brace + for { + token, err := t.dec.Token() + if err != nil { + return nil, err + } + if token == json.Delim('}') { + return nil, io.EOF + } + // We have an unknown key, decode its value into RawMessage and discard it + var raw json.RawMessage + if err := t.dec.Decode(&raw); err != nil { + return nil, err + } + } +} +func (t *treeIterator) assertToken(token json.Token) error { + to, err := t.dec.Token() + if err != nil { + return err + } + if to != token { + return errors.Errorf("error decoding tree: expected %v, got %v", token, to) + } return nil } -func (t *Tree) find(name string) (int, *Node) { - pos := sort.Search(len(t.Nodes), func(i int) bool { - return t.Nodes[i].Name >= name - }) - - if pos < len(t.Nodes) && t.Nodes[pos].Name == name { - return pos, t.Nodes[pos] +func LoadTree(ctx context.Context, loader restic.BlobLoader, content restic.ID) (TreeNodeIterator, error) { + rd, err := loader.LoadBlob(ctx, restic.TreeBlob, content, nil) + if err != nil { + return nil, err } - - return pos, nil + return NewTreeNodeIterator(bytes.NewReader(rd)) } -// Find returns a node with the given name, or nil if none could be found. -func (t *Tree) Find(name string) *Node { - if t == nil { - return nil +type TreeFinder struct { + next func() (NodeOrError, bool) + stop func() + current *Node + last string +} + +func NewTreeFinder(tree TreeNodeIterator) *TreeFinder { + if tree == nil { + return &TreeFinder{stop: func() {}} } - - _, node := t.find(name) - return node + next, stop := iter.Pull(tree) + return &TreeFinder{next: next, stop: stop} } -// Sort sorts the nodes by name. -func (t *Tree) Sort() { - list := Nodes(t.Nodes) - sort.Sort(list) - t.Nodes = list -} - -// Subtrees returns a slice of all subtree IDs of the tree. -func (t *Tree) Subtrees() (trees restic.IDs) { - for _, node := range t.Nodes { - if node.Type == NodeTypeDir && node.Subtree != nil { - trees = append(trees, *node.Subtree) +// Find finds the node with the given name. If the node is not found, it returns nil. +// If Find was called before, the new name must be strictly greater than the last name. +func (t *TreeFinder) Find(name string) (*Node, error) { + if t.next == nil { + return nil, nil + } + if name <= t.last { + return nil, errors.Errorf("name %q is not greater than last name %q", name, t.last) + } + t.last = name + // loop until `t.current.Name` is >= name + for t.current == nil || t.current.Name < name { + current, ok := t.next() + if current.Error != nil { + return nil, current.Error } + if !ok { + return nil, nil + } + t.current = current.Node } - return trees + if t.current.Name == name { + // forget the current node to free memory as early as possible + current := t.current + t.current = nil + return current, nil + } + // we have already passed the name + return nil, nil } -// LoadTree loads a tree from the repository. -func LoadTree(ctx context.Context, r restic.BlobLoader, id restic.ID) (*Tree, error) { - debug.Log("load tree %v", id) - - buf, err := r.LoadBlob(ctx, restic.TreeBlob, id, nil) - if err != nil { - return nil, err - } - - t := &Tree{} - err = json.Unmarshal(buf, t) - if err != nil { - return nil, err - } - - return t, nil +func (t *TreeFinder) Close() { + t.stop() } -// SaveTree stores a tree into the repository and returns the ID. The ID is -// checked against the index. The tree is only stored when the index does not -// contain the ID. -func SaveTree(ctx context.Context, r restic.BlobSaver, t *Tree) (restic.ID, error) { - buf, err := json.Marshal(t) +type TreeWriter struct { + builder *TreeJSONBuilder + saver restic.BlobSaver +} + +func NewTreeWriter(saver restic.BlobSaver) *TreeWriter { + builder := NewTreeJSONBuilder() + return &TreeWriter{builder: builder, saver: saver} +} + +func (t *TreeWriter) AddNode(node *Node) error { + return t.builder.AddNode(node) +} + +func (t *TreeWriter) Finalize(ctx context.Context) (restic.ID, error) { + buf, err := t.builder.Finalize() if err != nil { - return restic.ID{}, errors.Wrap(err, "MarshalJSON") + return restic.ID{}, err } - - // append a newline so that the data is always consistent (json.Encoder - // adds a newline after each object) - buf = append(buf, '\n') - - id, _, _, err := r.SaveBlob(ctx, restic.TreeBlob, buf, restic.ID{}, false) + id, _, _, err := t.saver.SaveBlob(ctx, restic.TreeBlob, buf, restic.ID{}, false) return id, err } -var ErrTreeNotOrdered = errors.New("nodes are not ordered or duplicate") +func SaveTree(ctx context.Context, saver restic.BlobSaver, nodes TreeNodeIterator) (restic.ID, error) { + treeWriter := NewTreeWriter(saver) + for item := range nodes { + if item.Error != nil { + return restic.ID{}, item.Error + } + err := treeWriter.AddNode(item.Node) + if err != nil { + return restic.ID{}, err + } + } + return treeWriter.Finalize(ctx) +} type TreeJSONBuilder struct { buf bytes.Buffer @@ -197,7 +292,12 @@ func FindTreeDirectory(ctx context.Context, repo restic.BlobLoader, id *restic.I if err != nil { return nil, fmt.Errorf("path %s: %w", subfolder, err) } - node := tree.Find(name) + finder := NewTreeFinder(tree) + node, err := finder.Find(name) + finder.Close() + if err != nil { + return nil, fmt.Errorf("path %s: %w", subfolder, err) + } if node == nil { return nil, fmt.Errorf("path %s: not found", subfolder) } @@ -208,3 +308,105 @@ func FindTreeDirectory(ctx context.Context, repo restic.BlobLoader, id *restic.I } return id, nil } + +type peekableNodeIterator struct { + iter func() (NodeOrError, bool) + stop func() + value *Node +} + +func newPeekableNodeIterator(tree TreeNodeIterator) (*peekableNodeIterator, error) { + iter, stop := iter.Pull(tree) + it := &peekableNodeIterator{iter: iter, stop: stop} + err := it.Next() + if err != nil { + it.Close() + return nil, err + } + return it, nil +} + +func (i *peekableNodeIterator) Next() error { + item, ok := i.iter() + if item.Error != nil || !ok { + i.value = nil + return item.Error + } + i.value = item.Node + return nil +} + +func (i *peekableNodeIterator) Peek() *Node { + return i.value +} + +func (i *peekableNodeIterator) Close() { + i.stop() +} + +type DualTree struct { + Tree1 *Node + Tree2 *Node + Error error +} + +// DualTreeIterator iterates over two trees in parallel. It returns a sequence of DualTree structs. +// The sequence is terminated when both trees are exhausted. The error field must be checked before +// accessing any of the nodes. +func DualTreeIterator(tree1, tree2 TreeNodeIterator) iter.Seq[DualTree] { + started := false + return func(yield func(DualTree) bool) { + if started { + panic("tree iterator is single use only") + } + started = true + iter1, err := newPeekableNodeIterator(tree1) + if err != nil { + yield(DualTree{Tree1: nil, Tree2: nil, Error: err}) + return + } + defer iter1.Close() + iter2, err := newPeekableNodeIterator(tree2) + if err != nil { + yield(DualTree{Tree1: nil, Tree2: nil, Error: err}) + return + } + defer iter2.Close() + + for { + node1 := iter1.Peek() + node2 := iter2.Peek() + if node1 == nil && node2 == nil { + // both iterators are exhausted + break + } else if node1 != nil && node2 != nil { + // if both nodes have a different name, only keep the first one + if node1.Name < node2.Name { + node2 = nil + } else if node1.Name > node2.Name { + node1 = nil + } + } + + // non-nil nodes will be processed in the following, so advance the corresponding iterator + if node1 != nil { + if err = iter1.Next(); err != nil { + break + } + } + if node2 != nil { + if err = iter2.Next(); err != nil { + break + } + } + + if !yield(DualTree{Tree1: node1, Tree2: node2, Error: err}) { + return + } + } + if err != nil { + yield(DualTree{Tree1: nil, Tree2: nil, Error: err}) + return + } + } +} diff --git a/internal/data/tree_stream.go b/internal/data/tree_stream.go index c7d3588b5..1f832a731 100644 --- a/internal/data/tree_stream.go +++ b/internal/data/tree_stream.go @@ -2,26 +2,20 @@ package data import ( "context" - "errors" "runtime" "sync" "github.com/restic/restic/internal/debug" + "github.com/restic/restic/internal/errors" "github.com/restic/restic/internal/restic" "github.com/restic/restic/internal/ui/progress" "golang.org/x/sync/errgroup" ) -// TreeItem is used to return either an error or the tree for a tree id -type TreeItem struct { - restic.ID - Error error - *Tree -} - type trackedTreeItem struct { - TreeItem - rootIdx int + restic.ID + Subtrees restic.IDs + rootIdx int } type trackedID struct { @@ -29,35 +23,87 @@ type trackedID struct { rootIdx int } +// subtreesCollector wraps a TreeNodeIterator and returns a new iterator that collects the subtrees. +func subtreesCollector(tree TreeNodeIterator) (TreeNodeIterator, func() restic.IDs) { + subtrees := restic.IDs{} + isComplete := false + + return func(yield func(NodeOrError) bool) { + for item := range tree { + if !yield(item) { + return + } + // be defensive and check for nil subtree as this code is also used by the checker + if item.Node != nil && item.Node.Type == NodeTypeDir && item.Node.Subtree != nil { + subtrees = append(subtrees, *item.Node.Subtree) + } + } + isComplete = true + }, func() restic.IDs { + if !isComplete { + panic("tree was not read completely") + } + return subtrees + } +} + // loadTreeWorker loads trees from repo and sends them to out. -func loadTreeWorker(ctx context.Context, repo restic.Loader, - in <-chan trackedID, out chan<- trackedTreeItem) { +func loadTreeWorker( + ctx context.Context, + repo restic.Loader, + in <-chan trackedID, + process func(id restic.ID, error error, nodes TreeNodeIterator) error, + out chan<- trackedTreeItem, +) error { for treeID := range in { tree, err := LoadTree(ctx, repo, treeID.ID) + if tree == nil && err == nil { + err = errors.New("tree is nil and error is nil") + } debug.Log("load tree %v (%v) returned err: %v", tree, treeID, err) - job := trackedTreeItem{TreeItem: TreeItem{ID: treeID.ID, Error: err, Tree: tree}, rootIdx: treeID.rootIdx} + + // wrap iterator to collect subtrees while `process` iterates over `tree` + var collectSubtrees func() restic.IDs + if tree != nil { + tree, collectSubtrees = subtreesCollector(tree) + } + + err = process(treeID.ID, err, tree) + if err != nil { + return err + } + + // assume that the number of subtrees is within reasonable limits, such that the memory usage is not a problem + var subtrees restic.IDs + if collectSubtrees != nil { + subtrees = collectSubtrees() + } + + job := trackedTreeItem{ID: treeID.ID, Subtrees: subtrees, rootIdx: treeID.rootIdx} select { case <-ctx.Done(): - return + return nil case out <- job: } } + return nil } +// filterTree receives the result of a tree load and queues new trees for loading and processing. func filterTrees(ctx context.Context, repo restic.Loader, trees restic.IDs, loaderChan chan<- trackedID, hugeTreeLoaderChan chan<- trackedID, - in <-chan trackedTreeItem, out chan<- TreeItem, skip func(tree restic.ID) bool, p *progress.Counter) { + in <-chan trackedTreeItem, skip func(tree restic.ID) bool, p *progress.Counter) { var ( inCh = in - outCh chan<- TreeItem loadCh chan<- trackedID - job TreeItem nextTreeID trackedID outstandingLoadTreeJobs = 0 ) + // tracks how many trees are currently waiting to be processed for a given root tree rootCounter := make([]int, len(trees)) + // build initial backlog backlog := make([]trackedID, 0, len(trees)) for idx, id := range trees { backlog = append(backlog, trackedID{ID: id, rootIdx: idx}) @@ -65,6 +111,7 @@ func filterTrees(ctx context.Context, repo restic.Loader, trees restic.IDs, load } for { + // if no tree is waiting to be sent, pick the next one if loadCh == nil && len(backlog) > 0 { // process last added ids first, that is traverse the tree in depth-first order ln := len(backlog) - 1 @@ -86,7 +133,8 @@ func filterTrees(ctx context.Context, repo restic.Loader, trees restic.IDs, load } } - if loadCh == nil && outCh == nil && outstandingLoadTreeJobs == 0 { + // loadCh is only nil at this point if the backlog is empty + if loadCh == nil && outstandingLoadTreeJobs == 0 { debug.Log("backlog is empty, all channels nil, exiting") return } @@ -103,7 +151,6 @@ func filterTrees(ctx context.Context, repo restic.Loader, trees restic.IDs, load if !ok { debug.Log("input channel closed") inCh = nil - in = nil continue } @@ -111,58 +158,47 @@ func filterTrees(ctx context.Context, repo restic.Loader, trees restic.IDs, load rootCounter[j.rootIdx]-- debug.Log("input job tree %v", j.ID) - - if j.Error != nil { - debug.Log("received job with error: %v (tree %v, ID %v)", j.Error, j.Tree, j.ID) - } else if j.Tree == nil { - debug.Log("received job with nil tree pointer: %v (ID %v)", j.Error, j.ID) - // send a new job with the new error instead of the old one - j = trackedTreeItem{TreeItem: TreeItem{ID: j.ID, Error: errors.New("tree is nil and error is nil")}, rootIdx: j.rootIdx} - } else { - subtrees := j.Tree.Subtrees() - debug.Log("subtrees for tree %v: %v", j.ID, subtrees) - // iterate backwards over subtree to compensate backwards traversal order of nextTreeID selection - for i := len(subtrees) - 1; i >= 0; i-- { - id := subtrees[i] - if id.IsNull() { - // We do not need to raise this error here, it is - // checked when the tree is checked. Just make sure - // that we do not add any null IDs to the backlog. - debug.Log("tree %v has nil subtree", j.ID) - continue - } - backlog = append(backlog, trackedID{ID: id, rootIdx: j.rootIdx}) - rootCounter[j.rootIdx]++ + // iterate backwards over subtree to compensate backwards traversal order of nextTreeID selection + for i := len(j.Subtrees) - 1; i >= 0; i-- { + id := j.Subtrees[i] + if id.IsNull() { + // We do not need to raise this error here, it is + // checked when the tree is checked. Just make sure + // that we do not add any null IDs to the backlog. + debug.Log("tree %v has nil subtree", j.ID) + continue } + backlog = append(backlog, trackedID{ID: id, rootIdx: j.rootIdx}) + rootCounter[j.rootIdx]++ } + // the progress check must happen after j.Subtrees was added to the backlog if p != nil && rootCounter[j.rootIdx] == 0 { p.Add(1) } - - job = j.TreeItem - outCh = out - inCh = nil - - case outCh <- job: - debug.Log("tree sent to process: %v", job.ID) - outCh = nil - inCh = in } } } // StreamTrees iteratively loads the given trees and their subtrees. The skip method -// is guaranteed to always be called from the same goroutine. To shutdown the started -// goroutines, either read all items from the channel or cancel the context. Then `Wait()` -// on the errgroup until all goroutines were stopped. -func StreamTrees(ctx context.Context, wg *errgroup.Group, repo restic.Loader, trees restic.IDs, skip func(tree restic.ID) bool, p *progress.Counter) <-chan TreeItem { +// is guaranteed to always be called from the same goroutine. The process function is +// directly called from the worker goroutines. It MUST read `nodes` until it returns an +// error or completes. If the process function returns an error, then StreamTrees will +// abort and return the error. +func StreamTrees( + ctx context.Context, + repo restic.Loader, + trees restic.IDs, + p *progress.Counter, + skip func(tree restic.ID) bool, + process func(id restic.ID, error error, nodes TreeNodeIterator) error, +) error { loaderChan := make(chan trackedID) hugeTreeChan := make(chan trackedID, 10) loadedTreeChan := make(chan trackedTreeItem) - treeStream := make(chan TreeItem) var loadTreeWg sync.WaitGroup + wg, ctx := errgroup.WithContext(ctx) // decoding a tree can take quite some time such that this can be both CPU- or IO-bound // one extra worker to handle huge tree blobs workerCount := int(repo.Connections()) + runtime.GOMAXPROCS(0) + 1 @@ -174,8 +210,7 @@ func StreamTrees(ctx context.Context, wg *errgroup.Group, repo restic.Loader, tr loadTreeWg.Add(1) wg.Go(func() error { defer loadTreeWg.Done() - loadTreeWorker(ctx, repo, workerLoaderChan, loadedTreeChan) - return nil + return loadTreeWorker(ctx, repo, workerLoaderChan, process, loadedTreeChan) }) } @@ -189,9 +224,8 @@ func StreamTrees(ctx context.Context, wg *errgroup.Group, repo restic.Loader, tr wg.Go(func() error { defer close(loaderChan) defer close(hugeTreeChan) - defer close(treeStream) - filterTrees(ctx, repo, trees, loaderChan, hugeTreeChan, loadedTreeChan, treeStream, skip, p) + filterTrees(ctx, repo, trees, loaderChan, hugeTreeChan, loadedTreeChan, skip, p) return nil }) - return treeStream + return wg.Wait() } diff --git a/internal/data/tree_test.go b/internal/data/tree_test.go index 054cf7c0a..47fc4b9a0 100644 --- a/internal/data/tree_test.go +++ b/internal/data/tree_test.go @@ -4,9 +4,12 @@ import ( "context" "encoding/json" "errors" + "fmt" "os" "path/filepath" + "slices" "strconv" + "strings" "testing" "github.com/restic/restic/internal/archiver" @@ -105,37 +108,45 @@ func TestNodeComparison(t *testing.T) { func TestEmptyLoadTree(t *testing.T) { repo := repository.TestRepository(t) - tree := data.NewTree(0) + nodes := []*data.Node{} var id restic.ID rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { - var err error // save tree - id, err = data.SaveTree(ctx, uploader, tree) - return err + id = data.TestSaveNodes(t, ctx, uploader, nodes) + return nil })) // load tree again - tree2, err := data.LoadTree(context.TODO(), repo, id) + it, err := data.LoadTree(context.TODO(), repo, id) rtest.OK(t, err) + nodes2 := []*data.Node{} + for item := range it { + rtest.OK(t, item.Error) + nodes2 = append(nodes2, item.Node) + } - rtest.Assert(t, tree.Equals(tree2), - "trees are not equal: want %v, got %v", - tree, tree2) + rtest.Assert(t, slices.Equal(nodes, nodes2), + "tree nodes are not equal: want %v, got %v", + nodes, nodes2) +} + +// Basic type for comparing the serialization of the tree +type Tree struct { + Nodes []*data.Node `json:"nodes"` } func TestTreeEqualSerialization(t *testing.T) { files := []string{"node.go", "tree.go", "tree_test.go"} for i := 1; i <= len(files); i++ { - tree := data.NewTree(i) + tree := Tree{Nodes: make([]*data.Node, 0, i)} builder := data.NewTreeJSONBuilder() for _, fn := range files[:i] { node := nodeForFile(t, fn) - rtest.OK(t, tree.Insert(node)) + tree.Nodes = append(tree.Nodes, node) rtest.OK(t, builder.AddNode(node)) - rtest.Assert(t, tree.Insert(node) != nil, "no error on duplicate node") rtest.Assert(t, builder.AddNode(node) != nil, "no error on duplicate node") rtest.Assert(t, errors.Is(builder.AddNode(node), data.ErrTreeNotOrdered), "wrong error returned") } @@ -144,14 +155,34 @@ func TestTreeEqualSerialization(t *testing.T) { treeBytes = append(treeBytes, '\n') rtest.OK(t, err) - stiBytes, err := builder.Finalize() + buf, err := builder.Finalize() rtest.OK(t, err) // compare serialization of an individual node and the SaveTreeIterator - rtest.Equals(t, treeBytes, stiBytes) + rtest.Equals(t, treeBytes, buf) } } +func TestTreeLoadSaveCycle(t *testing.T) { + files := []string{"node.go", "tree.go", "tree_test.go"} + builder := data.NewTreeJSONBuilder() + for _, fn := range files { + node := nodeForFile(t, fn) + rtest.OK(t, builder.AddNode(node)) + } + buf, err := builder.Finalize() + rtest.OK(t, err) + + tm := data.TestTreeMap{restic.Hash(buf): buf} + it, err := data.LoadTree(context.TODO(), tm, restic.Hash(buf)) + rtest.OK(t, err) + + mtm := data.TestWritableTreeMap{TestTreeMap: data.TestTreeMap{}} + id, err := data.SaveTree(context.TODO(), mtm, it) + rtest.OK(t, err) + rtest.Equals(t, restic.Hash(buf), id, "saved tree id mismatch") +} + func BenchmarkBuildTree(b *testing.B) { const size = 100 // Directories of this size are not uncommon. @@ -165,11 +196,12 @@ func BenchmarkBuildTree(b *testing.B) { b.ReportAllocs() for i := 0; i < b.N; i++ { - t := data.NewTree(size) - + t := data.NewTreeJSONBuilder() for i := range nodes { - _ = t.Insert(&nodes[i]) + rtest.OK(b, t.AddNode(&nodes[i])) } + _, err := t.Finalize() + rtest.OK(b, err) } } @@ -186,8 +218,80 @@ func testLoadTree(t *testing.T, version uint) { repo, _, _ := repository.TestRepositoryWithVersion(t, version) sn := archiver.TestSnapshot(t, repo, rtest.BenchArchiveDirectory, nil) - _, err := data.LoadTree(context.TODO(), repo, *sn.Tree) + nodes, err := data.LoadTree(context.TODO(), repo, *sn.Tree) rtest.OK(t, err) + for item := range nodes { + rtest.OK(t, item.Error) + } +} + +func TestTreeIteratorUnknownKeys(t *testing.T) { + tests := []struct { + name string + jsonData string + wantNodes []string + }{ + { + name: "unknown key before nodes", + jsonData: `{"extra": "value", "nodes": [{"name": "test1"}, {"name": "test2"}]}`, + wantNodes: []string{"test1", "test2"}, + }, + { + name: "unknown key after nodes", + jsonData: `{"nodes": [{"name": "test1"}, {"name": "test2"}], "extra": "value"}`, + wantNodes: []string{"test1", "test2"}, + }, + { + name: "multiple unknown keys before nodes", + jsonData: `{"key1": "value1", "key2": 42, "nodes": [{"name": "test1"}]}`, + wantNodes: []string{"test1"}, + }, + { + name: "multiple unknown keys after nodes", + jsonData: `{"nodes": [{"name": "test1"}], "key1": "value1", "key2": 42}`, + wantNodes: []string{"test1"}, + }, + { + name: "unknown keys before and after nodes", + jsonData: `{"before": "value", "nodes": [{"name": "test1"}], "after": "value"}`, + wantNodes: []string{"test1"}, + }, + { + name: "nested object as unknown value", + jsonData: `{"extra": {"nested": "value"}, "nodes": [{"name": "test1"}]}`, + wantNodes: []string{"test1"}, + }, + { + name: "nested array as unknown value", + jsonData: `{"extra": [1, 2, 3], "nodes": [{"name": "test1"}]}`, + wantNodes: []string{"test1"}, + }, + { + name: "complex nested structure as unknown value", + jsonData: `{"extra": {"obj": {"arr": [1, {"nested": true}]}}, "nodes": [{"name": "test1"}]}`, + wantNodes: []string{"test1"}, + }, + { + name: "empty nodes array with unknown keys", + jsonData: `{"extra": "value", "nodes": []}`, + wantNodes: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + it, err := data.NewTreeNodeIterator(strings.NewReader(tt.jsonData + "\n")) + rtest.OK(t, err) + + var gotNodes []string + for item := range it { + rtest.OK(t, item.Error) + gotNodes = append(gotNodes, item.Node.Name) + } + + rtest.Equals(t, tt.wantNodes, gotNodes, "nodes mismatch") + }) + } } func BenchmarkLoadTree(t *testing.B) { @@ -211,6 +315,58 @@ func benchmarkLoadTree(t *testing.B, version uint) { } } +func TestTreeFinderNilIterator(t *testing.T) { + finder := data.NewTreeFinder(nil) + defer finder.Close() + node, err := finder.Find("foo") + rtest.OK(t, err) + rtest.Equals(t, node, nil, "finder should return nil node") +} + +func TestTreeFinderError(t *testing.T) { + testErr := errors.New("error") + finder := data.NewTreeFinder(slices.Values([]data.NodeOrError{ + {Node: &data.Node{Name: "a"}, Error: nil}, + {Node: &data.Node{Name: "b"}, Error: nil}, + {Node: nil, Error: testErr}, + })) + defer finder.Close() + node, err := finder.Find("b") + rtest.OK(t, err) + rtest.Equals(t, node.Name, "b", "finder should return node with name b") + + node, err = finder.Find("c") + rtest.Equals(t, err, testErr, "finder should return correcterror") + rtest.Equals(t, node, nil, "finder should return nil node") +} + +func TestTreeFinderNotFound(t *testing.T) { + finder := data.NewTreeFinder(slices.Values([]data.NodeOrError{ + {Node: &data.Node{Name: "a"}, Error: nil}, + })) + defer finder.Close() + node, err := finder.Find("b") + rtest.OK(t, err) + rtest.Equals(t, node, nil, "finder should return nil node") + // must also be ok multiple times + node, err = finder.Find("c") + rtest.OK(t, err) + rtest.Equals(t, node, nil, "finder should return nil node") +} + +func TestTreeFinderWrongOrder(t *testing.T) { + finder := data.NewTreeFinder(slices.Values([]data.NodeOrError{ + {Node: &data.Node{Name: "d"}, Error: nil}, + })) + defer finder.Close() + node, err := finder.Find("b") + rtest.OK(t, err) + rtest.Equals(t, node, nil, "finder should return nil node") + node, err = finder.Find("a") + rtest.Assert(t, strings.Contains(err.Error(), "is not greater than"), "unexpected error: %v", err) + rtest.Equals(t, node, nil, "finder should return nil node") +} + func TestFindTreeDirectory(t *testing.T) { repo := repository.TestRepository(t) sn := data.TestCreateSnapshot(t, repo, parseTimeUTC("2017-07-07 07:07:08"), 3) @@ -220,15 +376,15 @@ func TestFindTreeDirectory(t *testing.T) { id restic.ID err error }{ - {"", restic.TestParseID("c25199703a67455b34cc0c6e49a8ac8861b268a5dd09dc5b2e31e7380973fc97"), nil}, - {"/", restic.TestParseID("c25199703a67455b34cc0c6e49a8ac8861b268a5dd09dc5b2e31e7380973fc97"), nil}, - {".", restic.TestParseID("c25199703a67455b34cc0c6e49a8ac8861b268a5dd09dc5b2e31e7380973fc97"), nil}, + {"", restic.TestParseID("8804a5505fc3012e7d08b2843e9bda1bf3dc7644f64b542470340e1b4059f09f"), nil}, + {"/", restic.TestParseID("8804a5505fc3012e7d08b2843e9bda1bf3dc7644f64b542470340e1b4059f09f"), nil}, + {".", restic.TestParseID("8804a5505fc3012e7d08b2843e9bda1bf3dc7644f64b542470340e1b4059f09f"), nil}, {"..", restic.ID{}, errors.New("path ..: not found")}, {"file-1", restic.ID{}, errors.New("path file-1: not a directory")}, - {"dir-21", restic.TestParseID("76172f9dec15d7e4cb98d2993032e99f06b73b2f02ffea3b7cfd9e6b4d762712"), nil}, - {"/dir-21", restic.TestParseID("76172f9dec15d7e4cb98d2993032e99f06b73b2f02ffea3b7cfd9e6b4d762712"), nil}, - {"dir-21/", restic.TestParseID("76172f9dec15d7e4cb98d2993032e99f06b73b2f02ffea3b7cfd9e6b4d762712"), nil}, - {"dir-21/dir-24", restic.TestParseID("74626b3fb2bd4b3e572b81a4059b3e912bcf2a8f69fecd9c187613b7173f13b1"), nil}, + {"dir-7", restic.TestParseID("1af51eb70cd4457d51db40d649bb75446a3eaa29b265916d411bb7ae971d4849"), nil}, + {"/dir-7", restic.TestParseID("1af51eb70cd4457d51db40d649bb75446a3eaa29b265916d411bb7ae971d4849"), nil}, + {"dir-7/", restic.TestParseID("1af51eb70cd4457d51db40d649bb75446a3eaa29b265916d411bb7ae971d4849"), nil}, + {"dir-7/dir-5", restic.TestParseID("f05534d2673964de698860e5069da1ee3c198acf21c187975c6feb49feb8e9c9"), nil}, } { t.Run("", func(t *testing.T) { id, err := data.FindTreeDirectory(context.TODO(), repo, sn.Tree, exp.subfolder) @@ -244,3 +400,187 @@ func TestFindTreeDirectory(t *testing.T) { _, err := data.FindTreeDirectory(context.TODO(), repo, nil, "") rtest.Assert(t, err != nil, "missing error on null tree id") } + +func TestDualTreeIterator(t *testing.T) { + testErr := errors.New("test error") + + tests := []struct { + name string + tree1 []data.NodeOrError + tree2 []data.NodeOrError + expected []data.DualTree + }{ + { + name: "both empty", + tree1: []data.NodeOrError{}, + tree2: []data.NodeOrError{}, + expected: []data.DualTree{}, + }, + { + name: "tree1 empty", + tree1: []data.NodeOrError{}, + tree2: []data.NodeOrError{ + {Node: &data.Node{Name: "a"}}, + {Node: &data.Node{Name: "b"}}, + }, + expected: []data.DualTree{ + {Tree1: nil, Tree2: &data.Node{Name: "a"}, Error: nil}, + {Tree1: nil, Tree2: &data.Node{Name: "b"}, Error: nil}, + }, + }, + { + name: "tree2 empty", + tree1: []data.NodeOrError{ + {Node: &data.Node{Name: "a"}}, + {Node: &data.Node{Name: "b"}}, + }, + tree2: []data.NodeOrError{}, + expected: []data.DualTree{ + {Tree1: &data.Node{Name: "a"}, Tree2: nil, Error: nil}, + {Tree1: &data.Node{Name: "b"}, Tree2: nil, Error: nil}, + }, + }, + { + name: "identical trees", + tree1: []data.NodeOrError{ + {Node: &data.Node{Name: "a"}}, + {Node: &data.Node{Name: "b"}}, + }, + tree2: []data.NodeOrError{ + {Node: &data.Node{Name: "a"}}, + {Node: &data.Node{Name: "b"}}, + }, + expected: []data.DualTree{ + {Tree1: &data.Node{Name: "a"}, Tree2: &data.Node{Name: "a"}, Error: nil}, + {Tree1: &data.Node{Name: "b"}, Tree2: &data.Node{Name: "b"}, Error: nil}, + }, + }, + { + name: "disjoint trees", + tree1: []data.NodeOrError{ + {Node: &data.Node{Name: "a"}}, + {Node: &data.Node{Name: "c"}}, + }, + tree2: []data.NodeOrError{ + {Node: &data.Node{Name: "b"}}, + {Node: &data.Node{Name: "d"}}, + }, + expected: []data.DualTree{ + {Tree1: &data.Node{Name: "a"}, Tree2: nil, Error: nil}, + {Tree1: nil, Tree2: &data.Node{Name: "b"}, Error: nil}, + {Tree1: &data.Node{Name: "c"}, Tree2: nil, Error: nil}, + {Tree1: nil, Tree2: &data.Node{Name: "d"}, Error: nil}, + }, + }, + { + name: "overlapping trees", + tree1: []data.NodeOrError{ + {Node: &data.Node{Name: "a"}}, + {Node: &data.Node{Name: "b"}}, + {Node: &data.Node{Name: "d"}}, + }, + tree2: []data.NodeOrError{ + {Node: &data.Node{Name: "b"}}, + {Node: &data.Node{Name: "c"}}, + {Node: &data.Node{Name: "d"}}, + }, + expected: []data.DualTree{ + {Tree1: &data.Node{Name: "a"}, Tree2: nil, Error: nil}, + {Tree1: &data.Node{Name: "b"}, Tree2: &data.Node{Name: "b"}, Error: nil}, + {Tree1: nil, Tree2: &data.Node{Name: "c"}, Error: nil}, + {Tree1: &data.Node{Name: "d"}, Tree2: &data.Node{Name: "d"}, Error: nil}, + }, + }, + { + name: "error in tree1 during iteration", + tree1: []data.NodeOrError{ + {Node: &data.Node{Name: "a"}}, + {Error: testErr}, + }, + tree2: []data.NodeOrError{ + {Node: &data.Node{Name: "c"}}, + }, + expected: []data.DualTree{ + {Tree1: nil, Tree2: nil, Error: testErr}, + }, + }, + { + name: "error in tree2 during iteration", + tree1: []data.NodeOrError{ + {Node: &data.Node{Name: "a"}}, + }, + tree2: []data.NodeOrError{ + {Node: &data.Node{Name: "b"}}, + {Error: testErr}, + }, + expected: []data.DualTree{ + {Tree1: &data.Node{Name: "a"}, Tree2: nil, Error: nil}, + {Tree1: nil, Tree2: nil, Error: testErr}, + }, + }, + { + name: "error at start of tree1", + tree1: []data.NodeOrError{{Error: testErr}}, + tree2: []data.NodeOrError{{Node: &data.Node{Name: "b"}}}, + expected: []data.DualTree{ + {Tree1: nil, Tree2: nil, Error: testErr}, + }, + }, + { + name: "error at start of tree2", + tree1: []data.NodeOrError{{Node: &data.Node{Name: "a"}}}, + tree2: []data.NodeOrError{{Error: testErr}}, + expected: []data.DualTree{ + {Tree1: nil, Tree2: nil, Error: testErr}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + iter1 := slices.Values(tt.tree1) + iter2 := slices.Values(tt.tree2) + + dualIter := data.DualTreeIterator(iter1, iter2) + var results []data.DualTree + for dt := range dualIter { + results = append(results, dt) + } + + rtest.Equals(t, len(tt.expected), len(results), "unexpected number of results") + for i, exp := range tt.expected { + rtest.Equals(t, exp.Error, results[i].Error, fmt.Sprintf("error mismatch at index %d", i)) + rtest.Equals(t, exp.Tree1, results[i].Tree1, fmt.Sprintf("Tree1 mismatch at index %d", i)) + rtest.Equals(t, exp.Tree2, results[i].Tree2, fmt.Sprintf("Tree2 mismatch at index %d", i)) + } + }) + } + + t.Run("single use restriction", func(t *testing.T) { + iter1 := slices.Values([]data.NodeOrError{{Node: &data.Node{Name: "a"}}}) + iter2 := slices.Values([]data.NodeOrError{{Node: &data.Node{Name: "b"}}}) + dualIter := data.DualTreeIterator(iter1, iter2) + + // First use should work + var count int + for range dualIter { + count++ + } + rtest.Assert(t, count > 0, "first iteration should produce results") + + // Second use should panic + func() { + defer func() { + if r := recover(); r == nil { + t.Fatal("expected panic on second use") + } + }() + count = 0 + for range dualIter { + // Should panic before reaching here + count++ + } + rtest.Equals(t, count, 0, "expected count to be 0") + }() + }) +} diff --git a/internal/dump/common.go b/internal/dump/common.go index aea5c1291..2c0edf67a 100644 --- a/internal/dump/common.go +++ b/internal/dump/common.go @@ -30,33 +30,42 @@ func New(format string, repo restic.Loader, w io.Writer) *Dumper { } } -func (d *Dumper) DumpTree(ctx context.Context, tree *data.Tree, rootPath string) error { - ctx, cancel := context.WithCancel(ctx) - defer cancel() +func (d *Dumper) DumpTree(ctx context.Context, tree data.TreeNodeIterator, rootPath string) error { + wg, ctx := errgroup.WithContext(ctx) // ch is buffered to deal with variable download/write speeds. ch := make(chan *data.Node, 10) - go sendTrees(ctx, d.repo, tree, rootPath, ch) + wg.Go(func() error { + return sendTrees(ctx, d.repo, tree, rootPath, ch) + }) - switch d.format { - case "tar": - return d.dumpTar(ctx, ch) - case "zip": - return d.dumpZip(ctx, ch) - default: - panic("unknown dump format") - } + wg.Go(func() error { + switch d.format { + case "tar": + return d.dumpTar(ctx, ch) + case "zip": + return d.dumpZip(ctx, ch) + default: + panic("unknown dump format") + } + }) + return wg.Wait() } -func sendTrees(ctx context.Context, repo restic.BlobLoader, tree *data.Tree, rootPath string, ch chan *data.Node) { +func sendTrees(ctx context.Context, repo restic.BlobLoader, nodes data.TreeNodeIterator, rootPath string, ch chan *data.Node) error { defer close(ch) - for _, root := range tree.Nodes { - root.Path = path.Join(rootPath, root.Name) - if sendNodes(ctx, repo, root, ch) != nil { - break + for item := range nodes { + if item.Error != nil { + return item.Error + } + node := item.Node + node.Path = path.Join(rootPath, node.Name) + if err := sendNodes(ctx, repo, node, ch); err != nil { + return err } } + return nil } func sendNodes(ctx context.Context, repo restic.BlobLoader, root *data.Node, ch chan *data.Node) error { diff --git a/internal/dump/common_test.go b/internal/dump/common_test.go index 5599e2717..bb0347189 100644 --- a/internal/dump/common_test.go +++ b/internal/dump/common_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/restic/restic/internal/archiver" + "github.com/restic/restic/internal/backend" "github.com/restic/restic/internal/data" "github.com/restic/restic/internal/fs" "github.com/restic/restic/internal/repository" @@ -13,13 +14,13 @@ import ( rtest "github.com/restic/restic/internal/test" ) -func prepareTempdirRepoSrc(t testing.TB, src archiver.TestDir) (string, restic.Repository) { +func prepareTempdirRepoSrc(t testing.TB, src archiver.TestDir) (string, restic.Repository, backend.Backend) { tempdir := rtest.TempDir(t) - repo := repository.TestRepository(t) + repo, _, be := repository.TestRepositoryWithVersion(t, 0) archiver.TestCreateFiles(t, tempdir, src) - return tempdir, repo + return tempdir, repo, be } type CheckDump func(t *testing.T, testDir string, testDump *bytes.Buffer) error @@ -67,13 +68,22 @@ func WriteTest(t *testing.T, format string, cd CheckDump) { }, target: "/", }, + { + name: "directory only", + args: archiver.TestDir{ + "firstDir": archiver.TestDir{ + "secondDir": archiver.TestDir{}, + }, + }, + target: "/", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - tmpdir, repo := prepareTempdirRepoSrc(t, tt.args) + tmpdir, repo, be := prepareTempdirRepoSrc(t, tt.args) arch := archiver.New(repo, fs.Track{FS: fs.Local{}}, archiver.Options{}) back := rtest.Chdir(t, tmpdir) @@ -93,6 +103,15 @@ func WriteTest(t *testing.T, format string, cd CheckDump) { if err := cd(t, tmpdir, dst); err != nil { t.Errorf("WriteDump() = does not match: %v", err) } + + // test that dump returns an error if the repository is broken + tree, err = data.LoadTree(ctx, repo, *sn.Tree) + rtest.OK(t, err) + rtest.OK(t, be.Delete(ctx)) + // use new dumper as the old one has the blobs cached + d = New(format, repo, dst) + err = d.DumpTree(ctx, tree, tt.target) + rtest.Assert(t, err != nil, "expected error, got nil") }) } } diff --git a/internal/fuse/dir.go b/internal/fuse/dir.go index 28f7ba9a7..df558ac1f 100644 --- a/internal/fuse/dir.go +++ b/internal/fuse/dir.go @@ -7,6 +7,7 @@ import ( "errors" "os" "path/filepath" + "slices" "sync" "syscall" @@ -65,13 +66,13 @@ func unwrapCtxCanceled(err error) error { // replaceSpecialNodes replaces nodes with name "." and "/" by their contents. // Otherwise, the node is returned. -func replaceSpecialNodes(ctx context.Context, repo restic.BlobLoader, node *data.Node) ([]*data.Node, error) { +func replaceSpecialNodes(ctx context.Context, repo restic.BlobLoader, node *data.Node) (data.TreeNodeIterator, error) { if node.Type != data.NodeTypeDir || node.Subtree == nil { - return []*data.Node{node}, nil + return slices.Values([]data.NodeOrError{{Node: node}}), nil } if node.Name != "." && node.Name != "/" { - return []*data.Node{node}, nil + return slices.Values([]data.NodeOrError{{Node: node}}), nil } tree, err := data.LoadTree(ctx, repo, *node.Subtree) @@ -79,7 +80,7 @@ func replaceSpecialNodes(ctx context.Context, repo restic.BlobLoader, node *data return nil, unwrapCtxCanceled(err) } - return tree.Nodes, nil + return tree, nil } func newDirFromSnapshot(root *Root, forget forgetFn, inode uint64, snapshot *data.Snapshot) (*dir, error) { @@ -115,18 +116,25 @@ func (d *dir) open(ctx context.Context) error { return unwrapCtxCanceled(err) } items := make(map[string]*data.Node) - for _, n := range tree.Nodes { + for item := range tree { + if item.Error != nil { + return unwrapCtxCanceled(item.Error) + } if ctx.Err() != nil { return ctx.Err() } + n := item.Node nodes, err := replaceSpecialNodes(ctx, d.root.repo, n) if err != nil { debug.Log(" replaceSpecialNodes(%v) failed: %v", n, err) return err } - for _, node := range nodes { - items[cleanupNodeName(node.Name)] = node + for item := range nodes { + if item.Error != nil { + return unwrapCtxCanceled(item.Error) + } + items[cleanupNodeName(item.Node.Name)] = item.Node } } d.items = items diff --git a/internal/fuse/fuse_test.go b/internal/fuse/fuse_test.go index 05e6c5340..c82252458 100644 --- a/internal/fuse/fuse_test.go +++ b/internal/fuse/fuse_test.go @@ -59,7 +59,7 @@ func loadFirstSnapshot(t testing.TB, repo restic.ListerLoaderUnpacked) *data.Sna return sn } -func loadTree(t testing.TB, repo restic.Loader, id restic.ID) *data.Tree { +func loadTree(t testing.TB, repo restic.Loader, id restic.ID) data.TreeNodeIterator { tree, err := data.LoadTree(context.TODO(), repo, id) rtest.OK(t, err) return tree @@ -79,8 +79,9 @@ func TestFuseFile(t *testing.T) { tree := loadTree(t, repo, *sn.Tree) var content restic.IDs - for _, node := range tree.Nodes { - content = append(content, node.Content...) + for item := range tree { + rtest.OK(t, item.Error) + content = append(content, item.Node.Content...) } t.Logf("tree loaded, content: %v", content) diff --git a/internal/repository/repository.go b/internal/repository/repository.go index 32d2e1aac..e7a1b8c17 100644 --- a/internal/repository/repository.go +++ b/internal/repository/repository.go @@ -562,6 +562,8 @@ func (r *Repository) removeUnpacked(ctx context.Context, t restic.FileType, id r } func (r *Repository) WithBlobUploader(ctx context.Context, fn func(ctx context.Context, uploader restic.BlobSaverWithAsync) error) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() wg, ctx := errgroup.WithContext(ctx) // pack uploader + wg.Go below + blob saver (CPU bound) wg.SetLimit(2 + runtime.GOMAXPROCS(0)) @@ -570,7 +572,18 @@ func (r *Repository) WithBlobUploader(ctx context.Context, fn func(ctx context.C // blob saver are spawned on demand, use wait group to keep track of them r.blobSaver = &sync.WaitGroup{} wg.Go(func() error { - if err := fn(ctx, &blobSaverRepo{repo: r}); err != nil { + inCallback := true + defer func() { + // when the defer is called while inCallback is true, this means + // that runtime.Goexit was called within `fn`. This should only happen + // if a test uses t.Fatal within `fn`. + if inCallback { + cancel() + } + }() + err := fn(ctx, &blobSaverRepo{repo: r}) + inCallback = false + if err != nil { return err } if err := r.flush(ctx); err != nil { diff --git a/internal/restorer/restorer.go b/internal/restorer/restorer.go index 22ab196a5..8454591e4 100644 --- a/internal/restorer/restorer.go +++ b/internal/restorer/restorer.go @@ -163,15 +163,18 @@ func (res *Restorer) traverseTreeInner(ctx context.Context, target, location str } if res.opts.Delete { - filenames = make([]string, 0, len(tree.Nodes)) + filenames = make([]string, 0) } - for i, node := range tree.Nodes { + for item := range tree { + if item.Error != nil { + debug.Log("error iterating tree %v: %v", treeID, item.Error) + return nil, hasRestored, res.sanitizeError(location, item.Error) + } + node := item.Node if ctx.Err() != nil { return nil, hasRestored, ctx.Err() } - // allow GC of tree node - tree.Nodes[i] = nil if res.opts.Delete { // just track all files included in the tree node to simplify the control flow. // tracking too many files does not matter except for a slightly elevated memory usage diff --git a/internal/restorer/restorer_test.go b/internal/restorer/restorer_test.go index 2e419f55c..337a99918 100644 --- a/internal/restorer/restorer_test.go +++ b/internal/restorer/restorer_test.go @@ -79,7 +79,7 @@ func saveDir(t testing.TB, repo restic.BlobSaver, nodes map[string]Node, inode u ctx, cancel := context.WithCancel(context.Background()) defer cancel() - tree := &data.Tree{} + tree := make([]*data.Node, 0, len(nodes)) for name, n := range nodes { inode++ switch node := n.(type) { @@ -107,7 +107,7 @@ func saveDir(t testing.TB, repo restic.BlobSaver, nodes map[string]Node, inode u if mode == 0 { mode = 0644 } - err := tree.Insert(&data.Node{ + tree = append(tree, &data.Node{ Type: data.NodeTypeFile, Mode: mode, ModTime: node.ModTime, @@ -120,9 +120,8 @@ func saveDir(t testing.TB, repo restic.BlobSaver, nodes map[string]Node, inode u Links: lc, GenericAttributes: getGenericAttributes(node.attributes, false), }) - rtest.OK(t, err) case Symlink: - err := tree.Insert(&data.Node{ + tree = append(tree, &data.Node{ Type: data.NodeTypeSymlink, Mode: os.ModeSymlink | 0o777, ModTime: node.ModTime, @@ -133,7 +132,6 @@ func saveDir(t testing.TB, repo restic.BlobSaver, nodes map[string]Node, inode u Inode: inode, Links: 1, }) - rtest.OK(t, err) case Dir: id := saveDir(t, repo, node.Nodes, inode, getGenericAttributes) @@ -142,7 +140,7 @@ func saveDir(t testing.TB, repo restic.BlobSaver, nodes map[string]Node, inode u mode = 0755 } - err := tree.Insert(&data.Node{ + tree = append(tree, &data.Node{ Type: data.NodeTypeDir, Mode: mode, ModTime: node.ModTime, @@ -152,18 +150,12 @@ func saveDir(t testing.TB, repo restic.BlobSaver, nodes map[string]Node, inode u Subtree: &id, GenericAttributes: getGenericAttributes(node.attributes, false), }) - rtest.OK(t, err) default: t.Fatalf("unknown node type %T", node) } } - id, err := data.SaveTree(ctx, repo, tree) - if err != nil { - t.Fatal(err) - } - - return id + return data.TestSaveNodes(t, ctx, repo, tree) } func saveSnapshot(t testing.TB, repo restic.Repository, snapshot Snapshot, getGenericAttributes func(attr *FileAttributes, isDir bool) (genericAttributes map[data.GenericAttributeType]json.RawMessage)) (*data.Snapshot, restic.ID) { diff --git a/internal/test/helpers.go b/internal/test/helpers.go index 3387d36df..e3fded66e 100644 --- a/internal/test/helpers.go +++ b/internal/test/helpers.go @@ -48,7 +48,7 @@ func OKs(tb testing.TB, errs []error) { // Equals fails the test if exp is not equal to act. // msg is optional message to be printed, first param being format string and rest being arguments. -func Equals(tb testing.TB, exp, act interface{}, msgs ...string) { +func Equals[T any](tb testing.TB, exp, act T, msgs ...string) { tb.Helper() if !reflect.DeepEqual(exp, act) { var msgString string diff --git a/internal/ui/termstatus/status_test.go b/internal/ui/termstatus/status_test.go index f65bb096f..b19e00557 100644 --- a/internal/ui/termstatus/status_test.go +++ b/internal/ui/termstatus/status_test.go @@ -128,7 +128,7 @@ func TestRawInputOutput(t *testing.T) { defer cancel() rtest.Equals(t, input, term.InputRaw()) rtest.Equals(t, false, term.InputIsTerminal()) - rtest.Equals(t, &output, term.OutputRaw()) + rtest.Equals(t, io.Writer(&output), term.OutputRaw()) rtest.Equals(t, false, term.OutputIsTerminal()) rtest.Equals(t, false, term.CanUpdateStatus()) } diff --git a/internal/walker/rewriter.go b/internal/walker/rewriter.go index bd05b90d7..f53577e5a 100644 --- a/internal/walker/rewriter.go +++ b/internal/walker/rewriter.go @@ -11,7 +11,7 @@ import ( ) type NodeRewriteFunc func(node *data.Node, path string) *data.Node -type FailedTreeRewriteFunc func(nodeID restic.ID, path string, err error) (*data.Tree, error) +type FailedTreeRewriteFunc func(nodeID restic.ID, path string, err error) (data.TreeNodeIterator, error) type QueryRewrittenSizeFunc func() SnapshotSize type SnapshotSize struct { @@ -52,7 +52,7 @@ func NewTreeRewriter(opts RewriteOpts) *TreeRewriter { } if rw.opts.RewriteFailedTree == nil { // fail with error by default - rw.opts.RewriteFailedTree = func(_ restic.ID, _ string, err error) (*data.Tree, error) { + rw.opts.RewriteFailedTree = func(_ restic.ID, _ string, err error) (data.TreeNodeIterator, error) { return nil, err } } @@ -117,15 +117,26 @@ func (t *TreeRewriter) RewriteTree(ctx context.Context, loader restic.BlobLoader if nodeID != testID { return restic.ID{}, fmt.Errorf("cannot encode tree at %q without losing information", nodepath) } + + // reload the tree to get a new iterator + curTree, err = data.LoadTree(ctx, loader, nodeID) + if err != nil { + // shouldn't fail as the first load was successful + return restic.ID{}, fmt.Errorf("failed to reload tree %v: %w", nodeID, err) + } } debug.Log("filterTree: %s, nodeId: %s\n", nodepath, nodeID.Str()) - tb := data.NewTreeJSONBuilder() - for _, node := range curTree.Nodes { + tb := data.NewTreeWriter(saver) + for item := range curTree { if ctx.Err() != nil { return restic.ID{}, ctx.Err() } + if item.Error != nil { + return restic.ID{}, err + } + node := item.Node path := path.Join(nodepath, node.Name) node = t.opts.RewriteNode(node, path) @@ -156,13 +167,11 @@ func (t *TreeRewriter) RewriteTree(ctx context.Context, loader restic.BlobLoader } } - tree, err := tb.Finalize() + newTreeID, err := tb.Finalize(ctx) if err != nil { return restic.ID{}, err } - // Save new tree - newTreeID, _, _, err := saver.SaveBlob(ctx, restic.TreeBlob, tree, restic.ID{}, false) if t.replaces != nil { t.replaces[nodeID] = newTreeID } diff --git a/internal/walker/rewriter_test.go b/internal/walker/rewriter_test.go index 9290a62d5..edc3685dc 100644 --- a/internal/walker/rewriter_test.go +++ b/internal/walker/rewriter_test.go @@ -2,42 +2,14 @@ package walker import ( "context" + "slices" "testing" - "github.com/pkg/errors" "github.com/restic/restic/internal/data" "github.com/restic/restic/internal/restic" "github.com/restic/restic/internal/test" ) -// WritableTreeMap also support saving -type WritableTreeMap struct { - TreeMap -} - -func (t WritableTreeMap) SaveBlob(_ context.Context, tpe restic.BlobType, buf []byte, id restic.ID, _ bool) (newID restic.ID, known bool, size int, err error) { - if tpe != restic.TreeBlob { - return restic.ID{}, false, 0, errors.New("can only save trees") - } - - if id.IsNull() { - id = restic.Hash(buf) - } - _, ok := t.TreeMap[id] - if ok { - return id, false, 0, nil - } - - t.TreeMap[id] = append([]byte{}, buf...) - return id, true, len(buf), nil -} - -func (t WritableTreeMap) Dump(test testing.TB) { - for k, v := range t.TreeMap { - test.Logf("%v: %v", k, string(v)) - } -} - type checkRewriteFunc func(t testing.TB) (rewriter *TreeRewriter, final func(testing.TB)) // checkRewriteItemOrder ensures that the order of the 'path' arguments is the one passed in as 'want'. @@ -279,7 +251,7 @@ func TestRewriter(t *testing.T) { test.newTree = test.tree } expRepo, expRoot := BuildTreeMap(test.newTree) - modrepo := WritableTreeMap{repo} + modrepo := data.TestWritableTreeMap{TestTreeMap: repo} ctx, cancel := context.WithCancel(context.TODO()) defer cancel() @@ -297,7 +269,7 @@ func TestRewriter(t *testing.T) { t.Log("Got") modrepo.Dump(t) t.Log("Expected") - WritableTreeMap{expRepo}.Dump(t) + data.TestWritableTreeMap{TestTreeMap: expRepo}.Dump(t) } }) } @@ -320,7 +292,7 @@ func TestSnapshotSizeQuery(t *testing.T) { t.Run("", func(t *testing.T) { repo, root := BuildTreeMap(tree) expRepo, expRoot := BuildTreeMap(newTree) - modrepo := WritableTreeMap{repo} + modrepo := data.TestWritableTreeMap{TestTreeMap: repo} ctx, cancel := context.WithCancel(context.TODO()) defer cancel() @@ -351,17 +323,17 @@ func TestSnapshotSizeQuery(t *testing.T) { t.Log("Got") modrepo.Dump(t) t.Log("Expected") - WritableTreeMap{expRepo}.Dump(t) + data.TestWritableTreeMap{TestTreeMap: expRepo}.Dump(t) } }) } func TestRewriterFailOnUnknownFields(t *testing.T) { - tm := WritableTreeMap{TreeMap{}} + tm := data.TestWritableTreeMap{TestTreeMap: data.TestTreeMap{}} node := []byte(`{"nodes":[{"name":"subfile","type":"file","mtime":"0001-01-01T00:00:00Z","atime":"0001-01-01T00:00:00Z","ctime":"0001-01-01T00:00:00Z","uid":0,"gid":0,"content":null,"unknown_field":42}]}`) id := restic.Hash(node) - tm.TreeMap[id] = node + tm.TestTreeMap[id] = node ctx, cancel := context.WithCancel(context.TODO()) defer cancel() @@ -392,7 +364,7 @@ func TestRewriterFailOnUnknownFields(t *testing.T) { } func TestRewriterTreeLoadError(t *testing.T) { - tm := WritableTreeMap{TreeMap{}} + tm := data.TestWritableTreeMap{TestTreeMap: data.TestTreeMap{}} id := restic.NewRandomID() ctx, cancel := context.WithCancel(context.TODO()) @@ -405,16 +377,15 @@ func TestRewriterTreeLoadError(t *testing.T) { t.Fatal("missing error on unloadable tree") } - replacementTree := &data.Tree{Nodes: []*data.Node{{Name: "replacement", Type: data.NodeTypeFile, Size: 42}}} - replacementID, err := data.SaveTree(ctx, tm, replacementTree) - test.OK(t, err) + replacementNode := &data.Node{Name: "replacement", Type: data.NodeTypeFile, Size: 42} + replacementID := data.TestSaveNodes(t, ctx, tm, []*data.Node{replacementNode}) rewriter = NewTreeRewriter(RewriteOpts{ - RewriteFailedTree: func(nodeID restic.ID, path string, err error) (*data.Tree, error) { + RewriteFailedTree: func(nodeID restic.ID, path string, err error) (data.TreeNodeIterator, error) { if nodeID != id || path != "/" { t.Fail() } - return replacementTree, nil + return slices.Values([]data.NodeOrError{{Node: replacementNode}}), nil }, }) newRoot, err := rewriter.RewriteTree(ctx, tm, tm, "/", id) diff --git a/internal/walker/walker.go b/internal/walker/walker.go index 67c4a9d03..8347c28c4 100644 --- a/internal/walker/walker.go +++ b/internal/walker/walker.go @@ -3,7 +3,6 @@ package walker import ( "context" "path" - "sort" "github.com/pkg/errors" @@ -52,15 +51,15 @@ func Walk(ctx context.Context, repo restic.BlobLoader, root restic.ID, visitor W // walk recursively traverses the tree, ignoring subtrees when the ID of the // subtree is in ignoreTrees. If err is nil and ignore is true, the subtree ID // will be added to ignoreTrees by walk. -func walk(ctx context.Context, repo restic.BlobLoader, prefix string, parentTreeID restic.ID, tree *data.Tree, visitor WalkVisitor) (err error) { - sort.Slice(tree.Nodes, func(i, j int) bool { - return tree.Nodes[i].Name < tree.Nodes[j].Name - }) - - for _, node := range tree.Nodes { +func walk(ctx context.Context, repo restic.BlobLoader, prefix string, parentTreeID restic.ID, tree data.TreeNodeIterator, visitor WalkVisitor) (err error) { + for item := range tree { + if item.Error != nil { + return item.Error + } if ctx.Err() != nil { return ctx.Err() } + node := item.Node p := path.Join(prefix, node.Name) diff --git a/internal/walker/walker_test.go b/internal/walker/walker_test.go index fa561bf19..fad95476d 100644 --- a/internal/walker/walker_test.go +++ b/internal/walker/walker_test.go @@ -6,7 +6,6 @@ import ( "sort" "testing" - "github.com/pkg/errors" "github.com/restic/restic/internal/data" "github.com/restic/restic/internal/restic" rtest "github.com/restic/restic/internal/test" @@ -20,13 +19,13 @@ type TestFile struct { Size uint64 } -func BuildTreeMap(tree TestTree) (m TreeMap, root restic.ID) { - m = TreeMap{} +func BuildTreeMap(tree TestTree) (m data.TestTreeMap, root restic.ID) { + m = data.TestTreeMap{} id := buildTreeMap(tree, m) return m, id } -func buildTreeMap(tree TestTree, m TreeMap) restic.ID { +func buildTreeMap(tree TestTree, m data.TestTreeMap) restic.ID { tb := data.NewTreeJSONBuilder() var names []string for name := range tree { @@ -75,24 +74,6 @@ func buildTreeMap(tree TestTree, m TreeMap) restic.ID { return id } -// TreeMap returns the trees from the map on LoadTree. -type TreeMap map[restic.ID][]byte - -func (t TreeMap) LoadBlob(_ context.Context, tpe restic.BlobType, id restic.ID, _ []byte) ([]byte, error) { - if tpe != restic.TreeBlob { - return nil, errors.New("can only load trees") - } - tree, ok := t[id] - if !ok { - return nil, errors.New("tree not found") - } - return tree, nil -} - -func (t TreeMap) Connections() uint { - return 2 -} - // checkFunc returns a function suitable for walking the tree to check // something, and a function which will check the final result. type checkFunc func(t testing.TB) (walker WalkFunc, leaveDir func(path string) error, final func(testing.TB, error))