diff --git a/cmd/restic/cmd_copy.go b/cmd/restic/cmd_copy.go index f209015f0..46e70f120 100644 --- a/cmd/restic/cmd_copy.go +++ b/cmd/restic/cmd_copy.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "iter" + "sync" "time" "github.com/restic/restic/internal/data" @@ -14,7 +15,6 @@ import ( "github.com/restic/restic/internal/restic" "github.com/restic/restic/internal/ui" "github.com/restic/restic/internal/ui/progress" - "golang.org/x/sync/errgroup" "github.com/spf13/cobra" "github.com/spf13/pflag" @@ -257,19 +257,13 @@ func copyTreeBatched(ctx context.Context, srcRepo restic.Repository, dstRepo res func copyTree(ctx context.Context, srcRepo restic.Repository, dstRepo restic.Repository, visitedTrees restic.AssociatedBlobSet, rootTreeID restic.ID, printer progress.Printer, uploader restic.BlobSaverWithAsync) (uint64, error) { - wg, wgCtx := errgroup.WithContext(ctx) - - treeStream := data.StreamTrees(wgCtx, wg, srcRepo, restic.IDs{rootTreeID}, func(treeID restic.ID) bool { - handle := restic.BlobHandle{ID: treeID, Type: restic.TreeBlob} - visited := visitedTrees.Has(handle) - visitedTrees.Insert(handle) - return visited - }, nil) - copyBlobs := srcRepo.NewAssociatedBlobSet() packList := restic.NewIDSet() + var lock sync.Mutex enqueue := func(h restic.BlobHandle) { + lock.Lock() + defer lock.Unlock() if _, ok := dstRepo.LookupBlobSize(h.Type, h.ID); !ok { pb := srcRepo.LookupBlob(h.Type, h.ID) copyBlobs.Insert(h) @@ -279,26 +273,28 @@ func copyTree(ctx context.Context, srcRepo restic.Repository, dstRepo restic.Rep } } - wg.Go(func() error { - for tree := range treeStream { - if tree.Error != nil { - return fmt.Errorf("LoadTree(%v) returned error %v", tree.ID.Str(), tree.Error) - } + err := data.StreamTrees(ctx, srcRepo, restic.IDs{rootTreeID}, nil, func(treeID restic.ID) bool { + handle := restic.BlobHandle{ID: treeID, Type: restic.TreeBlob} + visited := visitedTrees.Has(handle) + visitedTrees.Insert(handle) + return visited + }, func(treeID restic.ID, err error, tree *data.Tree) error { + if err != nil { + return fmt.Errorf("LoadTree(%v) returned error %v", treeID.Str(), err) + } - // copy raw tree bytes to avoid problems if the serialization changes - enqueue(restic.BlobHandle{ID: tree.ID, Type: restic.TreeBlob}) + // copy raw tree bytes to avoid problems if the serialization changes + enqueue(restic.BlobHandle{ID: treeID, Type: restic.TreeBlob}) - for _, entry := range tree.Nodes { - // Recursion into directories is handled by StreamTrees - // Copy the blobs for this file. - for _, blobID := range entry.Content { - enqueue(restic.BlobHandle{Type: restic.DataBlob, ID: blobID}) - } + for _, entry := range tree.Nodes { + // Recursion into directories is handled by StreamTrees + // Copy the blobs for this file. + for _, blobID := range entry.Content { + enqueue(restic.BlobHandle{Type: restic.DataBlob, ID: blobID}) } } return nil }) - err := wg.Wait() if err != nil { return 0, err } diff --git a/internal/checker/checker.go b/internal/checker/checker.go index c985951fd..164e6b053 100644 --- a/internal/checker/checker.go +++ b/internal/checker/checker.go @@ -3,7 +3,6 @@ package checker import ( "context" "fmt" - "runtime" "sync" "github.com/restic/restic/internal/data" @@ -12,7 +11,6 @@ import ( "github.com/restic/restic/internal/repository" "github.com/restic/restic/internal/restic" "github.com/restic/restic/internal/ui/progress" - "golang.org/x/sync/errgroup" ) // Checker runs various checks on a repository. It is advisable to create an @@ -92,31 +90,6 @@ func (e *TreeError) Error() string { return fmt.Sprintf("tree %v: %v", e.ID, e.Errors) } -// checkTreeWorker checks the trees received and sends out errors to errChan. -func (c *Checker) checkTreeWorker(ctx context.Context, trees <-chan data.TreeItem, out chan<- error) { - for job := range trees { - debug.Log("check tree %v (tree %v, err %v)", job.ID, job.Tree, job.Error) - - var errs []error - if job.Error != nil { - errs = append(errs, job.Error) - } else { - errs = c.checkTree(job.ID, job.Tree) - } - - if len(errs) == 0 { - continue - } - treeError := &TreeError{ID: job.ID, Errors: errs} - select { - case <-ctx.Done(): - return - case out <- treeError: - debug.Log("tree %v: sent %d errors", treeError.ID, len(treeError.Errors)) - } - } -} - func loadSnapshotTreeIDs(ctx context.Context, lister restic.Lister, repo restic.LoaderUnpacked) (ids restic.IDs, errs []error) { err := data.ForAllSnapshots(ctx, lister, repo, nil, func(id restic.ID, sn *data.Snapshot, err error) error { if err != nil { @@ -171,6 +144,7 @@ func (c *Checker) Structure(ctx context.Context, p *progress.Counter, errChan ch p.SetMax(uint64(len(trees))) debug.Log("need to check %d trees from snapshots, %d errs returned", len(trees), len(errs)) + defer close(errChan) for _, err := range errs { select { case <-ctx.Done(): @@ -179,8 +153,7 @@ func (c *Checker) Structure(ctx context.Context, p *progress.Counter, errChan ch } } - wg, ctx := errgroup.WithContext(ctx) - treeStream := data.StreamTrees(ctx, wg, c.repo, trees, func(treeID restic.ID) bool { + err := data.StreamTrees(ctx, c.repo, trees, p, func(treeID restic.ID) bool { // blobRefs may be accessed in parallel by checkTree c.blobRefs.Lock() h := restic.BlobHandle{ID: treeID, Type: restic.TreeBlob} @@ -189,21 +162,32 @@ func (c *Checker) Structure(ctx context.Context, p *progress.Counter, errChan ch c.blobRefs.M.Insert(h) c.blobRefs.Unlock() return blobReferenced - }, p) + }, func(treeID restic.ID, err error, tree *data.Tree) error { + debug.Log("check tree %v (err %v)", treeID, err) - defer close(errChan) - // The checkTree worker only processes already decoded trees and is thus CPU-bound - workerCount := runtime.GOMAXPROCS(0) - for i := 0; i < workerCount; i++ { - wg.Go(func() error { - c.checkTreeWorker(ctx, treeStream, errChan) + var errs []error + if err != nil { + errs = append(errs, err) + } else { + errs = c.checkTree(treeID, tree) + } + if len(errs) == 0 { return nil - }) - } + } - // the wait group should not return an error because no worker returns an + treeError := &TreeError{ID: treeID, Errors: errs} + select { + case <-ctx.Done(): + return nil + case errChan <- treeError: + debug.Log("tree %v: sent %d errors", treeError.ID, len(treeError.Errors)) + } + + return nil + }) + + // StreamTrees should not return an error because no worker returns an // error, so panic if that has changed somehow. - err := wg.Wait() if err != nil { panic(err) } diff --git a/internal/data/find.go b/internal/data/find.go index 14d64670e..a009c6496 100644 --- a/internal/data/find.go +++ b/internal/data/find.go @@ -6,7 +6,6 @@ import ( "github.com/restic/restic/internal/restic" "github.com/restic/restic/internal/ui/progress" - "golang.org/x/sync/errgroup" ) // FindUsedBlobs traverses the tree ID and adds all seen blobs (trees and data @@ -14,8 +13,7 @@ import ( func FindUsedBlobs(ctx context.Context, repo restic.Loader, treeIDs restic.IDs, blobs restic.FindBlobSet, p *progress.Counter) error { var lock sync.Mutex - wg, ctx := errgroup.WithContext(ctx) - treeStream := StreamTrees(ctx, wg, repo, treeIDs, func(treeID restic.ID) bool { + return StreamTrees(ctx, repo, treeIDs, p, func(treeID restic.ID) bool { // locking is necessary the goroutine below concurrently adds data blobs lock.Lock() h := restic.BlobHandle{ID: treeID, Type: restic.TreeBlob} @@ -24,26 +22,21 @@ func FindUsedBlobs(ctx context.Context, repo restic.Loader, treeIDs restic.IDs, blobs.Insert(h) lock.Unlock() return blobReferenced - }, p) - - wg.Go(func() error { - for tree := range treeStream { - if tree.Error != nil { - return tree.Error - } + }, func(_ restic.ID, err error, tree *Tree) error { + if err != nil { + return err + } + for _, node := range tree.Nodes { lock.Lock() - for _, node := range tree.Nodes { - switch node.Type { - case NodeTypeFile: - for _, blob := range node.Content { - blobs.Insert(restic.BlobHandle{ID: blob, Type: restic.DataBlob}) - } + switch node.Type { + case NodeTypeFile: + for _, blob := range node.Content { + blobs.Insert(restic.BlobHandle{ID: blob, Type: restic.DataBlob}) } } lock.Unlock() } return nil }) - return wg.Wait() } diff --git a/internal/data/tree_stream.go b/internal/data/tree_stream.go index c7d3588b5..042a55f7e 100644 --- a/internal/data/tree_stream.go +++ b/internal/data/tree_stream.go @@ -2,26 +2,20 @@ package data import ( "context" - "errors" "runtime" "sync" "github.com/restic/restic/internal/debug" + "github.com/restic/restic/internal/errors" "github.com/restic/restic/internal/restic" "github.com/restic/restic/internal/ui/progress" "golang.org/x/sync/errgroup" ) -// TreeItem is used to return either an error or the tree for a tree id -type TreeItem struct { - restic.ID - Error error - *Tree -} - type trackedTreeItem struct { - TreeItem - rootIdx int + restic.ID + Subtrees restic.IDs + rootIdx int } type trackedID struct { @@ -30,34 +24,55 @@ type trackedID struct { } // loadTreeWorker loads trees from repo and sends them to out. -func loadTreeWorker(ctx context.Context, repo restic.Loader, - in <-chan trackedID, out chan<- trackedTreeItem) { +func loadTreeWorker( + ctx context.Context, + repo restic.Loader, + in <-chan trackedID, + process func(id restic.ID, error error, tree *Tree) error, + out chan<- trackedTreeItem, +) error { for treeID := range in { tree, err := LoadTree(ctx, repo, treeID.ID) + if tree == nil && err == nil { + err = errors.New("tree is nil and error is nil") + } debug.Log("load tree %v (%v) returned err: %v", tree, treeID, err) - job := trackedTreeItem{TreeItem: TreeItem{ID: treeID.ID, Error: err, Tree: tree}, rootIdx: treeID.rootIdx} + + err = process(treeID.ID, err, tree) + if err != nil { + return err + } + + var subtrees restic.IDs + if tree != nil { + subtrees = tree.Subtrees() + } + + job := trackedTreeItem{ID: treeID.ID, Subtrees: subtrees, rootIdx: treeID.rootIdx} select { case <-ctx.Done(): - return + return nil case out <- job: } } + return nil } +// filterTree receives the result of a tree load and queues new trees for loading and processing. func filterTrees(ctx context.Context, repo restic.Loader, trees restic.IDs, loaderChan chan<- trackedID, hugeTreeLoaderChan chan<- trackedID, - in <-chan trackedTreeItem, out chan<- TreeItem, skip func(tree restic.ID) bool, p *progress.Counter) { + in <-chan trackedTreeItem, skip func(tree restic.ID) bool, p *progress.Counter) { var ( inCh = in - outCh chan<- TreeItem loadCh chan<- trackedID - job TreeItem nextTreeID trackedID outstandingLoadTreeJobs = 0 ) + // tracks how many trees are currently waiting to be processed for a given root tree rootCounter := make([]int, len(trees)) + // build initial backlog backlog := make([]trackedID, 0, len(trees)) for idx, id := range trees { backlog = append(backlog, trackedID{ID: id, rootIdx: idx}) @@ -65,6 +80,7 @@ func filterTrees(ctx context.Context, repo restic.Loader, trees restic.IDs, load } for { + // if no tree is waiting to be sent, pick the next one if loadCh == nil && len(backlog) > 0 { // process last added ids first, that is traverse the tree in depth-first order ln := len(backlog) - 1 @@ -86,7 +102,8 @@ func filterTrees(ctx context.Context, repo restic.Loader, trees restic.IDs, load } } - if loadCh == nil && outCh == nil && outstandingLoadTreeJobs == 0 { + // loadCh is only nil at this point if the backlog is empty + if loadCh == nil && outstandingLoadTreeJobs == 0 { debug.Log("backlog is empty, all channels nil, exiting") return } @@ -103,7 +120,6 @@ func filterTrees(ctx context.Context, repo restic.Loader, trees restic.IDs, load if !ok { debug.Log("input channel closed") inCh = nil - in = nil continue } @@ -111,58 +127,47 @@ func filterTrees(ctx context.Context, repo restic.Loader, trees restic.IDs, load rootCounter[j.rootIdx]-- debug.Log("input job tree %v", j.ID) - - if j.Error != nil { - debug.Log("received job with error: %v (tree %v, ID %v)", j.Error, j.Tree, j.ID) - } else if j.Tree == nil { - debug.Log("received job with nil tree pointer: %v (ID %v)", j.Error, j.ID) - // send a new job with the new error instead of the old one - j = trackedTreeItem{TreeItem: TreeItem{ID: j.ID, Error: errors.New("tree is nil and error is nil")}, rootIdx: j.rootIdx} - } else { - subtrees := j.Tree.Subtrees() - debug.Log("subtrees for tree %v: %v", j.ID, subtrees) - // iterate backwards over subtree to compensate backwards traversal order of nextTreeID selection - for i := len(subtrees) - 1; i >= 0; i-- { - id := subtrees[i] - if id.IsNull() { - // We do not need to raise this error here, it is - // checked when the tree is checked. Just make sure - // that we do not add any null IDs to the backlog. - debug.Log("tree %v has nil subtree", j.ID) - continue - } - backlog = append(backlog, trackedID{ID: id, rootIdx: j.rootIdx}) - rootCounter[j.rootIdx]++ + // iterate backwards over subtree to compensate backwards traversal order of nextTreeID selection + for i := len(j.Subtrees) - 1; i >= 0; i-- { + id := j.Subtrees[i] + if id.IsNull() { + // We do not need to raise this error here, it is + // checked when the tree is checked. Just make sure + // that we do not add any null IDs to the backlog. + debug.Log("tree %v has nil subtree", j.ID) + continue } + backlog = append(backlog, trackedID{ID: id, rootIdx: j.rootIdx}) + rootCounter[j.rootIdx]++ } + // the progress check must happen after j.Subtrees was added to the backlog if p != nil && rootCounter[j.rootIdx] == 0 { p.Add(1) } - - job = j.TreeItem - outCh = out - inCh = nil - - case outCh <- job: - debug.Log("tree sent to process: %v", job.ID) - outCh = nil - inCh = in } } } // StreamTrees iteratively loads the given trees and their subtrees. The skip method -// is guaranteed to always be called from the same goroutine. To shutdown the started -// goroutines, either read all items from the channel or cancel the context. Then `Wait()` -// on the errgroup until all goroutines were stopped. -func StreamTrees(ctx context.Context, wg *errgroup.Group, repo restic.Loader, trees restic.IDs, skip func(tree restic.ID) bool, p *progress.Counter) <-chan TreeItem { +// is guaranteed to always be called from the same goroutine. The process function is +// directly called from the worker goroutines. It MUST read `nodes` until it returns an +// error or completes. If the process function returns an error, then StreamTrees will +// abort and return the error. +func StreamTrees( + ctx context.Context, + repo restic.Loader, + trees restic.IDs, + p *progress.Counter, + skip func(tree restic.ID) bool, + process func(id restic.ID, error error, tree *Tree) error, +) error { loaderChan := make(chan trackedID) hugeTreeChan := make(chan trackedID, 10) loadedTreeChan := make(chan trackedTreeItem) - treeStream := make(chan TreeItem) var loadTreeWg sync.WaitGroup + wg, ctx := errgroup.WithContext(ctx) // decoding a tree can take quite some time such that this can be both CPU- or IO-bound // one extra worker to handle huge tree blobs workerCount := int(repo.Connections()) + runtime.GOMAXPROCS(0) + 1 @@ -174,8 +179,7 @@ func StreamTrees(ctx context.Context, wg *errgroup.Group, repo restic.Loader, tr loadTreeWg.Add(1) wg.Go(func() error { defer loadTreeWg.Done() - loadTreeWorker(ctx, repo, workerLoaderChan, loadedTreeChan) - return nil + return loadTreeWorker(ctx, repo, workerLoaderChan, process, loadedTreeChan) }) } @@ -189,9 +193,8 @@ func StreamTrees(ctx context.Context, wg *errgroup.Group, repo restic.Loader, tr wg.Go(func() error { defer close(loaderChan) defer close(hugeTreeChan) - defer close(treeStream) - filterTrees(ctx, repo, trees, loaderChan, hugeTreeChan, loadedTreeChan, treeStream, skip, p) + filterTrees(ctx, repo, trees, loaderChan, hugeTreeChan, loadedTreeChan, skip, p) return nil }) - return treeStream + return wg.Wait() }