diff --git a/cmd/restic/cmd_copy.go b/cmd/restic/cmd_copy.go index fa81755d2..d6a5efe57 100644 --- a/cmd/restic/cmd_copy.go +++ b/cmd/restic/cmd_copy.go @@ -203,7 +203,7 @@ func copyTreeBatched(ctx context.Context, srcRepo restic.Repository, dstRepo res startTime := time.Now() // call WithBlobUploader() once and then loop over all selectedSnapshots - err := dstRepo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error { + err := dstRepo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { for len(selectedSnapshots) > 0 && (batchSize < targetSize || time.Since(startTime) < minDuration) { sn := selectedSnapshots[0] selectedSnapshots = selectedSnapshots[1:] @@ -242,7 +242,7 @@ func copyTreeBatched(ctx context.Context, srcRepo restic.Repository, dstRepo res } func copyTree(ctx context.Context, srcRepo restic.Repository, dstRepo restic.Repository, - visitedTrees restic.AssociatedBlobSet, rootTreeID restic.ID, printer progress.Printer, uploader restic.BlobSaver) (uint64, error) { + visitedTrees restic.AssociatedBlobSet, rootTreeID restic.ID, printer progress.Printer, uploader restic.BlobSaverWithAsync) (uint64, error) { wg, wgCtx := errgroup.WithContext(ctx) diff --git a/cmd/restic/cmd_debug.go b/cmd/restic/cmd_debug.go index 48e6d58b7..2dd9d5cf0 100644 --- a/cmd/restic/cmd_debug.go +++ b/cmd/restic/cmd_debug.go @@ -353,7 +353,7 @@ func loadBlobs(ctx context.Context, opts DebugExamineOptions, repo restic.Reposi return err } - err = repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error { + err = repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { for _, blob := range list { printer.S(" loading blob %v at %v (length %v)", blob.ID, blob.Offset, blob.Length) if int(blob.Offset+blob.Length) > len(pack) { diff --git a/cmd/restic/cmd_recover.go b/cmd/restic/cmd_recover.go index ca22ee2de..fec4c44b5 100644 --- a/cmd/restic/cmd_recover.go +++ b/cmd/restic/cmd_recover.go @@ -153,7 +153,7 @@ func runRecover(ctx context.Context, gopts global.Options, term ui.Terminal) err } var treeID restic.ID - err = repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error { + err = repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { var err error treeID, err = data.SaveTree(ctx, uploader, tree) if err != nil { diff --git a/cmd/restic/cmd_rewrite.go b/cmd/restic/cmd_rewrite.go index 76a504652..9c53dcae6 100644 --- a/cmd/restic/cmd_rewrite.go +++ b/cmd/restic/cmd_rewrite.go @@ -194,7 +194,7 @@ func filterAndReplaceSnapshot(ctx context.Context, repo restic.Repository, sn *d var filteredTree restic.ID var summary *data.SnapshotSummary - err := repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error { + err := repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { var err error filteredTree, summary, err = filter(ctx, sn, uploader) return err diff --git a/internal/archiver/archiver.go b/internal/archiver/archiver.go index d619ad9b4..0ed37eb5b 100644 --- a/internal/archiver/archiver.go +++ b/internal/archiver/archiver.go @@ -96,7 +96,6 @@ type Archiver struct { FS fs.FS Options Options - blobSaver *blobSaver fileSaver *fileSaver treeSaver *treeSaver mu sync.Mutex @@ -145,11 +144,6 @@ type Options struct { // turned out to be a good default for most situations). ReadConcurrency uint - // SaveBlobConcurrency sets how many blobs are hashed and saved - // concurrently. If it's set to zero, the default is the number of CPUs - // available in the system. - SaveBlobConcurrency uint - // SaveTreeConcurrency sets how many trees are marshalled and saved to the // repo concurrently. SaveTreeConcurrency uint @@ -165,12 +159,6 @@ func (o Options) ApplyDefaults() Options { o.ReadConcurrency = 2 } - if o.SaveBlobConcurrency == 0 { - // blob saving is CPU bound due to hash checking and encryption - // the actual upload is handled by the repository itself - o.SaveBlobConcurrency = uint(runtime.GOMAXPROCS(0)) - } - if o.SaveTreeConcurrency == 0 { // can either wait for a file, wait for a tree, serialize a tree or wait for saveblob // the last two are cpu-bound and thus mutually exclusive. @@ -834,24 +822,20 @@ func (arch *Archiver) loadParentTree(ctx context.Context, sn *data.Snapshot) *da } // runWorkers starts the worker pools, which are stopped when the context is cancelled. -func (arch *Archiver) runWorkers(ctx context.Context, wg *errgroup.Group, uploader restic.BlobSaver) { - arch.blobSaver = newBlobSaver(ctx, wg, uploader, arch.Options.SaveBlobConcurrency) - +func (arch *Archiver) runWorkers(ctx context.Context, wg *errgroup.Group, uploader restic.BlobSaverAsync) { arch.fileSaver = newFileSaver(ctx, wg, - arch.blobSaver.Save, + uploader, arch.Repo.Config().ChunkerPolynomial, - arch.Options.ReadConcurrency, arch.Options.SaveBlobConcurrency) + arch.Options.ReadConcurrency) arch.fileSaver.CompleteBlob = arch.CompleteBlob arch.fileSaver.NodeFromFileInfo = arch.nodeFromFileInfo - arch.treeSaver = newTreeSaver(ctx, wg, arch.Options.SaveTreeConcurrency, arch.blobSaver.Save, arch.Error) + arch.treeSaver = newTreeSaver(ctx, wg, arch.Options.SaveTreeConcurrency, uploader, arch.Error) } func (arch *Archiver) stopWorkers() { - arch.blobSaver.TriggerShutdown() arch.fileSaver.TriggerShutdown() arch.treeSaver.TriggerShutdown() - arch.blobSaver = nil arch.fileSaver = nil arch.treeSaver = nil } @@ -874,7 +858,7 @@ func (arch *Archiver) Snapshot(ctx context.Context, targets []string, opts Snaps var rootTreeID restic.ID - err = arch.Repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error { + err = arch.Repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { wg, wgCtx := errgroup.WithContext(ctx) start := time.Now() diff --git a/internal/archiver/archiver_test.go b/internal/archiver/archiver_test.go index adc7695cb..a4012d585 100644 --- a/internal/archiver/archiver_test.go +++ b/internal/archiver/archiver_test.go @@ -56,7 +56,7 @@ func saveFile(t testing.TB, repo archiverRepo, filename string, filesystem fs.FS return err } - err := repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + err := repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { wg, ctx := errgroup.WithContext(ctx) arch.runWorkers(ctx, wg, uploader) @@ -219,7 +219,7 @@ func TestArchiverSave(t *testing.T) { arch.summary = &Summary{} var fnr futureNodeResult - err := repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error { + err := repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { wg, ctx := errgroup.WithContext(ctx) arch.runWorkers(ctx, wg, uploader) @@ -296,7 +296,7 @@ func TestArchiverSaveReaderFS(t *testing.T) { arch.summary = &Summary{} var fnr futureNodeResult - err = repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error { + err = repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { wg, ctx := errgroup.WithContext(ctx) arch.runWorkers(ctx, wg, uploader) @@ -415,29 +415,39 @@ type blobCountingRepo struct { saved map[restic.BlobHandle]uint } -func (repo *blobCountingRepo) WithBlobUploader(ctx context.Context, fn func(ctx context.Context, uploader restic.BlobSaver) error) error { - return repo.archiverRepo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error { +func (repo *blobCountingRepo) WithBlobUploader(ctx context.Context, fn func(ctx context.Context, uploader restic.BlobSaverWithAsync) error) error { + return repo.archiverRepo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { return fn(ctx, &blobCountingSaver{saver: uploader, blobCountingRepo: repo}) }) } type blobCountingSaver struct { - saver restic.BlobSaver + saver restic.BlobSaverWithAsync blobCountingRepo *blobCountingRepo } +func (repo *blobCountingSaver) count(exists bool, h restic.BlobHandle) { + if exists { + return + } + repo.blobCountingRepo.m.Lock() + repo.blobCountingRepo.saved[h]++ + repo.blobCountingRepo.m.Unlock() +} + func (repo *blobCountingSaver) SaveBlob(ctx context.Context, t restic.BlobType, buf []byte, id restic.ID, storeDuplicate bool) (restic.ID, bool, int, error) { id, exists, size, err := repo.saver.SaveBlob(ctx, t, buf, id, storeDuplicate) - if exists { - return id, exists, size, err - } - h := restic.BlobHandle{ID: id, Type: t} - repo.blobCountingRepo.m.Lock() - repo.blobCountingRepo.saved[h]++ - repo.blobCountingRepo.m.Unlock() + repo.count(exists, restic.BlobHandle{ID: id, Type: t}) return id, exists, size, err } +func (repo *blobCountingSaver) SaveBlobAsync(ctx context.Context, t restic.BlobType, buf []byte, id restic.ID, storeDuplicate bool, cb func(newID restic.ID, known bool, size int, err error)) { + repo.saver.SaveBlobAsync(ctx, t, buf, id, storeDuplicate, func(newID restic.ID, known bool, size int, err error) { + repo.count(known, restic.BlobHandle{ID: newID, Type: t}) + cb(newID, known, size, err) + }) +} + func appendToFile(t testing.TB, filename string, data []byte) { f, err := os.OpenFile(filename, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644) if err != nil { @@ -840,7 +850,7 @@ func TestArchiverSaveDir(t *testing.T) { defer back() var treeID restic.ID - err := repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + err := repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { wg, ctx := errgroup.WithContext(ctx) arch.runWorkers(ctx, wg, uploader) meta, err := testFS.OpenFile(test.target, fs.O_NOFOLLOW, true) @@ -906,7 +916,7 @@ func TestArchiverSaveDirIncremental(t *testing.T) { arch.summary = &Summary{} var fnr futureNodeResult - err := repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + err := repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { wg, ctx := errgroup.WithContext(ctx) arch.runWorkers(ctx, wg, uploader) meta, err := testFS.OpenFile(tempdir, fs.O_NOFOLLOW, true) @@ -1096,7 +1106,7 @@ func TestArchiverSaveTree(t *testing.T) { } var treeID restic.ID - err := repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + err := repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { wg, ctx := errgroup.WithContext(ctx) arch.runWorkers(ctx, wg, uploader) @@ -2074,8 +2084,6 @@ func TestArchiverContextCanceled(t *testing.T) { type TrackFS struct { fs.FS - errorOn map[string]error - opened map[string]uint m sync.Mutex } @@ -2091,38 +2099,61 @@ func (m *TrackFS) OpenFile(name string, flag int, metadataOnly bool) (fs.File, e type failSaveRepo struct { archiverRepo failAfter int32 - cnt int32 + cnt atomic.Int32 err error } -func (f *failSaveRepo) WithBlobUploader(ctx context.Context, fn func(ctx context.Context, uploader restic.BlobSaver) error) error { - return f.archiverRepo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error { - return fn(ctx, &failSaveSaver{saver: uploader, failSaveRepo: f}) +func (f *failSaveRepo) WithBlobUploader(ctx context.Context, fn func(ctx context.Context, uploader restic.BlobSaverWithAsync) error) error { + outerCtx, outerCancel := context.WithCancelCause(ctx) + defer outerCancel(f.err) + return f.archiverRepo.WithBlobUploader(outerCtx, func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { + return fn(ctx, &failSaveSaver{saver: uploader, failSaveRepo: f, semaphore: make(chan struct{}, 1), outerCancel: outerCancel}) }) } type failSaveSaver struct { - saver restic.BlobSaver + saver restic.BlobSaverWithAsync failSaveRepo *failSaveRepo + semaphore chan struct{} + outerCancel context.CancelCauseFunc } func (f *failSaveSaver) SaveBlob(ctx context.Context, t restic.BlobType, buf []byte, id restic.ID, storeDuplicate bool) (restic.ID, bool, int, error) { - val := atomic.AddInt32(&f.failSaveRepo.cnt, 1) + val := f.failSaveRepo.cnt.Add(1) if val >= f.failSaveRepo.failAfter { - return restic.Hash(buf), false, 0, f.failSaveRepo.err + return restic.ID{}, false, 0, f.failSaveRepo.err } return f.saver.SaveBlob(ctx, t, buf, id, storeDuplicate) } -func TestArchiverAbortEarlyOnError(t *testing.T) { - var testErr = errors.New("test error") +func (f *failSaveSaver) SaveBlobAsync(ctx context.Context, t restic.BlobType, buf []byte, id restic.ID, storeDuplicate bool, cb func(newID restic.ID, known bool, size int, err error)) { + // limit concurrency to make test reliable + f.semaphore <- struct{}{} + val := f.failSaveRepo.cnt.Add(1) + if val >= f.failSaveRepo.failAfter { + // kill the outer context to make SaveBlobAsync fail + // precisely injecting a specific error into the repository is not possible, so just cancel the context + f.outerCancel(f.failSaveRepo.err) + } + + f.saver.SaveBlobAsync(ctx, t, buf, id, storeDuplicate, func(newID restic.ID, known bool, size int, err error) { + if val >= f.failSaveRepo.failAfter { + if err == nil { + panic("expected error") + } + } + cb(newID, known, size, err) + <-f.semaphore + }) +} + +func TestArchiverAbortEarlyOnError(t *testing.T) { var tests = []struct { src TestDir wantOpen map[string]uint failAfter uint // error after so many blobs have been saved to the repo - err error }{ { src: TestDir{ @@ -2134,8 +2165,6 @@ func TestArchiverAbortEarlyOnError(t *testing.T) { }, wantOpen: map[string]uint{ filepath.FromSlash("dir/bar"): 1, - filepath.FromSlash("dir/baz"): 1, - filepath.FromSlash("dir/foo"): 1, }, }, { @@ -2162,9 +2191,8 @@ func TestArchiverAbortEarlyOnError(t *testing.T) { filepath.FromSlash("dir/file9"): 0, }, // fails after four to seven files were opened, as the ReadConcurrency allows for - // two queued files and SaveBlobConcurrency for one blob queued for saving. + // two queued files and one blob queued for saving. failAfter: 4, - err: testErr, }, } @@ -2183,25 +2211,25 @@ func TestArchiverAbortEarlyOnError(t *testing.T) { opened: make(map[string]uint), } - if testFS.errorOn == nil { - testFS.errorOn = make(map[string]error) - } - + testErr := context.Canceled testRepo := &failSaveRepo{ archiverRepo: repo, failAfter: int32(test.failAfter), - err: test.err, + err: testErr, } // at most two files may be queued arch := New(testRepo, testFS, Options{ - ReadConcurrency: 2, - SaveBlobConcurrency: 1, + ReadConcurrency: 2, }) + arch.Error = func(item string, err error) error { + t.Logf("archiver error for %q: %v", item, err) + return err + } _, _, _, err := arch.Snapshot(ctx, []string{"."}, SnapshotOptions{Time: time.Now()}) - if !errors.Is(err, test.err) { - t.Errorf("expected error (%v) not found, got %v", test.err, err) + if !errors.Is(err, testErr) { + t.Errorf("expected error (%v) not found, got %v", testErr, err) } t.Logf("Snapshot return error: %v", err) @@ -2425,7 +2453,7 @@ func TestRacyFileTypeSwap(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - _ = repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error { + _ = repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { wg, ctx := errgroup.WithContext(ctx) arch := New(repo, fs.Track{FS: statfs}, Options{}) diff --git a/internal/archiver/blob_saver.go b/internal/archiver/blob_saver.go deleted file mode 100644 index 356a32ce2..000000000 --- a/internal/archiver/blob_saver.go +++ /dev/null @@ -1,105 +0,0 @@ -package archiver - -import ( - "context" - "fmt" - - "github.com/restic/restic/internal/debug" - "github.com/restic/restic/internal/restic" - "golang.org/x/sync/errgroup" -) - -// saver allows saving a blob. -type saver interface { - SaveBlob(ctx context.Context, t restic.BlobType, data []byte, id restic.ID, storeDuplicate bool) (restic.ID, bool, int, error) -} - -// blobSaver concurrently saves incoming blobs to the repo. -type blobSaver struct { - repo saver - ch chan<- saveBlobJob -} - -// newBlobSaver returns a new blob. A worker pool is started, it is stopped -// when ctx is cancelled. -func newBlobSaver(ctx context.Context, wg *errgroup.Group, repo saver, workers uint) *blobSaver { - ch := make(chan saveBlobJob) - s := &blobSaver{ - repo: repo, - ch: ch, - } - - for i := uint(0); i < workers; i++ { - wg.Go(func() error { - return s.worker(ctx, ch) - }) - } - - return s -} - -func (s *blobSaver) TriggerShutdown() { - close(s.ch) -} - -// Save stores a blob in the repo. It checks the index and the known blobs -// before saving anything. It takes ownership of the buffer passed in. -func (s *blobSaver) Save(ctx context.Context, t restic.BlobType, buf *buffer, filename string, cb func(res saveBlobResponse)) { - select { - case s.ch <- saveBlobJob{BlobType: t, buf: buf, fn: filename, cb: cb}: - case <-ctx.Done(): - debug.Log("not sending job, context is cancelled") - } -} - -type saveBlobJob struct { - restic.BlobType - buf *buffer - fn string - cb func(res saveBlobResponse) -} - -type saveBlobResponse struct { - id restic.ID - length int - sizeInRepo int - known bool -} - -func (s *blobSaver) saveBlob(ctx context.Context, t restic.BlobType, buf []byte) (saveBlobResponse, error) { - id, known, sizeInRepo, err := s.repo.SaveBlob(ctx, t, buf, restic.ID{}, false) - - if err != nil { - return saveBlobResponse{}, err - } - - return saveBlobResponse{ - id: id, - length: len(buf), - sizeInRepo: sizeInRepo, - known: known, - }, nil -} - -func (s *blobSaver) worker(ctx context.Context, jobs <-chan saveBlobJob) error { - for { - var job saveBlobJob - var ok bool - select { - case <-ctx.Done(): - return nil - case job, ok = <-jobs: - if !ok { - return nil - } - } - - res, err := s.saveBlob(ctx, job.BlobType, job.buf.Data) - if err != nil { - debug.Log("saveBlob returned error, exiting: %v", err) - return fmt.Errorf("failed to save blob from file %q: %w", job.fn, err) - } - job.cb(res) - job.buf.Release() - } -} diff --git a/internal/archiver/blob_saver_test.go b/internal/archiver/blob_saver_test.go deleted file mode 100644 index e23ed12e5..000000000 --- a/internal/archiver/blob_saver_test.go +++ /dev/null @@ -1,116 +0,0 @@ -package archiver - -import ( - "context" - "fmt" - "runtime" - "strings" - "sync" - "sync/atomic" - "testing" - - "github.com/restic/restic/internal/errors" - "github.com/restic/restic/internal/restic" - rtest "github.com/restic/restic/internal/test" - "golang.org/x/sync/errgroup" -) - -var errTest = errors.New("test error") - -type saveFail struct { - cnt int32 - failAt int32 -} - -func (b *saveFail) SaveBlob(_ context.Context, _ restic.BlobType, _ []byte, id restic.ID, _ bool) (restic.ID, bool, int, error) { - val := atomic.AddInt32(&b.cnt, 1) - if val == b.failAt { - return restic.ID{}, false, 0, errTest - } - - return id, false, 0, nil -} - -func TestBlobSaver(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - wg, ctx := errgroup.WithContext(ctx) - saver := &saveFail{} - - b := newBlobSaver(ctx, wg, saver, uint(runtime.NumCPU())) - - var wait sync.WaitGroup - var results []saveBlobResponse - var lock sync.Mutex - - wait.Add(20) - for i := 0; i < 20; i++ { - buf := &buffer{Data: []byte(fmt.Sprintf("foo%d", i))} - idx := i - lock.Lock() - results = append(results, saveBlobResponse{}) - lock.Unlock() - b.Save(ctx, restic.DataBlob, buf, "file", func(res saveBlobResponse) { - lock.Lock() - results[idx] = res - lock.Unlock() - wait.Done() - }) - } - - wait.Wait() - for i, sbr := range results { - if sbr.known { - t.Errorf("blob %v is known, that should not be the case", i) - } - } - - b.TriggerShutdown() - - err := wg.Wait() - if err != nil { - t.Fatal(err) - } -} - -func TestBlobSaverError(t *testing.T) { - var tests = []struct { - blobs int - failAt int - }{ - {20, 2}, - {20, 5}, - {20, 15}, - {200, 150}, - } - - for _, test := range tests { - t.Run("", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - wg, ctx := errgroup.WithContext(ctx) - saver := &saveFail{ - failAt: int32(test.failAt), - } - - b := newBlobSaver(ctx, wg, saver, uint(runtime.NumCPU())) - - for i := 0; i < test.blobs; i++ { - buf := &buffer{Data: []byte(fmt.Sprintf("foo%d", i))} - b.Save(ctx, restic.DataBlob, buf, "errfile", func(res saveBlobResponse) {}) - } - - b.TriggerShutdown() - - err := wg.Wait() - if err == nil { - t.Errorf("expected error not found") - } - - rtest.Assert(t, errors.Is(err, errTest), "unexpected error %v", err) - rtest.Assert(t, strings.Contains(err.Error(), "errfile"), "expected error to contain 'errfile' got: %v", err) - }) - } -} diff --git a/internal/archiver/buffer.go b/internal/archiver/buffer.go index d5bfb46b3..0a6ae6d8f 100644 --- a/internal/archiver/buffer.go +++ b/internal/archiver/buffer.go @@ -1,5 +1,7 @@ package archiver +import "sync" + // buffer is a reusable buffer. After the buffer has been used, Release should // be called so the underlying slice is put back into the pool. type buffer struct { @@ -14,41 +16,32 @@ func (b *buffer) Release() { return } - select { - case pool.ch <- b: - default: - } + pool.pool.Put(b) } // bufferPool implements a limited set of reusable buffers. type bufferPool struct { - ch chan *buffer + pool sync.Pool defaultSize int } // newBufferPool initializes a new buffer pool. The pool stores at most max // items. New buffers are created with defaultSize. Buffers that have grown // larger are not put back. -func newBufferPool(max int, defaultSize int) *bufferPool { +func newBufferPool(defaultSize int) *bufferPool { b := &bufferPool{ - ch: make(chan *buffer, max), defaultSize: defaultSize, } + b.pool = sync.Pool{New: func() any { + return &buffer{ + Data: make([]byte, defaultSize), + pool: b, + } + }} return b } // Get returns a new buffer, either from the pool or newly allocated. func (pool *bufferPool) Get() *buffer { - select { - case buf := <-pool.ch: - return buf - default: - } - - b := &buffer{ - Data: make([]byte, pool.defaultSize), - pool: pool, - } - - return b + return pool.pool.Get().(*buffer) } diff --git a/internal/archiver/buffer_test.go b/internal/archiver/buffer_test.go new file mode 100644 index 000000000..1b577fa2d --- /dev/null +++ b/internal/archiver/buffer_test.go @@ -0,0 +1,58 @@ +package archiver + +import ( + "testing" +) + +func TestBufferPoolReuse(t *testing.T) { + success := false + // retries to avoid flakiness. The test can fail depending on the GC. + for i := 0; i < 100; i++ { + // Test that buffers are actually reused from the pool + pool := newBufferPool(1024) + + // Get a buffer and modify it + buf1 := pool.Get() + buf1.Data[0] = 0xFF + originalAddr := &buf1.Data[0] + buf1.Release() + + // Get another buffer and check if it's the same underlying slice + buf2 := pool.Get() + if &buf2.Data[0] == originalAddr { + success = true + break + } + buf2.Release() + } + if !success { + t.Error("buffer was not reused from pool") + } +} + +func TestBufferPoolLargeBuffers(t *testing.T) { + success := false + // retries to avoid flakiness. The test can fail depending on the GC. + for i := 0; i < 100; i++ { + // Test that buffers larger than defaultSize are not returned to pool + pool := newBufferPool(1024) + buf := pool.Get() + + // Grow the buffer beyond default size + buf.Data = append(buf.Data, make([]byte, 2048)...) + originalCap := cap(buf.Data) + + buf.Release() + + // Get a new buffer - should not be the same slice + newBuf := pool.Get() + if cap(newBuf.Data) != originalCap { + success = true + break + } + } + + if !success { + t.Error("large buffer was incorrectly returned to pool") + } +} diff --git a/internal/archiver/file_saver.go b/internal/archiver/file_saver.go index 8370bee4d..84e175d82 100644 --- a/internal/archiver/file_saver.go +++ b/internal/archiver/file_saver.go @@ -15,13 +15,10 @@ import ( "golang.org/x/sync/errgroup" ) -// saveBlobFn saves a blob to a repo. -type saveBlobFn func(context.Context, restic.BlobType, *buffer, string, func(res saveBlobResponse)) - // fileSaver concurrently saves incoming files to the repo. type fileSaver struct { saveFilePool *bufferPool - saveBlob saveBlobFn + uploader restic.BlobSaverAsync pol chunker.Pol @@ -34,16 +31,13 @@ type fileSaver struct { // newFileSaver returns a new file saver. A worker pool with fileWorkers is // started, it is stopped when ctx is cancelled. -func newFileSaver(ctx context.Context, wg *errgroup.Group, save saveBlobFn, pol chunker.Pol, fileWorkers, blobWorkers uint) *fileSaver { +func newFileSaver(ctx context.Context, wg *errgroup.Group, uploader restic.BlobSaverAsync, pol chunker.Pol, fileWorkers uint) *fileSaver { ch := make(chan saveFileJob) - - debug.Log("new file saver with %v file workers and %v blob workers", fileWorkers, blobWorkers) - - poolSize := fileWorkers + blobWorkers + debug.Log("new file saver with %v file workers", fileWorkers) s := &fileSaver{ - saveBlob: save, - saveFilePool: newBufferPool(int(poolSize), chunker.MaxSize), + uploader: uploader, + saveFilePool: newBufferPool(chunker.MaxSize), pol: pol, ch: ch, @@ -203,15 +197,20 @@ func (s *fileSaver) saveFile(ctx context.Context, chnker *chunker.Chunker, snPat node.Content = append(node.Content, restic.ID{}) lock.Unlock() - s.saveBlob(ctx, restic.DataBlob, buf, target, func(sbr saveBlobResponse) { - lock.Lock() - if !sbr.known { - fnr.stats.DataBlobs++ - fnr.stats.DataSize += uint64(sbr.length) - fnr.stats.DataSizeInRepo += uint64(sbr.sizeInRepo) + s.uploader.SaveBlobAsync(ctx, restic.DataBlob, buf.Data, restic.ID{}, false, func(newID restic.ID, known bool, sizeInRepo int, err error) { + defer buf.Release() + if err != nil { + completeError(err) + return } - node.Content[pos] = sbr.id + lock.Lock() + if !known { + fnr.stats.DataBlobs++ + fnr.stats.DataSize += uint64(len(buf.Data)) + fnr.stats.DataSizeInRepo += uint64(sizeInRepo) + } + node.Content[pos] = newID lock.Unlock() completeBlob() diff --git a/internal/archiver/file_saver_test.go b/internal/archiver/file_saver_test.go index 5aab78558..4dbf78548 100644 --- a/internal/archiver/file_saver_test.go +++ b/internal/archiver/file_saver_test.go @@ -11,7 +11,6 @@ import ( "github.com/restic/chunker" "github.com/restic/restic/internal/data" "github.com/restic/restic/internal/fs" - "github.com/restic/restic/internal/restic" "github.com/restic/restic/internal/test" "golang.org/x/sync/errgroup" ) @@ -31,44 +30,35 @@ func createTestFiles(t testing.TB, num int) (files []string) { return files } -func startFileSaver(ctx context.Context, t testing.TB, _ fs.FS) (*fileSaver, context.Context, *errgroup.Group) { +func startFileSaver(ctx context.Context, t testing.TB, _ fs.FS) (*fileSaver, *mockSaver, context.Context, *errgroup.Group) { wg, ctx := errgroup.WithContext(ctx) - saveBlob := func(ctx context.Context, tpe restic.BlobType, buf *buffer, _ string, cb func(saveBlobResponse)) { - cb(saveBlobResponse{ - id: restic.Hash(buf.Data), - length: len(buf.Data), - sizeInRepo: len(buf.Data), - known: false, - }) - } - workers := uint(runtime.NumCPU()) pol, err := chunker.RandomPolynomial() if err != nil { t.Fatal(err) } - s := newFileSaver(ctx, wg, saveBlob, pol, workers, workers) + saver := &mockSaver{saved: make(map[string]int)} + s := newFileSaver(ctx, wg, saver, pol, workers) s.NodeFromFileInfo = func(snPath, filename string, meta ToNoder, ignoreXattrListError bool) (*data.Node, error) { return meta.ToNode(ignoreXattrListError, t.Logf) } - return s, ctx, wg + return s, saver, ctx, wg } func TestFileSaver(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - files := createTestFiles(t, 15) - startFn := func() {} completeReadingFn := func() {} completeFn := func(*data.Node, ItemStats) {} + files := createTestFiles(t, 15) testFs := fs.Local{} - s, ctx, wg := startFileSaver(ctx, t, testFs) + s, saver, ctx, wg := startFileSaver(ctx, t, testFs) var results []futureNode @@ -89,6 +79,8 @@ func TestFileSaver(t *testing.T) { } } + test.Assert(t, len(saver.saved) == len(files), "expected %d saved files, got %d", len(files), len(saver.saved)) + s.TriggerShutdown() err := wg.Wait() diff --git a/internal/archiver/tree_saver.go b/internal/archiver/tree_saver.go index d0e802765..8b38b5eb2 100644 --- a/internal/archiver/tree_saver.go +++ b/internal/archiver/tree_saver.go @@ -12,7 +12,7 @@ import ( // treeSaver concurrently saves incoming trees to the repo. type treeSaver struct { - saveBlob saveBlobFn + uploader restic.BlobSaverAsync errFn ErrorFunc ch chan<- saveTreeJob @@ -20,12 +20,12 @@ type treeSaver struct { // newTreeSaver returns a new tree saver. A worker pool with treeWorkers is // started, it is stopped when ctx is cancelled. -func newTreeSaver(ctx context.Context, wg *errgroup.Group, treeWorkers uint, saveBlob saveBlobFn, errFn ErrorFunc) *treeSaver { +func newTreeSaver(ctx context.Context, wg *errgroup.Group, treeWorkers uint, uploader restic.BlobSaverAsync, errFn ErrorFunc) *treeSaver { ch := make(chan saveTreeJob) s := &treeSaver{ ch: ch, - saveBlob: saveBlob, + uploader: uploader, errFn: errFn, } @@ -129,21 +129,35 @@ func (s *treeSaver) save(ctx context.Context, job *saveTreeJob) (*data.Node, Ite return nil, stats, err } - b := &buffer{Data: buf} - ch := make(chan saveBlobResponse, 1) - s.saveBlob(ctx, restic.TreeBlob, b, job.target, func(res saveBlobResponse) { - ch <- res + var ( + known bool + length int + sizeInRepo int + id restic.ID + ) + + ch := make(chan struct{}, 1) + s.uploader.SaveBlobAsync(ctx, restic.TreeBlob, buf, restic.ID{}, false, func(newID restic.ID, cbKnown bool, cbSizeInRepo int, cbErr error) { + known = cbKnown + length = len(buf) + sizeInRepo = cbSizeInRepo + id = newID + err = cbErr + ch <- struct{}{} }) select { - case sbr := <-ch: - if !sbr.known { + case <-ch: + if err != nil { + return nil, stats, err + } + if !known { stats.TreeBlobs++ - stats.TreeSize += uint64(sbr.length) - stats.TreeSizeInRepo += uint64(sbr.sizeInRepo) + stats.TreeSize += uint64(length) + stats.TreeSizeInRepo += uint64(sizeInRepo) } - node.Subtree = &sbr.id + node.Subtree = &id return node, stats, nil case <-ctx.Done(): return nil, stats, ctx.Err() diff --git a/internal/archiver/tree_saver_test.go b/internal/archiver/tree_saver_test.go index 2a4826444..ed3a148af 100644 --- a/internal/archiver/tree_saver_test.go +++ b/internal/archiver/tree_saver_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "runtime" + "sync" "testing" "github.com/restic/restic/internal/data" @@ -13,13 +14,20 @@ import ( "golang.org/x/sync/errgroup" ) -func treeSaveHelper(_ context.Context, _ restic.BlobType, buf *buffer, _ string, cb func(res saveBlobResponse)) { - cb(saveBlobResponse{ - id: restic.NewRandomID(), - known: false, - length: len(buf.Data), - sizeInRepo: len(buf.Data), - }) +type mockSaver struct { + saved map[string]int + mutex sync.Mutex +} + +func (m *mockSaver) SaveBlobAsync(_ context.Context, _ restic.BlobType, buf []byte, id restic.ID, storeDuplicate bool, cb func(newID restic.ID, known bool, sizeInRepo int, err error)) { + // Fake async operation + go func() { + m.mutex.Lock() + m.saved[string(buf)]++ + m.mutex.Unlock() + + cb(restic.Hash(buf), false, len(buf), nil) + }() } func setupTreeSaver() (context.Context, context.CancelFunc, *treeSaver, func() error) { @@ -30,7 +38,7 @@ func setupTreeSaver() (context.Context, context.CancelFunc, *treeSaver, func() e return err } - b := newTreeSaver(ctx, wg, uint(runtime.NumCPU()), treeSaveHelper, errFn) + b := newTreeSaver(ctx, wg, uint(runtime.NumCPU()), &mockSaver{saved: make(map[string]int)}, errFn) shutdown := func() error { b.TriggerShutdown() diff --git a/internal/checker/checker_test.go b/internal/checker/checker_test.go index 960942d80..8c78f4395 100644 --- a/internal/checker/checker_test.go +++ b/internal/checker/checker_test.go @@ -527,7 +527,7 @@ func TestCheckerBlobTypeConfusion(t *testing.T) { } var id restic.ID - test.OK(t, repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error { + test.OK(t, repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { var err error id, err = data.SaveTree(ctx, uploader, damagedTree) return err @@ -536,7 +536,7 @@ func TestCheckerBlobTypeConfusion(t *testing.T) { buf, err := repo.LoadBlob(ctx, restic.TreeBlob, id, nil) test.OK(t, err) - test.OK(t, repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error { + test.OK(t, repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { var err error _, _, _, err = uploader.SaveBlob(ctx, restic.DataBlob, buf, id, false) return err @@ -561,7 +561,7 @@ func TestCheckerBlobTypeConfusion(t *testing.T) { } var rootID restic.ID - test.OK(t, repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error { + test.OK(t, repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { var err error rootID, err = data.SaveTree(ctx, uploader, rootTree) return err diff --git a/internal/data/testing.go b/internal/data/testing.go index be4ab4edb..8187833a6 100644 --- a/internal/data/testing.go +++ b/internal/data/testing.go @@ -136,7 +136,7 @@ func TestCreateSnapshot(t testing.TB, repo restic.Repository, at time.Time, dept } var treeID restic.ID - test.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + test.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { treeID = fs.saveTree(ctx, uploader, seed, depth) return nil })) diff --git a/internal/data/tree_test.go b/internal/data/tree_test.go index 9164f4da1..054cf7c0a 100644 --- a/internal/data/tree_test.go +++ b/internal/data/tree_test.go @@ -107,7 +107,7 @@ func TestEmptyLoadTree(t *testing.T) { tree := data.NewTree(0) var id restic.ID - rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { var err error // save tree id, err = data.SaveTree(ctx, uploader, tree) diff --git a/internal/repository/fuzz_test.go b/internal/repository/fuzz_test.go index 16155f3a4..62dbd167e 100644 --- a/internal/repository/fuzz_test.go +++ b/internal/repository/fuzz_test.go @@ -20,7 +20,7 @@ func FuzzSaveLoadBlob(f *testing.F) { id := restic.Hash(blob) repo, _, _ := TestRepositoryWithVersion(t, 2) - rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { _, _, _, err := uploader.SaveBlob(ctx, restic.DataBlob, blob, id, false) return err })) diff --git a/internal/repository/prune.go b/internal/repository/prune.go index 772765129..843837617 100644 --- a/internal/repository/prune.go +++ b/internal/repository/prune.go @@ -563,7 +563,7 @@ func (plan *PrunePlan) Execute(ctx context.Context, printer progress.Printer) er if len(plan.repackPacks) != 0 { printer.P("repacking packs\n") bar := printer.NewCounter("packs repacked") - err := repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error { + err := repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { return CopyBlobs(ctx, repo, repo, uploader, plan.repackPacks, plan.keepBlobs, bar, printer.P) }) if err != nil { diff --git a/internal/repository/prune_internal_test.go b/internal/repository/prune_internal_test.go index 49a876884..640ab061b 100644 --- a/internal/repository/prune_internal_test.go +++ b/internal/repository/prune_internal_test.go @@ -47,7 +47,7 @@ func TestPruneMaxUnusedDuplicate(t *testing.T) { {bufs[1], bufs[3]}, {bufs[2], bufs[3]}, } { - rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { for _, blob := range blobs { id, _, _, err := uploader.SaveBlob(ctx, restic.DataBlob, blob, restic.ID{}, true) keep.Insert(restic.BlobHandle{Type: restic.DataBlob, ID: id}) diff --git a/internal/repository/prune_test.go b/internal/repository/prune_test.go index 744de0b14..a363acd41 100644 --- a/internal/repository/prune_test.go +++ b/internal/repository/prune_test.go @@ -25,7 +25,7 @@ func testPrune(t *testing.T, opts repository.PruneOptions, errOnUnused bool) { createRandomBlobs(t, random, repo, 5, 0.5, true) keep, _ := selectBlobs(t, random, repo, 0.5) - rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { // duplicate a few blobs to exercise those code paths for blob := range keep { buf, err := repo.LoadBlob(ctx, blob.Type, blob.ID, nil) @@ -133,7 +133,7 @@ func TestPruneSmall(t *testing.T) { const numBlobsCreated = 55 keep := restic.NewBlobSet() - rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { // we need a minum of 11 packfiles, each packfile will be about 5 Mb long for i := 0; i < numBlobsCreated; i++ { buf := make([]byte, blobSize) diff --git a/internal/repository/repack.go b/internal/repository/repack.go index ca0a8a48b..c2eaa8f41 100644 --- a/internal/repository/repack.go +++ b/internal/repository/repack.go @@ -32,7 +32,7 @@ func CopyBlobs( ctx context.Context, repo restic.Repository, dstRepo restic.Repository, - dstUploader restic.BlobSaver, + dstUploader restic.BlobSaverWithAsync, packs restic.IDSet, keepBlobs repackBlobSet, p *progress.Counter, @@ -57,7 +57,7 @@ func repack( ctx context.Context, repo restic.Repository, dstRepo restic.Repository, - uploader restic.BlobSaver, + uploader restic.BlobSaverWithAsync, packs restic.IDSet, keepBlobs repackBlobSet, p *progress.Counter, diff --git a/internal/repository/repack_test.go b/internal/repository/repack_test.go index bedacaa7e..0c1095301 100644 --- a/internal/repository/repack_test.go +++ b/internal/repository/repack_test.go @@ -20,7 +20,7 @@ func randomSize(random *rand.Rand, min, max int) int { func createRandomBlobs(t testing.TB, random *rand.Rand, repo restic.Repository, blobs int, pData float32, smallBlobs bool) { // two loops to allow creating multiple pack files for blobs > 0 { - rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { for blobs > 0 { blobs-- var ( @@ -70,7 +70,7 @@ func createRandomWrongBlob(t testing.TB, random *rand.Rand, repo restic.Reposito // invert first data byte buf[0] ^= 0xff - rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { _, _, _, err := uploader.SaveBlob(ctx, restic.DataBlob, buf, id, false) return err })) @@ -150,7 +150,7 @@ func findPacksForBlobs(t *testing.T, repo restic.Repository, blobs restic.BlobSe } func repack(t *testing.T, repo restic.Repository, be backend.Backend, packs restic.IDSet, blobs restic.BlobSet) { - rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { return repository.CopyBlobs(ctx, repo, repo, uploader, packs, blobs, nil, nil) })) @@ -265,7 +265,7 @@ func testRepackCopy(t *testing.T, version uint) { _, keepBlobs := selectBlobs(t, random, repo, 0.2) copyPacks := findPacksForBlobs(t, repo, keepBlobs) - rtest.OK(t, repoWrapped.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + rtest.OK(t, repoWrapped.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { return repository.CopyBlobs(ctx, repoWrapped, dstRepoWrapped, uploader, copyPacks, keepBlobs, nil, nil) })) rebuildAndReloadIndex(t, dstRepo) @@ -303,7 +303,7 @@ func testRepackWrongBlob(t *testing.T, version uint) { _, keepBlobs := selectBlobs(t, random, repo, 0) rewritePacks := findPacksForBlobs(t, repo, keepBlobs) - err := repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + err := repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { return repository.CopyBlobs(ctx, repo, repo, uploader, rewritePacks, keepBlobs, nil, nil) }) if err == nil { @@ -336,7 +336,7 @@ func testRepackBlobFallback(t *testing.T, version uint) { modbuf[0] ^= 0xff // create pack with broken copy - rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { _, _, _, err := uploader.SaveBlob(ctx, restic.DataBlob, modbuf, id, false) return err })) @@ -346,13 +346,13 @@ func testRepackBlobFallback(t *testing.T, version uint) { rewritePacks := findPacksForBlobs(t, repo, keepBlobs) // create pack with valid copy - rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { _, _, _, err := uploader.SaveBlob(ctx, restic.DataBlob, buf, id, true) return err })) // repack must fallback to valid copy - rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { return repository.CopyBlobs(ctx, repo, repo, uploader, rewritePacks, keepBlobs, nil, nil) })) diff --git a/internal/repository/repair_pack.go b/internal/repository/repair_pack.go index a6f4a52b8..0c9d3a43f 100644 --- a/internal/repository/repair_pack.go +++ b/internal/repository/repair_pack.go @@ -15,7 +15,7 @@ func RepairPacks(ctx context.Context, repo *Repository, ids restic.IDSet, printe bar.SetMax(uint64(len(ids))) defer bar.Done() - err := repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error { + err := repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { // examine all data the indexes have for the pack file for b := range repo.ListPacksFromIndex(ctx, ids) { blobs := b.Blobs diff --git a/internal/repository/repository.go b/internal/repository/repository.go index 9abdaeec2..32d2e1aac 100644 --- a/internal/repository/repository.go +++ b/internal/repository/repository.go @@ -42,6 +42,8 @@ type Repository struct { opts Options packerWg *errgroup.Group + mainWg *errgroup.Group + blobSaver *sync.WaitGroup uploader *packerUploader treePM *packerManager dataPM *packerManager @@ -559,9 +561,14 @@ func (r *Repository) removeUnpacked(ctx context.Context, t restic.FileType, id r return r.be.Remove(ctx, backend.Handle{Type: t, Name: id.String()}) } -func (r *Repository) WithBlobUploader(ctx context.Context, fn func(ctx context.Context, uploader restic.BlobSaver) error) error { +func (r *Repository) WithBlobUploader(ctx context.Context, fn func(ctx context.Context, uploader restic.BlobSaverWithAsync) error) error { wg, ctx := errgroup.WithContext(ctx) + // pack uploader + wg.Go below + blob saver (CPU bound) + wg.SetLimit(2 + runtime.GOMAXPROCS(0)) + r.mainWg = wg r.startPackUploader(ctx, wg) + // blob saver are spawned on demand, use wait group to keep track of them + r.blobSaver = &sync.WaitGroup{} wg.Go(func() error { if err := fn(ctx, &blobSaverRepo{repo: r}); err != nil { return err @@ -574,14 +581,6 @@ func (r *Repository) WithBlobUploader(ctx context.Context, fn func(ctx context.C return wg.Wait() } -type blobSaverRepo struct { - repo *Repository -} - -func (r *blobSaverRepo) SaveBlob(ctx context.Context, t restic.BlobType, buf []byte, id restic.ID, storeDuplicate bool) (newID restic.ID, known bool, size int, err error) { - return r.repo.saveBlob(ctx, t, buf, id, storeDuplicate) -} - func (r *Repository) startPackUploader(ctx context.Context, wg *errgroup.Group) { if r.packerWg != nil { panic("uploader already started") @@ -598,17 +597,40 @@ func (r *Repository) startPackUploader(ctx context.Context, wg *errgroup.Group) }) } +type blobSaverRepo struct { + repo *Repository +} + +func (r *blobSaverRepo) SaveBlob(ctx context.Context, t restic.BlobType, buf []byte, id restic.ID, storeDuplicate bool) (newID restic.ID, known bool, size int, err error) { + return r.repo.saveBlob(ctx, t, buf, id, storeDuplicate) +} + +func (r *blobSaverRepo) SaveBlobAsync(ctx context.Context, t restic.BlobType, buf []byte, id restic.ID, storeDuplicate bool, cb func(newID restic.ID, known bool, size int, err error)) { + r.repo.saveBlobAsync(ctx, t, buf, id, storeDuplicate, cb) +} + // Flush saves all remaining packs and the index func (r *Repository) flush(ctx context.Context) error { - if err := r.flushPacks(ctx); err != nil { + r.flushBlobSaver() + r.mainWg = nil + + if err := r.flushPackUploader(ctx); err != nil { return err } return r.idx.Flush(ctx, &internalRepository{r}) } +func (r *Repository) flushBlobSaver() { + if r.blobSaver == nil { + return + } + r.blobSaver.Wait() + r.blobSaver = nil +} + // FlushPacks saves all remaining packs. -func (r *Repository) flushPacks(ctx context.Context) error { +func (r *Repository) flushPackUploader(ctx context.Context) error { if r.packerWg == nil { return nil } @@ -994,6 +1016,19 @@ func (r *Repository) saveBlob(ctx context.Context, t restic.BlobType, buf []byte return newID, known, size, err } +func (r *Repository) saveBlobAsync(ctx context.Context, t restic.BlobType, buf []byte, id restic.ID, storeDuplicate bool, cb func(newID restic.ID, known bool, size int, err error)) { + r.mainWg.Go(func() error { + if ctx.Err() != nil { + // fail fast if the context is cancelled + cb(restic.ID{}, false, 0, ctx.Err()) + return ctx.Err() + } + newID, known, size, err := r.saveBlob(ctx, t, buf, id, storeDuplicate) + cb(newID, known, size, err) + return err + }) +} + type backendLoadFn func(ctx context.Context, h backend.Handle, length int, offset int64, fn func(rd io.Reader) error) error type loadBlobFn func(ctx context.Context, t restic.BlobType, id restic.ID, buf []byte) ([]byte, error) diff --git a/internal/repository/repository_test.go b/internal/repository/repository_test.go index 2a181312c..f2ef1d082 100644 --- a/internal/repository/repository_test.go +++ b/internal/repository/repository_test.go @@ -4,11 +4,13 @@ import ( "bytes" "context" "crypto/sha256" + "fmt" "io" "math/rand" "path/filepath" "strings" "sync" + "sync/atomic" "testing" "time" @@ -51,7 +53,7 @@ func testSave(t *testing.T, version uint, calculateID bool) { id := restic.Hash(data) - rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { // save inputID := restic.ID{} if !calculateID { @@ -97,7 +99,7 @@ func testSavePackMerging(t *testing.T, targetPercentage int, expectedPacks int) }) var ids restic.IDs - rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { // add blobs with size targetPercentage / 100 * repo.PackSize to the repository blobSize := repository.MinPackSize / 100 for range targetPercentage { @@ -147,7 +149,7 @@ func benchmarkSaveAndEncrypt(t *testing.B, version uint) { t.ResetTimer() t.SetBytes(int64(size)) - _ = repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + _ = repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { for i := 0; i < t.N; i++ { _, _, _, err = uploader.SaveBlob(ctx, restic.DataBlob, data, id, true) rtest.OK(t, err) @@ -168,7 +170,7 @@ func testLoadBlob(t *testing.T, version uint) { rtest.OK(t, err) var id restic.ID - rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { var err error id, _, _, err = uploader.SaveBlob(ctx, restic.DataBlob, buf, restic.ID{}, false) return err @@ -196,7 +198,7 @@ func TestLoadBlobBroken(t *testing.T) { buf := rtest.Random(42, 1000) var id restic.ID - rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { var err error id, _, _, err = uploader.SaveBlob(ctx, restic.TreeBlob, buf, restic.ID{}, false) return err @@ -225,7 +227,7 @@ func benchmarkLoadBlob(b *testing.B, version uint) { rtest.OK(b, err) var id restic.ID - rtest.OK(b, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + rtest.OK(b, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { var err error id, _, _, err = uploader.SaveBlob(ctx, restic.DataBlob, buf, restic.ID{}, false) return err @@ -361,7 +363,7 @@ func TestRepositoryLoadUnpackedRetryBroken(t *testing.T) { // saveRandomDataBlobs generates random data blobs and saves them to the repository. func saveRandomDataBlobs(t testing.TB, repo restic.Repository, num int, sizeMax int) { - rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { for i := 0; i < num; i++ { size := rand.Int() % sizeMax @@ -432,7 +434,7 @@ func TestListPack(t *testing.T) { buf := rtest.Random(42, 1000) var id restic.ID - rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaver) error { + rtest.OK(t, repo.WithBlobUploader(context.TODO(), func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { var err error id, _, _, err = uploader.SaveBlob(ctx, restic.TreeBlob, buf, restic.ID{}, false) return err @@ -487,3 +489,65 @@ func TestNoDoubleInit(t *testing.T) { err = repo.Init(context.TODO(), r.Config().Version, rtest.TestPassword, &pol) rtest.Assert(t, strings.Contains(err.Error(), "repository already contains snapshots"), "expected already contains snapshots error, got %q", err) } + +func TestSaveBlobAsync(t *testing.T) { + repo, _, _ := repository.TestRepositoryWithVersion(t, 2) + ctx := context.Background() + + type result struct { + id restic.ID + known bool + size int + err error + } + numCalls := 10 + results := make([]result, numCalls) + var resultsMutex sync.Mutex + + err := repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { + var wg sync.WaitGroup + wg.Add(numCalls) + for i := 0; i < numCalls; i++ { + // Use unique data for each call + testData := []byte(fmt.Sprintf("test blob data %d", i)) + uploader.SaveBlobAsync(ctx, restic.DataBlob, testData, restic.ID{}, false, + func(newID restic.ID, known bool, size int, err error) { + defer wg.Done() + resultsMutex.Lock() + results[i] = result{newID, known, size, err} + resultsMutex.Unlock() + }) + } + wg.Wait() + return nil + }) + rtest.OK(t, err) + + for i, result := range results { + testData := []byte(fmt.Sprintf("test blob data %d", i)) + expectedID := restic.Hash(testData) + rtest.Assert(t, result.err == nil, "result %d: unexpected error %v", i, result.err) + rtest.Assert(t, result.id.Equal(expectedID), "result %d: expected ID %v, got %v", i, expectedID, result.id) + rtest.Assert(t, !result.known, "result %d: expected unknown blob", i) + } +} + +func TestSaveBlobAsyncErrorHandling(t *testing.T) { + repo, _, _ := repository.TestRepositoryWithVersion(t, 2) + ctx, cancel := context.WithCancel(context.Background()) + + var callbackCalled atomic.Bool + + err := repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { + cancel() + // Callback must be called even if the context is canceled + uploader.SaveBlobAsync(ctx, restic.DataBlob, []byte("test blob data"), restic.ID{}, false, + func(newID restic.ID, known bool, size int, err error) { + callbackCalled.Store(true) + }) + return nil + }) + + rtest.Assert(t, errors.Is(err, context.Canceled), "expected context canceled error, got %v", err) + rtest.Assert(t, callbackCalled.Load(), "callback was not called") +} diff --git a/internal/restic/repository.go b/internal/restic/repository.go index ed0c64cf0..c7f326823 100644 --- a/internal/restic/repository.go +++ b/internal/restic/repository.go @@ -42,7 +42,7 @@ type Repository interface { // WithUploader starts the necessary workers to upload new blobs. Once the callback returns, // the workers are stopped and the index is written to the repository. The callback must use // the passed context and must not keep references to any of its parameters after returning. - WithBlobUploader(ctx context.Context, fn func(ctx context.Context, uploader BlobSaver) error) error + WithBlobUploader(ctx context.Context, fn func(ctx context.Context, uploader BlobSaverWithAsync) error) error // List calls the function fn for each file of type t in the repository. // When an error is returned by fn, processing stops and List() returns the @@ -162,11 +162,23 @@ type BlobLoader interface { } type WithBlobUploader interface { - WithBlobUploader(ctx context.Context, fn func(ctx context.Context, uploader BlobSaver) error) error + WithBlobUploader(ctx context.Context, fn func(ctx context.Context, uploader BlobSaverWithAsync) error) error +} + +type BlobSaverWithAsync interface { + BlobSaver + BlobSaverAsync } type BlobSaver interface { - SaveBlob(context.Context, BlobType, []byte, ID, bool) (ID, bool, int, error) + // SaveBlob saves a blob to the repository. ctx must be derived from the context created by WithBlobUploader. + SaveBlob(ctx context.Context, tpe BlobType, buf []byte, id ID, storeDuplicate bool) (newID ID, known bool, sizeInRepo int, err error) +} + +type BlobSaverAsync interface { + // SaveBlobAsync saves a blob to the repository. ctx must be derived from the context created by WithBlobUploader. + // The callback is called asynchronously from a different goroutine. + SaveBlobAsync(ctx context.Context, tpe BlobType, buf []byte, id ID, storeDuplicate bool, cb func(newID ID, known bool, sizeInRepo int, err error)) } // Loader loads a blob from a repository. diff --git a/internal/restorer/restorer_test.go b/internal/restorer/restorer_test.go index fe9db6b33..2e419f55c 100644 --- a/internal/restorer/restorer_test.go +++ b/internal/restorer/restorer_test.go @@ -171,7 +171,7 @@ func saveSnapshot(t testing.TB, repo restic.Repository, snapshot Snapshot, getGe defer cancel() var treeID restic.ID - rtest.OK(t, repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaver) error { + rtest.OK(t, repo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { treeID = saveDir(t, uploader, snapshot.Nodes, 1000, getGenericAttributes) return nil }))