diff --git a/pkg/icingaredis/client.go b/pkg/icingaredis/client.go index 308dffec..f05b456f 100644 --- a/pkg/icingaredis/client.go +++ b/pkg/icingaredis/client.go @@ -46,6 +46,7 @@ func (o *Options) Validate() error { if o.HScanCount < 1 { return errors.New("hscan_count must be at least 1") } + return nil } @@ -63,14 +64,13 @@ type HPair struct { // HYield yields HPair field-value pairs for all fields in the hash stored at key. func (c *Client) HYield(ctx context.Context, key string) (<-chan HPair, <-chan error) { pairs := make(chan HPair) - g, ctx := errgroup.WithContext(ctx) c.logger.Infof("Syncing %s", key) - g.Go(func() error { - var cnt com.Counter - + return pairs, com.WaitAsync(contracts.WaiterFunc(func() error { defer close(pairs) + + var cnt com.Counter defer utils.Timed(time.Now(), func(elapsed time.Duration) { c.logger.Infof("Fetched %d elements of %s in %s", cnt.Val(), key, elapsed) }) @@ -79,8 +79,6 @@ func (c *Client) HYield(ctx context.Context, key string) (<-chan HPair, <-chan e var err error var page []string - g, ctx := errgroup.WithContext(ctx) - for { cmd := c.HScan(ctx, key, cursor, "", int64(c.options.HScanCount)) page, cursor, err = cmd.Result() @@ -89,94 +87,85 @@ func (c *Client) HYield(ctx context.Context, key string) (<-chan HPair, <-chan e return WrapCmdErr(cmd) } - g.Go(func(page []string) func() error { - return func() error { - for i := 0; i < len(page); i += 2 { - select { - case pairs <- HPair{ - Field: page[i], - Value: page[i+1], - }: - cnt.Inc() - case <-ctx.Done(): - return ctx.Err() - } - } - - return nil + for i := 0; i < len(page); i += 2 { + select { + case pairs <- HPair{ + Field: page[i], + Value: page[i+1], + }: + cnt.Inc() + case <-ctx.Done(): + return ctx.Err() } - }(page)) + } if cursor == 0 { break } } - return g.Wait() - }) - - return pairs, com.WaitAsync(g) + return nil + })) } // HMYield yields HPair field-value pairs for the specified fields in the hash stored at key. func (c *Client) HMYield(ctx context.Context, key string, fields ...string) (<-chan HPair, <-chan error) { pairs := make(chan HPair) - g, ctx := errgroup.WithContext(ctx) - // Use context from group. - batches := utils.BatchSliceOfStrings(ctx, fields, c.options.HMGetCount) - g.Go(func() error { - defer close(pairs) + return pairs, com.WaitAsync(contracts.WaiterFunc(func() error { + g, ctx := errgroup.WithContext(ctx) + + defer func() { + // Wait until the group is done so that we can safely close the pairs channel, + // because on error, sem.Acquire will return before calling g.Wait(), + // which can result in goroutines working on a closed channel. + _ = g.Wait() + close(pairs) + }() + + // Use context from group. + batches := utils.BatchSliceOfStrings(ctx, fields, c.options.HMGetCount) sem := semaphore.NewWeighted(int64(c.options.MaxHMGetConnections)) - g, ctx := errgroup.WithContext(ctx) - for batch := range batches { if err := sem.Acquire(ctx, 1); err != nil { return errors.Wrap(err, "can't acquire semaphore") } - g.Go(func(batch []string) func() error { - return func() error { - defer sem.Release(1) + batch := batch + g.Go(func() error { + defer sem.Release(1) - cmd := c.HMGet(ctx, key, batch...) - vals, err := cmd.Result() + cmd := c.HMGet(ctx, key, batch...) + vals, err := cmd.Result() - if err != nil { - return WrapCmdErr(cmd) + if err != nil { + return WrapCmdErr(cmd) + } + + for i, v := range vals { + if v == nil { + c.logger.Warnf("HMGET %s: field %#v missing", key, batch[i]) + continue } - g.Go(func() error { - for i, v := range vals { - if v == nil { - c.logger.Warnf("HMGET %s: field %#v missing", key, batch[i]) - continue - } - - select { - case pairs <- HPair{ - Field: batch[i], - Value: v.(string), - }: - case <-ctx.Done(): - return ctx.Err() - } - } - - return nil - }) - - return nil + select { + case pairs <- HPair{ + Field: batch[i], + Value: v.(string), + }: + case <-ctx.Done(): + return ctx.Err() + } } - }(batch)) + + return nil + }) } return g.Wait() - }) - - return pairs, com.WaitAsync(g) + })) } // StreamLastId fetches the last message of a stream and returns its ID.