diff --git a/pkg/com/counter.go b/pkg/com/counter.go index 9c6e5264..3692c63d 100644 --- a/pkg/com/counter.go +++ b/pkg/com/counter.go @@ -16,7 +16,7 @@ func (c *Counter) Inc() { } // Val returns the counter value. -func (c Counter) Val() uint64 { +func (c *Counter) Val() uint64 { return atomic.LoadUint64(c.ptr()) } diff --git a/pkg/contracts/contracts.go b/pkg/contracts/contracts.go index 25513aa7..51594144 100644 --- a/pkg/contracts/contracts.go +++ b/pkg/contracts/contracts.go @@ -66,6 +66,15 @@ type Waiter interface { Wait() error // Wait waits for execution to complete. } +// The WaiterFunc type is an adapter to allow the use of ordinary functions as Waiter. +// If f is a function with the appropriate signature, WaiterFunc(f) is a Waiter that calls f. +type WaiterFunc func() error + +// Wait implements the Waiter interface. +func (f WaiterFunc) Wait() error { + return f() +} + // Initer implements the Init method, // which initializes the object in addition to zeroing. type Initer interface { diff --git a/pkg/icingaredis/client.go b/pkg/icingaredis/client.go index 308dffec..aeea1c02 100644 --- a/pkg/icingaredis/client.go +++ b/pkg/icingaredis/client.go @@ -46,6 +46,7 @@ func (o *Options) Validate() error { if o.HScanCount < 1 { return errors.New("hscan_count must be at least 1") } + return nil } @@ -62,25 +63,22 @@ type HPair struct { // HYield yields HPair field-value pairs for all fields in the hash stored at key. func (c *Client) HYield(ctx context.Context, key string) (<-chan HPair, <-chan error) { - pairs := make(chan HPair) - g, ctx := errgroup.WithContext(ctx) + pairs := make(chan HPair, c.options.HScanCount) c.logger.Infof("Syncing %s", key) - g.Go(func() error { - var cnt com.Counter - + return pairs, com.WaitAsync(contracts.WaiterFunc(func() error { defer close(pairs) + + var cnt uint64 defer utils.Timed(time.Now(), func(elapsed time.Duration) { - c.logger.Infof("Fetched %d elements of %s in %s", cnt.Val(), key, elapsed) + c.logger.Infof("Fetched %d elements of %s in %s", cnt, key, elapsed) }) var cursor uint64 var err error var page []string - g, ctx := errgroup.WithContext(ctx) - for { cmd := c.HScan(ctx, key, cursor, "", int64(c.options.HScanCount)) page, cursor, err = cmd.Result() @@ -89,94 +87,85 @@ func (c *Client) HYield(ctx context.Context, key string) (<-chan HPair, <-chan e return WrapCmdErr(cmd) } - g.Go(func(page []string) func() error { - return func() error { - for i := 0; i < len(page); i += 2 { - select { - case pairs <- HPair{ - Field: page[i], - Value: page[i+1], - }: - cnt.Inc() - case <-ctx.Done(): - return ctx.Err() - } - } - - return nil + for i := 0; i < len(page); i += 2 { + select { + case pairs <- HPair{ + Field: page[i], + Value: page[i+1], + }: + cnt++ + case <-ctx.Done(): + return ctx.Err() } - }(page)) + } if cursor == 0 { break } } - return g.Wait() - }) - - return pairs, com.WaitAsync(g) + return nil + })) } // HMYield yields HPair field-value pairs for the specified fields in the hash stored at key. func (c *Client) HMYield(ctx context.Context, key string, fields ...string) (<-chan HPair, <-chan error) { pairs := make(chan HPair) - g, ctx := errgroup.WithContext(ctx) - // Use context from group. - batches := utils.BatchSliceOfStrings(ctx, fields, c.options.HMGetCount) - g.Go(func() error { - defer close(pairs) + return pairs, com.WaitAsync(contracts.WaiterFunc(func() error { + g, ctx := errgroup.WithContext(ctx) + + defer func() { + // Wait until the group is done so that we can safely close the pairs channel, + // because on error, sem.Acquire will return before calling g.Wait(), + // which can result in goroutines working on a closed channel. + _ = g.Wait() + close(pairs) + }() + + // Use context from group. + batches := utils.BatchSliceOfStrings(ctx, fields, c.options.HMGetCount) sem := semaphore.NewWeighted(int64(c.options.MaxHMGetConnections)) - g, ctx := errgroup.WithContext(ctx) - for batch := range batches { if err := sem.Acquire(ctx, 1); err != nil { return errors.Wrap(err, "can't acquire semaphore") } - g.Go(func(batch []string) func() error { - return func() error { - defer sem.Release(1) + batch := batch + g.Go(func() error { + defer sem.Release(1) - cmd := c.HMGet(ctx, key, batch...) - vals, err := cmd.Result() + cmd := c.HMGet(ctx, key, batch...) + vals, err := cmd.Result() - if err != nil { - return WrapCmdErr(cmd) + if err != nil { + return WrapCmdErr(cmd) + } + + for i, v := range vals { + if v == nil { + c.logger.Warnf("HMGET %s: field %#v missing", key, batch[i]) + continue } - g.Go(func() error { - for i, v := range vals { - if v == nil { - c.logger.Warnf("HMGET %s: field %#v missing", key, batch[i]) - continue - } - - select { - case pairs <- HPair{ - Field: batch[i], - Value: v.(string), - }: - case <-ctx.Done(): - return ctx.Err() - } - } - - return nil - }) - - return nil + select { + case pairs <- HPair{ + Field: batch[i], + Value: v.(string), + }: + case <-ctx.Done(): + return ctx.Err() + } } - }(batch)) + + return nil + }) } return g.Wait() - }) - - return pairs, com.WaitAsync(g) + })) } // StreamLastId fetches the last message of a stream and returns its ID.