diff --git a/cmd/restic/cmd_debug.go b/cmd/restic/cmd_debug.go index 013b0c3ff..fe269fc3e 100644 --- a/cmd/restic/cmd_debug.go +++ b/cmd/restic/cmd_debug.go @@ -352,7 +352,7 @@ func loadBlobs(ctx context.Context, opts DebugExamineOptions, repo restic.Reposi return err } - err = repo.WithBlobUploader(ctx, func(ctx context.Context) error { + err = repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error { for _, blob := range list { printer.S(" loading blob %v at %v (length %v)", blob.ID, blob.Offset, blob.Length) if int(blob.Offset+blob.Length) > len(pack) { @@ -410,7 +410,7 @@ func loadBlobs(ctx context.Context, opts DebugExamineOptions, repo restic.Reposi } } if opts.ReuploadBlobs { - _, _, _, err := repo.SaveBlob(ctx, blob.Type, plaintext, id, true) + _, _, _, err := uploader.SaveBlob(ctx, blob.Type, plaintext, id, true) if err != nil { return err } diff --git a/cmd/restic/cmd_recover.go b/cmd/restic/cmd_recover.go index 0f58759d0..afd710a54 100644 --- a/cmd/restic/cmd_recover.go +++ b/cmd/restic/cmd_recover.go @@ -152,9 +152,9 @@ func runRecover(ctx context.Context, gopts GlobalOptions, term ui.Terminal) erro } var treeID restic.ID - err = repo.WithBlobUploader(ctx, func(ctx context.Context) error { + err = repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error { var err error - treeID, err = data.SaveTree(ctx, repo, tree) + treeID, err = data.SaveTree(ctx, uploader, tree) if err != nil { return errors.Fatalf("unable to save new tree to the repository: %v", err) } diff --git a/cmd/restic/cmd_repair_snapshots.go b/cmd/restic/cmd_repair_snapshots.go index 0a702e596..2a09dc72e 100644 --- a/cmd/restic/cmd_repair_snapshots.go +++ b/cmd/restic/cmd_repair_snapshots.go @@ -129,19 +129,15 @@ func runRepairSnapshots(ctx context.Context, gopts GlobalOptions, opts RepairOpt node.Size = newSize return node }, - RewriteFailedTree: func(_ restic.ID, path string, _ error) (restic.ID, error) { + RewriteFailedTree: func(_ restic.ID, path string, _ error) (*data.Tree, error) { if path == "/" { printer.P(" dir %q: not readable", path) // remove snapshots with invalid root node - return restic.ID{}, nil + return nil, nil } // If a subtree fails to load, remove it printer.P(" dir %q: replaced with empty directory", path) - emptyID, err := data.SaveTree(ctx, repo, &data.Tree{}) - if err != nil { - return restic.ID{}, err - } - return emptyID, nil + return &data.Tree{}, nil }, AllowUnstableSerialization: true, }) @@ -150,8 +146,8 @@ func runRepairSnapshots(ctx context.Context, gopts GlobalOptions, opts RepairOpt for sn := range FindFilteredSnapshots(ctx, snapshotLister, repo, &opts.SnapshotFilter, args, printer) { printer.P("\n%v", sn) changed, err := filterAndReplaceSnapshot(ctx, repo, sn, - func(ctx context.Context, sn *data.Snapshot) (restic.ID, *data.SnapshotSummary, error) { - id, err := rewriter.RewriteTree(ctx, repo, "/", *sn.Tree) + func(ctx context.Context, sn *data.Snapshot, uploader restic.BlobSaver) (restic.ID, *data.SnapshotSummary, error) { + id, err := rewriter.RewriteTree(ctx, repo, uploader, "/", *sn.Tree) return id, nil, err }, opts.DryRun, opts.Forget, nil, "repaired", printer) if err != nil { diff --git a/cmd/restic/cmd_rewrite.go b/cmd/restic/cmd_rewrite.go index e70a457d8..c40b1f9ba 100644 --- a/cmd/restic/cmd_rewrite.go +++ b/cmd/restic/cmd_rewrite.go @@ -123,7 +123,7 @@ func (opts *RewriteOptions) AddFlags(f *pflag.FlagSet) { // rewriteFilterFunc returns the filtered tree ID or an error. If a snapshot summary is returned, the snapshot will // be updated accordingly. -type rewriteFilterFunc func(ctx context.Context, sn *data.Snapshot) (restic.ID, *data.SnapshotSummary, error) +type rewriteFilterFunc func(ctx context.Context, sn *data.Snapshot, uploader restic.BlobSaver) (restic.ID, *data.SnapshotSummary, error) func rewriteSnapshot(ctx context.Context, repo *repository.Repository, sn *data.Snapshot, opts RewriteOptions, printer progress.Printer) (bool, error) { if sn.Tree == nil { @@ -163,8 +163,8 @@ func rewriteSnapshot(ctx context.Context, repo *repository.Repository, sn *data. rewriter, querySize := walker.NewSnapshotSizeRewriter(rewriteNode) - filter = func(ctx context.Context, sn *data.Snapshot) (restic.ID, *data.SnapshotSummary, error) { - id, err := rewriter.RewriteTree(ctx, repo, "/", *sn.Tree) + filter = func(ctx context.Context, sn *data.Snapshot, uploader restic.BlobSaver) (restic.ID, *data.SnapshotSummary, error) { + id, err := rewriter.RewriteTree(ctx, repo, uploader, "/", *sn.Tree) if err != nil { return restic.ID{}, nil, err } @@ -179,7 +179,7 @@ func rewriteSnapshot(ctx context.Context, repo *repository.Repository, sn *data. } } else { - filter = func(_ context.Context, sn *data.Snapshot) (restic.ID, *data.SnapshotSummary, error) { + filter = func(_ context.Context, sn *data.Snapshot, _ restic.BlobSaver) (restic.ID, *data.SnapshotSummary, error) { return *sn.Tree, nil, nil } } @@ -193,9 +193,9 @@ func filterAndReplaceSnapshot(ctx context.Context, repo restic.Repository, sn *d var filteredTree restic.ID var summary *data.SnapshotSummary - err := repo.WithBlobUploader(ctx, func(ctx context.Context) error { + err := repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error { var err error - filteredTree, summary, err = filter(ctx, sn) + filteredTree, summary, err = filter(ctx, sn, uploader) return err }) if err != nil { diff --git a/internal/archiver/archiver.go b/internal/archiver/archiver.go index 7dc51ecfb..d619ad9b4 100644 --- a/internal/archiver/archiver.go +++ b/internal/archiver/archiver.go @@ -74,11 +74,10 @@ type ToNoder interface { type archiverRepo interface { restic.Loader - restic.BlobSaver + restic.WithBlobUploader restic.SaverUnpacked[restic.WriteableFileType] Config() restic.Config - WithBlobUploader(ctx context.Context, fn func(ctx context.Context) error) error } // Archiver saves a directory structure to the repo. @@ -835,8 +834,8 @@ func (arch *Archiver) loadParentTree(ctx context.Context, sn *data.Snapshot) *da } // runWorkers starts the worker pools, which are stopped when the context is cancelled. -func (arch *Archiver) runWorkers(ctx context.Context, wg *errgroup.Group) { - arch.blobSaver = newBlobSaver(ctx, wg, arch.Repo, arch.Options.SaveBlobConcurrency) +func (arch *Archiver) runWorkers(ctx context.Context, wg *errgroup.Group, uploader restic.BlobSaver) { + arch.blobSaver = newBlobSaver(ctx, wg, uploader, arch.Options.SaveBlobConcurrency) arch.fileSaver = newFileSaver(ctx, wg, arch.blobSaver.Save, @@ -875,12 +874,12 @@ func (arch *Archiver) Snapshot(ctx context.Context, targets []string, opts Snaps var rootTreeID restic.ID - err = arch.Repo.WithBlobUploader(ctx, func(ctx context.Context) error { + err = arch.Repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error { wg, wgCtx := errgroup.WithContext(ctx) start := time.Now() wg.Go(func() error { - arch.runWorkers(wgCtx, wg) + arch.runWorkers(wgCtx, wg, uploader) debug.Log("starting snapshot") fn, nodeCount, err := arch.saveTree(wgCtx, "/", atree, arch.loadParentTree(wgCtx, opts.ParentSnapshot), func(_ *data.Node, is ItemStats) { diff --git a/internal/archiver/archiver_test.go b/internal/archiver/archiver_test.go index b80942c68..adc7695cb 100644 --- a/internal/archiver/archiver_test.go +++ b/internal/archiver/archiver_test.go @@ -56,9 +56,9 @@ func saveFile(t testing.TB, repo archiverRepo, filename string, filesystem fs.FS return err } - err := repo.WithBlobUploader(context.TODO(), func(ctx context.Context) error { + err := repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { wg, ctx := errgroup.WithContext(ctx) - arch.runWorkers(ctx, wg) + arch.runWorkers(ctx, wg, uploader) completeReading := func() { completeReadingCallback = true @@ -219,9 +219,9 @@ func TestArchiverSave(t *testing.T) { arch.summary = &Summary{} var fnr futureNodeResult - err := repo.WithBlobUploader(ctx, func(ctx context.Context) error { + err := repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error { wg, ctx := errgroup.WithContext(ctx) - arch.runWorkers(ctx, wg) + arch.runWorkers(ctx, wg, uploader) node, excluded, err := arch.save(ctx, "/", filepath.Join(tempdir, "file"), nil) if err != nil { @@ -296,9 +296,9 @@ func TestArchiverSaveReaderFS(t *testing.T) { arch.summary = &Summary{} var fnr futureNodeResult - err = repo.WithBlobUploader(ctx, func(ctx context.Context) error { + err = repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error { wg, ctx := errgroup.WithContext(ctx) - arch.runWorkers(ctx, wg) + arch.runWorkers(ctx, wg, uploader) node, excluded, err := arch.save(ctx, "/", filename, nil) t.Logf("Save returned %v %v", node, err) @@ -415,27 +415,29 @@ type blobCountingRepo struct { saved map[restic.BlobHandle]uint } -func (repo *blobCountingRepo) SaveBlob(ctx context.Context, t restic.BlobType, buf []byte, id restic.ID, storeDuplicate bool) (restic.ID, bool, int, error) { - id, exists, size, err := repo.archiverRepo.SaveBlob(ctx, t, buf, id, storeDuplicate) +func (repo *blobCountingRepo) WithBlobUploader(ctx context.Context, fn func(ctx context.Context, uploader restic.BlobSaver) error) error { + return repo.archiverRepo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error { + return fn(ctx, &blobCountingSaver{saver: uploader, blobCountingRepo: repo}) + }) +} + +type blobCountingSaver struct { + saver restic.BlobSaver + blobCountingRepo *blobCountingRepo +} + +func (repo *blobCountingSaver) SaveBlob(ctx context.Context, t restic.BlobType, buf []byte, id restic.ID, storeDuplicate bool) (restic.ID, bool, int, error) { + id, exists, size, err := repo.saver.SaveBlob(ctx, t, buf, id, storeDuplicate) if exists { return id, exists, size, err } h := restic.BlobHandle{ID: id, Type: t} - repo.m.Lock() - repo.saved[h]++ - repo.m.Unlock() + repo.blobCountingRepo.m.Lock() + repo.blobCountingRepo.saved[h]++ + repo.blobCountingRepo.m.Unlock() return id, exists, size, err } -func (repo *blobCountingRepo) SaveTree(ctx context.Context, t *data.Tree) (restic.ID, error) { - id, err := data.SaveTree(ctx, repo.archiverRepo, t) - h := restic.BlobHandle{ID: id, Type: restic.TreeBlob} - repo.m.Lock() - repo.saved[h]++ - repo.m.Unlock() - return id, err -} - func appendToFile(t testing.TB, filename string, data []byte) { f, err := os.OpenFile(filename, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644) if err != nil { @@ -838,9 +840,9 @@ func TestArchiverSaveDir(t *testing.T) { defer back() var treeID restic.ID - err := repo.WithBlobUploader(context.TODO(), func(ctx context.Context) error { + err := repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { wg, ctx := errgroup.WithContext(ctx) - arch.runWorkers(ctx, wg) + arch.runWorkers(ctx, wg, uploader) meta, err := testFS.OpenFile(test.target, fs.O_NOFOLLOW, true) rtest.OK(t, err) ft, err := arch.saveDir(ctx, "/", test.target, meta, nil, nil) @@ -866,7 +868,7 @@ func TestArchiverSaveDir(t *testing.T) { node.Name = targetNodeName tree := &data.Tree{Nodes: []*data.Node{node}} - treeID, err = data.SaveTree(ctx, repo, tree) + treeID, err = data.SaveTree(ctx, uploader, tree) if err != nil { t.Fatal(err) } @@ -904,9 +906,9 @@ func TestArchiverSaveDirIncremental(t *testing.T) { arch.summary = &Summary{} var fnr futureNodeResult - err := repo.WithBlobUploader(context.TODO(), func(ctx context.Context) error { + err := repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { wg, ctx := errgroup.WithContext(ctx) - arch.runWorkers(ctx, wg) + arch.runWorkers(ctx, wg, uploader) meta, err := testFS.OpenFile(tempdir, fs.O_NOFOLLOW, true) rtest.OK(t, err) ft, err := arch.saveDir(ctx, "/", tempdir, meta, nil, nil) @@ -1094,9 +1096,9 @@ func TestArchiverSaveTree(t *testing.T) { } var treeID restic.ID - err := repo.WithBlobUploader(context.TODO(), func(ctx context.Context) error { + err := repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { wg, ctx := errgroup.WithContext(ctx) - arch.runWorkers(ctx, wg) + arch.runWorkers(ctx, wg, uploader) atree, err := newTree(testFS, test.targets) if err != nil { @@ -2093,13 +2095,24 @@ type failSaveRepo struct { err error } -func (f *failSaveRepo) SaveBlob(ctx context.Context, t restic.BlobType, buf []byte, id restic.ID, storeDuplicate bool) (restic.ID, bool, int, error) { - val := atomic.AddInt32(&f.cnt, 1) - if val >= f.failAfter { - return restic.Hash(buf), false, 0, f.err +func (f *failSaveRepo) WithBlobUploader(ctx context.Context, fn func(ctx context.Context, uploader restic.BlobSaver) error) error { + return f.archiverRepo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error { + return fn(ctx, &failSaveSaver{saver: uploader, failSaveRepo: f}) + }) +} + +type failSaveSaver struct { + saver restic.BlobSaver + failSaveRepo *failSaveRepo +} + +func (f *failSaveSaver) SaveBlob(ctx context.Context, t restic.BlobType, buf []byte, id restic.ID, storeDuplicate bool) (restic.ID, bool, int, error) { + val := atomic.AddInt32(&f.failSaveRepo.cnt, 1) + if val >= f.failSaveRepo.failAfter { + return restic.Hash(buf), false, 0, f.failSaveRepo.err } - return f.archiverRepo.SaveBlob(ctx, t, buf, id, storeDuplicate) + return f.saver.SaveBlob(ctx, t, buf, id, storeDuplicate) } func TestArchiverAbortEarlyOnError(t *testing.T) { @@ -2412,7 +2425,7 @@ func TestRacyFileTypeSwap(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - _ = repo.WithBlobUploader(ctx, func(ctx context.Context) error { + _ = repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error { wg, ctx := errgroup.WithContext(ctx) arch := New(repo, fs.Track{FS: statfs}, Options{}) @@ -2420,7 +2433,7 @@ func TestRacyFileTypeSwap(t *testing.T) { t.Logf("archiver error as expected for %v: %v", item, err) return err } - arch.runWorkers(ctx, wg) + arch.runWorkers(ctx, wg, uploader) // fs.Track will panic if the file was not closed _, excluded, err := arch.save(ctx, "/", tempfile, nil) diff --git a/internal/checker/checker_test.go b/internal/checker/checker_test.go index 445d7a55f..f461ce5f8 100644 --- a/internal/checker/checker_test.go +++ b/internal/checker/checker_test.go @@ -525,18 +525,18 @@ func TestCheckerBlobTypeConfusion(t *testing.T) { } var id restic.ID - test.OK(t, repo.WithBlobUploader(ctx, func(ctx context.Context) error { + test.OK(t, repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error { var err error - id, err = data.SaveTree(ctx, repo, damagedTree) + id, err = data.SaveTree(ctx, uploader, damagedTree) return err })) buf, err := repo.LoadBlob(ctx, restic.TreeBlob, id, nil) test.OK(t, err) - test.OK(t, repo.WithBlobUploader(ctx, func(ctx context.Context) error { + test.OK(t, repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error { var err error - _, _, _, err = repo.SaveBlob(ctx, restic.DataBlob, buf, id, false) + _, _, _, err = uploader.SaveBlob(ctx, restic.DataBlob, buf, id, false) return err })) @@ -559,9 +559,9 @@ func TestCheckerBlobTypeConfusion(t *testing.T) { } var rootID restic.ID - test.OK(t, repo.WithBlobUploader(ctx, func(ctx context.Context) error { + test.OK(t, repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error { var err error - rootID, err = data.SaveTree(ctx, repo, rootTree) + rootID, err = data.SaveTree(ctx, uploader, rootTree) return err })) diff --git a/internal/data/testing.go b/internal/data/testing.go index 8e8d9fbde..be4ab4edb 100644 --- a/internal/data/testing.go +++ b/internal/data/testing.go @@ -28,7 +28,7 @@ type fakeFileSystem struct { // saveFile reads from rd and saves the blobs in the repository. The list of // IDs is returned. -func (fs *fakeFileSystem) saveFile(ctx context.Context, rd io.Reader) (blobs restic.IDs) { +func (fs *fakeFileSystem) saveFile(ctx context.Context, uploader restic.BlobSaver, rd io.Reader) (blobs restic.IDs) { if fs.buf == nil { fs.buf = make([]byte, chunker.MaxSize) } @@ -50,7 +50,7 @@ func (fs *fakeFileSystem) saveFile(ctx context.Context, rd io.Reader) (blobs res fs.t.Fatalf("unable to save chunk in repo: %v", err) } - id, _, _, err := fs.repo.SaveBlob(ctx, restic.DataBlob, chunk.Data, restic.ID{}, false) + id, _, _, err := uploader.SaveBlob(ctx, restic.DataBlob, chunk.Data, restic.ID{}, false) if err != nil { fs.t.Fatalf("error saving chunk: %v", err) } @@ -68,7 +68,7 @@ const ( ) // saveTree saves a tree of fake files in the repo and returns the ID. -func (fs *fakeFileSystem) saveTree(ctx context.Context, seed int64, depth int) restic.ID { +func (fs *fakeFileSystem) saveTree(ctx context.Context, uploader restic.BlobSaver, seed int64, depth int) restic.ID { rnd := rand.NewSource(seed) numNodes := int(rnd.Int63() % maxNodes) @@ -78,7 +78,7 @@ func (fs *fakeFileSystem) saveTree(ctx context.Context, seed int64, depth int) r // randomly select the type of the node, either tree (p = 1/4) or file (p = 3/4). if depth > 1 && rnd.Int63()%4 == 0 { treeSeed := rnd.Int63() % maxSeed - id := fs.saveTree(ctx, treeSeed, depth-1) + id := fs.saveTree(ctx, uploader, treeSeed, depth-1) node := &Node{ Name: fmt.Sprintf("dir-%v", treeSeed), @@ -101,13 +101,13 @@ func (fs *fakeFileSystem) saveTree(ctx context.Context, seed int64, depth int) r Size: uint64(fileSize), } - node.Content = fs.saveFile(ctx, fakeFile(fileSeed, fileSize)) + node.Content = fs.saveFile(ctx, uploader, fakeFile(fileSeed, fileSize)) tree.Nodes = append(tree.Nodes, node) } tree.Sort() - id, err := SaveTree(ctx, fs.repo, &tree) + id, err := SaveTree(ctx, uploader, &tree) if err != nil { fs.t.Fatalf("SaveTree returned error: %v", err) } @@ -136,8 +136,8 @@ func TestCreateSnapshot(t testing.TB, repo restic.Repository, at time.Time, dept } var treeID restic.ID - test.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context) error { - treeID = fs.saveTree(ctx, seed, depth) + test.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + treeID = fs.saveTree(ctx, uploader, seed, depth) return nil })) snapshot.Tree = &treeID diff --git a/internal/data/tree_test.go b/internal/data/tree_test.go index 51cf16f2c..9164f4da1 100644 --- a/internal/data/tree_test.go +++ b/internal/data/tree_test.go @@ -107,10 +107,10 @@ func TestEmptyLoadTree(t *testing.T) { tree := data.NewTree(0) var id restic.ID - rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context) error { + rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { var err error // save tree - id, err = data.SaveTree(ctx, repo, tree) + id, err = data.SaveTree(ctx, uploader, tree) return err })) diff --git a/internal/repository/fuzz_test.go b/internal/repository/fuzz_test.go index ea4375289..16155f3a4 100644 --- a/internal/repository/fuzz_test.go +++ b/internal/repository/fuzz_test.go @@ -20,8 +20,8 @@ func FuzzSaveLoadBlob(f *testing.F) { id := restic.Hash(blob) repo, _, _ := TestRepositoryWithVersion(t, 2) - rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context) error { - _, _, _, err := repo.SaveBlob(ctx, restic.DataBlob, blob, id, false) + rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + _, _, _, err := uploader.SaveBlob(ctx, restic.DataBlob, blob, id, false) return err })) diff --git a/internal/repository/prune_internal_test.go b/internal/repository/prune_internal_test.go index 089567afe..49a876884 100644 --- a/internal/repository/prune_internal_test.go +++ b/internal/repository/prune_internal_test.go @@ -47,9 +47,9 @@ func TestPruneMaxUnusedDuplicate(t *testing.T) { {bufs[1], bufs[3]}, {bufs[2], bufs[3]}, } { - rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context) error { + rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { for _, blob := range blobs { - id, _, _, err := repo.SaveBlob(ctx, restic.DataBlob, blob, restic.ID{}, true) + id, _, _, err := uploader.SaveBlob(ctx, restic.DataBlob, blob, restic.ID{}, true) keep.Insert(restic.BlobHandle{Type: restic.DataBlob, ID: id}) rtest.OK(t, err) } diff --git a/internal/repository/prune_test.go b/internal/repository/prune_test.go index c66998193..744de0b14 100644 --- a/internal/repository/prune_test.go +++ b/internal/repository/prune_test.go @@ -25,12 +25,12 @@ func testPrune(t *testing.T, opts repository.PruneOptions, errOnUnused bool) { createRandomBlobs(t, random, repo, 5, 0.5, true) keep, _ := selectBlobs(t, random, repo, 0.5) - rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context) error { + rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { // duplicate a few blobs to exercise those code paths for blob := range keep { buf, err := repo.LoadBlob(ctx, blob.Type, blob.ID, nil) rtest.OK(t, err) - _, _, _, err = repo.SaveBlob(ctx, blob.Type, buf, blob.ID, true) + _, _, _, err = uploader.SaveBlob(ctx, blob.Type, buf, blob.ID, true) rtest.OK(t, err) } return nil @@ -133,13 +133,13 @@ func TestPruneSmall(t *testing.T) { const numBlobsCreated = 55 keep := restic.NewBlobSet() - rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context) error { + rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { // we need a minum of 11 packfiles, each packfile will be about 5 Mb long for i := 0; i < numBlobsCreated; i++ { buf := make([]byte, blobSize) random.Read(buf) - id, _, _, err := repo.SaveBlob(ctx, restic.DataBlob, buf, restic.ID{}, false) + id, _, _, err := uploader.SaveBlob(ctx, restic.DataBlob, buf, restic.ID{}, false) rtest.OK(t, err) keep.Insert(restic.BlobHandle{Type: restic.DataBlob, ID: id}) } diff --git a/internal/repository/repack.go b/internal/repository/repack.go index 65f2c0bc4..6ee86eb22 100644 --- a/internal/repository/repack.go +++ b/internal/repository/repack.go @@ -47,9 +47,9 @@ func Repack( return nil, errors.New("repack step requires a backend connection limit of at least two") } - err = dstRepo.WithBlobUploader(ctx, func(ctx context.Context) error { + err = dstRepo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error { var err error - obsoletePacks, err = repack(ctx, repo, dstRepo, packs, keepBlobs, p, logf) + obsoletePacks, err = repack(ctx, repo, dstRepo, uploader, packs, keepBlobs, p, logf) return err }) if err != nil { @@ -62,6 +62,7 @@ func repack( ctx context.Context, repo restic.Repository, dstRepo restic.Repository, + uploader restic.BlobSaver, packs restic.IDSet, keepBlobs repackBlobSet, p *progress.Counter, @@ -128,7 +129,7 @@ func repack( } // We do want to save already saved blobs! - _, _, _, err = dstRepo.SaveBlob(wgCtx, blob.Type, buf, blob.ID, true) + _, _, _, err = uploader.SaveBlob(wgCtx, blob.Type, buf, blob.ID, true) if err != nil { return err } diff --git a/internal/repository/repack_test.go b/internal/repository/repack_test.go index 1aec4e069..599178371 100644 --- a/internal/repository/repack_test.go +++ b/internal/repository/repack_test.go @@ -20,7 +20,7 @@ func randomSize(random *rand.Rand, min, max int) int { func createRandomBlobs(t testing.TB, random *rand.Rand, repo restic.Repository, blobs int, pData float32, smallBlobs bool) { // two loops to allow creating multiple pack files for blobs > 0 { - rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context) error { + rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { for blobs > 0 { blobs-- var ( @@ -43,7 +43,7 @@ func createRandomBlobs(t testing.TB, random *rand.Rand, repo restic.Repository, buf := make([]byte, length) random.Read(buf) - id, exists, _, err := repo.SaveBlob(ctx, tpe, buf, restic.ID{}, false) + id, exists, _, err := uploader.SaveBlob(ctx, tpe, buf, restic.ID{}, false) if err != nil { t.Fatalf("SaveFrom() error %v", err) } @@ -70,8 +70,8 @@ func createRandomWrongBlob(t testing.TB, random *rand.Rand, repo restic.Reposito // invert first data byte buf[0] ^= 0xff - rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context) error { - _, _, _, err := repo.SaveBlob(ctx, restic.DataBlob, buf, id, false) + rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + _, _, _, err := uploader.SaveBlob(ctx, restic.DataBlob, buf, id, false) return err })) return restic.BlobHandle{ID: id, Type: restic.DataBlob} @@ -339,8 +339,8 @@ func testRepackBlobFallback(t *testing.T, version uint) { modbuf[0] ^= 0xff // create pack with broken copy - rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context) error { - _, _, _, err := repo.SaveBlob(ctx, restic.DataBlob, modbuf, id, false) + rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + _, _, _, err := uploader.SaveBlob(ctx, restic.DataBlob, modbuf, id, false) return err })) @@ -349,8 +349,8 @@ func testRepackBlobFallback(t *testing.T, version uint) { rewritePacks := findPacksForBlobs(t, repo, keepBlobs) // create pack with valid copy - rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context) error { - _, _, _, err := repo.SaveBlob(ctx, restic.DataBlob, buf, id, true) + rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + _, _, _, err := uploader.SaveBlob(ctx, restic.DataBlob, buf, id, true) return err })) diff --git a/internal/repository/repair_pack.go b/internal/repository/repair_pack.go index 966c81b53..a6f4a52b8 100644 --- a/internal/repository/repair_pack.go +++ b/internal/repository/repair_pack.go @@ -15,7 +15,7 @@ func RepairPacks(ctx context.Context, repo *Repository, ids restic.IDSet, printe bar.SetMax(uint64(len(ids))) defer bar.Done() - err := repo.WithBlobUploader(ctx, func(ctx context.Context) error { + err := repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error { // examine all data the indexes have for the pack file for b := range repo.ListPacksFromIndex(ctx, ids) { blobs := b.Blobs @@ -30,7 +30,7 @@ func RepairPacks(ctx context.Context, repo *Repository, ids restic.IDSet, printe printer.E("failed to load blob %v: %v", blob.ID, err) return nil } - id, _, _, err := repo.SaveBlob(ctx, blob.Type, buf, restic.ID{}, true) + id, _, _, err := uploader.SaveBlob(ctx, blob.Type, buf, restic.ID{}, true) if !id.Equal(blob.ID) { panic("pack id mismatch during upload") } diff --git a/internal/repository/repository.go b/internal/repository/repository.go index f19a6fe44..3aa87faff 100644 --- a/internal/repository/repository.go +++ b/internal/repository/repository.go @@ -559,11 +559,11 @@ func (r *Repository) removeUnpacked(ctx context.Context, t restic.FileType, id r return r.be.Remove(ctx, backend.Handle{Type: t, Name: id.String()}) } -func (r *Repository) WithBlobUploader(ctx context.Context, fn func(ctx context.Context) error) error { +func (r *Repository) WithBlobUploader(ctx context.Context, fn func(ctx context.Context, uploader restic.BlobSaver) error) error { wg, ctx := errgroup.WithContext(ctx) r.startPackUploader(ctx, wg) wg.Go(func() error { - if err := fn(ctx); err != nil { + if err := fn(ctx, &blobSaverRepo{repo: r}); err != nil { return err } if err := r.flush(ctx); err != nil { @@ -574,6 +574,14 @@ func (r *Repository) WithBlobUploader(ctx context.Context, fn func(ctx context.C return wg.Wait() } +type blobSaverRepo struct { + repo *Repository +} + +func (r *blobSaverRepo) SaveBlob(ctx context.Context, t restic.BlobType, buf []byte, id restic.ID, storeDuplicate bool) (newID restic.ID, known bool, size int, err error) { + return r.repo.saveBlob(ctx, t, buf, id, storeDuplicate) +} + func (r *Repository) startPackUploader(ctx context.Context, wg *errgroup.Group) { if r.packerWg != nil { panic("uploader already started") @@ -926,14 +934,14 @@ func (r *Repository) Close() error { return r.be.Close() } -// SaveBlob saves a blob of type t into the repository. +// saveBlob saves a blob of type t into the repository. // It takes care that no duplicates are saved; this can be overwritten // by setting storeDuplicate to true. // If id is the null id, it will be computed and returned. // Also returns if the blob was already known before. // If the blob was not known before, it returns the number of bytes the blob // occupies in the repo (compressed or not, including encryption overhead). -func (r *Repository) SaveBlob(ctx context.Context, t restic.BlobType, buf []byte, id restic.ID, storeDuplicate bool) (newID restic.ID, known bool, size int, err error) { +func (r *Repository) saveBlob(ctx context.Context, t restic.BlobType, buf []byte, id restic.ID, storeDuplicate bool) (newID restic.ID, known bool, size int, err error) { if int64(len(buf)) > math.MaxUint32 { return restic.ID{}, false, 0, fmt.Errorf("blob is larger than 4GB") diff --git a/internal/repository/repository_test.go b/internal/repository/repository_test.go index 9f364bfc3..60bfd2b5c 100644 --- a/internal/repository/repository_test.go +++ b/internal/repository/repository_test.go @@ -51,13 +51,13 @@ func testSave(t *testing.T, version uint, calculateID bool) { id := restic.Hash(data) - rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context) error { + rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { // save inputID := restic.ID{} if !calculateID { inputID = id } - sid, _, _, err := repo.SaveBlob(ctx, restic.DataBlob, data, inputID, false) + sid, _, _, err := uploader.SaveBlob(ctx, restic.DataBlob, data, inputID, false) rtest.OK(t, err) rtest.Equals(t, id, sid) return nil @@ -97,7 +97,7 @@ func testSavePackMerging(t *testing.T, targetPercentage int, expectedPacks int) }) var ids restic.IDs - rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context) error { + rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { // add blobs with size targetPercentage / 100 * repo.PackSize to the repository blobSize := repository.MinPackSize / 100 for range targetPercentage { @@ -105,7 +105,7 @@ func testSavePackMerging(t *testing.T, targetPercentage int, expectedPacks int) _, err := io.ReadFull(rnd, data) rtest.OK(t, err) - sid, _, _, err := repo.SaveBlob(ctx, restic.DataBlob, data, restic.ID{}, false) + sid, _, _, err := uploader.SaveBlob(ctx, restic.DataBlob, data, restic.ID{}, false) rtest.OK(t, err) ids = append(ids, sid) } @@ -147,9 +147,9 @@ func benchmarkSaveAndEncrypt(t *testing.B, version uint) { t.ResetTimer() t.SetBytes(int64(size)) - _ = repo.WithBlobUploader(context.TODO(), func(ctx context.Context) error { + _ = repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { for i := 0; i < t.N; i++ { - _, _, _, err = repo.SaveBlob(ctx, restic.DataBlob, data, id, true) + _, _, _, err = uploader.SaveBlob(ctx, restic.DataBlob, data, id, true) rtest.OK(t, err) } return nil @@ -168,9 +168,9 @@ func testLoadBlob(t *testing.T, version uint) { rtest.OK(t, err) var id restic.ID - rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context) error { + rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { var err error - id, _, _, err = repo.SaveBlob(ctx, restic.DataBlob, buf, restic.ID{}, false) + id, _, _, err = uploader.SaveBlob(ctx, restic.DataBlob, buf, restic.ID{}, false) return err })) @@ -196,9 +196,9 @@ func TestLoadBlobBroken(t *testing.T) { buf := rtest.Random(42, 1000) var id restic.ID - rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context) error { + rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { var err error - id, _, _, err = repo.SaveBlob(ctx, restic.TreeBlob, buf, restic.ID{}, false) + id, _, _, err = uploader.SaveBlob(ctx, restic.TreeBlob, buf, restic.ID{}, false) return err })) @@ -225,9 +225,9 @@ func benchmarkLoadBlob(b *testing.B, version uint) { rtest.OK(b, err) var id restic.ID - rtest.OK(b, repo.WithBlobUploader(context.TODO(), func(ctx context.Context) error { + rtest.OK(b, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { var err error - id, _, _, err = repo.SaveBlob(ctx, restic.DataBlob, buf, restic.ID{}, false) + id, _, _, err = uploader.SaveBlob(ctx, restic.DataBlob, buf, restic.ID{}, false) return err })) @@ -361,7 +361,7 @@ func TestRepositoryLoadUnpackedRetryBroken(t *testing.T) { // saveRandomDataBlobs generates random data blobs and saves them to the repository. func saveRandomDataBlobs(t testing.TB, repo restic.Repository, num int, sizeMax int) { - rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context) error { + rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { for i := 0; i < num; i++ { size := rand.Int() % sizeMax @@ -369,7 +369,7 @@ func saveRandomDataBlobs(t testing.TB, repo restic.Repository, num int, sizeMax _, err := io.ReadFull(rnd, buf) rtest.OK(t, err) - _, _, _, err = repo.SaveBlob(ctx, restic.DataBlob, buf, restic.ID{}, false) + _, _, _, err = uploader.SaveBlob(ctx, restic.DataBlob, buf, restic.ID{}, false) rtest.OK(t, err) } return nil @@ -432,9 +432,9 @@ func TestListPack(t *testing.T) { buf := rtest.Random(42, 1000) var id restic.ID - rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context) error { + rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { var err error - id, _, _, err = repo.SaveBlob(ctx, restic.TreeBlob, buf, restic.ID{}, false) + id, _, _, err = uploader.SaveBlob(ctx, restic.TreeBlob, buf, restic.ID{}, false) return err })) diff --git a/internal/restic/repository.go b/internal/restic/repository.go index 1087f3757..236171cd2 100644 --- a/internal/restic/repository.go +++ b/internal/restic/repository.go @@ -39,8 +39,7 @@ type Repository interface { // WithUploader starts the necessary workers to upload new blobs. Once the callback returns, // the workers are stopped and the index is written to the repository. The callback must use // the passed context and must not keep references to any of its parameters after returning. - WithBlobUploader(ctx context.Context, fn func(ctx context.Context) error) error - SaveBlob(ctx context.Context, t BlobType, buf []byte, id ID, storeDuplicate bool) (newID ID, known bool, size int, err error) + WithBlobUploader(ctx context.Context, fn func(ctx context.Context, uploader BlobSaver) error) error // List calls the function fn for each file of type t in the repository. // When an error is returned by fn, processing stops and List() returns the @@ -158,6 +157,10 @@ type BlobLoader interface { LoadBlob(context.Context, BlobType, ID, []byte) ([]byte, error) } +type WithBlobUploader interface { + WithBlobUploader(ctx context.Context, fn func(ctx context.Context, uploader BlobSaver) error) error +} + type BlobSaver interface { SaveBlob(context.Context, BlobType, []byte, ID, bool) (ID, bool, int, error) } diff --git a/internal/restorer/restorer_test.go b/internal/restorer/restorer_test.go index 28fae47da..fe9db6b33 100644 --- a/internal/restorer/restorer_test.go +++ b/internal/restorer/restorer_test.go @@ -171,8 +171,8 @@ func saveSnapshot(t testing.TB, repo restic.Repository, snapshot Snapshot, getGe defer cancel() var treeID restic.ID - rtest.OK(t, repo.WithBlobUploader(ctx, func(ctx context.Context) error { - treeID = saveDir(t, repo, snapshot.Nodes, 1000, getGenericAttributes) + rtest.OK(t, repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error { + treeID = saveDir(t, uploader, snapshot.Nodes, 1000, getGenericAttributes) return nil })) diff --git a/internal/walker/rewriter.go b/internal/walker/rewriter.go index 4d70e4d09..bd05b90d7 100644 --- a/internal/walker/rewriter.go +++ b/internal/walker/rewriter.go @@ -11,7 +11,7 @@ import ( ) type NodeRewriteFunc func(node *data.Node, path string) *data.Node -type FailedTreeRewriteFunc func(nodeID restic.ID, path string, err error) (restic.ID, error) +type FailedTreeRewriteFunc func(nodeID restic.ID, path string, err error) (*data.Tree, error) type QueryRewrittenSizeFunc func() SnapshotSize type SnapshotSize struct { @@ -52,8 +52,8 @@ func NewTreeRewriter(opts RewriteOpts) *TreeRewriter { } if rw.opts.RewriteFailedTree == nil { // fail with error by default - rw.opts.RewriteFailedTree = func(_ restic.ID, _ string, err error) (restic.ID, error) { - return restic.ID{}, err + rw.opts.RewriteFailedTree = func(_ restic.ID, _ string, err error) (*data.Tree, error) { + return nil, err } } return rw @@ -82,12 +82,7 @@ func NewSnapshotSizeRewriter(rewriteNode NodeRewriteFunc) (*TreeRewriter, QueryR return t, ss } -type BlobLoadSaver interface { - restic.BlobSaver - restic.BlobLoader -} - -func (t *TreeRewriter) RewriteTree(ctx context.Context, repo BlobLoadSaver, nodepath string, nodeID restic.ID) (newNodeID restic.ID, err error) { +func (t *TreeRewriter) RewriteTree(ctx context.Context, loader restic.BlobLoader, saver restic.BlobSaver, nodepath string, nodeID restic.ID) (newNodeID restic.ID, err error) { // check if tree was already changed newID, ok := t.replaces[nodeID] if ok { @@ -95,16 +90,27 @@ func (t *TreeRewriter) RewriteTree(ctx context.Context, repo BlobLoadSaver, node } // a nil nodeID will lead to a load error - curTree, err := data.LoadTree(ctx, repo, nodeID) + curTree, err := data.LoadTree(ctx, loader, nodeID) if err != nil { - return t.opts.RewriteFailedTree(nodeID, nodepath, err) + replacement, err := t.opts.RewriteFailedTree(nodeID, nodepath, err) + if err != nil { + return restic.ID{}, err + } + if replacement != nil { + replacementID, err := data.SaveTree(ctx, saver, replacement) + if err != nil { + return restic.ID{}, err + } + return replacementID, nil + } + return restic.ID{}, nil } if !t.opts.AllowUnstableSerialization { // check that we can properly encode this tree without losing information // The alternative of using json/Decoder.DisallowUnknownFields() doesn't work as we use // a custom UnmarshalJSON to decode trees, see also https://github.com/golang/go/issues/41144 - testID, err := data.SaveTree(ctx, repo, curTree) + testID, err := data.SaveTree(ctx, saver, curTree) if err != nil { return restic.ID{}, err } @@ -139,7 +145,7 @@ func (t *TreeRewriter) RewriteTree(ctx context.Context, repo BlobLoadSaver, node if node.Subtree != nil { subtree = *node.Subtree } - newID, err := t.RewriteTree(ctx, repo, path, subtree) + newID, err := t.RewriteTree(ctx, loader, saver, path, subtree) if err != nil { return restic.ID{}, err } @@ -156,7 +162,7 @@ func (t *TreeRewriter) RewriteTree(ctx context.Context, repo BlobLoadSaver, node } // Save new tree - newTreeID, _, _, err := repo.SaveBlob(ctx, restic.TreeBlob, tree, restic.ID{}, false) + newTreeID, _, _, err := saver.SaveBlob(ctx, restic.TreeBlob, tree, restic.ID{}, false) if t.replaces != nil { t.replaces[nodeID] = newTreeID } diff --git a/internal/walker/rewriter_test.go b/internal/walker/rewriter_test.go index a21af36c6..9290a62d5 100644 --- a/internal/walker/rewriter_test.go +++ b/internal/walker/rewriter_test.go @@ -285,7 +285,7 @@ func TestRewriter(t *testing.T) { defer cancel() rewriter, last := test.check(t) - newRoot, err := rewriter.RewriteTree(ctx, modrepo, "/", root) + newRoot, err := rewriter.RewriteTree(ctx, modrepo, modrepo, "/", root) if err != nil { t.Error(err) } @@ -335,7 +335,7 @@ func TestSnapshotSizeQuery(t *testing.T) { return node } rewriter, querySize := NewSnapshotSizeRewriter(rewriteNode) - newRoot, err := rewriter.RewriteTree(ctx, modrepo, "/", root) + newRoot, err := rewriter.RewriteTree(ctx, modrepo, modrepo, "/", root) if err != nil { t.Error(err) } @@ -373,7 +373,7 @@ func TestRewriterFailOnUnknownFields(t *testing.T) { return node }, }) - _, err := rewriter.RewriteTree(ctx, tm, "/", id) + _, err := rewriter.RewriteTree(ctx, tm, tm, "/", id) if err == nil { t.Error("missing error on unknown field") @@ -383,7 +383,7 @@ func TestRewriterFailOnUnknownFields(t *testing.T) { rewriter = NewTreeRewriter(RewriteOpts{ AllowUnstableSerialization: true, }) - root, err := rewriter.RewriteTree(ctx, tm, "/", id) + root, err := rewriter.RewriteTree(ctx, tm, tm, "/", id) test.OK(t, err) _, expRoot := BuildTreeMap(TestTree{ "subfile": TestFile{}, @@ -400,21 +400,24 @@ func TestRewriterTreeLoadError(t *testing.T) { // also check that load error by default cause the operation to fail rewriter := NewTreeRewriter(RewriteOpts{}) - _, err := rewriter.RewriteTree(ctx, tm, "/", id) + _, err := rewriter.RewriteTree(ctx, tm, tm, "/", id) if err == nil { t.Fatal("missing error on unloadable tree") } - replacementID := restic.NewRandomID() + replacementTree := &data.Tree{Nodes: []*data.Node{{Name: "replacement", Type: data.NodeTypeFile, Size: 42}}} + replacementID, err := data.SaveTree(ctx, tm, replacementTree) + test.OK(t, err) + rewriter = NewTreeRewriter(RewriteOpts{ - RewriteFailedTree: func(nodeID restic.ID, path string, err error) (restic.ID, error) { + RewriteFailedTree: func(nodeID restic.ID, path string, err error) (*data.Tree, error) { if nodeID != id || path != "/" { t.Fail() } - return replacementID, nil + return replacementTree, nil }, }) - newRoot, err := rewriter.RewriteTree(ctx, tm, "/", id) + newRoot, err := rewriter.RewriteTree(ctx, tm, tm, "/", id) test.OK(t, err) test.Equals(t, replacementID, newRoot) }