ktesting: more flexible WithContext

As a special case, WithContext preserved the logger in the parent context. But
for the upcoming usage of WithValue to store a Kubernetes client it is
important to also preserve access to other values.
This commit is contained in:
Patrick Ohly 2026-04-08 11:04:47 +02:00
parent 8629b87211
commit bc2a34caae
2 changed files with 31 additions and 9 deletions

View file

@ -26,7 +26,6 @@ import (
"testing/synctest"
"time"
"github.com/go-logr/logr"
"github.com/onsi/gomega"
apiextensions "k8s.io/apiextensions-apiserver/pkg/client/clientset/clientset"
@ -333,15 +332,12 @@ func run(tCtx TContext, name string, syncTest bool, cb func(tCtx TContext)) bool
// tCtx := ktesting.WithContext(tCtx, ctx)
// ...
//
// This is important because the Context in the callback could have
// a different deadline than in the parent TContext.
// Cancellation and deadline are determined by the new context.
// Values are looked up first in the new context, then the old one.
// In other words, values set previous via WithValue are still
// available.
func (tCtx TContext) WithContext(ctx context.Context) TContext {
logger := tCtx.Logger()
tCtx.Context = ctx
if _, err := logr.FromContext(ctx); err != nil {
// Keep using the logger from the parent context.
tCtx = tCtx.WithLogger(logger)
}
tCtx.Context = &chainContext{Context: ctx, previousCtx: tCtx.Context}
return tCtx
}
@ -351,6 +347,18 @@ func (tCtx TContext) WithValue(key, val any) TContext {
return tCtx.WithContext(ctx)
}
type chainContext struct {
context.Context
previousCtx context.Context
}
func (ctx *chainContext) Value(key any) any {
if val := ctx.Context.Value(key); val != nil {
return val
}
return ctx.previousCtx.Value(key)
}
// TContext implements [context.Context], [testing.TB] and some additional
// methods. [TContext] is the public pointer type for referencing a TC.
// Variables are usually called tCtx. To ensure that test code does not

View file

@ -17,6 +17,7 @@ limitations under the License.
package ktesting_test
import (
"context"
"sync"
"testing"
"testing/synctest"
@ -208,3 +209,16 @@ func TestWithNamespace(t *testing.T) {
tCtxWithNamespace := tCtx.WithNamespace(namespace)
tCtx.Expect(tCtxWithNamespace.Namespace()).To(gomega.Equal(namespace))
}
func TestWithContext(t *testing.T) {
tCtx := ktesting.Init(t)
tCtx.Cancel("done")
tCtx = tCtx.WithValue("foo", "bar")
deadline := time.Now().Add(-time.Minute)
ctx, cancel := context.WithDeadline(context.Background(), deadline)
defer cancel()
newCtx := tCtx.WithContext(ctx)
tCtx.Expect(context.Cause(tCtx)).To(gomega.MatchError(gomega.ContainSubstring("done")))
tCtx.Expect(newCtx.Err()).To(gomega.MatchError(context.DeadlineExceeded))
tCtx.Expect(newCtx.Value("foo")).To(gomega.Equal("bar"))
}