Do not close channel too early

This fixes a data race where the pairs channel was closed too early
when the context is canceled and therefore the outer errgroup
returns from Redis operations before Wait() is called on the inner
errgroup. Unfinished Go methods in the inner errgroup would then
try to work on a closed channel.
This commit is contained in:
Eric Lippmann 2021-09-20 09:12:58 +02:00
parent 7351559793
commit c1e722f5fa

View file

@ -46,6 +46,7 @@ func (o *Options) Validate() error {
if o.HScanCount < 1 {
return errors.New("hscan_count must be at least 1")
}
return nil
}
@ -63,14 +64,13 @@ type HPair struct {
// HYield yields HPair field-value pairs for all fields in the hash stored at key.
func (c *Client) HYield(ctx context.Context, key string) (<-chan HPair, <-chan error) {
pairs := make(chan HPair)
g, ctx := errgroup.WithContext(ctx)
c.logger.Infof("Syncing %s", key)
g.Go(func() error {
var cnt com.Counter
return pairs, com.WaitAsync(contracts.WaiterFunc(func() error {
defer close(pairs)
var cnt com.Counter
defer utils.Timed(time.Now(), func(elapsed time.Duration) {
c.logger.Infof("Fetched %d elements of %s in %s", cnt.Val(), key, elapsed)
})
@ -79,8 +79,6 @@ func (c *Client) HYield(ctx context.Context, key string) (<-chan HPair, <-chan e
var err error
var page []string
g, ctx := errgroup.WithContext(ctx)
for {
cmd := c.HScan(ctx, key, cursor, "", int64(c.options.HScanCount))
page, cursor, err = cmd.Result()
@ -89,94 +87,85 @@ func (c *Client) HYield(ctx context.Context, key string) (<-chan HPair, <-chan e
return WrapCmdErr(cmd)
}
g.Go(func(page []string) func() error {
return func() error {
for i := 0; i < len(page); i += 2 {
select {
case pairs <- HPair{
Field: page[i],
Value: page[i+1],
}:
cnt.Inc()
case <-ctx.Done():
return ctx.Err()
}
}
return nil
for i := 0; i < len(page); i += 2 {
select {
case pairs <- HPair{
Field: page[i],
Value: page[i+1],
}:
cnt.Inc()
case <-ctx.Done():
return ctx.Err()
}
}(page))
}
if cursor == 0 {
break
}
}
return g.Wait()
})
return pairs, com.WaitAsync(g)
return nil
}))
}
// HMYield yields HPair field-value pairs for the specified fields in the hash stored at key.
func (c *Client) HMYield(ctx context.Context, key string, fields ...string) (<-chan HPair, <-chan error) {
pairs := make(chan HPair)
g, ctx := errgroup.WithContext(ctx)
// Use context from group.
batches := utils.BatchSliceOfStrings(ctx, fields, c.options.HMGetCount)
g.Go(func() error {
defer close(pairs)
return pairs, com.WaitAsync(contracts.WaiterFunc(func() error {
g, ctx := errgroup.WithContext(ctx)
defer func() {
// Wait until the group is done so that we can safely close the pairs channel,
// because on error, sem.Acquire will return before calling g.Wait(),
// which can result in goroutines working on a closed channel.
_ = g.Wait()
close(pairs)
}()
// Use context from group.
batches := utils.BatchSliceOfStrings(ctx, fields, c.options.HMGetCount)
sem := semaphore.NewWeighted(int64(c.options.MaxHMGetConnections))
g, ctx := errgroup.WithContext(ctx)
for batch := range batches {
if err := sem.Acquire(ctx, 1); err != nil {
return errors.Wrap(err, "can't acquire semaphore")
}
g.Go(func(batch []string) func() error {
return func() error {
defer sem.Release(1)
batch := batch
g.Go(func() error {
defer sem.Release(1)
cmd := c.HMGet(ctx, key, batch...)
vals, err := cmd.Result()
cmd := c.HMGet(ctx, key, batch...)
vals, err := cmd.Result()
if err != nil {
return WrapCmdErr(cmd)
if err != nil {
return WrapCmdErr(cmd)
}
for i, v := range vals {
if v == nil {
c.logger.Warnf("HMGET %s: field %#v missing", key, batch[i])
continue
}
g.Go(func() error {
for i, v := range vals {
if v == nil {
c.logger.Warnf("HMGET %s: field %#v missing", key, batch[i])
continue
}
select {
case pairs <- HPair{
Field: batch[i],
Value: v.(string),
}:
case <-ctx.Done():
return ctx.Err()
}
}
return nil
})
return nil
select {
case pairs <- HPair{
Field: batch[i],
Value: v.(string),
}:
case <-ctx.Done():
return ctx.Err()
}
}
}(batch))
return nil
})
}
return g.Wait()
})
return pairs, com.WaitAsync(g)
}))
}
// StreamLastId fetches the last message of a stream and returns its ID.