Merge pull request #136896 from aaron-prindle/modal-validation-upstream

Implement declarative modal validation (+k8s:discriminator and +k8s:member)
This commit is contained in:
Kubernetes Prow Robot 2026-02-14 03:08:00 +05:30 committed by GitHub
commit 18c8b8c4d3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 1595 additions and 0 deletions

View file

@ -0,0 +1,71 @@
/*
Copyright 2025 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package validate
import (
"context"
"k8s.io/apimachinery/pkg/api/operation"
"k8s.io/apimachinery/pkg/util/validation/field"
)
// DiscriminatedRule defines a validation to apply for a specific discriminator value.
type DiscriminatedRule[Tfield any, Tdisc comparable] struct {
Value Tdisc
Validation ValidateFunc[Tfield]
}
// Discriminated validates a member field based on a discriminator value.
// It iterates through the rules and applies the first one that matches the discriminator.
// If no rule matches, it applies the defaultValidation if provided.
//
// It performs ratcheting: if the operation is an Update, and neither the discriminator
// nor the value (checked via equiv) have changed, validation is skipped.
func Discriminated[Tfield any, Tdisc comparable, Tstruct any](ctx context.Context, op operation.Operation, structPath *field.Path,
obj, oldObj *Tstruct, fieldName string, getMemberValue func(*Tstruct) Tfield, getDiscriminator func(*Tstruct) Tdisc,
equiv MatchFunc[Tfield], defaultValidation ValidateFunc[Tfield], rules []DiscriminatedRule[Tfield, Tdisc],
) field.ErrorList {
value := getMemberValue(obj)
discriminator := getDiscriminator(obj)
var oldValue Tfield
var oldDiscriminator Tdisc
if oldObj != nil {
oldValue = getMemberValue(oldObj)
oldDiscriminator = getDiscriminator(oldObj)
}
if op.Type == operation.Update && oldObj != nil && discriminator == oldDiscriminator && equiv(value, oldValue) {
return nil
}
fldPath := structPath.Child(fieldName)
for _, rule := range rules {
if rule.Value == discriminator {
if rule.Validation == nil {
return nil
}
return rule.Validation(ctx, op, fldPath, value, oldValue)
}
}
if defaultValidation != nil {
return defaultValidation(ctx, op, fldPath, value, oldValue)
}
return nil
}

View file

@ -0,0 +1,265 @@
/*
Copyright 2025 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package validate
import (
"context"
"reflect"
"testing"
"k8s.io/apimachinery/pkg/api/operation"
"k8s.io/apimachinery/pkg/util/validation/field"
)
func TestDiscriminated(t *testing.T) {
errMatch := field.ErrorList{field.Invalid(field.NewPath("foo"), "bar", "match error")}
errDefault := field.ErrorList{field.Invalid(field.NewPath("foo"), "bar", "default error")}
mockValid := func(_ context.Context, _ operation.Operation, _ *field.Path, _, _ *string) field.ErrorList {
return nil
}
mockErrorMatch := func(_ context.Context, _ operation.Operation, _ *field.Path, _, _ *string) field.ErrorList {
return errMatch
}
mockErrorDefault := func(_ context.Context, _ operation.Operation, _ *field.Path, _, _ *string) field.ErrorList {
return errDefault
}
// mockEqual compares pointer values by dereferencing, not by pointer identity.
mockEqual := func(a, b *string) bool {
if a == nil && b == nil {
return true
}
if a == nil || b == nil {
return false
}
return *a == *b
}
testCases := []struct {
name string
opType operation.Type
discriminator string
oldDiscriminator string
value *string
oldValue *string
rules []DiscriminatedRule[*string, string]
defaultValidation ValidateFunc[*string]
expected field.ErrorList
}{
{
name: "matches rule, returns valid",
opType: operation.Create,
discriminator: "A",
oldDiscriminator: "A",
rules: []DiscriminatedRule[*string, string]{
{Value: "A", Validation: mockValid},
{Value: "B", Validation: mockErrorMatch},
},
defaultValidation: mockErrorDefault,
expected: nil,
},
{
name: "matches rule, returns error",
opType: operation.Create,
discriminator: "B",
oldDiscriminator: "B",
rules: []DiscriminatedRule[*string, string]{
{Value: "A", Validation: mockValid},
{Value: "B", Validation: mockErrorMatch},
},
defaultValidation: mockErrorDefault,
expected: errMatch,
},
{
name: "ratcheting: update, unchanged, skips validation",
opType: operation.Update,
discriminator: "B",
oldDiscriminator: "B", // unchanged
value: nil,
oldValue: nil, // unchanged
rules: []DiscriminatedRule[*string, string]{
{Value: "B", Validation: mockErrorMatch}, // would fail if run
},
defaultValidation: mockErrorDefault,
expected: nil,
},
{
name: "ratcheting: update, same value different pointers, skips validation",
opType: operation.Update,
discriminator: "B",
oldDiscriminator: "B",
value: strPtr("same"),
oldValue: strPtr("same"), // different pointer, same value
rules: []DiscriminatedRule[*string, string]{
{Value: "B", Validation: mockErrorMatch}, // would fail if run
},
defaultValidation: mockErrorDefault,
expected: nil,
},
{
name: "ratcheting: update, discriminator changed, runs validation",
opType: operation.Update,
discriminator: "B",
oldDiscriminator: "A", // changed
value: nil,
oldValue: nil,
rules: []DiscriminatedRule[*string, string]{
{Value: "B", Validation: mockErrorMatch},
},
defaultValidation: mockErrorDefault,
expected: errMatch,
},
{
name: "ratcheting: update, value changed, discriminator unchanged, runs validation",
opType: operation.Update,
discriminator: "B",
oldDiscriminator: "B", // unchanged
value: strPtr("new"),
oldValue: strPtr("old"), // changed
rules: []DiscriminatedRule[*string, string]{
{Value: "B", Validation: mockErrorMatch},
},
defaultValidation: mockErrorDefault,
expected: errMatch,
},
{
name: "matches rule with nil validation, returns valid",
opType: operation.Create,
discriminator: "A",
rules: []DiscriminatedRule[*string, string]{
{Value: "A", Validation: nil},
},
defaultValidation: mockErrorDefault,
expected: nil,
},
{
name: "no match, runs default",
opType: operation.Create,
discriminator: "C",
rules: []DiscriminatedRule[*string, string]{
{Value: "A", Validation: mockValid},
{Value: "B", Validation: mockErrorMatch},
},
defaultValidation: mockErrorDefault,
expected: errDefault,
},
{
name: "no match, nil default, returns valid",
opType: operation.Create,
discriminator: "C",
rules: []DiscriminatedRule[*string, string]{
{Value: "A", Validation: mockValid},
},
defaultValidation: nil,
expected: nil,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
oldDisc := tc.oldDiscriminator
if oldDisc == "" {
oldDisc = tc.discriminator
}
type StringP struct {
Val *string
Disc string
}
newObj := &StringP{Val: tc.value, Disc: tc.discriminator}
var oldObj *StringP
if tc.opType == operation.Update {
oldObj = &StringP{Val: tc.oldValue, Disc: oldDisc}
}
getVal := func(p *StringP) *string { return p.Val }
getDisc := func(p *StringP) string { return p.Disc }
got := Discriminated[*string, string, StringP](context.Background(), operation.Operation{Type: tc.opType}, field.NewPath("root"), newObj, oldObj, "field", getVal, getDisc, mockEqual, tc.defaultValidation, tc.rules)
if !reflect.DeepEqual(got, tc.expected) {
t.Errorf("got %v want %v", got, tc.expected)
}
})
}
}
func TestDiscriminatedIntDiscriminator(t *testing.T) {
errMatch := field.ErrorList{field.Invalid(field.NewPath("foo"), "bar", "match error")}
errDefault := field.ErrorList{field.Invalid(field.NewPath("foo"), "bar", "default error")}
mockErrorMatch := func(_ context.Context, _ operation.Operation, _ *field.Path, _, _ *string) field.ErrorList {
return errMatch
}
mockErrorDefault := func(_ context.Context, _ operation.Operation, _ *field.Path, _, _ *string) field.ErrorList {
return errDefault
}
mockEqual := func(a, b *string) bool {
return a == b
}
type IntP struct {
Val *string
Disc int
}
newObj := &IntP{Val: nil, Disc: 1}
getVal := func(p *IntP) *string { return p.Val }
getDisc := func(p *IntP) int { return p.Disc }
got := Discriminated[*string, int, IntP](context.Background(), operation.Operation{Type: operation.Create}, field.NewPath("root"), newObj, nil, "field", getVal, getDisc, mockEqual, mockErrorDefault, []DiscriminatedRule[*string, int]{
{Value: 1, Validation: mockErrorMatch},
})
if !reflect.DeepEqual(got, errMatch) {
t.Errorf("int discriminator: got %v want %v", got, errMatch)
}
}
func TestDiscriminatedBoolDiscriminator(t *testing.T) {
errMatch := field.ErrorList{field.Invalid(field.NewPath("foo"), "bar", "match error")}
errDefault := field.ErrorList{field.Invalid(field.NewPath("foo"), "bar", "default error")}
mockErrorMatch := func(_ context.Context, _ operation.Operation, _ *field.Path, _, _ *string) field.ErrorList {
return errMatch
}
mockErrorDefault := func(_ context.Context, _ operation.Operation, _ *field.Path, _, _ *string) field.ErrorList {
return errDefault
}
mockEqual := func(a, b *string) bool {
return a == b
}
type BoolP struct {
Val *string
Disc bool
}
newObj := &BoolP{Val: nil, Disc: true}
getVal := func(p *BoolP) *string { return p.Val }
getDisc := func(p *BoolP) bool { return p.Disc }
got := Discriminated[*string, bool, BoolP](context.Background(), operation.Operation{Type: operation.Create}, field.NewPath("root"), newObj, nil, "field", getVal, getDisc, mockEqual, mockErrorDefault, []DiscriminatedRule[*string, bool]{
{Value: true, Validation: mockErrorMatch},
})
if !reflect.DeepEqual(got, errMatch) {
t.Errorf("bool discriminator: got %v want %v", got, errMatch)
}
}
// strPtr returns a new pointer to a copy of s, guaranteeing a distinct allocation.
func strPtr(s string) *string {
return &s
}

View file

@ -0,0 +1,120 @@
/*
Copyright 2025 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// +k8s:validation-gen=TypeMeta
// +k8s:validation-gen-scheme-registry=k8s.io/code-generator/cmd/validation-gen/testscheme.Scheme
// This is a test package.
package discriminator
import "k8s.io/code-generator/cmd/validation-gen/testscheme"
var localSchemeBuilder = testscheme.New()
type StrictUnion struct {
TypeMeta int
// +k8s:discriminator
D1 string `json:"d1"`
// +k8s:member("A")=+k8s:required
FieldA *string `json:"fieldA,omitempty"`
// +k8s:member("B")=+k8s:required
FieldB *string `json:"fieldB,omitempty"`
}
type SharedField struct {
TypeMeta int
// +k8s:discriminator
D1 string `json:"d1"`
// Valid in A and B, implicitly forbidden in C.
// +k8s:member("A")=+k8s:optional
// +k8s:member("B")=+k8s:optional
FieldA *string `json:"fieldA,omitempty"`
}
type ChainedValidation struct {
TypeMeta int
// +k8s:discriminator
D1 string `json:"d1"`
// In mode A, it is required AND must have maxLength 5.
// +k8s:member("A")=+k8s:required
// +k8s:member("A")=+k8s:maxLength=5
FieldA *string `json:"fieldA,omitempty"`
}
type ImplicitForbidden struct {
TypeMeta int
// +k8s:discriminator
D1 string `json:"d1"`
// Field is only mentioned for mode A. Mode B should implicitly forbid it.
// +k8s:member("A")=+k8s:optional
FieldA *string `json:"fieldA,omitempty"`
}
type NonStringDiscriminator struct {
TypeMeta int
// +k8s:discriminator(name:"Bool")
D1 bool `json:"d1"`
// +k8s:member(discriminator:"Bool", value:"true")=+k8s:required
FieldA *string `json:"fieldA,omitempty"`
// +k8s:discriminator(name:"Int")
D2 int `json:"d2"`
// +k8s:member(discriminator:"Int", value:"1")=+k8s:required
FieldB *string `json:"fieldB,omitempty"`
}
type MultipleDiscriminators struct {
TypeMeta int
// +k8s:discriminator(name:"D1")
D1 string `json:"d1"`
// +k8s:discriminator(name:"D2")
D2 string `json:"d2"`
// +k8s:member(discriminator:"D1", value:"A")=+k8s:required
FieldA *string `json:"fieldA,omitempty"`
// +k8s:member(discriminator:"D2", value:"B")=+k8s:required
FieldB *string `json:"fieldB,omitempty"`
}
type Collections struct {
TypeMeta int
// +k8s:discriminator
D1 string `json:"d1"`
// +k8s:member("A")=+k8s:optional
ListField []string `json:"listField,omitempty"`
// +k8s:member("A")=+k8s:optional
MapField map[string]string `json:"mapField,omitempty"`
}
type TypeMeta int

View file

@ -0,0 +1,204 @@
/*
Copyright 2025 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package discriminator
import (
"testing"
"k8s.io/apimachinery/pkg/util/validation/field"
"k8s.io/utils/ptr"
)
func TestStrictUnion(t *testing.T) {
st := localSchemeBuilder.Test(t)
// Mode A: FieldA required, FieldB implicitly forbidden
st.Value(&StrictUnion{D1: "A", FieldA: ptr.To("val")}).ExpectValid()
st.Value(&StrictUnion{D1: "A"}).ExpectMatches(field.ErrorMatcher{}.ByType().ByField(), field.ErrorList{
field.Required(field.NewPath("fieldA"), ""),
})
st.Value(&StrictUnion{D1: "A", FieldA: ptr.To("val"), FieldB: ptr.To("val")}).ExpectMatches(field.ErrorMatcher{}.ByType().ByField(), field.ErrorList{
field.Forbidden(field.NewPath("fieldB"), ""),
})
// Mode B: FieldA implicitly forbidden, FieldB required
st.Value(&StrictUnion{D1: "B", FieldB: ptr.To("val")}).ExpectValid()
st.Value(&StrictUnion{D1: "B"}).ExpectMatches(field.ErrorMatcher{}.ByType().ByField(), field.ErrorList{
field.Required(field.NewPath("fieldB"), ""),
})
st.Value(&StrictUnion{D1: "B", FieldA: ptr.To("val"), FieldB: ptr.To("val")}).ExpectMatches(field.ErrorMatcher{}.ByType().ByField(), field.ErrorList{
field.Forbidden(field.NewPath("fieldA"), ""),
})
}
func TestSharedField(t *testing.T) {
st := localSchemeBuilder.Test(t)
// Valid (optional) in A and B
st.Value(&SharedField{D1: "A"}).ExpectValid()
st.Value(&SharedField{D1: "A", FieldA: ptr.To("val")}).ExpectValid()
st.Value(&SharedField{D1: "B"}).ExpectValid()
st.Value(&SharedField{D1: "B", FieldA: ptr.To("val")}).ExpectValid()
// Forbidden in C
st.Value(&SharedField{D1: "C", FieldA: ptr.To("val")}).ExpectMatches(field.ErrorMatcher{}.ByType().ByField(), field.ErrorList{
field.Forbidden(field.NewPath("fieldA"), ""),
})
}
func TestChainedValidation(t *testing.T) {
st := localSchemeBuilder.Test(t)
// Mode A: Required AND maxLength 5
st.Value(&ChainedValidation{D1: "A", FieldA: ptr.To("abc")}).ExpectValid()
st.Value(&ChainedValidation{D1: "A"}).ExpectMatches(field.ErrorMatcher{}.ByType().ByField(), field.ErrorList{
field.Required(field.NewPath("fieldA"), ""),
})
st.Value(&ChainedValidation{D1: "A", FieldA: ptr.To("too-long")}).ExpectMatches(field.ErrorMatcher{}.ByType().ByField(), field.ErrorList{
field.TooLong(field.NewPath("fieldA"), "too-long", 5),
})
// Mode B: Unlisted, so implicitly forbidden
st.Value(&ChainedValidation{D1: "B", FieldA: ptr.To("val")}).ExpectMatches(field.ErrorMatcher{}.ByType().ByField(), field.ErrorList{
field.Forbidden(field.NewPath("fieldA"), ""),
})
}
func TestImplicitForbidden(t *testing.T) {
st := localSchemeBuilder.Test(t)
// Mode A: Optional
st.Value(&ImplicitForbidden{D1: "A"}).ExpectValid()
st.Value(&ImplicitForbidden{D1: "A", FieldA: ptr.To("val")}).ExpectValid()
// Mode B: Not listed, so implicitly Forbidden
st.Value(&ImplicitForbidden{D1: "B", FieldA: ptr.To("val")}).ExpectMatches(field.ErrorMatcher{}.ByType().ByField(), field.ErrorList{
field.Forbidden(field.NewPath("fieldA"), ""),
})
}
func TestNonStringDiscriminator(t *testing.T) {
st := localSchemeBuilder.Test(t)
// Bool mode
st.Value(&NonStringDiscriminator{D1: true, FieldA: ptr.To("val")}).ExpectValid()
st.Value(&NonStringDiscriminator{D1: true}).ExpectMatches(field.ErrorMatcher{}.ByType().ByField(), field.ErrorList{
field.Required(field.NewPath("fieldA"), ""),
})
st.Value(&NonStringDiscriminator{D1: false, FieldA: ptr.To("val")}).ExpectMatches(field.ErrorMatcher{}.ByType().ByField(), field.ErrorList{
field.Forbidden(field.NewPath("fieldA"), ""),
})
// Int mode
st.Value(&NonStringDiscriminator{D2: 1, FieldB: ptr.To("val")}).ExpectValid()
st.Value(&NonStringDiscriminator{D2: 1}).ExpectMatches(field.ErrorMatcher{}.ByType().ByField(), field.ErrorList{
field.Required(field.NewPath("fieldB"), ""),
})
st.Value(&NonStringDiscriminator{D2: 2, FieldB: ptr.To("val")}).ExpectMatches(field.ErrorMatcher{}.ByType().ByField(), field.ErrorList{
field.Forbidden(field.NewPath("fieldB"), ""),
})
}
func TestMultipleDiscriminators(t *testing.T) {
st := localSchemeBuilder.Test(t)
st.Value(&MultipleDiscriminators{
D1: "A",
D2: "B",
FieldA: ptr.To("valA"),
FieldB: ptr.To("valB"),
}).ExpectValid()
st.Value(&MultipleDiscriminators{
D1: "A",
D2: "B",
}).ExpectMatches(field.ErrorMatcher{}.ByType().ByField(), field.ErrorList{
field.Required(field.NewPath("fieldA"), ""),
field.Required(field.NewPath("fieldB"), ""),
})
}
func TestCollections(t *testing.T) {
st := localSchemeBuilder.Test(t)
// Mode A: Collections are valid (optional)
st.Value(&Collections{
D1: "A",
}).ExpectValid()
st.Value(&Collections{
D1: "A",
ListField: []string{"item"},
MapField: map[string]string{"key": "val"},
}).ExpectValid()
// Mode B: Unlisted, so implicitly forbidden
st.Value(&Collections{
D1: "B",
ListField: []string{"item"},
}).ExpectMatches(field.ErrorMatcher{}.ByType().ByField(), field.ErrorList{
field.Forbidden(field.NewPath("listField"), ""),
})
st.Value(&Collections{
D1: "B",
MapField: map[string]string{"key": "val"},
}).ExpectMatches(field.ErrorMatcher{}.ByType().ByField(), field.ErrorList{
field.Forbidden(field.NewPath("mapField"), ""),
})
}
func TestRatcheting(t *testing.T) {
mkTest := func() *ChainedValidation {
return &ChainedValidation{
D1: "A",
FieldA: ptr.To("too-long-string"),
}
}
st := localSchemeBuilder.Test(t)
// 1. New object is invalid
st.Value(mkTest()).ExpectMatches(field.ErrorMatcher{}.ByType().ByField(), field.ErrorList{
field.TooLong(field.NewPath("fieldA"), "too-long-string", 5),
})
// 2. Unchanged update is valid (ratcheting)
st.Value(mkTest()).OldValue(mkTest()).ExpectValid()
// 3. Changed value re-validates (and fails)
mkDifferent := func() *ChainedValidation {
return &ChainedValidation{
D1: "A",
FieldA: ptr.To("also-too-long"),
}
}
st.Value(mkTest()).OldValue(mkDifferent()).ExpectMatches(field.ErrorMatcher{}.ByType().ByField(), field.ErrorList{
field.TooLong(field.NewPath("fieldA"), "too-long-string", 5),
})
// 4. Changed discriminator re-validates (and fails)
mkDifferentDisc := func() *ChainedValidation {
return &ChainedValidation{
D1: "B", // Discriminator changed from B -> A
FieldA: ptr.To("too-long-string"),
}
}
st.Value(mkTest()).OldValue(mkDifferentDisc()).ExpectMatches(field.ErrorMatcher{}.ByType().ByField(), field.ErrorList{
field.TooLong(field.NewPath("fieldA"), "too-long-string", 5),
})
}

View file

@ -0,0 +1,391 @@
//go:build !ignore_autogenerated
// +build !ignore_autogenerated
/*
Copyright The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Code generated by validation-gen. DO NOT EDIT.
package discriminator
import (
context "context"
fmt "fmt"
operation "k8s.io/apimachinery/pkg/api/operation"
safe "k8s.io/apimachinery/pkg/api/safe"
validate "k8s.io/apimachinery/pkg/api/validate"
field "k8s.io/apimachinery/pkg/util/validation/field"
testscheme "k8s.io/code-generator/cmd/validation-gen/testscheme"
)
func init() { localSchemeBuilder.Register(RegisterValidations) }
// RegisterValidations adds validation functions to the given scheme.
// Public to allow building arbitrary schemes.
func RegisterValidations(scheme *testscheme.Scheme) error {
// type ChainedValidation
scheme.AddValidationFunc((*ChainedValidation)(nil), func(ctx context.Context, op operation.Operation, obj, oldObj interface{}) field.ErrorList {
switch op.Request.SubresourcePath() {
case "/":
return Validate_ChainedValidation(ctx, op, nil /* fldPath */, obj.(*ChainedValidation), safe.Cast[*ChainedValidation](oldObj))
}
return field.ErrorList{field.InternalError(nil, fmt.Errorf("no validation found for %T, subresource: %v", obj, op.Request.SubresourcePath()))}
})
// type Collections
scheme.AddValidationFunc((*Collections)(nil), func(ctx context.Context, op operation.Operation, obj, oldObj interface{}) field.ErrorList {
switch op.Request.SubresourcePath() {
case "/":
return Validate_Collections(ctx, op, nil /* fldPath */, obj.(*Collections), safe.Cast[*Collections](oldObj))
}
return field.ErrorList{field.InternalError(nil, fmt.Errorf("no validation found for %T, subresource: %v", obj, op.Request.SubresourcePath()))}
})
// type ImplicitForbidden
scheme.AddValidationFunc((*ImplicitForbidden)(nil), func(ctx context.Context, op operation.Operation, obj, oldObj interface{}) field.ErrorList {
switch op.Request.SubresourcePath() {
case "/":
return Validate_ImplicitForbidden(ctx, op, nil /* fldPath */, obj.(*ImplicitForbidden), safe.Cast[*ImplicitForbidden](oldObj))
}
return field.ErrorList{field.InternalError(nil, fmt.Errorf("no validation found for %T, subresource: %v", obj, op.Request.SubresourcePath()))}
})
// type MultipleDiscriminators
scheme.AddValidationFunc((*MultipleDiscriminators)(nil), func(ctx context.Context, op operation.Operation, obj, oldObj interface{}) field.ErrorList {
switch op.Request.SubresourcePath() {
case "/":
return Validate_MultipleDiscriminators(ctx, op, nil /* fldPath */, obj.(*MultipleDiscriminators), safe.Cast[*MultipleDiscriminators](oldObj))
}
return field.ErrorList{field.InternalError(nil, fmt.Errorf("no validation found for %T, subresource: %v", obj, op.Request.SubresourcePath()))}
})
// type NonStringDiscriminator
scheme.AddValidationFunc((*NonStringDiscriminator)(nil), func(ctx context.Context, op operation.Operation, obj, oldObj interface{}) field.ErrorList {
switch op.Request.SubresourcePath() {
case "/":
return Validate_NonStringDiscriminator(ctx, op, nil /* fldPath */, obj.(*NonStringDiscriminator), safe.Cast[*NonStringDiscriminator](oldObj))
}
return field.ErrorList{field.InternalError(nil, fmt.Errorf("no validation found for %T, subresource: %v", obj, op.Request.SubresourcePath()))}
})
// type SharedField
scheme.AddValidationFunc((*SharedField)(nil), func(ctx context.Context, op operation.Operation, obj, oldObj interface{}) field.ErrorList {
switch op.Request.SubresourcePath() {
case "/":
return Validate_SharedField(ctx, op, nil /* fldPath */, obj.(*SharedField), safe.Cast[*SharedField](oldObj))
}
return field.ErrorList{field.InternalError(nil, fmt.Errorf("no validation found for %T, subresource: %v", obj, op.Request.SubresourcePath()))}
})
// type StrictUnion
scheme.AddValidationFunc((*StrictUnion)(nil), func(ctx context.Context, op operation.Operation, obj, oldObj interface{}) field.ErrorList {
switch op.Request.SubresourcePath() {
case "/":
return Validate_StrictUnion(ctx, op, nil /* fldPath */, obj.(*StrictUnion), safe.Cast[*StrictUnion](oldObj))
}
return field.ErrorList{field.InternalError(nil, fmt.Errorf("no validation found for %T, subresource: %v", obj, op.Request.SubresourcePath()))}
})
return nil
}
// Validate_ChainedValidation validates an instance of ChainedValidation according
// to declarative validation rules in the API schema.
func Validate_ChainedValidation(ctx context.Context, op operation.Operation, fldPath *field.Path, obj, oldObj *ChainedValidation) (errs field.ErrorList) {
errs = append(errs, validate.Discriminated(ctx, op, fldPath, obj, oldObj, "fieldA", func(obj *ChainedValidation) *string { return obj.FieldA }, func(obj *ChainedValidation) string { return obj.D1 }, validate.DirectEqualPtr, func(ctx context.Context, op operation.Operation, fldPath *field.Path, obj, oldObj *string) field.ErrorList {
errs := field.ErrorList{}
errs = append(errs, validate.ForbiddenPointer(ctx, op, fldPath, obj, oldObj)...)
return errs
}, []validate.DiscriminatedRule[*string, string]{
{
Value: "A", Validation: func(ctx context.Context, op operation.Operation, fldPath *field.Path, obj, oldObj *string) field.ErrorList {
errs := field.ErrorList{}
earlyReturn := false
if e := validate.RequiredPointer(ctx, op, fldPath, obj, oldObj); len(e) != 0 {
errs = append(errs, e...)
earlyReturn = true
}
if earlyReturn {
return errs
}
errs = append(errs, validate.MaxLength(ctx, op, fldPath, obj, oldObj, 5)...)
return errs
}},
})...)
// field ChainedValidation.TypeMeta has no validation
// field ChainedValidation.D1 has no validation
// field ChainedValidation.FieldA has no validation
return errs
}
// Validate_Collections validates an instance of Collections according
// to declarative validation rules in the API schema.
func Validate_Collections(ctx context.Context, op operation.Operation, fldPath *field.Path, obj, oldObj *Collections) (errs field.ErrorList) {
errs = append(errs, validate.Discriminated(ctx, op, fldPath, obj, oldObj, "listField", func(obj *Collections) []string { return obj.ListField }, func(obj *Collections) string { return obj.D1 }, validate.SemanticDeepEqual, func(ctx context.Context, op operation.Operation, fldPath *field.Path, obj, oldObj []string) field.ErrorList {
errs := field.ErrorList{}
errs = append(errs, validate.ForbiddenSlice(ctx, op, fldPath, obj, oldObj)...)
return errs
}, []validate.DiscriminatedRule[[]string, string]{
{
Value: "A", Validation: func(ctx context.Context, op operation.Operation, fldPath *field.Path, obj, oldObj []string) field.ErrorList {
errs := field.ErrorList{}
earlyReturn := false
if e := validate.OptionalSlice(ctx, op, fldPath, obj, oldObj); len(e) != 0 {
earlyReturn = true
}
if earlyReturn {
return errs
}
return errs
}},
})...)
errs = append(errs, validate.Discriminated(ctx, op, fldPath, obj, oldObj, "mapField", func(obj *Collections) map[string]string { return obj.MapField }, func(obj *Collections) string { return obj.D1 }, validate.SemanticDeepEqual, func(ctx context.Context, op operation.Operation, fldPath *field.Path, obj, oldObj map[string]string) field.ErrorList {
errs := field.ErrorList{}
errs = append(errs, validate.ForbiddenMap(ctx, op, fldPath, obj, oldObj)...)
return errs
}, []validate.DiscriminatedRule[map[string]string, string]{
{
Value: "A", Validation: func(ctx context.Context, op operation.Operation, fldPath *field.Path, obj, oldObj map[string]string) field.ErrorList {
errs := field.ErrorList{}
earlyReturn := false
if e := validate.OptionalMap(ctx, op, fldPath, obj, oldObj); len(e) != 0 {
earlyReturn = true
}
if earlyReturn {
return errs
}
return errs
}},
})...)
// field Collections.TypeMeta has no validation
// field Collections.D1 has no validation
// field Collections.ListField has no validation
// field Collections.MapField has no validation
return errs
}
// Validate_ImplicitForbidden validates an instance of ImplicitForbidden according
// to declarative validation rules in the API schema.
func Validate_ImplicitForbidden(ctx context.Context, op operation.Operation, fldPath *field.Path, obj, oldObj *ImplicitForbidden) (errs field.ErrorList) {
errs = append(errs, validate.Discriminated(ctx, op, fldPath, obj, oldObj, "fieldA", func(obj *ImplicitForbidden) *string { return obj.FieldA }, func(obj *ImplicitForbidden) string { return obj.D1 }, validate.DirectEqualPtr, func(ctx context.Context, op operation.Operation, fldPath *field.Path, obj, oldObj *string) field.ErrorList {
errs := field.ErrorList{}
errs = append(errs, validate.ForbiddenPointer(ctx, op, fldPath, obj, oldObj)...)
return errs
}, []validate.DiscriminatedRule[*string, string]{
{
Value: "A", Validation: func(ctx context.Context, op operation.Operation, fldPath *field.Path, obj, oldObj *string) field.ErrorList {
errs := field.ErrorList{}
earlyReturn := false
if e := validate.OptionalPointer(ctx, op, fldPath, obj, oldObj); len(e) != 0 {
earlyReturn = true
}
if earlyReturn {
return errs
}
return errs
}},
})...)
// field ImplicitForbidden.TypeMeta has no validation
// field ImplicitForbidden.D1 has no validation
// field ImplicitForbidden.FieldA has no validation
return errs
}
// Validate_MultipleDiscriminators validates an instance of MultipleDiscriminators according
// to declarative validation rules in the API schema.
func Validate_MultipleDiscriminators(ctx context.Context, op operation.Operation, fldPath *field.Path, obj, oldObj *MultipleDiscriminators) (errs field.ErrorList) {
errs = append(errs, validate.Discriminated(ctx, op, fldPath, obj, oldObj, "fieldA", func(obj *MultipleDiscriminators) *string { return obj.FieldA }, func(obj *MultipleDiscriminators) string { return obj.D1 }, validate.DirectEqualPtr, func(ctx context.Context, op operation.Operation, fldPath *field.Path, obj, oldObj *string) field.ErrorList {
errs := field.ErrorList{}
errs = append(errs, validate.ForbiddenPointer(ctx, op, fldPath, obj, oldObj)...)
return errs
}, []validate.DiscriminatedRule[*string, string]{
{
Value: "A", Validation: func(ctx context.Context, op operation.Operation, fldPath *field.Path, obj, oldObj *string) field.ErrorList {
errs := field.ErrorList{}
earlyReturn := false
if e := validate.RequiredPointer(ctx, op, fldPath, obj, oldObj); len(e) != 0 {
errs = append(errs, e...)
earlyReturn = true
}
if earlyReturn {
return errs
}
return errs
}},
})...)
errs = append(errs, validate.Discriminated(ctx, op, fldPath, obj, oldObj, "fieldB", func(obj *MultipleDiscriminators) *string { return obj.FieldB }, func(obj *MultipleDiscriminators) string { return obj.D2 }, validate.DirectEqualPtr, func(ctx context.Context, op operation.Operation, fldPath *field.Path, obj, oldObj *string) field.ErrorList {
errs := field.ErrorList{}
errs = append(errs, validate.ForbiddenPointer(ctx, op, fldPath, obj, oldObj)...)
return errs
}, []validate.DiscriminatedRule[*string, string]{
{
Value: "B", Validation: func(ctx context.Context, op operation.Operation, fldPath *field.Path, obj, oldObj *string) field.ErrorList {
errs := field.ErrorList{}
earlyReturn := false
if e := validate.RequiredPointer(ctx, op, fldPath, obj, oldObj); len(e) != 0 {
errs = append(errs, e...)
earlyReturn = true
}
if earlyReturn {
return errs
}
return errs
}},
})...)
// field MultipleDiscriminators.TypeMeta has no validation
// field MultipleDiscriminators.D1 has no validation
// field MultipleDiscriminators.D2 has no validation
// field MultipleDiscriminators.FieldA has no validation
// field MultipleDiscriminators.FieldB has no validation
return errs
}
// Validate_NonStringDiscriminator validates an instance of NonStringDiscriminator according
// to declarative validation rules in the API schema.
func Validate_NonStringDiscriminator(ctx context.Context, op operation.Operation, fldPath *field.Path, obj, oldObj *NonStringDiscriminator) (errs field.ErrorList) {
errs = append(errs, validate.Discriminated(ctx, op, fldPath, obj, oldObj, "fieldA", func(obj *NonStringDiscriminator) *string { return obj.FieldA }, func(obj *NonStringDiscriminator) bool { return obj.D1 }, validate.DirectEqualPtr, func(ctx context.Context, op operation.Operation, fldPath *field.Path, obj, oldObj *string) field.ErrorList {
errs := field.ErrorList{}
errs = append(errs, validate.ForbiddenPointer(ctx, op, fldPath, obj, oldObj)...)
return errs
}, []validate.DiscriminatedRule[*string, bool]{
{
Value: true, Validation: func(ctx context.Context, op operation.Operation, fldPath *field.Path, obj, oldObj *string) field.ErrorList {
errs := field.ErrorList{}
earlyReturn := false
if e := validate.RequiredPointer(ctx, op, fldPath, obj, oldObj); len(e) != 0 {
errs = append(errs, e...)
earlyReturn = true
}
if earlyReturn {
return errs
}
return errs
}},
})...)
errs = append(errs, validate.Discriminated(ctx, op, fldPath, obj, oldObj, "fieldB", func(obj *NonStringDiscriminator) *string { return obj.FieldB }, func(obj *NonStringDiscriminator) int { return obj.D2 }, validate.DirectEqualPtr, func(ctx context.Context, op operation.Operation, fldPath *field.Path, obj, oldObj *string) field.ErrorList {
errs := field.ErrorList{}
errs = append(errs, validate.ForbiddenPointer(ctx, op, fldPath, obj, oldObj)...)
return errs
}, []validate.DiscriminatedRule[*string, int]{
{
Value: 1, Validation: func(ctx context.Context, op operation.Operation, fldPath *field.Path, obj, oldObj *string) field.ErrorList {
errs := field.ErrorList{}
earlyReturn := false
if e := validate.RequiredPointer(ctx, op, fldPath, obj, oldObj); len(e) != 0 {
errs = append(errs, e...)
earlyReturn = true
}
if earlyReturn {
return errs
}
return errs
}},
})...)
// field NonStringDiscriminator.TypeMeta has no validation
// field NonStringDiscriminator.D1 has no validation
// field NonStringDiscriminator.FieldA has no validation
// field NonStringDiscriminator.D2 has no validation
// field NonStringDiscriminator.FieldB has no validation
return errs
}
// Validate_SharedField validates an instance of SharedField according
// to declarative validation rules in the API schema.
func Validate_SharedField(ctx context.Context, op operation.Operation, fldPath *field.Path, obj, oldObj *SharedField) (errs field.ErrorList) {
errs = append(errs, validate.Discriminated(ctx, op, fldPath, obj, oldObj, "fieldA", func(obj *SharedField) *string { return obj.FieldA }, func(obj *SharedField) string { return obj.D1 }, validate.DirectEqualPtr, func(ctx context.Context, op operation.Operation, fldPath *field.Path, obj, oldObj *string) field.ErrorList {
errs := field.ErrorList{}
errs = append(errs, validate.ForbiddenPointer(ctx, op, fldPath, obj, oldObj)...)
return errs
}, []validate.DiscriminatedRule[*string, string]{
{
Value: "A", Validation: func(ctx context.Context, op operation.Operation, fldPath *field.Path, obj, oldObj *string) field.ErrorList {
errs := field.ErrorList{}
earlyReturn := false
if e := validate.OptionalPointer(ctx, op, fldPath, obj, oldObj); len(e) != 0 {
earlyReturn = true
}
if earlyReturn {
return errs
}
return errs
}},
{
Value: "B", Validation: func(ctx context.Context, op operation.Operation, fldPath *field.Path, obj, oldObj *string) field.ErrorList {
errs := field.ErrorList{}
earlyReturn := false
if e := validate.OptionalPointer(ctx, op, fldPath, obj, oldObj); len(e) != 0 {
earlyReturn = true
}
if earlyReturn {
return errs
}
return errs
}},
})...)
// field SharedField.TypeMeta has no validation
// field SharedField.D1 has no validation
// field SharedField.FieldA has no validation
return errs
}
// Validate_StrictUnion validates an instance of StrictUnion according
// to declarative validation rules in the API schema.
func Validate_StrictUnion(ctx context.Context, op operation.Operation, fldPath *field.Path, obj, oldObj *StrictUnion) (errs field.ErrorList) {
errs = append(errs, validate.Discriminated(ctx, op, fldPath, obj, oldObj, "fieldA", func(obj *StrictUnion) *string { return obj.FieldA }, func(obj *StrictUnion) string { return obj.D1 }, validate.DirectEqualPtr, func(ctx context.Context, op operation.Operation, fldPath *field.Path, obj, oldObj *string) field.ErrorList {
errs := field.ErrorList{}
errs = append(errs, validate.ForbiddenPointer(ctx, op, fldPath, obj, oldObj)...)
return errs
}, []validate.DiscriminatedRule[*string, string]{
{
Value: "A", Validation: func(ctx context.Context, op operation.Operation, fldPath *field.Path, obj, oldObj *string) field.ErrorList {
errs := field.ErrorList{}
earlyReturn := false
if e := validate.RequiredPointer(ctx, op, fldPath, obj, oldObj); len(e) != 0 {
errs = append(errs, e...)
earlyReturn = true
}
if earlyReturn {
return errs
}
return errs
}},
})...)
errs = append(errs, validate.Discriminated(ctx, op, fldPath, obj, oldObj, "fieldB", func(obj *StrictUnion) *string { return obj.FieldB }, func(obj *StrictUnion) string { return obj.D1 }, validate.DirectEqualPtr, func(ctx context.Context, op operation.Operation, fldPath *field.Path, obj, oldObj *string) field.ErrorList {
errs := field.ErrorList{}
errs = append(errs, validate.ForbiddenPointer(ctx, op, fldPath, obj, oldObj)...)
return errs
}, []validate.DiscriminatedRule[*string, string]{
{
Value: "B", Validation: func(ctx context.Context, op operation.Operation, fldPath *field.Path, obj, oldObj *string) field.ErrorList {
errs := field.ErrorList{}
earlyReturn := false
if e := validate.RequiredPointer(ctx, op, fldPath, obj, oldObj); len(e) != 0 {
errs = append(errs, e...)
earlyReturn = true
}
if earlyReturn {
return errs
}
return errs
}},
})...)
// field StrictUnion.TypeMeta has no validation
// field StrictUnion.D1 has no validation
// field StrictUnion.FieldA has no validation
// field StrictUnion.FieldB has no validation
return errs
}

