From c1e722f5fa5d271daa3a16a3c48aefce3f18a403 Mon Sep 17 00:00:00 2001 From: Eric Lippmann Date: Mon, 20 Sep 2021 09:12:58 +0200 Subject: [PATCH] Do not close channel too early This fixes a data race where the pairs channel was closed too early when the context is canceled and therefore the outer errgroup returns from Redis operations before Wait() is called on the inner errgroup. Unfinished Go methods in the inner errgroup would then try to work on a closed channel. --- pkg/icingaredis/client.go | 119 +++++++++++++++++--------------------- 1 file changed, 54 insertions(+), 65 deletions(-) diff --git a/pkg/icingaredis/client.go b/pkg/icingaredis/client.go index 308dffec..f05b456f 100644 --- a/pkg/icingaredis/client.go +++ b/pkg/icingaredis/client.go @@ -46,6 +46,7 @@ func (o *Options) Validate() error { if o.HScanCount < 1 { return errors.New("hscan_count must be at least 1") } + return nil } @@ -63,14 +64,13 @@ type HPair struct { // HYield yields HPair field-value pairs for all fields in the hash stored at key. func (c *Client) HYield(ctx context.Context, key string) (<-chan HPair, <-chan error) { pairs := make(chan HPair) - g, ctx := errgroup.WithContext(ctx) c.logger.Infof("Syncing %s", key) - g.Go(func() error { - var cnt com.Counter - + return pairs, com.WaitAsync(contracts.WaiterFunc(func() error { defer close(pairs) + + var cnt com.Counter defer utils.Timed(time.Now(), func(elapsed time.Duration) { c.logger.Infof("Fetched %d elements of %s in %s", cnt.Val(), key, elapsed) }) @@ -79,8 +79,6 @@ func (c *Client) HYield(ctx context.Context, key string) (<-chan HPair, <-chan e var err error var page []string - g, ctx := errgroup.WithContext(ctx) - for { cmd := c.HScan(ctx, key, cursor, "", int64(c.options.HScanCount)) page, cursor, err = cmd.Result() @@ -89,94 +87,85 @@ func (c *Client) HYield(ctx context.Context, key string) (<-chan HPair, <-chan e return WrapCmdErr(cmd) } - g.Go(func(page []string) func() error { - return func() error { - for i := 0; i < len(page); i += 2 { - select { - case pairs <- HPair{ - Field: page[i], - Value: page[i+1], - }: - cnt.Inc() - case <-ctx.Done(): - return ctx.Err() - } - } - - return nil + for i := 0; i < len(page); i += 2 { + select { + case pairs <- HPair{ + Field: page[i], + Value: page[i+1], + }: + cnt.Inc() + case <-ctx.Done(): + return ctx.Err() } - }(page)) + } if cursor == 0 { break } } - return g.Wait() - }) - - return pairs, com.WaitAsync(g) + return nil + })) } // HMYield yields HPair field-value pairs for the specified fields in the hash stored at key. func (c *Client) HMYield(ctx context.Context, key string, fields ...string) (<-chan HPair, <-chan error) { pairs := make(chan HPair) - g, ctx := errgroup.WithContext(ctx) - // Use context from group. - batches := utils.BatchSliceOfStrings(ctx, fields, c.options.HMGetCount) - g.Go(func() error { - defer close(pairs) + return pairs, com.WaitAsync(contracts.WaiterFunc(func() error { + g, ctx := errgroup.WithContext(ctx) + + defer func() { + // Wait until the group is done so that we can safely close the pairs channel, + // because on error, sem.Acquire will return before calling g.Wait(), + // which can result in goroutines working on a closed channel. + _ = g.Wait() + close(pairs) + }() + + // Use context from group. + batches := utils.BatchSliceOfStrings(ctx, fields, c.options.HMGetCount) sem := semaphore.NewWeighted(int64(c.options.MaxHMGetConnections)) - g, ctx := errgroup.WithContext(ctx) - for batch := range batches { if err := sem.Acquire(ctx, 1); err != nil { return errors.Wrap(err, "can't acquire semaphore") } - g.Go(func(batch []string) func() error { - return func() error { - defer sem.Release(1) + batch := batch + g.Go(func() error { + defer sem.Release(1) - cmd := c.HMGet(ctx, key, batch...) - vals, err := cmd.Result() + cmd := c.HMGet(ctx, key, batch...) + vals, err := cmd.Result() - if err != nil { - return WrapCmdErr(cmd) + if err != nil { + return WrapCmdErr(cmd) + } + + for i, v := range vals { + if v == nil { + c.logger.Warnf("HMGET %s: field %#v missing", key, batch[i]) + continue } - g.Go(func() error { - for i, v := range vals { - if v == nil { - c.logger.Warnf("HMGET %s: field %#v missing", key, batch[i]) - continue - } - - select { - case pairs <- HPair{ - Field: batch[i], - Value: v.(string), - }: - case <-ctx.Done(): - return ctx.Err() - } - } - - return nil - }) - - return nil + select { + case pairs <- HPair{ + Field: batch[i], + Value: v.(string), + }: + case <-ctx.Done(): + return ctx.Err() + } } - }(batch)) + + return nil + }) } return g.Wait() - }) - - return pairs, com.WaitAsync(g) + })) } // StreamLastId fetches the last message of a stream and returns its ID.