diff --git a/test/utils/ktesting/contexthelper_test.go b/test/utils/ktesting/contexthelper_test.go index 1bb79ca34c0..721d2eada55 100644 --- a/test/utils/ktesting/contexthelper_test.go +++ b/test/utils/ktesting/contexthelper_test.go @@ -19,9 +19,8 @@ package ktesting import ( "context" "errors" - "os" - "strings" "testing" + "testing/synctest" "time" "github.com/stretchr/testify/assert" @@ -38,9 +37,12 @@ func TestCause(t *testing.T) { timeoutCause := canceledError("I timed out") parentCause := errors.New("parent canceled") - t.Parallel() + contextBackground := func(t *testing.T) context.Context { + return context.Background() + } + for name, tt := range map[string]struct { - parentCtx context.Context + parentCtx func(t *testing.T) context.Context timeout time.Duration sleep time.Duration cancelCause string @@ -48,95 +50,81 @@ func TestCause(t *testing.T) { expectDeadline time.Duration }{ "nothing": { - parentCtx: context.Background(), + parentCtx: contextBackground, timeout: 5 * time.Millisecond, sleep: time.Millisecond, }, "timeout": { - parentCtx: context.Background(), + parentCtx: contextBackground, timeout: time.Millisecond, sleep: 5 * time.Millisecond, expectErr: context.Canceled, expectCause: canceledError(timeoutCause), }, "parent-canceled": { - parentCtx: func() context.Context { + parentCtx: func(t *testing.T) context.Context { ctx, cancel := context.WithCancel(context.Background()) cancel() return ctx - }(), + }, timeout: time.Millisecond, sleep: 5 * time.Millisecond, expectErr: context.Canceled, expectCause: context.Canceled, }, "parent-cause": { - parentCtx: func() context.Context { + parentCtx: func(t *testing.T) context.Context { ctx, cancel := context.WithCancelCause(context.Background()) cancel(parentCause) return ctx - }(), + }, timeout: time.Millisecond, sleep: 5 * time.Millisecond, expectErr: context.Canceled, expectCause: parentCause, }, "deadline-no-parent": { - parentCtx: context.Background(), + parentCtx: contextBackground, timeout: time.Minute, expectDeadline: time.Minute, }, "deadline-parent": { - parentCtx: func() context.Context { + parentCtx: func(t *testing.T) context.Context { ctx, cancel := context.WithTimeout(context.Background(), time.Minute) t.Cleanup(cancel) return ctx - }(), + }, timeout: 2 * time.Minute, expectDeadline: time.Minute, }, "deadline-child": { - parentCtx: func() context.Context { + parentCtx: func(t *testing.T) context.Context { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) t.Cleanup(cancel) return ctx - }(), + }, timeout: time.Minute, expectDeadline: time.Minute, }, } { t.Run(name, func(t *testing.T) { - ctx, cancel := withTimeout(tt.parentCtx, t, tt.timeout, timeoutCause.Error()) - if tt.cancelCause != "" { - cancel(tt.cancelCause) - } - if tt.expectDeadline != 0 { - actualDeadline, ok := ctx.Deadline() - if assert.True(t, ok, "should have had a deadline") { - // Testing timing behavior is unreliable in Prow because - // the test runs in parallel with several others. - // Therefore this check is skipped if a CI environment is - // detected. - ci, _ := os.LookupEnv("CI") - switch strings.ToLower(ci) { - case "yes", "true", "1": - // Skip. - default: - assert.InDelta(t, time.Until(actualDeadline), tt.expectDeadline, float64(time.Second), "remaining time till Deadline()") + synctest.Test(t, func(t *testing.T) { + ctx, cancel := withTimeout(tt.parentCtx(t), t, tt.timeout, timeoutCause.Error()) + if tt.cancelCause != "" { + cancel(tt.cancelCause) + } + if tt.expectDeadline != 0 { + actualDeadline, ok := ctx.Deadline() + if assert.True(t, ok, "should have had a deadline") { + assert.Equal(t, tt.expectDeadline, time.Until(actualDeadline), "remaining time till Deadline()") } } - } - time.Sleep(tt.sleep) - actualErr := ctx.Err() - actualCause := context.Cause(ctx) - ci, _ := os.LookupEnv("CI") - switch strings.ToLower(ci) { - case "yes", "true", "1": - // Skip. - default: + time.Sleep(tt.sleep) + actualErr := ctx.Err() + actualCause := context.Cause(ctx) assert.Equal(t, tt.expectErr, actualErr, "ctx.Err()") assert.Equal(t, tt.expectCause, actualCause, "context.Cause()") - } + }) }) } } diff --git a/test/utils/ktesting/examples/with_ktesting/example_test.go b/test/utils/ktesting/examples/with_ktesting/example_test.go index 99739960be6..eabfee59f6b 100644 --- a/test/utils/ktesting/examples/with_ktesting/example_test.go +++ b/test/utils/ktesting/examples/with_ktesting/example_test.go @@ -40,6 +40,16 @@ func TestTimeout(t *testing.T) { if deadline, ok := t.Deadline(); ok { t.Logf("Will fail shortly before the test suite deadline at %s.", deadline) } + + // This is how Ginkgo and ktesting communicate to Gomega how to + // provide a progress report when stuck in e.g. gomega.Eventually. + // Here we use this to provide some additional output when + // this example is sent a SIGUSR1. + remove := tCtx.Value("GINKGO_SPEC_CONTEXT").(interface { + AttachProgressReporter(func() string) func() + }).AttachProgressReporter(func() string { return "waiting for timeout or interrupt" }) + defer remove() + select { case <-time.After(1000 * time.Hour): // This should not be reached. diff --git a/test/utils/ktesting/helper_test.go b/test/utils/ktesting/helper_test.go index 423708f5396..85d509ded73 100644 --- a/test/utils/ktesting/helper_test.go +++ b/test/utils/ktesting/helper_test.go @@ -21,6 +21,7 @@ import ( "regexp" "strings" "testing" + "testing/synctest" "time" "github.com/stretchr/testify/assert" @@ -35,23 +36,25 @@ type testcase struct { } func (tc testcase) run(t *testing.T) { - buffer := &mockTB{} - tCtx := Init(buffer) - start := time.Now() - func() { - defer func() { - if r := recover(); r != nil && r != logBufferStop { - panic(r) - } + synctest.Test(t, func(t *testing.T) { + buffer := &mockTB{} + tCtx := Init(buffer) + start := time.Now() + func() { + defer func() { + if r := recover(); r != nil && r != logBufferStop { + panic(r) + } + }() + tc.cb(tCtx) }() - tc.cb(tCtx) - }() - duration := time.Since(start) + duration := time.Since(start) - trace := buffer.log.String() - t.Logf("Trace:\n%s\n", trace) - assert.InDelta(t, tc.expectDuration.Seconds(), duration.Seconds(), 0.1, "callback invocation duration %s", duration) - assert.Equal(t, tc.expectTrace, normalize(trace)) + trace := buffer.log.String() + t.Logf("Trace:\n%s\n", trace) + assert.Equal(t, tc.expectDuration, duration, "callback invocation duration %s") + assert.Equal(t, tc.expectTrace, normalize(trace)) + }) } // normalize replaces parts of message texts which may vary with constant strings. diff --git a/test/utils/ktesting/main_test.go b/test/utils/ktesting/main_test.go index 17c461f6610..dcab727fc43 100644 --- a/test/utils/ktesting/main_test.go +++ b/test/utils/ktesting/main_test.go @@ -29,10 +29,6 @@ func TestMain(m *testing.M) { // Bail out early when -help was given as parameter. flag.Parse() - // The unit tests assume that they run as a unit test, with progress reporting enabled. - // This leaks a goroutine, so we have to do it before IgnoreCurrent. - initSignals() - // Must be called *before* creating new goroutines. goleakOpts := []goleak.Option{ goleak.IgnoreCurrent(), diff --git a/test/utils/ktesting/signals.go b/test/utils/ktesting/signals.go index b6ac765f259..f498c46259c 100644 --- a/test/utils/ktesting/signals.go +++ b/test/utils/ktesting/signals.go @@ -24,13 +24,16 @@ import ( "os/signal" "strings" "sync" + "testing" ) var ( - interruptCtx = context.Background() - - defaultProgressReporter = new(progressReporter) - defaultSignalChannel chan os.Signal + // defaultProgressReporter is inactive until init is called. + defaultProgressReporter = &progressReporter{ + // os.Stderr gets redirected by "go test". "go test -v" has to be + // used to see the output while a test runs. + out: os.Stderr, + } ) const ginkgoSpecContextKey = "GINKGO_SPEC_CONTEXT" @@ -39,42 +42,18 @@ type ginkgoReporter interface { AttachProgressReporter(reporter func() string) func() } -// initSignals is invoked once when ktesting is used for a `go test` unit test. -// It implements support for triggering a progress report in -// a running test when sending it a USR1 signal, similar to the corresponding -// Ginkgo feature. -func initSignals() { - signalCtx, _ := signal.NotifyContext(context.Background(), os.Interrupt) - cancelCtx, cancel := context.WithCancelCause(context.Background()) - go func() { - <-signalCtx.Done() - cancel(errors.New("received interrupt signal")) - }() - - // This reimplements the contract between Ginkgo and Gomega for progress reporting. - // When using Ginkgo contexts, Ginkgo will implement it. This here is for "go test". - // - // nolint:staticcheck // It complains about using a plain string. This can only be fixed - // by Ginkgo and Gomega formalizing this interface and define a type (somewhere... - // probably cannot be in either Ginkgo or Gomega). - interruptCtx = context.WithValue(cancelCtx, ginkgoSpecContextKey, defaultProgressReporter) - - defaultSignalChannel = make(chan os.Signal, 1) - // progressSignals will be empty on Windows. - if len(progressSignals) > 0 { - signal.Notify(defaultSignalChannel, progressSignals...) - } - - // os.Stderr gets redirected by "go test". "go test -v" has to be - // used to see the output while a test runs. - defaultProgressReporter.setOutput(os.Stderr) - go defaultProgressReporter.run(interruptCtx, defaultSignalChannel) -} - -var initSignalsOnce sync.Once - type progressReporter struct { - mutex sync.Mutex + // initMutex protects initialization and finalization of the reporter. + initMutex sync.Mutex + + usageCount int64 + wg sync.WaitGroup + signalCtx, interruptCtx context.Context + signalCancel func() + progressChannel chan os.Signal + + // reportMutex protects report creation and settings. + reportMutex sync.Mutex reporterCounter int64 reporters map[int64]func() string out io.Writer @@ -82,18 +61,77 @@ type progressReporter struct { var _ ginkgoReporter = &progressReporter{} -func (p *progressReporter) setOutput(out io.Writer) io.Writer { - p.mutex.Lock() - defer p.mutex.Unlock() - oldOut := p.out - p.out = out - return oldOut +// init is invoked by Init. It returns the context to be used for the +// new TContext. +// +// By default, that is just context.Background. In a Go unit test, it +// is a context connected to os.Interrupt. +// +// Once activated like that in a Go unit test, the progressReporter implements +// support for triggering a progress report in a running test when sending it a +// USR1 signal, similar to the corresponding Ginkgo feature. +// +// This support is active until the last test terminates. +func (p *progressReporter) init(tb TB) context.Context { + if _, ok := tb.(testing.TB); !ok { + // Not in a Go unit test. + return context.Background() + } + + p.initMutex.Lock() + defer p.initMutex.Unlock() + + p.usageCount++ + tb.Cleanup(p.finalize) + if p.usageCount > 1 { + // Was already initialized. + return p.interruptCtx + } + + p.signalCtx, p.signalCancel = signal.NotifyContext(context.Background(), os.Interrupt) + cancelCtx, cancel := context.WithCancelCause(context.Background()) + p.wg.Go(func() { + <-p.signalCtx.Done() + cancel(errors.New("received interrupt signal")) + }) + + // This reimplements the contract between Ginkgo and Gomega for progress reporting. + // When using Ginkgo contexts, Ginkgo will implement it. This here is for "go test". + // + // nolint:staticcheck // It complains about using a plain string. This can only be fixed + // by Ginkgo and Gomega formalizing this interface and define a type (somewhere... + // probably cannot be in either Ginkgo or Gomega). + p.interruptCtx = context.WithValue(cancelCtx, ginkgoSpecContextKey, defaultProgressReporter) + + p.progressChannel = make(chan os.Signal, 1) + // progressSignals will be empty on Windows. + if len(progressSignals) > 0 { + signal.Notify(p.progressChannel, progressSignals...) + } + + p.wg.Go(p.run) + + return p.interruptCtx +} + +func (p *progressReporter) finalize() { + p.initMutex.Lock() + defer p.initMutex.Unlock() + + p.usageCount-- + if p.usageCount > 0 { + // Still in use. + return + } + + p.signalCancel() + p.wg.Wait() } // AttachProgressReporter implements Gomega's contextWithAttachProgressReporter. func (p *progressReporter) AttachProgressReporter(reporter func() string) func() { - p.mutex.Lock() - defer p.mutex.Unlock() + p.reportMutex.Lock() + defer p.reportMutex.Unlock() // TODO (?): identify the caller and record that for dumpProgress. p.reporterCounter++ @@ -108,18 +146,27 @@ func (p *progressReporter) AttachProgressReporter(reporter func() string) func() } func (p *progressReporter) detachProgressReporter(id int64) { - p.mutex.Lock() - defer p.mutex.Unlock() + p.reportMutex.Lock() + defer p.reportMutex.Unlock() delete(p.reporters, id) } -func (p *progressReporter) run(ctx context.Context, progressSignalChannel chan os.Signal) { +func (p *progressReporter) run() { for { select { - case <-ctx.Done(): + case <-p.interruptCtx.Done(): + // Maybe do one last progress report? + // + // This is primarily for unit testing of ktesting itself, + // in a normal test we don't care anymore. + select { + case <-p.progressChannel: + p.dumpProgress() + default: + } return - case <-progressSignalChannel: + case <-p.progressChannel: p.dumpProgress() } } @@ -132,8 +179,8 @@ func (p *progressReporter) run(ctx context.Context, progressSignalChannel chan o // But perhaps dumping goroutines and their callstacks is useful anyway? TODO: // look at how Ginkgo does it and replicate some of it. func (p *progressReporter) dumpProgress() { - p.mutex.Lock() - defer p.mutex.Unlock() + p.reportMutex.Lock() + defer p.reportMutex.Unlock() var buffer strings.Builder buffer.WriteString("You requested a progress report.\n") diff --git a/test/utils/ktesting/stepcontext_test.go b/test/utils/ktesting/stepcontext_test.go index 72035bde5a7..c8f3c143011 100644 --- a/test/utils/ktesting/stepcontext_test.go +++ b/test/utils/ktesting/stepcontext_test.go @@ -17,12 +17,12 @@ limitations under the License. package ktesting import ( - "bytes" + "io" "os" "testing" - "time" - "github.com/stretchr/testify/assert" + "github.com/onsi/gomega" + "go.uber.org/goleak" ) func TestStepContext(t *testing.T) { @@ -65,32 +65,57 @@ func TestStepContext(t *testing.T) { step: Error a b 42 `, }, - "progress": { - cb: func(tCtx TContext) { - tCtx = WithStep(tCtx, "step") - var buffer bytes.Buffer - oldOut := defaultProgressReporter.setOutput(&buffer) - defer defaultProgressReporter.setOutput(oldOut) - remove := tCtx.Value("GINKGO_SPEC_CONTEXT").(ginkgoReporter).AttachProgressReporter(func() string { return "hello world" }) - defer remove() - defaultSignalChannel <- os.Interrupt - // No good way to sync here, so let's just wait. - time.Sleep(5 * time.Second) - defaultProgressReporter.setOutput(oldOut) - tCtx.Log(buffer.String()) - - noSuchValue := tCtx.Value("some other key") - assert.Nil(tCtx, noSuchValue, "value for unknown context value key") - }, - expectTrace: `(LOG) : step: You requested a progress report. - - step: hello world -`, - expectDuration: 5 * time.Second, - }, } { t.Run(name, func(t *testing.T) { tc.run(t) }) } } + +func TestProgressReport(t *testing.T) { + t.Cleanup(func() { + goleak.VerifyNone(t) + }) + + oldOut := defaultProgressReporter.out + reportStream := newOutputStream() + defaultProgressReporter.out = reportStream + t.Cleanup(func() { + defaultProgressReporter.out = oldOut + }) + + // This must use a real testing.T, otherwise Init doesn't initialize signal handling. + tCtx := Init(t) + tCtx = WithStep(tCtx, "step") + removeReporter := tCtx.Value("GINKGO_SPEC_CONTEXT").(ginkgoReporter).AttachProgressReporter(func() string { return "hello world" }) + defer removeReporter() + tCtx.Expect(tCtx.Value("some other key")).To(gomega.BeNil(), "value for unknown context value key") + + // Trigger report and wait for it. + defaultProgressReporter.progressChannel <- os.Interrupt + report := <-reportStream.stream + tCtx.Expect(report).To(gomega.Equal(`You requested a progress report. + +step: hello world +`), "report") +} + +// outputStream forwards exactly one Write call to a stream. +// A second Write call is an error and will panic. +type outputStream struct { + stream chan string +} + +var _ io.Writer = &outputStream{} + +func newOutputStream() *outputStream { + return &outputStream{ + stream: make(chan string), + } +} + +func (s *outputStream) Write(buf []byte) (int, error) { + s.stream <- string(buf) + close(s.stream) + return len(buf), nil +} diff --git a/test/utils/ktesting/tcontext.go b/test/utils/ktesting/tcontext.go index 97960b67c0c..688d5cca5c8 100644 --- a/test/utils/ktesting/tcontext.go +++ b/test/utils/ktesting/tcontext.go @@ -136,40 +136,34 @@ func Init(tb TB, opts ...InitOption) TContext { Deadline() (time.Time, bool) }) - ctx := interruptCtx + ctx := defaultProgressReporter.init(tb) var header func() string if c.PerTestOutput { logger := newLogger(tb, c.BufferLogs) - ctx = klog.NewContext(interruptCtx, logger) + ctx = klog.NewContext(ctx, logger) header = klogHeader } + var cancelTimeout func(cause string) if deadlineOK { if deadline, ok := deadlineTB.Deadline(); ok { timeLeft := time.Until(deadline) timeLeft -= CleanupGracePeriod - ctx, cancel := withTimeout(ctx, tb, timeLeft, fmt.Sprintf("test suite deadline (%s) is close, need to clean up before the %s cleanup grace period", deadline.Truncate(time.Second), CleanupGracePeriod)) - tc := TC{ - Context: ctx, - testingTB: testingTB{TB: tb}, - cancel: cancel, - } - return &tc + ctx, cancelTimeout = withTimeout(ctx, tb, timeLeft, fmt.Sprintf("test suite deadline (%s) is close, need to clean up before the %s cleanup grace period", deadline.Truncate(time.Second), CleanupGracePeriod)) } } - tCtx := WithCancel(InitCtx(ctx, tb)) - tCtx.perTestHeader = header - tCtx.Cleanup(func() { - tCtx.Cancel(cleanupErr(tCtx.Name()).Error()) - }) - // Only enable signal handling if we are sure that we are not - // in a Ginkgo suite. Only structs from the testing package - // can implement this interface because it contains an "internal" - // method, so this has to run under `go test`. - if _, ok := tb.(testing.TB); ok { - initSignalsOnce.Do(initSignals) + // Construct new TContext with context and settings as determined above. + tCtx := InitCtx(ctx, tb) + if cancelTimeout != nil { + tCtx.cancel = cancelTimeout + } else { + tCtx = WithCancel(tCtx) + tCtx.Cleanup(func() { + tCtx.Cancel(cleanupErr(tCtx.Name()).Error()) + }) } + tCtx.perTestHeader = header return tCtx }