View file

@ -1574,6 +1574,73 @@ func toGolangSourceDataLiteral(sw *generator.SnippetWriter, c *generator.Context
emitFunctionCall(sw, c, v.Function, "ctx", "op", "fldPath", "obj", "oldObj")
sw.Do("\n}", targs)
}
case validators.MultiWrapperFunction:
// MultiWrapperFunction generates a closure to execute multiple validation functions.
targs := generator.Args{
"field": mkSymbolArgs(c, fieldPkgSymbols),
"operation": mkSymbolArgs(c, operationPkgSymbols),
"context": mkSymbolArgs(c, contextPkgSymbols),
"objType": v.ObjType,
"objTypePfx": "*",
}
// Use the nilable form to pass pointers to standard validation functions.
if util.IsNilableType(v.ObjType) {
targs["objTypePfx"] = ""
}
sw.Do("func(", targs)
sw.Do(" ctx $.context.Context|raw$, ", targs)
sw.Do(" op $.operation.Operation|raw$, ", targs)
sw.Do(" fldPath *$.field.Path|raw$, ", targs)
sw.Do(" obj, oldObj $.objTypePfx$$.objType|raw$ ", targs)
sw.Do(") $.field.ErrorList|raw$ {\n", targs)
sw.Do("errs := $.field.ErrorList|raw${}\n", targs)
// Determine if any wrapped functions short-circuit.
hasShortCircuits := false
lastShortCircuitIdx := -1
for i, fg := range v.Functions {
if fg.Flags.IsSet(validators.ShortCircuit) {
hasShortCircuits = true
lastShortCircuitIdx = i
}
}
if hasShortCircuits {
sw.Do("earlyReturn := false\n", nil)
}
for i, fg := range v.Functions {
isNonError := fg.Flags.IsSet(validators.NonError)
if fg.Flags.IsSet(validators.ShortCircuit) {
// Short-circuiting functions stop execution if they return an error.
sw.Do("if e := ", nil)
emitFunctionCall(sw, c, fg, "ctx", "op", "fldPath", "obj", "oldObj")
sw.Do("; len(e) != 0 {\n", nil)
if !isNonError {
sw.Do(" errs = append(errs, e...)\n", nil)
}
sw.Do(" earlyReturn = true\n", nil)
sw.Do("}\n", nil)
// If a failure occurred during short-circuiting, return early.
if i == lastShortCircuitIdx {
sw.Do("if earlyReturn {\n", nil)
sw.Do(" return errs\n", nil)
sw.Do("}\n", nil)
}
} else {
// Standard functions append errors to the list.
if isNonError {
emitFunctionCall(sw, c, fg, "ctx", "op", "fldPath", "obj", "oldObj")
} else {
sw.Do("errs = append(errs, ", nil)
emitFunctionCall(sw, c, fg, "ctx", "op", "fldPath", "obj", "oldObj")
sw.Do("...)\n", nil)
}
}
}
sw.Do("return errs\n}", nil)
case validators.Literal:
sw.Do("$.$", v)
case validators.FunctionGen:

