diff --git a/cmd/restic/cmd_diff.go b/cmd/restic/cmd_diff.go index 3f8e16c3a..30b0878c0 100644 --- a/cmd/restic/cmd_diff.go +++ b/cmd/restic/cmd_diff.go @@ -5,7 +5,6 @@ import ( "encoding/json" "path" "reflect" - "sort" "github.com/restic/restic/internal/data" "github.com/restic/restic/internal/debug" @@ -240,37 +239,6 @@ func (c *Comparer) collectDir(ctx context.Context, blobs restic.AssociatedBlobSe return ctx.Err() } -func uniqueNodeNames(tree1, tree2 data.TreeNodeIterator) (tree1Nodes, tree2Nodes map[string]*data.Node, uniqueNames []string, err error) { - names := make(map[string]struct{}) - tree1Nodes = make(map[string]*data.Node) - for item := range tree1 { - if item.Error != nil { - return nil, nil, nil, item.Error - } - node := item.Node - tree1Nodes[node.Name] = node - names[node.Name] = struct{}{} - } - - tree2Nodes = make(map[string]*data.Node) - for item := range tree2 { - if item.Error != nil { - return nil, nil, nil, item.Error - } - node := item.Node - tree2Nodes[node.Name] = node - names[node.Name] = struct{}{} - } - - uniqueNames = make([]string, 0, len(names)) - for name := range names { - uniqueNames = append(uniqueNames, name) - } - - sort.Strings(uniqueNames) - return tree1Nodes, tree2Nodes, uniqueNames, nil -} - func (c *Comparer) diffTree(ctx context.Context, stats *DiffStatsContainer, prefix string, id1, id2 restic.ID) error { debug.Log("diffing %v to %v", id1, id2) tree1, err := data.LoadTree(ctx, c.repo, id1) @@ -283,24 +251,29 @@ func (c *Comparer) diffTree(ctx context.Context, stats *DiffStatsContainer, pref return err } - tree1Nodes, tree2Nodes, names, err := uniqueNodeNames(tree1, tree2) - if err != nil { - return err - } - - for _, name := range names { + for dt := range data.DualTreeIterator(tree1, tree2) { + if dt.Error != nil { + return dt.Error + } if ctx.Err() != nil { return ctx.Err() } - node1, t1 := tree1Nodes[name] - node2, t2 := tree2Nodes[name] + node1 := dt.Tree1 + node2 := dt.Tree2 + + var name string + if node1 != nil { + name = node1.Name + } else { + name = node2.Name + } addBlobs(stats.BlobsBefore, node1) addBlobs(stats.BlobsAfter, node2) switch { - case t1 && t2: + case node1 != nil && node2 != nil: name := path.Join(prefix, name) mod := "" @@ -346,7 +319,7 @@ func (c *Comparer) diffTree(ctx context.Context, stats *DiffStatsContainer, pref c.printError("error: %v", err) } } - case t1 && !t2: + case node1 != nil && node2 == nil: prefix := path.Join(prefix, name) if node1.Type == data.NodeTypeDir { prefix += "/" @@ -360,7 +333,7 @@ func (c *Comparer) diffTree(ctx context.Context, stats *DiffStatsContainer, pref c.printError("error: %v", err) } } - case !t1 && t2: + case node1 == nil && node2 != nil: prefix := path.Join(prefix, name) if node2.Type == data.NodeTypeDir { prefix += "/" diff --git a/internal/data/tree.go b/internal/data/tree.go index 026ad9cdb..e5777673f 100644 --- a/internal/data/tree.go +++ b/internal/data/tree.go @@ -273,3 +273,105 @@ func FindTreeDirectory(ctx context.Context, repo restic.BlobLoader, id *restic.I } return id, nil } + +type peekableNodeIterator struct { + iter func() (NodeOrError, bool) + stop func() + value *Node +} + +func newPeekableNodeIterator(tree TreeNodeIterator) (*peekableNodeIterator, error) { + iter, stop := iter.Pull(tree) + it := &peekableNodeIterator{iter: iter, stop: stop} + err := it.Next() + if err != nil { + it.Close() + return nil, err + } + return it, nil +} + +func (i *peekableNodeIterator) Next() error { + item, ok := i.iter() + if item.Error != nil || !ok { + i.value = nil + return item.Error + } + i.value = item.Node + return nil +} + +func (i *peekableNodeIterator) Peek() *Node { + return i.value +} + +func (i *peekableNodeIterator) Close() { + i.stop() +} + +type DualTree struct { + Tree1 *Node + Tree2 *Node + Error error +} + +// DualTreeIterator iterates over two trees in parallel. It returns a sequence of DualTree structs. +// The sequence is terminated when both trees are exhausted. The error field must be checked before +// accessing any of the nodes. +func DualTreeIterator(tree1, tree2 TreeNodeIterator) iter.Seq[DualTree] { + started := false + return func(yield func(DualTree) bool) { + if started { + panic("tree iterator is single use only") + } + started = true + iter1, err := newPeekableNodeIterator(tree1) + if err != nil { + yield(DualTree{Tree1: nil, Tree2: nil, Error: err}) + return + } + defer iter1.Close() + iter2, err := newPeekableNodeIterator(tree2) + if err != nil { + yield(DualTree{Tree1: nil, Tree2: nil, Error: err}) + return + } + defer iter2.Close() + + for { + node1 := iter1.Peek() + node2 := iter2.Peek() + if node1 == nil && node2 == nil { + // both iterators are exhausted + break + } else if node1 != nil && node2 != nil { + // if both nodes have a different name, only keep the first one + if node1.Name < node2.Name { + node2 = nil + } else if node1.Name > node2.Name { + node1 = nil + } + } + + // non-nil nodes will be processed in the following, so advance the corresponding iterator + if node1 != nil { + if err = iter1.Next(); err != nil { + break + } + } + if node2 != nil { + if err = iter2.Next(); err != nil { + break + } + } + + if !yield(DualTree{Tree1: node1, Tree2: node2, Error: err}) { + return + } + } + if err != nil { + yield(DualTree{Tree1: nil, Tree2: nil, Error: err}) + return + } + } +}