diff --git a/postings.go b/postings.go index 2c6f890223..fcbbd4235c 100644 --- a/postings.go +++ b/postings.go @@ -139,30 +139,30 @@ func Merge(its ...Postings) Postings { a := its[0] for _, b := range its[1:] { - a = newMergePostings(a, b) + a = newMergedPostings(a, b) } return a } -type mergePostings struct { +type mergedPostings struct { a, b Postings aok, bok bool cur uint32 } -func newMergePostings(a, b Postings) *mergePostings { - it := &mergePostings{a: a, b: b} +func newMergedPostings(a, b Postings) *mergedPostings { + it := &mergedPostings{a: a, b: b} it.aok = it.a.Next() it.bok = it.b.Next() return it } -func (it *mergePostings) At() uint32 { +func (it *mergedPostings) At() uint32 { return it.cur } -func (it *mergePostings) Next() bool { +func (it *mergedPostings) Next() bool { if !it.aok && !it.bok { return false } @@ -197,13 +197,37 @@ func (it *mergePostings) Next() bool { return true } -func (it *mergePostings) Seek(id uint32) bool { +func (it *mergedPostings) Seek(id uint32) bool { it.aok = it.a.Seek(id) it.bok = it.b.Seek(id) - return it.Next() + + if !it.aok && !it.bok { + return false + } + + if !it.aok { + it.cur = it.b.At() + + return true + } + if !it.bok { + it.cur = it.a.At() + + return true + } + + acur, bcur := it.a.At(), it.b.At() + + if acur < bcur { + it.cur = acur + } else { + it.cur = bcur + } + + return true } -func (it *mergePostings) Err() error { +func (it *mergedPostings) Err() error { if it.a.Err() != nil { return it.a.Err() } diff --git a/postings_test.go b/postings_test.go index 14dd689509..62b0dd726b 100644 --- a/postings_test.go +++ b/postings_test.go @@ -3,7 +3,6 @@ package tsdb import ( "encoding/binary" "math/rand" - "reflect" "testing" "github.com/stretchr/testify/require" @@ -47,17 +46,13 @@ func TestIntersect(t *testing.T) { }, } - for i, c := range cases { + for _, c := range cases { a := newListPostings(c.a) b := newListPostings(c.b) res, err := expandPostings(Intersect(a, b)) - if err != nil { - t.Fatalf("%d: Unexpected error: %s", i, err) - } - if !reflect.DeepEqual(res, c.res) { - t.Fatalf("%d: Expected %v but got %v", i, c.res, res) - } + require.NoError(t, err) + require.Equal(t, c.res, res) } } @@ -80,12 +75,9 @@ func TestMultiIntersect(t *testing.T) { pc := newListPostings(c.c) res, err := expandPostings(Intersect(pa, pb, pc)) - if err != nil { - t.Fatalf("Unexpected error: %s", err) - } - if !reflect.DeepEqual(res, c.res) { - t.Fatalf("Expected %v but got %v", c.res, res) - } + + require.NoError(t, err) + require.Equal(t, c.res, res) } } @@ -141,12 +133,8 @@ func TestMultiMerge(t *testing.T) { i3 := newListPostings(c.c) res, err := expandPostings(Merge(i1, i2, i3)) - if err != nil { - t.Fatalf("Unexpected error: %s", err) - } - if !reflect.DeepEqual(res, c.res) { - t.Fatalf("Expected %v but got %v", c.res, res) - } + require.NoError(t, err) + require.Equal(t, c.res, res) } } @@ -176,14 +164,68 @@ func TestMerge(t *testing.T) { a := newListPostings(c.a) b := newListPostings(c.b) - res, err := expandPostings(newMergePostings(a, b)) - if err != nil { - t.Fatalf("Unexpected error: %s", err) - } - if !reflect.DeepEqual(res, c.res) { - t.Fatalf("Expected %v but got %v", c.res, res) - } + res, err := expandPostings(newMergedPostings(a, b)) + require.NoError(t, err) + require.Equal(t, c.res, res) } + + t.Run("Seek", func(t *testing.T) { + var cases = []struct { + a, b []uint32 + + seek uint32 + success bool + res []uint32 + }{ + { + a: []uint32{1, 2, 3, 4, 5}, + b: []uint32{6, 7, 8, 9, 10}, + + seek: 0, + success: true, + res: []uint32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + }, + { + a: []uint32{1, 2, 3, 4, 5}, + b: []uint32{6, 7, 8, 9, 10}, + + seek: 2, + success: true, + res: []uint32{2, 3, 4, 5, 6, 7, 8, 9, 10}, + }, + { + a: []uint32{1, 2, 3, 4, 5}, + b: []uint32{4, 5, 6, 7, 8}, + + seek: 9, + success: false, + res: nil, + }, + { + a: []uint32{1, 2, 3, 4, 9, 10}, + b: []uint32{1, 4, 5, 6, 7, 8, 10, 11}, + + seek: 10, + success: true, + res: []uint32{10, 11}, + }, + } + + for _, c := range cases { + a := newListPostings(c.a) + b := newListPostings(c.b) + + p := newMergedPostings(a, b) + + require.Equal(t, c.success, p.Seek(c.seek)) + + res, err := expandPostings(p) + require.NoError(t, err) + require.Equal(t, c.res, res) + } + + return + }) } func TestBigEndian(t *testing.T) {