mirror of
https://github.com/helm/helm.git
synced 2026-05-28 04:35:48 -04:00
fix: use local WaitGroup in Push to fix concurrent call safety
Address Copilot review feedback: - Move WaitGroup from Client struct field to local variable in Push() - Pass WaitGroup to runWorker as parameter instead of accessing via receiver - Add documentation comments to runWorker, blob type, and its methods - Add TestPushConcurrent to verify concurrent Push operations are safe This fixes a potential race condition if Push() is called concurrently on the same Client instance, where the shared WaitGroup could cause incorrect synchronization between unrelated Push operations. Signed-off-by: Terry Howe <terrylhowe@gmail.com>
This commit is contained in:
parent
52f64d280a
commit
a8466f849b
2 changed files with 122 additions and 8 deletions
|
|
@ -79,7 +79,6 @@ type (
|
|||
credentialsStore credentials.Store
|
||||
httpClient *http.Client
|
||||
plainHTTP bool
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// ClientOption allows specifying various settings configurable by the user for overriding the defaults
|
||||
|
|
@ -685,8 +684,9 @@ func (c *Client) Push(data []byte, ref string, options ...PushOption) (*PushResu
|
|||
}
|
||||
|
||||
layers := []ocispec.Descriptor{chartBlob.descriptor}
|
||||
var wg sync.WaitGroup
|
||||
if !exists {
|
||||
c.runWorker(ctx, chartBlob.push)
|
||||
runWorker(ctx, &wg, chartBlob.push)
|
||||
}
|
||||
|
||||
configData, err := json.Marshal(meta)
|
||||
|
|
@ -694,14 +694,14 @@ func (c *Client) Push(data []byte, ref string, options ...PushOption) (*PushResu
|
|||
return nil, err
|
||||
}
|
||||
configBlob := newBlob(repository, ConfigMediaType, configData)
|
||||
c.runWorker(ctx, configBlob.pushNew)
|
||||
runWorker(ctx, &wg, configBlob.pushNew)
|
||||
|
||||
var provBlob blob
|
||||
if operation.provData != nil {
|
||||
provBlob = newBlob(repository, ProvLayerMediaType, operation.provData)
|
||||
c.runWorker(ctx, provBlob.pushNew)
|
||||
runWorker(ctx, &wg, provBlob.pushNew)
|
||||
}
|
||||
c.wg.Wait()
|
||||
wg.Wait()
|
||||
|
||||
if chartBlob.err != nil {
|
||||
return nil, chartBlob.err
|
||||
|
|
@ -949,14 +949,20 @@ func (c *Client) tagManifest(ctx context.Context, memoryStore *memory.Store,
|
|||
manifestData, parsedRef.String())
|
||||
}
|
||||
|
||||
func (c *Client) runWorker(ctx context.Context, worker func(context.Context)) {
|
||||
c.wg.Add(1)
|
||||
// runWorker spawns a goroutine to execute the worker function and tracks it
|
||||
// with the provided WaitGroup. The WaitGroup counter is incremented before
|
||||
// spawning and decremented when the worker completes.
|
||||
func runWorker(ctx context.Context, wg *sync.WaitGroup, worker func(context.Context)) {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer c.wg.Done()
|
||||
defer wg.Done()
|
||||
worker(ctx)
|
||||
}()
|
||||
}
|
||||
|
||||
// blob represents a content-addressable blob to be pushed to an OCI registry.
|
||||
// It encapsulates the data, media type, and destination repository, and tracks
|
||||
// the resulting descriptor and any error from push operations.
|
||||
type blob struct {
|
||||
mediaType string
|
||||
dst *remote.Repository
|
||||
|
|
@ -965,6 +971,7 @@ type blob struct {
|
|||
err error
|
||||
}
|
||||
|
||||
// newBlob creates a new blob with the given repository, media type, and data.
|
||||
func newBlob(dst *remote.Repository, mediaType string, data []byte) blob {
|
||||
return blob{
|
||||
mediaType: mediaType,
|
||||
|
|
@ -973,6 +980,9 @@ func newBlob(dst *remote.Repository, mediaType string, data []byte) blob {
|
|||
}
|
||||
}
|
||||
|
||||
// exists checks if the blob already exists in the registry by computing its
|
||||
// digest and querying the repository. It also populates the blob's descriptor
|
||||
// with size, media type, and digest information.
|
||||
func (b *blob) exists(ctx context.Context) (bool, error) {
|
||||
hash := sha256.Sum256(b.data)
|
||||
b.descriptor.Size = int64(len(b.data))
|
||||
|
|
@ -981,6 +991,9 @@ func (b *blob) exists(ctx context.Context) (bool, error) {
|
|||
return b.dst.Exists(ctx, b.descriptor)
|
||||
}
|
||||
|
||||
// pushNew checks if the blob exists in the registry first, and only pushes
|
||||
// if it doesn't exist. This avoids redundant uploads for blobs that are
|
||||
// already present. Any error is stored in b.err.
|
||||
func (b *blob) pushNew(ctx context.Context) {
|
||||
var exists bool
|
||||
exists, b.err = b.exists(ctx)
|
||||
|
|
@ -993,6 +1006,8 @@ func (b *blob) pushNew(ctx context.Context) {
|
|||
b.descriptor, b.err = oras.PushBytes(ctx, b.dst, b.mediaType, b.data)
|
||||
}
|
||||
|
||||
// push unconditionally pushes the blob to the registry without checking
|
||||
// for existence first. Any error is stored in b.err.
|
||||
func (b *blob) push(ctx context.Context) {
|
||||
b.descriptor, b.err = oras.PushBytes(ctx, b.dst, b.mediaType, b.data)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,11 +17,15 @@ limitations under the License.
|
|||
package registry
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
ocispec "github.com/opencontainers/image-spec/specs-go/v1"
|
||||
|
|
@ -166,3 +170,98 @@ func TestWarnIfHostHasPath(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestPushConcurrent verifies that concurrent Push operations on the same Client
|
||||
// do not interfere with each other. This test is designed to catch race conditions
|
||||
// when run with -race flag.
|
||||
func TestPushConcurrent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a mock registry server that accepts pushes
|
||||
var mu sync.Mutex
|
||||
uploads := make(map[string][]byte)
|
||||
var uploadCounter int
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.Method == http.MethodHead && strings.Contains(r.URL.Path, "/blobs/"):
|
||||
// Blob existence check - return 404 to force upload
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
|
||||
case r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/blobs/uploads/"):
|
||||
// Start upload - return upload URL with unique ID
|
||||
mu.Lock()
|
||||
uploadCounter++
|
||||
uploadID := fmt.Sprintf("upload-%d", uploadCounter)
|
||||
mu.Unlock()
|
||||
w.Header().Set("Location", fmt.Sprintf("%s%s", r.URL.Path, uploadID))
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
|
||||
case r.Method == http.MethodPut && strings.Contains(r.URL.Path, "/blobs/uploads/"):
|
||||
// Complete upload - extract digest from query param
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
digest := r.URL.Query().Get("digest")
|
||||
mu.Lock()
|
||||
uploads[r.URL.Path] = body
|
||||
mu.Unlock()
|
||||
w.Header().Set("Docker-Content-Digest", digest)
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
|
||||
case r.Method == http.MethodPut && strings.Contains(r.URL.Path, "/manifests/"):
|
||||
// Manifest push - compute actual sha256 digest of the body
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
hash := sha256.Sum256(body)
|
||||
digest := fmt.Sprintf("sha256:%x", hash)
|
||||
w.Header().Set("Docker-Content-Digest", digest)
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
host := strings.TrimPrefix(srv.URL, "http://")
|
||||
|
||||
// Create client
|
||||
credFile := filepath.Join(t.TempDir(), "config.json")
|
||||
client, err := NewClient(
|
||||
ClientOptWriter(io.Discard),
|
||||
ClientOptCredentialsFile(credFile),
|
||||
ClientOptPlainHTTP(),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Load test chart
|
||||
chartData, err := os.ReadFile("../downloader/testdata/local-subchart-0.1.0.tgz")
|
||||
require.NoError(t, err, "no error loading test chart")
|
||||
|
||||
meta, err := extractChartMeta(chartData)
|
||||
require.NoError(t, err, "no error extracting chart meta")
|
||||
|
||||
// Run concurrent pushes
|
||||
const numGoroutines = 10
|
||||
var wg sync.WaitGroup
|
||||
errs := make(chan error, numGoroutines)
|
||||
|
||||
for i := range numGoroutines {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
// Each goroutine pushes to a different tag to avoid conflicts
|
||||
ref := fmt.Sprintf("%s/testrepo/%s:%s-%d", host, meta.Name, meta.Version, idx)
|
||||
_, err := client.Push(chartData, ref, PushOptStrictMode(false))
|
||||
if err != nil {
|
||||
errs <- fmt.Errorf("goroutine %d: %w", idx, err)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errs)
|
||||
|
||||
// Check for errors
|
||||
for err := range errs {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue