Merge pull request #365 from Icinga/data-races

Fix data races
This commit is contained in:
Julian Brost 2021-09-23 12:32:19 +02:00 committed by GitHub
commit 4457f9f440
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 66 additions and 68 deletions

View file

@ -16,7 +16,7 @@ func (c *Counter) Inc() {
}
// Val returns the counter value.
func (c Counter) Val() uint64 {
func (c *Counter) Val() uint64 {
return atomic.LoadUint64(c.ptr())
}

View file

@ -66,6 +66,15 @@ type Waiter interface {
Wait() error // Wait waits for execution to complete.
}
// The WaiterFunc type is an adapter to allow the use of ordinary functions as Waiter.
// If f is a function with the appropriate signature, WaiterFunc(f) is a Waiter that calls f.
type WaiterFunc func() error
// Wait implements the Waiter interface.
func (f WaiterFunc) Wait() error {
return f()
}
// Initer implements the Init method,
// which initializes the object in addition to zeroing.
type Initer interface {

View file

@ -46,6 +46,7 @@ func (o *Options) Validate() error {
if o.HScanCount < 1 {
return errors.New("hscan_count must be at least 1")
}
return nil
}
@ -62,25 +63,22 @@ type HPair struct {
// HYield yields HPair field-value pairs for all fields in the hash stored at key.
func (c *Client) HYield(ctx context.Context, key string) (<-chan HPair, <-chan error) {
pairs := make(chan HPair)
g, ctx := errgroup.WithContext(ctx)
pairs := make(chan HPair, c.options.HScanCount)
c.logger.Infof("Syncing %s", key)
g.Go(func() error {
var cnt com.Counter
return pairs, com.WaitAsync(contracts.WaiterFunc(func() error {
defer close(pairs)
var cnt uint64
defer utils.Timed(time.Now(), func(elapsed time.Duration) {
c.logger.Infof("Fetched %d elements of %s in %s", cnt.Val(), key, elapsed)
c.logger.Infof("Fetched %d elements of %s in %s", cnt, key, elapsed)
})
var cursor uint64
var err error
var page []string
g, ctx := errgroup.WithContext(ctx)
for {
cmd := c.HScan(ctx, key, cursor, "", int64(c.options.HScanCount))
page, cursor, err = cmd.Result()
@ -89,94 +87,85 @@ func (c *Client) HYield(ctx context.Context, key string) (<-chan HPair, <-chan e
return WrapCmdErr(cmd)
}
g.Go(func(page []string) func() error {
return func() error {
for i := 0; i < len(page); i += 2 {
select {
case pairs <- HPair{
Field: page[i],
Value: page[i+1],
}:
cnt.Inc()
case <-ctx.Done():
return ctx.Err()
}
}
return nil
for i := 0; i < len(page); i += 2 {
select {
case pairs <- HPair{
Field: page[i],
Value: page[i+1],
}:
cnt++
case <-ctx.Done():
return ctx.Err()
}
}(page))
}
if cursor == 0 {
break
}
}
return g.Wait()
})
return pairs, com.WaitAsync(g)
return nil
}))
}
// HMYield yields HPair field-value pairs for the specified fields in the hash stored at key.
func (c *Client) HMYield(ctx context.Context, key string, fields ...string) (<-chan HPair, <-chan error) {
pairs := make(chan HPair)
g, ctx := errgroup.WithContext(ctx)
// Use context from group.
batches := utils.BatchSliceOfStrings(ctx, fields, c.options.HMGetCount)
g.Go(func() error {
defer close(pairs)
return pairs, com.WaitAsync(contracts.WaiterFunc(func() error {
g, ctx := errgroup.WithContext(ctx)
defer func() {
// Wait until the group is done so that we can safely close the pairs channel,
// because on error, sem.Acquire will return before calling g.Wait(),
// which can result in goroutines working on a closed channel.
_ = g.Wait()
close(pairs)
}()
// Use context from group.
batches := utils.BatchSliceOfStrings(ctx, fields, c.options.HMGetCount)
sem := semaphore.NewWeighted(int64(c.options.MaxHMGetConnections))
g, ctx := errgroup.WithContext(ctx)
for batch := range batches {
if err := sem.Acquire(ctx, 1); err != nil {
return errors.Wrap(err, "can't acquire semaphore")
}
g.Go(func(batch []string) func() error {
return func() error {
defer sem.Release(1)
batch := batch
g.Go(func() error {
defer sem.Release(1)
cmd := c.HMGet(ctx, key, batch...)
vals, err := cmd.Result()
cmd := c.HMGet(ctx, key, batch...)
vals, err := cmd.Result()
if err != nil {
return WrapCmdErr(cmd)
if err != nil {
return WrapCmdErr(cmd)
}
for i, v := range vals {
if v == nil {
c.logger.Warnf("HMGET %s: field %#v missing", key, batch[i])
continue
}
g.Go(func() error {
for i, v := range vals {
if v == nil {
c.logger.Warnf("HMGET %s: field %#v missing", key, batch[i])
continue
}
select {
case pairs <- HPair{
Field: batch[i],
Value: v.(string),
}:
case <-ctx.Done():
return ctx.Err()
}
}
return nil
})
return nil
select {
case pairs <- HPair{
Field: batch[i],
Value: v.(string),
}:
case <-ctx.Done():
return ctx.Err()
}
}
}(batch))
return nil
})
}
return g.Wait()
})
return pairs, com.WaitAsync(g)
}))
}
// StreamLastId fetches the last message of a stream and returns its ID.