diff --git a/cmd/restic/cmd_debug.go b/cmd/restic/cmd_debug.go index f9ede9df1..48e6d58b7 100644 --- a/cmd/restic/cmd_debug.go +++ b/cmd/restic/cmd_debug.go @@ -353,13 +353,7 @@ func loadBlobs(ctx context.Context, opts DebugExamineOptions, repo restic.Reposi return err } - wg, ctx := errgroup.WithContext(ctx) - - if opts.ReuploadBlobs { - repo.StartPackUploader(ctx, wg) - } - - wg.Go(func() error { + err = repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error { for _, blob := range list { printer.S(" loading blob %v at %v (length %v)", blob.ID, blob.Offset, blob.Length) if int(blob.Offset+blob.Length) > len(pack) { @@ -417,21 +411,16 @@ func loadBlobs(ctx context.Context, opts DebugExamineOptions, repo restic.Reposi } } if opts.ReuploadBlobs { - _, _, _, err := repo.SaveBlob(ctx, blob.Type, plaintext, id, true) + _, _, _, err := uploader.SaveBlob(ctx, blob.Type, plaintext, id, true) if err != nil { return err } printer.S(" uploaded %v %v", blob.Type, id) } } - - if opts.ReuploadBlobs { - return repo.Flush(ctx) - } return nil }) - - return wg.Wait() + return err } func storePlainBlob(id restic.ID, prefix string, plain []byte, printer progress.Printer) error { diff --git a/cmd/restic/cmd_recover.go b/cmd/restic/cmd_recover.go index 3dda6214f..ca22ee2de 100644 --- a/cmd/restic/cmd_recover.go +++ b/cmd/restic/cmd_recover.go @@ -13,7 +13,6 @@ import ( "github.com/restic/restic/internal/ui" "github.com/restic/restic/internal/ui/progress" "github.com/spf13/cobra" - "golang.org/x/sync/errgroup" ) func newRecoverCommand(globalOptions *global.Options) *cobra.Command { @@ -153,24 +152,15 @@ func runRecover(ctx context.Context, gopts global.Options, term ui.Terminal) err } } - wg, wgCtx := errgroup.WithContext(ctx) - repo.StartPackUploader(wgCtx, wg) - var treeID restic.ID - wg.Go(func() error { + err = repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error { var err error - treeID, err = data.SaveTree(wgCtx, repo, tree) + treeID, err = data.SaveTree(ctx, uploader, tree) if err != nil { return errors.Fatalf("unable to save new tree to the repository: %v", err) } - - err = repo.Flush(wgCtx) - if err != nil { - return errors.Fatalf("unable to save blobs to the repository: %v", err) - } return nil }) - err = wg.Wait() if err != nil { return err } diff --git a/cmd/restic/cmd_repair_snapshots.go b/cmd/restic/cmd_repair_snapshots.go index e7547bdba..226da3d44 100644 --- a/cmd/restic/cmd_repair_snapshots.go +++ b/cmd/restic/cmd_repair_snapshots.go @@ -130,19 +130,15 @@ func runRepairSnapshots(ctx context.Context, gopts global.Options, opts RepairOp node.Size = newSize return node }, - RewriteFailedTree: func(_ restic.ID, path string, _ error) (restic.ID, error) { + RewriteFailedTree: func(_ restic.ID, path string, _ error) (*data.Tree, error) { if path == "/" { printer.P(" dir %q: not readable", path) // remove snapshots with invalid root node - return restic.ID{}, nil + return nil, nil } // If a subtree fails to load, remove it printer.P(" dir %q: replaced with empty directory", path) - emptyID, err := data.SaveTree(ctx, repo, &data.Tree{}) - if err != nil { - return restic.ID{}, err - } - return emptyID, nil + return &data.Tree{}, nil }, AllowUnstableSerialization: true, }) @@ -151,8 +147,8 @@ func runRepairSnapshots(ctx context.Context, gopts global.Options, opts RepairOp for sn := range FindFilteredSnapshots(ctx, snapshotLister, repo, &opts.SnapshotFilter, args, printer) { printer.P("\n%v", sn) changed, err := filterAndReplaceSnapshot(ctx, repo, sn, - func(ctx context.Context, sn *data.Snapshot) (restic.ID, *data.SnapshotSummary, error) { - id, err := rewriter.RewriteTree(ctx, repo, "/", *sn.Tree) + func(ctx context.Context, sn *data.Snapshot, uploader restic.BlobSaver) (restic.ID, *data.SnapshotSummary, error) { + id, err := rewriter.RewriteTree(ctx, repo, uploader, "/", *sn.Tree) return id, nil, err }, opts.DryRun, opts.Forget, nil, "repaired", printer) if err != nil { diff --git a/cmd/restic/cmd_rewrite.go b/cmd/restic/cmd_rewrite.go index 470b89024..76a504652 100644 --- a/cmd/restic/cmd_rewrite.go +++ b/cmd/restic/cmd_rewrite.go @@ -6,7 +6,6 @@ import ( "github.com/spf13/cobra" "github.com/spf13/pflag" - "golang.org/x/sync/errgroup" "github.com/restic/restic/internal/data" "github.com/restic/restic/internal/debug" @@ -125,7 +124,7 @@ func (opts *RewriteOptions) AddFlags(f *pflag.FlagSet) { // rewriteFilterFunc returns the filtered tree ID or an error. If a snapshot summary is returned, the snapshot will // be updated accordingly. -type rewriteFilterFunc func(ctx context.Context, sn *data.Snapshot) (restic.ID, *data.SnapshotSummary, error) +type rewriteFilterFunc func(ctx context.Context, sn *data.Snapshot, uploader restic.BlobSaver) (restic.ID, *data.SnapshotSummary, error) func rewriteSnapshot(ctx context.Context, repo *repository.Repository, sn *data.Snapshot, opts RewriteOptions, printer progress.Printer) (bool, error) { if sn.Tree == nil { @@ -165,8 +164,8 @@ func rewriteSnapshot(ctx context.Context, repo *repository.Repository, sn *data. rewriter, querySize := walker.NewSnapshotSizeRewriter(rewriteNode) - filter = func(ctx context.Context, sn *data.Snapshot) (restic.ID, *data.SnapshotSummary, error) { - id, err := rewriter.RewriteTree(ctx, repo, "/", *sn.Tree) + filter = func(ctx context.Context, sn *data.Snapshot, uploader restic.BlobSaver) (restic.ID, *data.SnapshotSummary, error) { + id, err := rewriter.RewriteTree(ctx, repo, uploader, "/", *sn.Tree) if err != nil { return restic.ID{}, nil, err } @@ -181,7 +180,7 @@ func rewriteSnapshot(ctx context.Context, repo *repository.Repository, sn *data. } } else { - filter = func(_ context.Context, sn *data.Snapshot) (restic.ID, *data.SnapshotSummary, error) { + filter = func(_ context.Context, sn *data.Snapshot, _ restic.BlobSaver) (restic.ID, *data.SnapshotSummary, error) { return *sn.Tree, nil, nil } } @@ -193,21 +192,13 @@ func rewriteSnapshot(ctx context.Context, repo *repository.Repository, sn *data. func filterAndReplaceSnapshot(ctx context.Context, repo restic.Repository, sn *data.Snapshot, filter rewriteFilterFunc, dryRun bool, forget bool, newMetadata *snapshotMetadata, addTag string, printer progress.Printer) (bool, error) { - wg, wgCtx := errgroup.WithContext(ctx) - repo.StartPackUploader(wgCtx, wg) - var filteredTree restic.ID var summary *data.SnapshotSummary - wg.Go(func() error { + err := repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error { var err error - filteredTree, summary, err = filter(ctx, sn) - if err != nil { - return err - } - - return repo.Flush(wgCtx) + filteredTree, summary, err = filter(ctx, sn, uploader) + return err }) - err := wg.Wait() if err != nil { return false, err } diff --git a/internal/archiver/archiver.go b/internal/archiver/archiver.go index d14cf74e3..d619ad9b4 100644 --- a/internal/archiver/archiver.go +++ b/internal/archiver/archiver.go @@ -74,12 +74,10 @@ type ToNoder interface { type archiverRepo interface { restic.Loader - restic.BlobSaver + restic.WithBlobUploader restic.SaverUnpacked[restic.WriteableFileType] Config() restic.Config - StartPackUploader(ctx context.Context, wg *errgroup.Group) - Flush(ctx context.Context) error } // Archiver saves a directory structure to the repo. @@ -836,8 +834,8 @@ func (arch *Archiver) loadParentTree(ctx context.Context, sn *data.Snapshot) *da } // runWorkers starts the worker pools, which are stopped when the context is cancelled. -func (arch *Archiver) runWorkers(ctx context.Context, wg *errgroup.Group) { - arch.blobSaver = newBlobSaver(ctx, wg, arch.Repo, arch.Options.SaveBlobConcurrency) +func (arch *Archiver) runWorkers(ctx context.Context, wg *errgroup.Group, uploader restic.BlobSaver) { + arch.blobSaver = newBlobSaver(ctx, wg, uploader, arch.Options.SaveBlobConcurrency) arch.fileSaver = newFileSaver(ctx, wg, arch.blobSaver.Save, @@ -876,15 +874,12 @@ func (arch *Archiver) Snapshot(ctx context.Context, targets []string, opts Snaps var rootTreeID restic.ID - wgUp, wgUpCtx := errgroup.WithContext(ctx) - arch.Repo.StartPackUploader(wgUpCtx, wgUp) - - wgUp.Go(func() error { - wg, wgCtx := errgroup.WithContext(wgUpCtx) + err = arch.Repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error { + wg, wgCtx := errgroup.WithContext(ctx) start := time.Now() wg.Go(func() error { - arch.runWorkers(wgCtx, wg) + arch.runWorkers(wgCtx, wg, uploader) debug.Log("starting snapshot") fn, nodeCount, err := arch.saveTree(wgCtx, "/", atree, arch.loadParentTree(wgCtx, opts.ParentSnapshot), func(_ *data.Node, is ItemStats) { @@ -919,10 +914,8 @@ func (arch *Archiver) Snapshot(ctx context.Context, targets []string, opts Snaps debug.Log("error while saving tree: %v", err) return err } - - return arch.Repo.Flush(ctx) + return nil }) - err = wgUp.Wait() if err != nil { return nil, restic.ID{}, nil, err } diff --git a/internal/archiver/archiver_test.go b/internal/archiver/archiver_test.go index 061306879..adc7695cb 100644 --- a/internal/archiver/archiver_test.go +++ b/internal/archiver/archiver_test.go @@ -39,17 +39,6 @@ func prepareTempdirRepoSrc(t testing.TB, src TestDir) (string, *repository.Repos } func saveFile(t testing.TB, repo archiverRepo, filename string, filesystem fs.FS) (*data.Node, ItemStats) { - wg, ctx := errgroup.WithContext(context.TODO()) - repo.StartPackUploader(ctx, wg) - - arch := New(repo, filesystem, Options{}) - arch.runWorkers(ctx, wg) - - arch.Error = func(item string, err error) error { - t.Errorf("archiver error for %v: %v", item, err) - return err - } - var ( completeReadingCallback bool @@ -58,47 +47,55 @@ func saveFile(t testing.TB, repo archiverRepo, filename string, filesystem fs.FS completeCallback bool startCallback bool + fnr futureNodeResult ) - completeReading := func() { - completeReadingCallback = true - if completeCallback { - t.Error("callbacks called in wrong order") + arch := New(repo, filesystem, Options{}) + arch.Error = func(item string, err error) error { + t.Errorf("archiver error for %v: %v", item, err) + return err + } + + err := repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + wg, ctx := errgroup.WithContext(ctx) + arch.runWorkers(ctx, wg, uploader) + + completeReading := func() { + completeReadingCallback = true + if completeCallback { + t.Error("callbacks called in wrong order") + } } - } - complete := func(node *data.Node, stats ItemStats) { - completeCallback = true - completeCallbackNode = node - completeCallbackStats = stats - } + complete := func(node *data.Node, stats ItemStats) { + completeCallback = true + completeCallbackNode = node + completeCallbackStats = stats + } - start := func() { - startCallback = true - } + start := func() { + startCallback = true + } - file, err := arch.FS.OpenFile(filename, fs.O_NOFOLLOW, false) + file, err := arch.FS.OpenFile(filename, fs.O_NOFOLLOW, false) + if err != nil { + t.Fatal(err) + } + + res := arch.fileSaver.Save(ctx, "/", filename, file, start, completeReading, complete) + + fnr = res.take(ctx) + if fnr.err != nil { + t.Fatal(fnr.err) + } + + arch.stopWorkers() + return wg.Wait() + }) if err != nil { t.Fatal(err) } - res := arch.fileSaver.Save(ctx, "/", filename, file, start, completeReading, complete) - - fnr := res.take(ctx) - if fnr.err != nil { - t.Fatal(fnr.err) - } - - arch.stopWorkers() - err = repo.Flush(context.Background()) - if err != nil { - t.Fatal(err) - } - - if err := wg.Wait(); err != nil { - t.Fatal(err) - } - if !startCallback { t.Errorf("start callback did not happen") } @@ -214,44 +211,45 @@ func TestArchiverSave(t *testing.T) { tempdir, repo := prepareTempdirRepoSrc(t, TestDir{"file": testfile}) - wg, ctx := errgroup.WithContext(ctx) - repo.StartPackUploader(ctx, wg) - arch := New(repo, fs.Track{FS: fs.Local{}}, Options{}) arch.Error = func(item string, err error) error { t.Errorf("archiver error for %v: %v", item, err) return err } - arch.runWorkers(ctx, wg) arch.summary = &Summary{} - node, excluded, err := arch.save(ctx, "/", filepath.Join(tempdir, "file"), nil) - if err != nil { - t.Fatal(err) - } + var fnr futureNodeResult + err := repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error { + wg, ctx := errgroup.WithContext(ctx) + arch.runWorkers(ctx, wg, uploader) - if excluded { - t.Errorf("Save() excluded the node, that's unexpected") - } + node, excluded, err := arch.save(ctx, "/", filepath.Join(tempdir, "file"), nil) + if err != nil { + t.Fatal(err) + } - fnr := node.take(ctx) - if fnr.err != nil { - t.Fatal(fnr.err) - } + if excluded { + t.Errorf("Save() excluded the node, that's unexpected") + } - if fnr.node == nil { - t.Fatalf("returned node is nil") - } + fnr = node.take(ctx) + if fnr.err != nil { + t.Fatal(fnr.err) + } - stats := fnr.stats + if fnr.node == nil { + t.Fatalf("returned node is nil") + } - arch.stopWorkers() - err = repo.Flush(ctx) + arch.stopWorkers() + return wg.Wait() + }) if err != nil { t.Fatal(err) } TestEnsureFileContent(ctx, t, repo, "file", fnr.node, testfile) + stats := fnr.stats if stats.DataSize != uint64(len(testfile.Content)) { t.Errorf("wrong stats returned in DataSize, want %d, got %d", len(testfile.Content), stats.DataSize) } @@ -283,9 +281,6 @@ func TestArchiverSaveReaderFS(t *testing.T) { repo := repository.TestRepository(t) - wg, ctx := errgroup.WithContext(ctx) - repo.StartPackUploader(ctx, wg) - ts := time.Now() filename := "xx" readerFs, err := fs.NewReader(filename, io.NopCloser(strings.NewReader(test.Data)), fs.ReaderOptions{ @@ -298,37 +293,41 @@ func TestArchiverSaveReaderFS(t *testing.T) { t.Errorf("archiver error for %v: %v", item, err) return err } - arch.runWorkers(ctx, wg) arch.summary = &Summary{} - node, excluded, err := arch.save(ctx, "/", filename, nil) - t.Logf("Save returned %v %v", node, err) - if err != nil { - t.Fatal(err) - } + var fnr futureNodeResult + err = repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error { + wg, ctx := errgroup.WithContext(ctx) + arch.runWorkers(ctx, wg, uploader) - if excluded { - t.Errorf("Save() excluded the node, that's unexpected") - } + node, excluded, err := arch.save(ctx, "/", filename, nil) + t.Logf("Save returned %v %v", node, err) + if err != nil { + t.Fatal(err) + } - fnr := node.take(ctx) - if fnr.err != nil { - t.Fatal(fnr.err) - } + if excluded { + t.Errorf("Save() excluded the node, that's unexpected") + } - if fnr.node == nil { - t.Fatalf("returned node is nil") - } + fnr = node.take(ctx) + if fnr.err != nil { + t.Fatal(fnr.err) + } - stats := fnr.stats + if fnr.node == nil { + t.Fatalf("returned node is nil") + } - arch.stopWorkers() - err = repo.Flush(ctx) + arch.stopWorkers() + return wg.Wait() + }) if err != nil { t.Fatal(err) } TestEnsureFileContent(ctx, t, repo, "file", fnr.node, TestFile{Content: test.Data}) + stats := fnr.stats if stats.DataSize != uint64(len(test.Data)) { t.Errorf("wrong stats returned in DataSize, want %d, got %d", len(test.Data), stats.DataSize) } @@ -416,27 +415,29 @@ type blobCountingRepo struct { saved map[restic.BlobHandle]uint } -func (repo *blobCountingRepo) SaveBlob(ctx context.Context, t restic.BlobType, buf []byte, id restic.ID, storeDuplicate bool) (restic.ID, bool, int, error) { - id, exists, size, err := repo.archiverRepo.SaveBlob(ctx, t, buf, id, storeDuplicate) +func (repo *blobCountingRepo) WithBlobUploader(ctx context.Context, fn func(ctx context.Context, uploader restic.BlobSaver) error) error { + return repo.archiverRepo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error { + return fn(ctx, &blobCountingSaver{saver: uploader, blobCountingRepo: repo}) + }) +} + +type blobCountingSaver struct { + saver restic.BlobSaver + blobCountingRepo *blobCountingRepo +} + +func (repo *blobCountingSaver) SaveBlob(ctx context.Context, t restic.BlobType, buf []byte, id restic.ID, storeDuplicate bool) (restic.ID, bool, int, error) { + id, exists, size, err := repo.saver.SaveBlob(ctx, t, buf, id, storeDuplicate) if exists { return id, exists, size, err } h := restic.BlobHandle{ID: id, Type: t} - repo.m.Lock() - repo.saved[h]++ - repo.m.Unlock() + repo.blobCountingRepo.m.Lock() + repo.blobCountingRepo.saved[h]++ + repo.blobCountingRepo.m.Unlock() return id, exists, size, err } -func (repo *blobCountingRepo) SaveTree(ctx context.Context, t *data.Tree) (restic.ID, error) { - id, err := data.SaveTree(ctx, repo.archiverRepo, t) - h := restic.BlobHandle{ID: id, Type: restic.TreeBlob} - repo.m.Lock() - repo.saved[h]++ - repo.m.Unlock() - return id, err -} - func appendToFile(t testing.TB, filename string, data []byte) { f, err := os.OpenFile(filename, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644) if err != nil { @@ -826,12 +827,8 @@ func TestArchiverSaveDir(t *testing.T) { t.Run("", func(t *testing.T) { tempdir, repo := prepareTempdirRepoSrc(t, test.src) - wg, ctx := errgroup.WithContext(context.Background()) - repo.StartPackUploader(ctx, wg) - testFS := fs.Track{FS: fs.Local{}} arch := New(repo, testFS, Options{}) - arch.runWorkers(ctx, wg) arch.summary = &Summary{} chdir := tempdir @@ -842,43 +839,42 @@ func TestArchiverSaveDir(t *testing.T) { back := rtest.Chdir(t, chdir) defer back() - meta, err := testFS.OpenFile(test.target, fs.O_NOFOLLOW, true) - rtest.OK(t, err) - ft, err := arch.saveDir(ctx, "/", test.target, meta, nil, nil) - rtest.OK(t, err) - rtest.OK(t, meta.Close()) + var treeID restic.ID + err := repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + wg, ctx := errgroup.WithContext(ctx) + arch.runWorkers(ctx, wg, uploader) + meta, err := testFS.OpenFile(test.target, fs.O_NOFOLLOW, true) + rtest.OK(t, err) + ft, err := arch.saveDir(ctx, "/", test.target, meta, nil, nil) + rtest.OK(t, err) + rtest.OK(t, meta.Close()) - fnr := ft.take(ctx) - node, stats := fnr.node, fnr.stats + fnr := ft.take(ctx) + node, stats := fnr.node, fnr.stats - t.Logf("stats: %v", stats) - if stats.DataSize != 0 { - t.Errorf("wrong stats returned in DataSize, want 0, got %d", stats.DataSize) - } - if stats.DataBlobs != 0 { - t.Errorf("wrong stats returned in DataBlobs, want 0, got %d", stats.DataBlobs) - } - if stats.TreeSize == 0 { - t.Errorf("wrong stats returned in TreeSize, want > 0, got %d", stats.TreeSize) - } - if stats.TreeBlobs <= 0 { - t.Errorf("wrong stats returned in TreeBlobs, want > 0, got %d", stats.TreeBlobs) - } + t.Logf("stats: %v", stats) + if stats.DataSize != 0 { + t.Errorf("wrong stats returned in DataSize, want 0, got %d", stats.DataSize) + } + if stats.DataBlobs != 0 { + t.Errorf("wrong stats returned in DataBlobs, want 0, got %d", stats.DataBlobs) + } + if stats.TreeSize == 0 { + t.Errorf("wrong stats returned in TreeSize, want > 0, got %d", stats.TreeSize) + } + if stats.TreeBlobs <= 0 { + t.Errorf("wrong stats returned in TreeBlobs, want > 0, got %d", stats.TreeBlobs) + } - node.Name = targetNodeName - tree := &data.Tree{Nodes: []*data.Node{node}} - treeID, err := data.SaveTree(ctx, repo, tree) - if err != nil { - t.Fatal(err) - } - arch.stopWorkers() - - err = repo.Flush(ctx) - if err != nil { - t.Fatal(err) - } - - err = wg.Wait() + node.Name = targetNodeName + tree := &data.Tree{Nodes: []*data.Node{node}} + treeID, err = data.SaveTree(ctx, uploader, tree) + if err != nil { + t.Fatal(err) + } + arch.stopWorkers() + return wg.Wait() + }) if err != nil { t.Fatal(err) } @@ -905,27 +901,30 @@ func TestArchiverSaveDirIncremental(t *testing.T) { // save the empty directory several times in a row, then have a look if the // archiver did save the same tree several times for i := 0; i < 5; i++ { - wg, ctx := errgroup.WithContext(context.TODO()) - repo.StartPackUploader(ctx, wg) - testFS := fs.Track{FS: fs.Local{}} arch := New(repo, testFS, Options{}) - arch.runWorkers(ctx, wg) arch.summary = &Summary{} - meta, err := testFS.OpenFile(tempdir, fs.O_NOFOLLOW, true) - rtest.OK(t, err) - ft, err := arch.saveDir(ctx, "/", tempdir, meta, nil, nil) - rtest.OK(t, err) - rtest.OK(t, meta.Close()) + var fnr futureNodeResult + err := repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + wg, ctx := errgroup.WithContext(ctx) + arch.runWorkers(ctx, wg, uploader) + meta, err := testFS.OpenFile(tempdir, fs.O_NOFOLLOW, true) + rtest.OK(t, err) + ft, err := arch.saveDir(ctx, "/", tempdir, meta, nil, nil) + rtest.OK(t, err) + rtest.OK(t, meta.Close()) - fnr := ft.take(ctx) - node, stats := fnr.node, fnr.stats + fnr = ft.take(ctx) + arch.stopWorkers() + return wg.Wait() + }) if err != nil { t.Fatal(err) } + node, stats := fnr.node, fnr.stats if i == 0 { // operation must have added new tree data if stats.DataSize != 0 { @@ -958,16 +957,6 @@ func TestArchiverSaveDirIncremental(t *testing.T) { t.Logf("node subtree %v", node.Subtree) - arch.stopWorkers() - err = repo.Flush(ctx) - if err != nil { - t.Fatal(err) - } - err = wg.Wait() - if err != nil { - t.Fatal(err) - } - for h, n := range repo.saved { if n > 1 { t.Errorf("iteration %v: blob %v saved more than once (%d times)", i, h, n) @@ -1097,11 +1086,6 @@ func TestArchiverSaveTree(t *testing.T) { testFS := fs.Track{FS: fs.Local{}} arch := New(repo, testFS, Options{}) - - wg, ctx := errgroup.WithContext(context.TODO()) - repo.StartPackUploader(ctx, wg) - - arch.runWorkers(ctx, wg) arch.summary = &Summary{} back := rtest.Chdir(t, tempdir) @@ -1111,29 +1095,31 @@ func TestArchiverSaveTree(t *testing.T) { test.prepare(t) } - atree, err := newTree(testFS, test.targets) - if err != nil { - t.Fatal(err) - } + var treeID restic.ID + err := repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + wg, ctx := errgroup.WithContext(ctx) + arch.runWorkers(ctx, wg, uploader) - fn, _, err := arch.saveTree(ctx, "/", atree, nil, nil) - if err != nil { - t.Fatal(err) - } + atree, err := newTree(testFS, test.targets) + if err != nil { + t.Fatal(err) + } - fnr := fn.take(context.TODO()) - if fnr.err != nil { - t.Fatal(fnr.err) - } + fn, _, err := arch.saveTree(ctx, "/", atree, nil, nil) + if err != nil { + t.Fatal(err) + } - treeID := *fnr.node.Subtree + fnr := fn.take(ctx) + if fnr.err != nil { + t.Fatal(fnr.err) + } - arch.stopWorkers() - err = repo.Flush(ctx) - if err != nil { - t.Fatal(err) - } - err = wg.Wait() + treeID = *fnr.node.Subtree + + arch.stopWorkers() + return wg.Wait() + }) if err != nil { t.Fatal(err) } @@ -2109,13 +2095,24 @@ type failSaveRepo struct { err error } -func (f *failSaveRepo) SaveBlob(ctx context.Context, t restic.BlobType, buf []byte, id restic.ID, storeDuplicate bool) (restic.ID, bool, int, error) { - val := atomic.AddInt32(&f.cnt, 1) - if val >= f.failAfter { - return restic.Hash(buf), false, 0, f.err +func (f *failSaveRepo) WithBlobUploader(ctx context.Context, fn func(ctx context.Context, uploader restic.BlobSaver) error) error { + return f.archiverRepo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error { + return fn(ctx, &failSaveSaver{saver: uploader, failSaveRepo: f}) + }) +} + +type failSaveSaver struct { + saver restic.BlobSaver + failSaveRepo *failSaveRepo +} + +func (f *failSaveSaver) SaveBlob(ctx context.Context, t restic.BlobType, buf []byte, id restic.ID, storeDuplicate bool) (restic.ID, bool, int, error) { + val := atomic.AddInt32(&f.failSaveRepo.cnt, 1) + if val >= f.failSaveRepo.failAfter { + return restic.Hash(buf), false, 0, f.failSaveRepo.err } - return f.archiverRepo.SaveBlob(ctx, t, buf, id, storeDuplicate) + return f.saver.SaveBlob(ctx, t, buf, id, storeDuplicate) } func TestArchiverAbortEarlyOnError(t *testing.T) { @@ -2428,25 +2425,27 @@ func TestRacyFileTypeSwap(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - wg, ctx := errgroup.WithContext(ctx) - repo.StartPackUploader(ctx, wg) + _ = repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error { + wg, ctx := errgroup.WithContext(ctx) - arch := New(repo, fs.Track{FS: statfs}, Options{}) - arch.Error = func(item string, err error) error { - t.Logf("archiver error as expected for %v: %v", item, err) - return err - } - arch.runWorkers(ctx, wg) + arch := New(repo, fs.Track{FS: statfs}, Options{}) + arch.Error = func(item string, err error) error { + t.Logf("archiver error as expected for %v: %v", item, err) + return err + } + arch.runWorkers(ctx, wg, uploader) - // fs.Track will panic if the file was not closed - _, excluded, err := arch.save(ctx, "/", tempfile, nil) - rtest.Assert(t, err != nil && strings.Contains(err.Error(), "changed type, refusing to archive"), "save() returned wrong error: %v", err) - tpe := "file" - if dirError { - tpe = "directory" - } - rtest.Assert(t, strings.Contains(err.Error(), tpe+" "), "unexpected item type in error: %v", err) - rtest.Assert(t, !excluded, "Save() excluded the node, that's unexpected") + // fs.Track will panic if the file was not closed + _, excluded, err := arch.save(ctx, "/", tempfile, nil) + rtest.Assert(t, err != nil && strings.Contains(err.Error(), "changed type, refusing to archive"), "save() returned wrong error: %v", err) + tpe := "file" + if dirError { + tpe = "directory" + } + rtest.Assert(t, strings.Contains(err.Error(), tpe+" "), "unexpected item type in error: %v", err) + rtest.Assert(t, !excluded, "Save() excluded the node, that's unexpected") + return nil + }) }) } } diff --git a/internal/checker/checker_test.go b/internal/checker/checker_test.go index c836c6412..960942d80 100644 --- a/internal/checker/checker_test.go +++ b/internal/checker/checker_test.go @@ -22,7 +22,6 @@ import ( "github.com/restic/restic/internal/repository/hashing" "github.com/restic/restic/internal/restic" "github.com/restic/restic/internal/test" - "golang.org/x/sync/errgroup" ) var checkerTestData = filepath.Join("testdata", "checker-test-repo.tar.gz") @@ -527,19 +526,21 @@ func TestCheckerBlobTypeConfusion(t *testing.T) { Nodes: []*data.Node{damagedNode}, } - wg, wgCtx := errgroup.WithContext(ctx) - repo.StartPackUploader(wgCtx, wg) - id, err := data.SaveTree(ctx, repo, damagedTree) - test.OK(t, repo.Flush(ctx)) - test.OK(t, err) + var id restic.ID + test.OK(t, repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error { + var err error + id, err = data.SaveTree(ctx, uploader, damagedTree) + return err + })) buf, err := repo.LoadBlob(ctx, restic.TreeBlob, id, nil) test.OK(t, err) - wg, wgCtx = errgroup.WithContext(ctx) - repo.StartPackUploader(wgCtx, wg) - _, _, _, err = repo.SaveBlob(ctx, restic.DataBlob, buf, id, false) - test.OK(t, err) + test.OK(t, repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error { + var err error + _, _, _, err = uploader.SaveBlob(ctx, restic.DataBlob, buf, id, false) + return err + })) malNode := &data.Node{ Name: "aaaaa", @@ -559,10 +560,12 @@ func TestCheckerBlobTypeConfusion(t *testing.T) { Nodes: []*data.Node{malNode, dirNode}, } - rootID, err := data.SaveTree(ctx, repo, rootTree) - test.OK(t, err) - - test.OK(t, repo.Flush(ctx)) + var rootID restic.ID + test.OK(t, repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error { + var err error + rootID, err = data.SaveTree(ctx, uploader, rootTree) + return err + })) snapshot, err := data.NewSnapshot([]string{"/damaged"}, []string{"test"}, "foo", time.Now()) test.OK(t, err) diff --git a/internal/data/testing.go b/internal/data/testing.go index 0c5fdb6d7..be4ab4edb 100644 --- a/internal/data/testing.go +++ b/internal/data/testing.go @@ -10,7 +10,7 @@ import ( "github.com/restic/chunker" "github.com/restic/restic/internal/restic" - "golang.org/x/sync/errgroup" + "github.com/restic/restic/internal/test" ) // fakeFile returns a reader which yields deterministic pseudo-random data. @@ -28,7 +28,7 @@ type fakeFileSystem struct { // saveFile reads from rd and saves the blobs in the repository. The list of // IDs is returned. -func (fs *fakeFileSystem) saveFile(ctx context.Context, rd io.Reader) (blobs restic.IDs) { +func (fs *fakeFileSystem) saveFile(ctx context.Context, uploader restic.BlobSaver, rd io.Reader) (blobs restic.IDs) { if fs.buf == nil { fs.buf = make([]byte, chunker.MaxSize) } @@ -50,7 +50,7 @@ func (fs *fakeFileSystem) saveFile(ctx context.Context, rd io.Reader) (blobs res fs.t.Fatalf("unable to save chunk in repo: %v", err) } - id, _, _, err := fs.repo.SaveBlob(ctx, restic.DataBlob, chunk.Data, restic.ID{}, false) + id, _, _, err := uploader.SaveBlob(ctx, restic.DataBlob, chunk.Data, restic.ID{}, false) if err != nil { fs.t.Fatalf("error saving chunk: %v", err) } @@ -68,7 +68,7 @@ const ( ) // saveTree saves a tree of fake files in the repo and returns the ID. -func (fs *fakeFileSystem) saveTree(ctx context.Context, seed int64, depth int) restic.ID { +func (fs *fakeFileSystem) saveTree(ctx context.Context, uploader restic.BlobSaver, seed int64, depth int) restic.ID { rnd := rand.NewSource(seed) numNodes := int(rnd.Int63() % maxNodes) @@ -78,7 +78,7 @@ func (fs *fakeFileSystem) saveTree(ctx context.Context, seed int64, depth int) r // randomly select the type of the node, either tree (p = 1/4) or file (p = 3/4). if depth > 1 && rnd.Int63()%4 == 0 { treeSeed := rnd.Int63() % maxSeed - id := fs.saveTree(ctx, treeSeed, depth-1) + id := fs.saveTree(ctx, uploader, treeSeed, depth-1) node := &Node{ Name: fmt.Sprintf("dir-%v", treeSeed), @@ -101,13 +101,13 @@ func (fs *fakeFileSystem) saveTree(ctx context.Context, seed int64, depth int) r Size: uint64(fileSize), } - node.Content = fs.saveFile(ctx, fakeFile(fileSeed, fileSize)) + node.Content = fs.saveFile(ctx, uploader, fakeFile(fileSeed, fileSize)) tree.Nodes = append(tree.Nodes, node) } tree.Sort() - id, err := SaveTree(ctx, fs.repo, &tree) + id, err := SaveTree(ctx, uploader, &tree) if err != nil { fs.t.Fatalf("SaveTree returned error: %v", err) } @@ -135,17 +135,13 @@ func TestCreateSnapshot(t testing.TB, repo restic.Repository, at time.Time, dept rand: rand.New(rand.NewSource(seed)), } - var wg errgroup.Group - repo.StartPackUploader(context.TODO(), &wg) - - treeID := fs.saveTree(context.TODO(), seed, depth) + var treeID restic.ID + test.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + treeID = fs.saveTree(ctx, uploader, seed, depth) + return nil + })) snapshot.Tree = &treeID - err = repo.Flush(context.Background()) - if err != nil { - t.Fatal(err) - } - id, err := SaveSnapshot(context.TODO(), repo, snapshot) if err != nil { t.Fatal(err) diff --git a/internal/data/tree_test.go b/internal/data/tree_test.go index 92949628b..9164f4da1 100644 --- a/internal/data/tree_test.go +++ b/internal/data/tree_test.go @@ -15,7 +15,6 @@ import ( "github.com/restic/restic/internal/repository" "github.com/restic/restic/internal/restic" rtest "github.com/restic/restic/internal/test" - "golang.org/x/sync/errgroup" ) var testFiles = []struct { @@ -106,15 +105,14 @@ func TestNodeComparison(t *testing.T) { func TestEmptyLoadTree(t *testing.T) { repo := repository.TestRepository(t) - var wg errgroup.Group - repo.StartPackUploader(context.TODO(), &wg) - // save tree tree := data.NewTree(0) - id, err := data.SaveTree(context.TODO(), repo, tree) - rtest.OK(t, err) - - // save packs - rtest.OK(t, repo.Flush(context.Background())) + var id restic.ID + rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + var err error + // save tree + id, err = data.SaveTree(ctx, uploader, tree) + return err + })) // load tree again tree2, err := data.LoadTree(context.TODO(), repo, id) @@ -187,7 +185,6 @@ func testLoadTree(t *testing.T, version uint) { // archive a few files repo, _, _ := repository.TestRepositoryWithVersion(t, version) sn := archiver.TestSnapshot(t, repo, rtest.BenchArchiveDirectory, nil) - rtest.OK(t, repo.Flush(context.Background())) _, err := data.LoadTree(context.TODO(), repo, *sn.Tree) rtest.OK(t, err) @@ -205,7 +202,6 @@ func benchmarkLoadTree(t *testing.B, version uint) { // archive a few files repo, _, _ := repository.TestRepositoryWithVersion(t, version) sn := archiver.TestSnapshot(t, repo, rtest.BenchArchiveDirectory, nil) - rtest.OK(t, repo.Flush(context.Background())) t.ResetTimer() diff --git a/internal/repository/fuzz_test.go b/internal/repository/fuzz_test.go index c20f9a710..16155f3a4 100644 --- a/internal/repository/fuzz_test.go +++ b/internal/repository/fuzz_test.go @@ -5,7 +5,7 @@ import ( "testing" "github.com/restic/restic/internal/restic" - "golang.org/x/sync/errgroup" + rtest "github.com/restic/restic/internal/test" ) // Test saving a blob and loading it again, with varying buffer sizes. @@ -20,17 +20,10 @@ func FuzzSaveLoadBlob(f *testing.F) { id := restic.Hash(blob) repo, _, _ := TestRepositoryWithVersion(t, 2) - var wg errgroup.Group - repo.StartPackUploader(context.TODO(), &wg) - - _, _, _, err := repo.SaveBlob(context.TODO(), restic.DataBlob, blob, id, false) - if err != nil { - t.Fatal(err) - } - err = repo.Flush(context.TODO()) - if err != nil { - t.Fatal(err) - } + rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + _, _, _, err := uploader.SaveBlob(ctx, restic.DataBlob, blob, id, false) + return err + })) buf, err := repo.LoadBlob(context.TODO(), restic.DataBlob, id, make([]byte, buflen)) if err != nil { diff --git a/internal/repository/prune_internal_test.go b/internal/repository/prune_internal_test.go index 5c9d0572e..49a876884 100644 --- a/internal/repository/prune_internal_test.go +++ b/internal/repository/prune_internal_test.go @@ -10,7 +10,6 @@ import ( "github.com/restic/restic/internal/restic" rtest "github.com/restic/restic/internal/test" "github.com/restic/restic/internal/ui/progress" - "golang.org/x/sync/errgroup" ) // TestPruneMaxUnusedDuplicate checks that MaxUnused correctly accounts for duplicates. @@ -48,16 +47,14 @@ func TestPruneMaxUnusedDuplicate(t *testing.T) { {bufs[1], bufs[3]}, {bufs[2], bufs[3]}, } { - var wg errgroup.Group - repo.StartPackUploader(context.TODO(), &wg) - - for _, blob := range blobs { - id, _, _, err := repo.SaveBlob(context.TODO(), restic.DataBlob, blob, restic.ID{}, true) - keep.Insert(restic.BlobHandle{Type: restic.DataBlob, ID: id}) - rtest.OK(t, err) - } - - rtest.OK(t, repo.Flush(context.Background())) + rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + for _, blob := range blobs { + id, _, _, err := uploader.SaveBlob(ctx, restic.DataBlob, blob, restic.ID{}, true) + keep.Insert(restic.BlobHandle{Type: restic.DataBlob, ID: id}) + rtest.OK(t, err) + } + return nil + })) } opts := PruneOptions{ diff --git a/internal/repository/prune_test.go b/internal/repository/prune_test.go index 6e5e05abf..744de0b14 100644 --- a/internal/repository/prune_test.go +++ b/internal/repository/prune_test.go @@ -13,7 +13,6 @@ import ( "github.com/restic/restic/internal/restic" rtest "github.com/restic/restic/internal/test" "github.com/restic/restic/internal/ui/progress" - "golang.org/x/sync/errgroup" ) func testPrune(t *testing.T, opts repository.PruneOptions, errOnUnused bool) { @@ -26,16 +25,16 @@ func testPrune(t *testing.T, opts repository.PruneOptions, errOnUnused bool) { createRandomBlobs(t, random, repo, 5, 0.5, true) keep, _ := selectBlobs(t, random, repo, 0.5) - var wg errgroup.Group - repo.StartPackUploader(context.TODO(), &wg) - // duplicate a few blobs to exercise those code paths - for blob := range keep { - buf, err := repo.LoadBlob(context.TODO(), blob.Type, blob.ID, nil) - rtest.OK(t, err) - _, _, _, err = repo.SaveBlob(context.TODO(), blob.Type, buf, blob.ID, true) - rtest.OK(t, err) - } - rtest.OK(t, repo.Flush(context.TODO())) + rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + // duplicate a few blobs to exercise those code paths + for blob := range keep { + buf, err := repo.LoadBlob(ctx, blob.Type, blob.ID, nil) + rtest.OK(t, err) + _, _, _, err = uploader.SaveBlob(ctx, blob.Type, buf, blob.ID, true) + rtest.OK(t, err) + } + return nil + })) plan, err := repository.PlanPrune(context.TODO(), opts, repo, func(ctx context.Context, repo restic.Repository, usedBlobs restic.FindBlobSet) error { for blob := range keep { @@ -133,20 +132,19 @@ func TestPruneSmall(t *testing.T) { const blobSize = 1000 * 1000 const numBlobsCreated = 55 - var wg errgroup.Group - repo.StartPackUploader(context.TODO(), &wg) keep := restic.NewBlobSet() - // we need a minum of 11 packfiles, each packfile will be about 5 Mb long - for i := 0; i < numBlobsCreated; i++ { - buf := make([]byte, blobSize) - random.Read(buf) + rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + // we need a minum of 11 packfiles, each packfile will be about 5 Mb long + for i := 0; i < numBlobsCreated; i++ { + buf := make([]byte, blobSize) + random.Read(buf) - id, _, _, err := repo.SaveBlob(context.TODO(), restic.DataBlob, buf, restic.ID{}, false) - rtest.OK(t, err) - keep.Insert(restic.BlobHandle{Type: restic.DataBlob, ID: id}) - } - - rtest.OK(t, repo.Flush(context.Background())) + id, _, _, err := uploader.SaveBlob(ctx, restic.DataBlob, buf, restic.ID{}, false) + rtest.OK(t, err) + keep.Insert(restic.BlobHandle{Type: restic.DataBlob, ID: id}) + } + return nil + })) // gather number of packfiles repoPacks, err := pack.Size(context.TODO(), repo, false) diff --git a/internal/repository/repack.go b/internal/repository/repack.go index 929191478..6ee86eb22 100644 --- a/internal/repository/repack.go +++ b/internal/repository/repack.go @@ -47,16 +47,12 @@ func Repack( return nil, errors.New("repack step requires a backend connection limit of at least two") } - wg, wgCtx := errgroup.WithContext(ctx) - - dstRepo.StartPackUploader(wgCtx, wg) - wg.Go(func() error { + err = dstRepo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error { var err error - obsoletePacks, err = repack(wgCtx, repo, dstRepo, packs, keepBlobs, p, logf) + obsoletePacks, err = repack(ctx, repo, dstRepo, uploader, packs, keepBlobs, p, logf) return err }) - - if err := wg.Wait(); err != nil { + if err != nil { return nil, err } return obsoletePacks, nil @@ -66,6 +62,7 @@ func repack( ctx context.Context, repo restic.Repository, dstRepo restic.Repository, + uploader restic.BlobSaver, packs restic.IDSet, keepBlobs repackBlobSet, p *progress.Counter, @@ -132,7 +129,7 @@ func repack( } // We do want to save already saved blobs! - _, _, _, err = dstRepo.SaveBlob(wgCtx, blob.Type, buf, blob.ID, true) + _, _, _, err = uploader.SaveBlob(wgCtx, blob.Type, buf, blob.ID, true) if err != nil { return err } @@ -163,9 +160,5 @@ func repack( return nil, err } - if err := dstRepo.Flush(ctx); err != nil { - return nil, err - } - return packs, nil } diff --git a/internal/repository/repack_test.go b/internal/repository/repack_test.go index 9248e42c2..599178371 100644 --- a/internal/repository/repack_test.go +++ b/internal/repository/repack_test.go @@ -11,7 +11,6 @@ import ( "github.com/restic/restic/internal/restic" rtest "github.com/restic/restic/internal/test" "github.com/restic/restic/internal/ui/progress" - "golang.org/x/sync/errgroup" ) func randomSize(random *rand.Rand, min, max int) int { @@ -19,50 +18,47 @@ func randomSize(random *rand.Rand, min, max int) int { } func createRandomBlobs(t testing.TB, random *rand.Rand, repo restic.Repository, blobs int, pData float32, smallBlobs bool) { - var wg errgroup.Group - repo.StartPackUploader(context.TODO(), &wg) + // two loops to allow creating multiple pack files + for blobs > 0 { + rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + for blobs > 0 { + blobs-- + var ( + tpe restic.BlobType + length int + ) - for i := 0; i < blobs; i++ { - var ( - tpe restic.BlobType - length int - ) + if random.Float32() < pData { + tpe = restic.DataBlob + if smallBlobs { + length = randomSize(random, 1*1024, 20*1024) // 1KiB to 20KiB of data + } else { + length = randomSize(random, 10*1024, 1024*1024) // 10KiB to 1MiB of data + } + } else { + tpe = restic.TreeBlob + length = randomSize(random, 1*1024, 20*1024) // 1KiB to 20KiB + } - if random.Float32() < pData { - tpe = restic.DataBlob - if smallBlobs { - length = randomSize(random, 1*1024, 20*1024) // 1KiB to 20KiB of data - } else { - length = randomSize(random, 10*1024, 1024*1024) // 10KiB to 1MiB of data + buf := make([]byte, length) + random.Read(buf) + + id, exists, _, err := uploader.SaveBlob(ctx, tpe, buf, restic.ID{}, false) + if err != nil { + t.Fatalf("SaveFrom() error %v", err) + } + + if exists { + t.Errorf("duplicate blob %v/%v ignored", id, restic.DataBlob) + continue + } + + if rand.Float32() < 0.2 { + break + } } - } else { - tpe = restic.TreeBlob - length = randomSize(random, 1*1024, 20*1024) // 1KiB to 20KiB - } - - buf := make([]byte, length) - random.Read(buf) - - id, exists, _, err := repo.SaveBlob(context.TODO(), tpe, buf, restic.ID{}, false) - if err != nil { - t.Fatalf("SaveFrom() error %v", err) - } - - if exists { - t.Errorf("duplicate blob %v/%v ignored", id, restic.DataBlob) - continue - } - - if rand.Float32() < 0.2 { - if err = repo.Flush(context.Background()); err != nil { - t.Fatalf("repo.Flush() returned error %v", err) - } - repo.StartPackUploader(context.TODO(), &wg) - } - } - - if err := repo.Flush(context.Background()); err != nil { - t.Fatalf("repo.Flush() returned error %v", err) + return nil + })) } } @@ -74,16 +70,10 @@ func createRandomWrongBlob(t testing.TB, random *rand.Rand, repo restic.Reposito // invert first data byte buf[0] ^= 0xff - var wg errgroup.Group - repo.StartPackUploader(context.TODO(), &wg) - _, _, _, err := repo.SaveBlob(context.TODO(), restic.DataBlob, buf, id, false) - if err != nil { - t.Fatalf("SaveFrom() error %v", err) - } - - if err := repo.Flush(context.Background()); err != nil { - t.Fatalf("repo.Flush() returned error %v", err) - } + rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + _, _, _, err := uploader.SaveBlob(ctx, restic.DataBlob, buf, id, false) + return err + })) return restic.BlobHandle{ID: id, Type: restic.DataBlob} } @@ -349,24 +339,23 @@ func testRepackBlobFallback(t *testing.T, version uint) { modbuf[0] ^= 0xff // create pack with broken copy - var wg errgroup.Group - repo.StartPackUploader(context.TODO(), &wg) - _, _, _, err := repo.SaveBlob(context.TODO(), restic.DataBlob, modbuf, id, false) - rtest.OK(t, err) - rtest.OK(t, repo.Flush(context.Background())) + rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + _, _, _, err := uploader.SaveBlob(ctx, restic.DataBlob, modbuf, id, false) + return err + })) // find pack with damaged blob keepBlobs := restic.NewBlobSet(restic.BlobHandle{Type: restic.DataBlob, ID: id}) rewritePacks := findPacksForBlobs(t, repo, keepBlobs) // create pack with valid copy - repo.StartPackUploader(context.TODO(), &wg) - _, _, _, err = repo.SaveBlob(context.TODO(), restic.DataBlob, buf, id, true) - rtest.OK(t, err) - rtest.OK(t, repo.Flush(context.Background())) + rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + _, _, _, err := uploader.SaveBlob(ctx, restic.DataBlob, buf, id, true) + return err + })) // repack must fallback to valid copy - _, err = repository.Repack(context.TODO(), repo, repo, rewritePacks, keepBlobs, nil, nil) + _, err := repository.Repack(context.TODO(), repo, repo, rewritePacks, keepBlobs, nil, nil) rtest.OK(t, err) keepBlobs = restic.NewBlobSet(restic.BlobHandle{Type: restic.DataBlob, ID: id}) diff --git a/internal/repository/repair_index.go b/internal/repository/repair_index.go index 929de3db2..d6734428f 100644 --- a/internal/repository/repair_index.go +++ b/internal/repository/repair_index.go @@ -93,10 +93,6 @@ func RepairIndex(ctx context.Context, repo *Repository, opts RepairIndexOptions, } } - if err := repo.Flush(ctx); err != nil { - return err - } - err = rewriteIndexFiles(ctx, repo, removePacks, oldIndexes, obsoleteIndexes, printer) if err != nil { return err diff --git a/internal/repository/repair_pack.go b/internal/repository/repair_pack.go index a9f8413e4..a6f4a52b8 100644 --- a/internal/repository/repair_pack.go +++ b/internal/repository/repair_pack.go @@ -7,21 +7,17 @@ import ( "github.com/restic/restic/internal/restic" "github.com/restic/restic/internal/ui/progress" - "golang.org/x/sync/errgroup" ) func RepairPacks(ctx context.Context, repo *Repository, ids restic.IDSet, printer progress.Printer) error { - wg, wgCtx := errgroup.WithContext(ctx) - repo.StartPackUploader(wgCtx, wg) - printer.P("salvaging intact data from specified pack files") bar := printer.NewCounter("pack files") bar.SetMax(uint64(len(ids))) defer bar.Done() - wg.Go(func() error { + err := repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error { // examine all data the indexes have for the pack file - for b := range repo.ListPacksFromIndex(wgCtx, ids) { + for b := range repo.ListPacksFromIndex(ctx, ids) { blobs := b.Blobs if len(blobs) == 0 { printer.E("no blobs found for pack %v", b.PackID) @@ -29,12 +25,12 @@ func RepairPacks(ctx context.Context, repo *Repository, ids restic.IDSet, printe continue } - err := repo.LoadBlobsFromPack(wgCtx, b.PackID, blobs, func(blob restic.BlobHandle, buf []byte, err error) error { + err := repo.LoadBlobsFromPack(ctx, b.PackID, blobs, func(blob restic.BlobHandle, buf []byte, err error) error { if err != nil { printer.E("failed to load blob %v: %v", blob.ID, err) return nil } - id, _, _, err := repo.SaveBlob(wgCtx, blob.Type, buf, restic.ID{}, true) + id, _, _, err := uploader.SaveBlob(ctx, blob.Type, buf, restic.ID{}, true) if !id.Equal(blob.ID) { panic("pack id mismatch during upload") } @@ -46,14 +42,12 @@ func RepairPacks(ctx context.Context, repo *Repository, ids restic.IDSet, printe } bar.Add(1) } - return repo.Flush(wgCtx) + return nil }) - - err := wg.Wait() - bar.Done() if err != nil { return err } + bar.Done() // remove salvaged packs from index err = rewriteIndexFiles(ctx, repo, ids, nil, nil, printer) diff --git a/internal/repository/repository.go b/internal/repository/repository.go index 897724414..3aa87faff 100644 --- a/internal/repository/repository.go +++ b/internal/repository/repository.go @@ -559,16 +559,30 @@ func (r *Repository) removeUnpacked(ctx context.Context, t restic.FileType, id r return r.be.Remove(ctx, backend.Handle{Type: t, Name: id.String()}) } -// Flush saves all remaining packs and the index -func (r *Repository) Flush(ctx context.Context) error { - if err := r.flushPacks(ctx); err != nil { - return err - } - - return r.idx.Flush(ctx, &internalRepository{r}) +func (r *Repository) WithBlobUploader(ctx context.Context, fn func(ctx context.Context, uploader restic.BlobSaver) error) error { + wg, ctx := errgroup.WithContext(ctx) + r.startPackUploader(ctx, wg) + wg.Go(func() error { + if err := fn(ctx, &blobSaverRepo{repo: r}); err != nil { + return err + } + if err := r.flush(ctx); err != nil { + return fmt.Errorf("error flushing repository: %w", err) + } + return nil + }) + return wg.Wait() } -func (r *Repository) StartPackUploader(ctx context.Context, wg *errgroup.Group) { +type blobSaverRepo struct { + repo *Repository +} + +func (r *blobSaverRepo) SaveBlob(ctx context.Context, t restic.BlobType, buf []byte, id restic.ID, storeDuplicate bool) (newID restic.ID, known bool, size int, err error) { + return r.repo.saveBlob(ctx, t, buf, id, storeDuplicate) +} + +func (r *Repository) startPackUploader(ctx context.Context, wg *errgroup.Group) { if r.packerWg != nil { panic("uploader already started") } @@ -584,6 +598,15 @@ func (r *Repository) StartPackUploader(ctx context.Context, wg *errgroup.Group) }) } +// Flush saves all remaining packs and the index +func (r *Repository) flush(ctx context.Context) error { + if err := r.flushPacks(ctx); err != nil { + return err + } + + return r.idx.Flush(ctx, &internalRepository{r}) +} + // FlushPacks saves all remaining packs. func (r *Repository) flushPacks(ctx context.Context) error { if r.packerWg == nil { @@ -697,7 +720,7 @@ func (r *Repository) createIndexFromPacks(ctx context.Context, packsize map[rest // track spawned goroutines using wg, create a new context which is // cancelled as soon as an error occurs. - wg, ctx := errgroup.WithContext(ctx) + wg, wgCtx := errgroup.WithContext(ctx) type FileInfo struct { restic.ID @@ -710,8 +733,8 @@ func (r *Repository) createIndexFromPacks(ctx context.Context, packsize map[rest defer close(ch) for id, size := range packsize { select { - case <-ctx.Done(): - return ctx.Err() + case <-wgCtx.Done(): + return wgCtx.Err() case ch <- FileInfo{id, size}: } } @@ -721,14 +744,14 @@ func (r *Repository) createIndexFromPacks(ctx context.Context, packsize map[rest // a worker receives an pack ID from ch, reads the pack contents, and adds them to idx worker := func() error { for fi := range ch { - entries, _, err := r.ListPack(ctx, fi.ID, fi.Size) + entries, _, err := r.ListPack(wgCtx, fi.ID, fi.Size) if err != nil { debug.Log("unable to list pack file %v", fi.ID.Str()) m.Lock() invalid = append(invalid, fi.ID) m.Unlock() } - if err := r.idx.StorePack(ctx, fi.ID, entries, &internalRepository{r}); err != nil { + if err := r.idx.StorePack(wgCtx, fi.ID, entries, &internalRepository{r}); err != nil { return err } p.Add(1) @@ -749,6 +772,12 @@ func (r *Repository) createIndexFromPacks(ctx context.Context, packsize map[rest return invalid, err } + // flush the index to the repository + err = r.flush(ctx) + if err != nil { + return invalid, err + } + return invalid, nil } @@ -905,14 +934,14 @@ func (r *Repository) Close() error { return r.be.Close() } -// SaveBlob saves a blob of type t into the repository. +// saveBlob saves a blob of type t into the repository. // It takes care that no duplicates are saved; this can be overwritten // by setting storeDuplicate to true. // If id is the null id, it will be computed and returned. // Also returns if the blob was already known before. // If the blob was not known before, it returns the number of bytes the blob // occupies in the repo (compressed or not, including encryption overhead). -func (r *Repository) SaveBlob(ctx context.Context, t restic.BlobType, buf []byte, id restic.ID, storeDuplicate bool) (newID restic.ID, known bool, size int, err error) { +func (r *Repository) saveBlob(ctx context.Context, t restic.BlobType, buf []byte, id restic.ID, storeDuplicate bool) (newID restic.ID, known bool, size int, err error) { if int64(len(buf)) > math.MaxUint32 { return restic.ID{}, false, 0, fmt.Errorf("blob is larger than 4GB") diff --git a/internal/repository/repository_test.go b/internal/repository/repository_test.go index fbce3b046..60bfd2b5c 100644 --- a/internal/repository/repository_test.go +++ b/internal/repository/repository_test.go @@ -22,7 +22,6 @@ import ( "github.com/restic/restic/internal/repository/index" "github.com/restic/restic/internal/restic" rtest "github.com/restic/restic/internal/test" - "golang.org/x/sync/errgroup" ) var testSizes = []int{5, 23, 2<<18 + 23, 1 << 20} @@ -52,19 +51,17 @@ func testSave(t *testing.T, version uint, calculateID bool) { id := restic.Hash(data) - var wg errgroup.Group - repo.StartPackUploader(context.TODO(), &wg) - - // save - inputID := restic.ID{} - if !calculateID { - inputID = id - } - sid, _, _, err := repo.SaveBlob(context.TODO(), restic.DataBlob, data, inputID, false) - rtest.OK(t, err) - rtest.Equals(t, id, sid) - - rtest.OK(t, repo.Flush(context.Background())) + rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + // save + inputID := restic.ID{} + if !calculateID { + inputID = id + } + sid, _, _, err := uploader.SaveBlob(ctx, restic.DataBlob, data, inputID, false) + rtest.OK(t, err) + rtest.Equals(t, id, sid) + return nil + })) // read back buf, err := repo.LoadBlob(context.TODO(), restic.DataBlob, id, nil) @@ -98,23 +95,22 @@ func testSavePackMerging(t *testing.T, targetPercentage int, expectedPacks int) // minimum pack size to speed up test PackSize: repository.MinPackSize, }) - var wg errgroup.Group - repo.StartPackUploader(context.TODO(), &wg) var ids restic.IDs - // add blobs with size targetPercentage / 100 * repo.PackSize to the repository - blobSize := repository.MinPackSize / 100 - for range targetPercentage { - data := make([]byte, blobSize) - _, err := io.ReadFull(rnd, data) - rtest.OK(t, err) + rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + // add blobs with size targetPercentage / 100 * repo.PackSize to the repository + blobSize := repository.MinPackSize / 100 + for range targetPercentage { + data := make([]byte, blobSize) + _, err := io.ReadFull(rnd, data) + rtest.OK(t, err) - sid, _, _, err := repo.SaveBlob(context.TODO(), restic.DataBlob, data, restic.ID{}, false) - rtest.OK(t, err) - ids = append(ids, sid) - } - - rtest.OK(t, repo.Flush(context.Background())) + sid, _, _, err := uploader.SaveBlob(ctx, restic.DataBlob, data, restic.ID{}, false) + rtest.OK(t, err) + ids = append(ids, sid) + } + return nil + })) // check that all blobs are readable for _, id := range ids { @@ -146,17 +142,18 @@ func benchmarkSaveAndEncrypt(t *testing.B, version uint) { rtest.OK(t, err) id := restic.ID(sha256.Sum256(data)) - var wg errgroup.Group - repo.StartPackUploader(context.Background(), &wg) t.ReportAllocs() t.ResetTimer() t.SetBytes(int64(size)) - for i := 0; i < t.N; i++ { - _, _, _, err = repo.SaveBlob(context.TODO(), restic.DataBlob, data, id, true) - rtest.OK(t, err) - } + _ = repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + for i := 0; i < t.N; i++ { + _, _, _, err = uploader.SaveBlob(ctx, restic.DataBlob, data, id, true) + rtest.OK(t, err) + } + return nil + }) } func TestLoadBlob(t *testing.T) { @@ -170,12 +167,12 @@ func testLoadBlob(t *testing.T, version uint) { _, err := io.ReadFull(rnd, buf) rtest.OK(t, err) - var wg errgroup.Group - repo.StartPackUploader(context.TODO(), &wg) - - id, _, _, err := repo.SaveBlob(context.TODO(), restic.DataBlob, buf, restic.ID{}, false) - rtest.OK(t, err) - rtest.OK(t, repo.Flush(context.Background())) + var id restic.ID + rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + var err error + id, _, _, err = uploader.SaveBlob(ctx, restic.DataBlob, buf, restic.ID{}, false) + return err + })) base := crypto.CiphertextLength(length) for _, testlength := range []int{0, base - 20, base - 1, base, base + 7, base + 15, base + 1000} { @@ -198,11 +195,12 @@ func TestLoadBlobBroken(t *testing.T) { repo, _ := repository.TestRepositoryWithBackend(t, &damageOnceBackend{Backend: be}, restic.StableRepoVersion, repository.Options{}) buf := rtest.Random(42, 1000) - var wg errgroup.Group - repo.StartPackUploader(context.TODO(), &wg) - id, _, _, err := repo.SaveBlob(context.TODO(), restic.TreeBlob, buf, restic.ID{}, false) - rtest.OK(t, err) - rtest.OK(t, repo.Flush(context.Background())) + var id restic.ID + rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + var err error + id, _, _, err = uploader.SaveBlob(ctx, restic.TreeBlob, buf, restic.ID{}, false) + return err + })) // setup cache after saving the blob to make sure that the damageOnceBackend damages the cached data c := cache.TestNewCache(t) @@ -226,12 +224,12 @@ func benchmarkLoadBlob(b *testing.B, version uint) { _, err := io.ReadFull(rnd, buf) rtest.OK(b, err) - var wg errgroup.Group - repo.StartPackUploader(context.TODO(), &wg) - - id, _, _, err := repo.SaveBlob(context.TODO(), restic.DataBlob, buf, restic.ID{}, false) - rtest.OK(b, err) - rtest.OK(b, repo.Flush(context.Background())) + var id restic.ID + rtest.OK(b, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + var err error + id, _, _, err = uploader.SaveBlob(ctx, restic.DataBlob, buf, restic.ID{}, false) + return err + })) b.ResetTimer() b.SetBytes(int64(length)) @@ -363,19 +361,19 @@ func TestRepositoryLoadUnpackedRetryBroken(t *testing.T) { // saveRandomDataBlobs generates random data blobs and saves them to the repository. func saveRandomDataBlobs(t testing.TB, repo restic.Repository, num int, sizeMax int) { - var wg errgroup.Group - repo.StartPackUploader(context.TODO(), &wg) + rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + for i := 0; i < num; i++ { + size := rand.Int() % sizeMax - for i := 0; i < num; i++ { - size := rand.Int() % sizeMax + buf := make([]byte, size) + _, err := io.ReadFull(rnd, buf) + rtest.OK(t, err) - buf := make([]byte, size) - _, err := io.ReadFull(rnd, buf) - rtest.OK(t, err) - - _, _, _, err = repo.SaveBlob(context.TODO(), restic.DataBlob, buf, restic.ID{}, false) - rtest.OK(t, err) - } + _, _, _, err = uploader.SaveBlob(ctx, restic.DataBlob, buf, restic.ID{}, false) + rtest.OK(t, err) + } + return nil + })) } func TestRepositoryIncrementalIndex(t *testing.T) { @@ -389,14 +387,10 @@ func testRepositoryIncrementalIndex(t *testing.T, version uint) { // add a few rounds of packs for j := 0; j < 5; j++ { - // add some packs, write intermediate index + // add some packs and write index saveRandomDataBlobs(t, repo, 20, 1<<15) - rtest.OK(t, repo.Flush(context.TODO())) } - // save final index - rtest.OK(t, repo.Flush(context.TODO())) - packEntries := make(map[restic.ID]map[restic.ID]struct{}) err := repo.List(context.TODO(), restic.IndexFile, func(id restic.ID, size int64) error { @@ -437,11 +431,12 @@ func TestListPack(t *testing.T) { repo, _ := repository.TestRepositoryWithBackend(t, &damageOnceBackend{Backend: be}, restic.StableRepoVersion, repository.Options{}) buf := rtest.Random(42, 1000) - var wg errgroup.Group - repo.StartPackUploader(context.TODO(), &wg) - id, _, _, err := repo.SaveBlob(context.TODO(), restic.TreeBlob, buf, restic.ID{}, false) - rtest.OK(t, err) - rtest.OK(t, repo.Flush(context.Background())) + var id restic.ID + rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + var err error + id, _, _, err = uploader.SaveBlob(ctx, restic.TreeBlob, buf, restic.ID{}, false) + return err + })) // setup cache after saving the blob to make sure that the damageOnceBackend damages the cached data c := cache.TestNewCache(t) diff --git a/internal/restic/repository.go b/internal/restic/repository.go index 509a0db8a..236171cd2 100644 --- a/internal/restic/repository.go +++ b/internal/restic/repository.go @@ -7,7 +7,6 @@ import ( "github.com/restic/restic/internal/crypto" "github.com/restic/restic/internal/errors" "github.com/restic/restic/internal/ui/progress" - "golang.org/x/sync/errgroup" ) // ErrInvalidData is used to report that a file is corrupted @@ -37,12 +36,10 @@ type Repository interface { LoadBlob(ctx context.Context, t BlobType, id ID, buf []byte) ([]byte, error) LoadBlobsFromPack(ctx context.Context, packID ID, blobs []Blob, handleBlobFn func(blob BlobHandle, buf []byte, err error) error) error - // StartPackUploader start goroutines to upload new pack files. The errgroup - // is used to immediately notify about an upload error. Flush() will also return - // that error. - StartPackUploader(ctx context.Context, wg *errgroup.Group) - SaveBlob(ctx context.Context, t BlobType, buf []byte, id ID, storeDuplicate bool) (newID ID, known bool, size int, err error) - Flush(ctx context.Context) error + // WithUploader starts the necessary workers to upload new blobs. Once the callback returns, + // the workers are stopped and the index is written to the repository. The callback must use + // the passed context and must not keep references to any of its parameters after returning. + WithBlobUploader(ctx context.Context, fn func(ctx context.Context, uploader BlobSaver) error) error // List calls the function fn for each file of type t in the repository. // When an error is returned by fn, processing stops and List() returns the @@ -160,6 +157,10 @@ type BlobLoader interface { LoadBlob(context.Context, BlobType, ID, []byte) ([]byte, error) } +type WithBlobUploader interface { + WithBlobUploader(ctx context.Context, fn func(ctx context.Context, uploader BlobSaver) error) error +} + type BlobSaver interface { SaveBlob(context.Context, BlobType, []byte, ID, bool) (ID, bool, int, error) } diff --git a/internal/restorer/restorer_test.go b/internal/restorer/restorer_test.go index ee9f7ef58..fe9db6b33 100644 --- a/internal/restorer/restorer_test.go +++ b/internal/restorer/restorer_test.go @@ -25,7 +25,6 @@ import ( rtest "github.com/restic/restic/internal/test" "github.com/restic/restic/internal/ui/progress" restoreui "github.com/restic/restic/internal/ui/restore" - "golang.org/x/sync/errgroup" ) type Node interface{} @@ -171,13 +170,11 @@ func saveSnapshot(t testing.TB, repo restic.Repository, snapshot Snapshot, getGe ctx, cancel := context.WithCancel(context.Background()) defer cancel() - wg, wgCtx := errgroup.WithContext(ctx) - repo.StartPackUploader(wgCtx, wg) - treeID := saveDir(t, repo, snapshot.Nodes, 1000, getGenericAttributes) - err := repo.Flush(ctx) - if err != nil { - t.Fatal(err) - } + var treeID restic.ID + rtest.OK(t, repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error { + treeID = saveDir(t, uploader, snapshot.Nodes, 1000, getGenericAttributes) + return nil + })) sn, err := data.NewSnapshot([]string{"test"}, nil, "", time.Now()) if err != nil { diff --git a/internal/walker/rewriter.go b/internal/walker/rewriter.go index 4d70e4d09..bd05b90d7 100644 --- a/internal/walker/rewriter.go +++ b/internal/walker/rewriter.go @@ -11,7 +11,7 @@ import ( ) type NodeRewriteFunc func(node *data.Node, path string) *data.Node -type FailedTreeRewriteFunc func(nodeID restic.ID, path string, err error) (restic.ID, error) +type FailedTreeRewriteFunc func(nodeID restic.ID, path string, err error) (*data.Tree, error) type QueryRewrittenSizeFunc func() SnapshotSize type SnapshotSize struct { @@ -52,8 +52,8 @@ func NewTreeRewriter(opts RewriteOpts) *TreeRewriter { } if rw.opts.RewriteFailedTree == nil { // fail with error by default - rw.opts.RewriteFailedTree = func(_ restic.ID, _ string, err error) (restic.ID, error) { - return restic.ID{}, err + rw.opts.RewriteFailedTree = func(_ restic.ID, _ string, err error) (*data.Tree, error) { + return nil, err } } return rw @@ -82,12 +82,7 @@ func NewSnapshotSizeRewriter(rewriteNode NodeRewriteFunc) (*TreeRewriter, QueryR return t, ss } -type BlobLoadSaver interface { - restic.BlobSaver - restic.BlobLoader -} - -func (t *TreeRewriter) RewriteTree(ctx context.Context, repo BlobLoadSaver, nodepath string, nodeID restic.ID) (newNodeID restic.ID, err error) { +func (t *TreeRewriter) RewriteTree(ctx context.Context, loader restic.BlobLoader, saver restic.BlobSaver, nodepath string, nodeID restic.ID) (newNodeID restic.ID, err error) { // check if tree was already changed newID, ok := t.replaces[nodeID] if ok { @@ -95,16 +90,27 @@ func (t *TreeRewriter) RewriteTree(ctx context.Context, repo BlobLoadSaver, node } // a nil nodeID will lead to a load error - curTree, err := data.LoadTree(ctx, repo, nodeID) + curTree, err := data.LoadTree(ctx, loader, nodeID) if err != nil { - return t.opts.RewriteFailedTree(nodeID, nodepath, err) + replacement, err := t.opts.RewriteFailedTree(nodeID, nodepath, err) + if err != nil { + return restic.ID{}, err + } + if replacement != nil { + replacementID, err := data.SaveTree(ctx, saver, replacement) + if err != nil { + return restic.ID{}, err + } + return replacementID, nil + } + return restic.ID{}, nil } if !t.opts.AllowUnstableSerialization { // check that we can properly encode this tree without losing information // The alternative of using json/Decoder.DisallowUnknownFields() doesn't work as we use // a custom UnmarshalJSON to decode trees, see also https://github.com/golang/go/issues/41144 - testID, err := data.SaveTree(ctx, repo, curTree) + testID, err := data.SaveTree(ctx, saver, curTree) if err != nil { return restic.ID{}, err } @@ -139,7 +145,7 @@ func (t *TreeRewriter) RewriteTree(ctx context.Context, repo BlobLoadSaver, node if node.Subtree != nil { subtree = *node.Subtree } - newID, err := t.RewriteTree(ctx, repo, path, subtree) + newID, err := t.RewriteTree(ctx, loader, saver, path, subtree) if err != nil { return restic.ID{}, err } @@ -156,7 +162,7 @@ func (t *TreeRewriter) RewriteTree(ctx context.Context, repo BlobLoadSaver, node } // Save new tree - newTreeID, _, _, err := repo.SaveBlob(ctx, restic.TreeBlob, tree, restic.ID{}, false) + newTreeID, _, _, err := saver.SaveBlob(ctx, restic.TreeBlob, tree, restic.ID{}, false) if t.replaces != nil { t.replaces[nodeID] = newTreeID } diff --git a/internal/walker/rewriter_test.go b/internal/walker/rewriter_test.go index a21af36c6..9290a62d5 100644 --- a/internal/walker/rewriter_test.go +++ b/internal/walker/rewriter_test.go @@ -285,7 +285,7 @@ func TestRewriter(t *testing.T) { defer cancel() rewriter, last := test.check(t) - newRoot, err := rewriter.RewriteTree(ctx, modrepo, "/", root) + newRoot, err := rewriter.RewriteTree(ctx, modrepo, modrepo, "/", root) if err != nil { t.Error(err) } @@ -335,7 +335,7 @@ func TestSnapshotSizeQuery(t *testing.T) { return node } rewriter, querySize := NewSnapshotSizeRewriter(rewriteNode) - newRoot, err := rewriter.RewriteTree(ctx, modrepo, "/", root) + newRoot, err := rewriter.RewriteTree(ctx, modrepo, modrepo, "/", root) if err != nil { t.Error(err) } @@ -373,7 +373,7 @@ func TestRewriterFailOnUnknownFields(t *testing.T) { return node }, }) - _, err := rewriter.RewriteTree(ctx, tm, "/", id) + _, err := rewriter.RewriteTree(ctx, tm, tm, "/", id) if err == nil { t.Error("missing error on unknown field") @@ -383,7 +383,7 @@ func TestRewriterFailOnUnknownFields(t *testing.T) { rewriter = NewTreeRewriter(RewriteOpts{ AllowUnstableSerialization: true, }) - root, err := rewriter.RewriteTree(ctx, tm, "/", id) + root, err := rewriter.RewriteTree(ctx, tm, tm, "/", id) test.OK(t, err) _, expRoot := BuildTreeMap(TestTree{ "subfile": TestFile{}, @@ -400,21 +400,24 @@ func TestRewriterTreeLoadError(t *testing.T) { // also check that load error by default cause the operation to fail rewriter := NewTreeRewriter(RewriteOpts{}) - _, err := rewriter.RewriteTree(ctx, tm, "/", id) + _, err := rewriter.RewriteTree(ctx, tm, tm, "/", id) if err == nil { t.Fatal("missing error on unloadable tree") } - replacementID := restic.NewRandomID() + replacementTree := &data.Tree{Nodes: []*data.Node{{Name: "replacement", Type: data.NodeTypeFile, Size: 42}}} + replacementID, err := data.SaveTree(ctx, tm, replacementTree) + test.OK(t, err) + rewriter = NewTreeRewriter(RewriteOpts{ - RewriteFailedTree: func(nodeID restic.ID, path string, err error) (restic.ID, error) { + RewriteFailedTree: func(nodeID restic.ID, path string, err error) (*data.Tree, error) { if nodeID != id || path != "/" { t.Fail() } - return replacementID, nil + return replacementTree, nil }, }) - newRoot, err := rewriter.RewriteTree(ctx, tm, "/", id) + newRoot, err := rewriter.RewriteTree(ctx, tm, tm, "/", id) test.OK(t, err) test.Equals(t, replacementID, newRoot) }