diff --git a/internal/archiver/archiver_test.go b/internal/archiver/archiver_test.go index eb60e1174..a4012d585 100644 --- a/internal/archiver/archiver_test.go +++ b/internal/archiver/archiver_test.go @@ -2104,8 +2104,10 @@ type failSaveRepo struct { } func (f *failSaveRepo) WithBlobUploader(ctx context.Context, fn func(ctx context.Context, uploader restic.BlobSaverWithAsync) error) error { - return f.archiverRepo.WithBlobUploader(ctx, func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { - return fn(ctx, &failSaveSaver{saver: uploader, failSaveRepo: f, semaphore: make(chan struct{}, 1)}) + outerCtx, outerCancel := context.WithCancelCause(ctx) + defer outerCancel(f.err) + return f.archiverRepo.WithBlobUploader(outerCtx, func(ctx context.Context, uploader restic.BlobSaverWithAsync) error { + return fn(ctx, &failSaveSaver{saver: uploader, failSaveRepo: f, semaphore: make(chan struct{}, 1), outerCancel: outerCancel}) }) } @@ -2113,6 +2115,7 @@ type failSaveSaver struct { saver restic.BlobSaverWithAsync failSaveRepo *failSaveRepo semaphore chan struct{} + outerCancel context.CancelCauseFunc } func (f *failSaveSaver) SaveBlob(ctx context.Context, t restic.BlobType, buf []byte, id restic.ID, storeDuplicate bool) (restic.ID, bool, int, error) { @@ -2130,10 +2133,9 @@ func (f *failSaveSaver) SaveBlobAsync(ctx context.Context, t restic.BlobType, bu val := f.failSaveRepo.cnt.Add(1) if val >= f.failSaveRepo.failAfter { - // use a canceled context to make SaveBlobAsync fail - var cancel context.CancelCauseFunc - ctx, cancel = context.WithCancelCause(ctx) - cancel(f.failSaveRepo.err) + // kill the outer context to make SaveBlobAsync fail + // precisely injecting a specific error into the repository is not possible, so just cancel the context + f.outerCancel(f.failSaveRepo.err) } f.saver.SaveBlobAsync(ctx, t, buf, id, storeDuplicate, func(newID restic.ID, known bool, size int, err error) { @@ -2141,7 +2143,6 @@ func (f *failSaveSaver) SaveBlobAsync(ctx context.Context, t restic.BlobType, bu if err == nil { panic("expected error") } - err = f.failSaveRepo.err } cb(newID, known, size, err) <-f.semaphore @@ -2149,13 +2150,10 @@ func (f *failSaveSaver) SaveBlobAsync(ctx context.Context, t restic.BlobType, bu } func TestArchiverAbortEarlyOnError(t *testing.T) { - var testErr = errors.New("test error") - var tests = []struct { src TestDir wantOpen map[string]uint failAfter uint // error after so many blobs have been saved to the repo - err error }{ { src: TestDir{ @@ -2167,10 +2165,7 @@ func TestArchiverAbortEarlyOnError(t *testing.T) { }, wantOpen: map[string]uint{ filepath.FromSlash("dir/bar"): 1, - filepath.FromSlash("dir/baz"): 1, - filepath.FromSlash("dir/foo"): 1, }, - err: testErr, }, { src: TestDir{ @@ -2198,7 +2193,6 @@ func TestArchiverAbortEarlyOnError(t *testing.T) { // fails after four to seven files were opened, as the ReadConcurrency allows for // two queued files and one blob queued for saving. failAfter: 4, - err: testErr, }, } @@ -2217,10 +2211,11 @@ func TestArchiverAbortEarlyOnError(t *testing.T) { opened: make(map[string]uint), } + testErr := context.Canceled testRepo := &failSaveRepo{ archiverRepo: repo, failAfter: int32(test.failAfter), - err: test.err, + err: testErr, } // at most two files may be queued @@ -2233,8 +2228,8 @@ func TestArchiverAbortEarlyOnError(t *testing.T) { } _, _, _, err := arch.Snapshot(ctx, []string{"."}, SnapshotOptions{Time: time.Now()}) - if !errors.Is(err, test.err) { - t.Errorf("expected error (%v) not found, got %v", test.err, err) + if !errors.Is(err, testErr) { + t.Errorf("expected error (%v) not found, got %v", testErr, err) } t.Logf("Snapshot return error: %v", err) diff --git a/internal/repository/repository.go b/internal/repository/repository.go index 1d5f5a505..32d2e1aac 100644 --- a/internal/repository/repository.go +++ b/internal/repository/repository.go @@ -42,7 +42,8 @@ type Repository struct { opts Options packerWg *errgroup.Group - blobWg *errgroup.Group + mainWg *errgroup.Group + blobSaver *sync.WaitGroup uploader *packerUploader treePM *packerManager dataPM *packerManager @@ -562,12 +563,14 @@ func (r *Repository) removeUnpacked(ctx context.Context, t restic.FileType, id r func (r *Repository) WithBlobUploader(ctx context.Context, fn func(ctx context.Context, uploader restic.BlobSaverWithAsync) error) error { wg, ctx := errgroup.WithContext(ctx) + // pack uploader + wg.Go below + blob saver (CPU bound) + wg.SetLimit(2 + runtime.GOMAXPROCS(0)) + r.mainWg = wg r.startPackUploader(ctx, wg) - saverCtx := r.startBlobSaver(ctx, wg) + // blob saver are spawned on demand, use wait group to keep track of them + r.blobSaver = &sync.WaitGroup{} wg.Go(func() error { - // must use saverCtx to ensure that the ctx used for saveBlob calls is bound to it - // otherwise the blob saver could deadlock in case of an error. - if err := fn(saverCtx, &blobSaverRepo{repo: r}); err != nil { + if err := fn(ctx, &blobSaverRepo{repo: r}); err != nil { return err } if err := r.flush(ctx); err != nil { @@ -594,22 +597,6 @@ func (r *Repository) startPackUploader(ctx context.Context, wg *errgroup.Group) }) } -func (r *Repository) startBlobSaver(ctx context.Context, wg *errgroup.Group) context.Context { - // blob upload computations are CPU bound - blobWg, blobCtx := errgroup.WithContext(ctx) - blobWg.SetLimit(runtime.GOMAXPROCS(0)) - r.blobWg = blobWg - - wg.Go(func() error { - // As the goroutines are only spawned on demand, wait until the context is canceled. - // This will either happen on an error while saving a blob or when blobWg.Wait() is called - // by flushBlobUploader(). - <-blobCtx.Done() - return blobWg.Wait() - }) - return blobCtx -} - type blobSaverRepo struct { repo *Repository } @@ -624,28 +611,26 @@ func (r *blobSaverRepo) SaveBlobAsync(ctx context.Context, t restic.BlobType, bu // Flush saves all remaining packs and the index func (r *Repository) flush(ctx context.Context) error { - if err := r.flushBlobUploader(); err != nil { - return err - } + r.flushBlobSaver() + r.mainWg = nil - if err := r.flushPacks(ctx); err != nil { + if err := r.flushPackUploader(ctx); err != nil { return err } return r.idx.Flush(ctx, &internalRepository{r}) } -func (r *Repository) flushBlobUploader() error { - if r.blobWg == nil { - return nil +func (r *Repository) flushBlobSaver() { + if r.blobSaver == nil { + return } - err := r.blobWg.Wait() - r.blobWg = nil - return err + r.blobSaver.Wait() + r.blobSaver = nil } // FlushPacks saves all remaining packs. -func (r *Repository) flushPacks(ctx context.Context) error { +func (r *Repository) flushPackUploader(ctx context.Context) error { if r.packerWg == nil { return nil } @@ -1032,11 +1017,11 @@ func (r *Repository) saveBlob(ctx context.Context, t restic.BlobType, buf []byte } func (r *Repository) saveBlobAsync(ctx context.Context, t restic.BlobType, buf []byte, id restic.ID, storeDuplicate bool, cb func(newID restic.ID, known bool, size int, err error)) { - r.blobWg.Go(func() error { + r.mainWg.Go(func() error { if ctx.Err() != nil { // fail fast if the context is cancelled - cb(restic.ID{}, false, 0, context.Cause(ctx)) - return context.Cause(ctx) + cb(restic.ID{}, false, 0, ctx.Err()) + return ctx.Err() } newID, known, size, err := r.saveBlob(ctx, t, buf, id, storeDuplicate) cb(newID, known, size, err)