diff --git a/internal/restorer/filerestorer.go b/internal/restorer/filerestorer.go index 166bf1ff8..d197bcf5b 100644 --- a/internal/restorer/filerestorer.go +++ b/internal/restorer/filerestorer.go @@ -217,6 +217,8 @@ func (r *fileRestorer) restoreFiles(ctx context.Context) error { wg, ctx := errgroup.WithContext(ctx) downloadCh := make(chan *packInfo) + // close all files when finished + defer r.filesWriter.flush() worker := func() error { for pack := range downloadCh { if err := r.downloadPack(ctx, pack); err != nil { diff --git a/internal/restorer/fileswriter.go b/internal/restorer/fileswriter.go index d6f78f2d7..2b60f7185 100644 --- a/internal/restorer/fileswriter.go +++ b/internal/restorer/fileswriter.go @@ -7,6 +7,7 @@ import ( "syscall" "github.com/cespare/xxhash/v2" + "github.com/hashicorp/golang-lru/v2/simplelru" "github.com/restic/restic/internal/debug" "github.com/restic/restic/internal/errors" "github.com/restic/restic/internal/fs" @@ -20,6 +21,8 @@ import ( type filesWriter struct { buckets []filesWriterBucket allowRecursiveDelete bool + cacheMu sync.Mutex + cache *simplelru.LRU[string, *partialFile] } type filesWriterBucket struct { @@ -34,13 +37,27 @@ type partialFile struct { } func newFilesWriter(count int, allowRecursiveDelete bool) *filesWriter { - buckets := make([]filesWriterBucket, count) - for b := 0; b < count; b++ { + // use a large number of buckets to minimize bucket contention + // creating a new file can be slow, so make sure that files typically end up in different buckets. + buckets := make([]filesWriterBucket, 1024) + for b := 0; b < len(buckets); b++ { buckets[b].files = make(map[string]*partialFile) } + + cache, err := simplelru.NewLRU[string, *partialFile](count+50, func(_ string, wr *partialFile) { + // close the file only when it is not in use + if wr.users == 0 { + _ = wr.Close() + } + }) + if err != nil { + panic(err) // can't happen + } + return &filesWriter{ buckets: buckets, allowRecursiveDelete: allowRecursiveDelete, + cache: cache, } } @@ -173,6 +190,24 @@ func (w *filesWriter) writeToFile(path string, blob []byte, offset int64, create bucket.files[path].users++ return wr, nil } + + // Check the global LRU cache for a cached file handle + w.cacheMu.Lock() + cached, ok := w.cache.Get(path) + if ok { + // mark as in use to prevent closing on remove call below + cached.users++ + + w.cache.Remove(path) + w.cacheMu.Unlock() + + // Use the cached file handle + bucket.files[path] = cached + return cached, nil + } + w.cacheMu.Unlock() + + // Not in cache, open/create the file var f *os.File var err error if createSize >= 0 { @@ -194,11 +229,14 @@ func (w *filesWriter) writeToFile(path string, blob []byte, offset int64, create bucket.lock.Lock() defer bucket.lock.Unlock() - if bucket.files[path].users == 1 { - delete(bucket.files, path) - return wr.Close() - } bucket.files[path].users-- + if bucket.files[path].users == 0 { + delete(bucket.files, path) + // Add to cache to allow re-use. Cache will close files on overflow. + w.cacheMu.Lock() + w.cache.Add(path, wr) + w.cacheMu.Unlock() + } return nil } @@ -217,3 +255,10 @@ func (w *filesWriter) writeToFile(path string, blob []byte, offset int64, create return releaseWriter(wr) } + +func (w *filesWriter) flush() { + w.cacheMu.Lock() + defer w.cacheMu.Unlock() + + w.cache.Purge() +} diff --git a/internal/restorer/fileswriter_test.go b/internal/restorer/fileswriter_test.go index 9ea8767b8..dfdccb647 100644 --- a/internal/restorer/fileswriter_test.go +++ b/internal/restorer/fileswriter_test.go @@ -30,6 +30,8 @@ func TestFilesWriterBasic(t *testing.T) { rtest.OK(t, w.writeToFile(f2, []byte{2}, 1, -1, false)) rtest.Equals(t, 0, len(w.buckets[0].files)) + w.flush() + buf, err := os.ReadFile(f1) rtest.OK(t, err) rtest.Equals(t, []byte{1, 1}, buf) @@ -51,11 +53,13 @@ func TestFilesWriterRecursiveOverwrite(t *testing.T) { err := w.writeToFile(path, []byte{1}, 0, 2, false) rtest.Assert(t, errors.Is(err, notEmptyDirError()), "unexpected error got %v", err) rtest.Equals(t, 0, len(w.buckets[0].files)) + w.flush() // must replace directory w = newFilesWriter(1, true) rtest.OK(t, w.writeToFile(path, []byte{1, 1}, 0, 2, false)) rtest.Equals(t, 0, len(w.buckets[0].files)) + w.flush() buf, err := os.ReadFile(path) rtest.OK(t, err)