View file

@ -0,0 +1,469 @@
/*
Copyright 2025 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package validators
import (
"fmt"
"regexp"
"slices"
"strconv"
"k8s.io/apimachinery/pkg/util/sets"
"k8s.io/code-generator/cmd/validation-gen/util"
"k8s.io/gengo/v2/codetags"
"k8s.io/gengo/v2/parser/tags"
"k8s.io/gengo/v2/types"
)
const (
discriminatorTagName = "k8s:discriminator"
memberTagName = "k8s:member"
)
// validGroupNameRegex restricts discriminator group names to identifiers that
// start with a letter and contain only alphanumeric characters and underscores.
var validGroupNameRegex = regexp.MustCompile(`^[a-zA-Z][a-zA-Z0-9_]*$`)
func init() {
RegisterTagValidator(&discriminatorTagValidator{discriminatorDefinitions})
RegisterTagValidator(&memberTagValidator{discriminatorDefinitions, nil})
RegisterTypeValidator(&discriminatorFieldValidator{discriminatorDefinitions})
RegisterFieldValidator(&discriminatorFieldValidator{discriminatorDefinitions})
}
// discriminatorDefinitions stores all discriminator definitions found by tag validators.
// Key is the struct path.
var discriminatorDefinitions = map[string]discriminatorGroups{}
type discriminatorGroups map[string]*discriminatorGroup
type discriminatorGroup struct {
name string
discriminatorMember *types.Member
// members maps field names to their rules in this discriminator group.
members map[string]*fieldMemberRules
}
type fieldMemberRules struct {
member *types.Member
rules []memberRule
}
type memberRule struct {
value string
validations Validations
}
func (mg discriminatorGroups) getOrCreate(name string) *discriminatorGroup {
if name == "" {
name = "default"
}
g, ok := mg[name]
if !ok {
g = &discriminatorGroup{
name: name,
members: make(map[string]*fieldMemberRules),
}
mg[name] = g
}
return g
}
type discriminatorTagValidator struct {
shared map[string]discriminatorGroups
}
func (mtv *discriminatorTagValidator) Init(_ Config) {}
func (mtv *discriminatorTagValidator) TagName() string {
return discriminatorTagName
}
func (mtv *discriminatorTagValidator) ValidScopes() sets.Set[Scope] {
return sets.New(ScopeField)
}
func (mtv *discriminatorTagValidator) GetValidations(context Context, tag codetags.Tag) (Validations, error) {
if util.NativeType(context.Type).Kind == types.Pointer {
return Validations{}, fmt.Errorf("can only be used on non-pointer types")
}
if t := util.NonPointer(util.NativeType(context.Type)); t.Kind != types.Builtin || (t.Name.Name != "string" && t.Name.Name != "bool" && !types.IsInteger(t)) {
return Validations{}, fmt.Errorf("can only be used on string, bool or integer types (%s)", rootTypeString(context.Type, t))
}
if mtv.shared[context.ParentPath.String()] == nil {
mtv.shared[context.ParentPath.String()] = make(discriminatorGroups)
}
groupName := ""
if nameArg, ok := tag.NamedArg("name"); ok {
groupName = nameArg.Value
}
if groupName != "" && !validGroupNameRegex.MatchString(groupName) {
return Validations{}, fmt.Errorf("discriminator group name must match %s, got %q", validGroupNameRegex.String(), groupName)
}
if groupName == "default" {
return Validations{}, fmt.Errorf("discriminator group name %q is reserved", groupName)
}
group := mtv.shared[context.ParentPath.String()].getOrCreate(groupName)
if group.discriminatorMember != nil && group.discriminatorMember != context.Member {
return Validations{}, fmt.Errorf("duplicate discriminator: %q", groupName)
}
group.discriminatorMember = context.Member
return Validations{}, nil
}
func (mtv *discriminatorTagValidator) Docs() TagDoc {
return TagDoc{
Tag: mtv.TagName(),
StabilityLevel: TagStabilityLevelAlpha,
Scopes: mtv.ValidScopes().UnsortedList(),
Description: "Indicates that this field is a discriminator for state-based validation.",
Args: []TagArgDoc{{
Name: "name",
Description: "<string>",
Docs: "the name of the discriminator group, if more than one exists",
Type: codetags.ArgTypeString,
}},
}
}
type memberTagValidator struct {
shared map[string]discriminatorGroups
validator TagValidationExtractor
}
func (mtv *memberTagValidator) Init(cfg Config) {
mtv.validator = cfg.TagValidator
}
func (mtv *memberTagValidator) TagName() string {
return memberTagName
}
func (mtv *memberTagValidator) ValidScopes() sets.Set[Scope] {
return sets.New(ScopeField)
}
func (mtv *memberTagValidator) GetValidations(context Context, tag codetags.Tag) (Validations, error) {
if tag.ValueTag == nil {
return Validations{}, fmt.Errorf("missing required payload")
}
groupName := ""
if modeArg, ok := tag.NamedArg("discriminator"); ok {
groupName = modeArg.Value
}
if groupName == "default" {
return Validations{}, fmt.Errorf("discriminator group name %q is reserved", groupName)
}
value := ""
if valArg, ok := tag.NamedArg("value"); ok {
value = valArg.Value
} else if len(tag.Args) > 0 && tag.Args[0].Name == "" {
// Positional argument
value = tag.Args[0].Value
} else {
return Validations{}, fmt.Errorf("missing required value")
}
if mtv.shared[context.ParentPath.String()] == nil {
mtv.shared[context.ParentPath.String()] = make(discriminatorGroups)
}
group := mtv.shared[context.ParentPath.String()].getOrCreate(groupName)
fieldName := context.Member.Name
if rules, ok := group.members[fieldName]; ok {
if rules.member != context.Member {
return Validations{}, fmt.Errorf("internal error: member mismatch for field %q", fieldName)
}
} else {
group.members[fieldName] = &fieldMemberRules{
member: context.Member,
}
}
payloadValidations, err := mtv.validator.ExtractTagValidations(context, *tag.ValueTag)
if err != nil {
return Validations{}, err
}
group.members[fieldName].rules = append(group.members[fieldName].rules, memberRule{
value: value,
validations: payloadValidations,
})
return Validations{}, nil
}
func (mtv *memberTagValidator) Docs() TagDoc {
return TagDoc{
Tag: mtv.TagName(),
StabilityLevel: TagStabilityLevelAlpha,
Scopes: mtv.ValidScopes().UnsortedList(),
Description: "Indicates that this field's validation depends on a discriminator.",
Args: []TagArgDoc{{
Name: "", // positional
Description: "<string>",
Docs: "the value of the discriminator for which this validation applies",
Type: codetags.ArgTypeString,
}, {
Name: "discriminator",
Description: "<string>",
Docs: "the name of the discriminator group",
Type: codetags.ArgTypeString,
}, {
Name: "value",
Description: "<string>",
Docs: "the value of the discriminator for which this validation applies",
Type: codetags.ArgTypeString,
}},
PayloadsType: codetags.ValueTypeTag,
PayloadsRequired: true,
}
}
type discriminatorFieldValidator struct {
shared map[string]discriminatorGroups
}
func (discriminatorFieldValidator) Init(_ Config) {}
func (discriminatorFieldValidator) Name() string {
return "discriminatorFieldValidator"
}
func (mtfv *discriminatorFieldValidator) GetValidations(context Context) (Validations, error) {
// Extract the most concrete type possible.
if k := util.NonPointer(util.NativeType(context.Type)).Kind; k != types.Struct {
return Validations{}, nil
}
groups, ok := mtfv.shared[context.Path.String()]
if !ok || len(groups) == 0 {
return Validations{}, nil
}
var result Validations
// Sort group names for deterministic output
groupNames := make([]string, 0, len(groups))
for name := range groups {
groupNames = append(groupNames, name)
}
slices.Sort(groupNames)
for _, gn := range groupNames {
group := groups[gn]
if group.discriminatorMember == nil {
if len(group.members) > 0 {
if gn == "default" {
return Validations{}, fmt.Errorf("missing discriminator")
}
return Validations{}, fmt.Errorf("missing discriminator for group %q", gn)
}
continue
}
fieldNames := make([]string, 0, len(group.members))
for name := range group.members {
fieldNames = append(fieldNames, name)
}
slices.Sort(fieldNames)
for _, fn := range fieldNames {
rules := group.members[fn]
v, err := mtfv.generateMemberFieldValidation(context, group, rules)
if err != nil {
return Validations{}, err
}
result.Add(v)
}
}
return result, nil
}
func (mtfv *discriminatorFieldValidator) generateMemberFieldValidation(context Context, group *discriminatorGroup, rules *fieldMemberRules) (Validations, error) {
fieldType := rules.member.Type
// Use the nilable form to handle missing values.
nilableFieldType := fieldType
fieldExprPrefix := ""
if !util.IsNilableType(nilableFieldType) {
nilableFieldType = types.PointerTo(nilableFieldType)
fieldExprPrefix = "&"
}
// Get the JSON name of the field
jsonName := rules.member.Name
if jt, ok := tags.LookupJSON(*rules.member); ok {
jsonName = jt.Name
}
// Default validation is Forbidden
defaultForbidden, err := mtfv.getForbiddenValidation(fieldType)
if err != nil {
return Validations{}, err
}
// Prepare DiscriminatedRules
// Aggregate rules by value
rulesByValue := make(map[string]Validations)
var values []string
for _, rule := range rules.rules {
if _, ok := rulesByValue[rule.value]; !ok {
values = append(values, rule.value)
}
v := rulesByValue[rule.value]
v.Add(rule.validations)
rulesByValue[rule.value] = v
}
slices.Sort(values)
discriminatorType := group.discriminatorMember.Type
var discriminatedRules []any
for _, val := range values {
ruleValidations := rulesByValue[val]
wrapper := MultiWrapperFunction{
Functions: ruleValidations.Functions,
ObjType: nilableFieldType,
}
// Convert the string tag value to the appropriate typed Go literal
// for the discriminator type.
typedValue, err := convertDiscriminatorValue(val, discriminatorType)
if err != nil {
return Validations{}, fmt.Errorf("invalid discriminator value %q: %w", val, err)
}
discriminatedRules = append(discriminatedRules, StructLiteral{
Type: types.Name{Package: libValidationPkg, Name: "DiscriminatedRule"},
TypeArgs: []*types.Type{nilableFieldType, discriminatorType},
Fields: []StructLiteralField{
{Name: "Value", Value: typedValue},
{Name: "Validation", Value: wrapper},
},
})
}
discriminatedValidator := types.Name{Package: libValidationPkg, Name: "Discriminated"}
rulesSlice := SliceLiteral{
ElementType: types.Name{Package: libValidationPkg, Name: "DiscriminatedRule"},
ElementTypeArgs: []*types.Type{nilableFieldType, discriminatorType},
Elements: discriminatedRules,
}
// getValue extractor
getValue := FunctionLiteral{
Parameters: []ParamResult{{Name: "obj", Type: types.PointerTo(context.Type)}},
Results: []ParamResult{{Type: nilableFieldType}},
Body: fmt.Sprintf("return %sobj.%s", fieldExprPrefix, rules.member.Name),
}
// getDiscriminator extractor
getDiscriminator := FunctionLiteral{
Parameters: []ParamResult{{Name: "obj", Type: types.PointerTo(context.Type)}},
Results: []ParamResult{{Type: discriminatorType}},
Body: fmt.Sprintf("return obj.%s", group.discriminatorMember.Name),
}
// directComparable is used to determine whether we can use the direct
// comparison operator "==" or need to use the semantic DeepEqual when
// looking up and comparing correlated list elements for validation ratcheting.
var equivArg any
if util.IsDirectComparable(util.NonPointer(util.NativeType(fieldType))) {
equivArg = Identifier(validateDirectEqualPtr)
} else {
equivArg = Identifier(validateSemanticDeepEqual)
}
fn := Function(discriminatorTagName, DefaultFlags, discriminatedValidator,
Literal(fmt.Sprintf("%q", jsonName)),
getValue,
getDiscriminator,
equivArg,
defaultForbidden,
rulesSlice,
)
return Validations{Functions: []FunctionGen{fn}}, nil
}
func (mtfv *discriminatorFieldValidator) getForbiddenValidation(t *types.Type) (any, error) {
var forbidden types.Name
nt := util.NativeType(t)
switch nt.Kind {
case types.Slice:
forbidden = types.Name{Package: libValidationPkg, Name: "ForbiddenSlice"}
case types.Map:
forbidden = types.Name{Package: libValidationPkg, Name: "ForbiddenMap"}
case types.Pointer:
forbidden = types.Name{Package: libValidationPkg, Name: "ForbiddenPointer"}
case types.Struct:
return nil, fmt.Errorf("discriminated member fields of struct type must be pointers")
default:
forbidden = types.Name{Package: libValidationPkg, Name: "ForbiddenValue"}
}
fg := Function(forbiddenTagName, DefaultFlags, forbidden)
// Use the nilable form to match standard validation function signatures.
wrapperObjType := t
if !util.IsNilableType(t) {
wrapperObjType = types.PointerTo(t)
}
return MultiWrapperFunction{
Functions: []FunctionGen{fg},
ObjType: wrapperObjType,
}, nil
}
// convertDiscriminatorValue converts a string tag value to the appropriate
// typed Go literal for the given discriminator type.
func convertDiscriminatorValue(val string, discType *types.Type) (any, error) {
nt := util.NonPointer(util.NativeType(discType))
if nt.Kind != types.Builtin {
return nil, fmt.Errorf("unsupported discriminator type: %s", nt.Name.Name)
}
switch nt.Name.Name {
case "string":
return val, nil
case "bool":
b, err := strconv.ParseBool(val)
if err != nil {
return nil, fmt.Errorf("cannot parse %q as bool: %w", val, err)
}
return b, nil
default:
if types.IsInteger(nt) {
i, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return nil, fmt.Errorf("cannot parse %q as integer: %w", val, err)
}
return int(i), nil
}
return nil, fmt.Errorf("unsupported discriminator type: %s", nt.Name.Name)
}
}

View file

@ -595,6 +595,14 @@ type WrapperFunction struct {
ObjType *types.Type
}
// MultiWrapperFunction describes a function literal which has the fingerprint
// of a regular validation function (op, fldPath, obj, oldObj) and calls
// multiple other validation functions with the same signature.
type MultiWrapperFunction struct {
Functions []FunctionGen
ObjType *types.Type
}
// Literal is a literal value that, when used as an argument to a validator,
// will be emitted without any further interpretation. Use this with caution,
// it will not be subject to Namers.