mirror of
https://github.com/grafana/grafana.git
synced 2026-02-18 18:20:52 -05:00
zanzana: Fix batch check to split requests exceeding OpenFGA max checks limit (#118154)
add support for large batch check size
This commit is contained in:
parent
8321b9046a
commit
dd37028512
2 changed files with 124 additions and 1 deletions
|
|
@ -9,6 +9,7 @@ import (
|
|||
authzv1 "github.com/grafana/authlib/authz/proto/v1"
|
||||
"github.com/grafana/authlib/types"
|
||||
openfgav1 "github.com/openfga/api/proto/openfga/v1"
|
||||
serverconfig "github.com/openfga/openfga/pkg/server/config"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/codes"
|
||||
grpccodes "google.golang.org/grpc/codes"
|
||||
|
|
@ -464,7 +465,8 @@ func (s *Server) addTypedResourceDirectChecks(
|
|||
return checks
|
||||
}
|
||||
|
||||
// doBatchCheck executes a batch check against OpenFGA
|
||||
// doBatchCheck executes a batch check against OpenFGA, splitting into
|
||||
// sub-batches if the number of checks exceeds the configured MaxChecksPerBatchCheck limit.
|
||||
func (s *Server) doBatchCheck(
|
||||
ctx context.Context,
|
||||
store *storeInfo,
|
||||
|
|
@ -474,6 +476,38 @@ func (s *Server) doBatchCheck(
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
maxChecks := s.getMaxChecksPerBatchCheck()
|
||||
|
||||
// If within limit, send a single batch
|
||||
if len(checks) <= maxChecks {
|
||||
return s.executeBatchCheck(ctx, store, checks)
|
||||
}
|
||||
|
||||
// Split into sub-batches
|
||||
allResults := make(map[string]*openfgav1.BatchCheckSingleResult, len(checks))
|
||||
for i := 0; i < len(checks); i += maxChecks {
|
||||
end := i + maxChecks
|
||||
if end > len(checks) {
|
||||
end = len(checks)
|
||||
}
|
||||
results, err := s.executeBatchCheck(ctx, store, checks[i:end])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for k, v := range results {
|
||||
allResults[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return allResults, nil
|
||||
}
|
||||
|
||||
// executeBatchCheck sends a single OpenFGA BatchCheck request.
|
||||
func (s *Server) executeBatchCheck(
|
||||
ctx context.Context,
|
||||
store *storeInfo,
|
||||
checks []*openfgav1.BatchCheckItem,
|
||||
) (map[string]*openfgav1.BatchCheckSingleResult, error) {
|
||||
openfgaReq := &openfgav1.BatchCheckRequest{
|
||||
StoreId: store.ID,
|
||||
AuthorizationModelId: store.ModelID,
|
||||
|
|
@ -487,3 +521,12 @@ func (s *Server) doBatchCheck(
|
|||
|
||||
return openfgaRes.GetResult(), nil
|
||||
}
|
||||
|
||||
// getMaxChecksPerBatchCheck returns the configured maximum checks per batch,
|
||||
// falling back to the default if not explicitly set.
|
||||
func (s *Server) getMaxChecksPerBatchCheck() int {
|
||||
if s.cfg.OpenFgaServerSettings.MaxChecksPerBatchCheck > 0 {
|
||||
return int(s.cfg.OpenFgaServerSettings.MaxChecksPerBatchCheck)
|
||||
}
|
||||
return serverconfig.DefaultMaxChecksPerBatchCheck
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
authzv1 "github.com/grafana/authlib/authz/proto/v1"
|
||||
|
|
@ -8,6 +9,7 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/grafana/grafana/pkg/apimachinery/utils"
|
||||
"github.com/grafana/grafana/pkg/setting"
|
||||
"github.com/grafana/grafana/pkg/util/testutil"
|
||||
)
|
||||
|
||||
|
|
@ -225,3 +227,81 @@ func TestIntegrationServerBatchCheck(t *testing.T) {
|
|||
assert.True(t, res.GetResults()["check2"].GetAllowed())
|
||||
})
|
||||
}
|
||||
|
||||
func TestIntegrationServerBatchCheck_SubBatching(t *testing.T) {
|
||||
testutil.SkipIntegrationTestInShortMode(t)
|
||||
|
||||
server := setupOpenFGAServer(t)
|
||||
setup(t, server)
|
||||
|
||||
// Set a low limit to force sub-batching within the test
|
||||
server.cfg.OpenFgaServerSettings = setting.OpenFgaServerSettings{
|
||||
MaxChecksPerBatchCheck: 3,
|
||||
}
|
||||
|
||||
newBatchReq := func(subject string, items []*authzv1.BatchCheckItem) *authzv1.BatchCheckRequest {
|
||||
return &authzv1.BatchCheckRequest{
|
||||
Namespace: namespace,
|
||||
Subject: subject,
|
||||
Checks: items,
|
||||
}
|
||||
}
|
||||
|
||||
newItem := func(correlationID, verb, group, resource, subresource, folder, name string) *authzv1.BatchCheckItem {
|
||||
return &authzv1.BatchCheckItem{
|
||||
CorrelationId: correlationID,
|
||||
Verb: verb,
|
||||
Group: group,
|
||||
Resource: resource,
|
||||
Subresource: subresource,
|
||||
Name: name,
|
||||
Folder: folder,
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("batch exceeding limit returns correct results for all items", func(t *testing.T) {
|
||||
// user:2 has group_resource access to all dashboards, so all should be allowed.
|
||||
// 7 items with MaxChecksPerBatchCheck=3 forces splitting into multiple sub-batches.
|
||||
items := make([]*authzv1.BatchCheckItem, 7)
|
||||
for i := range items {
|
||||
items[i] = newItem(
|
||||
fmt.Sprintf("check-%d", i),
|
||||
utils.VerbGet, dashboardGroup, dashboardResource, "", "1", fmt.Sprintf("%d", i+1),
|
||||
)
|
||||
}
|
||||
|
||||
res, err := server.BatchCheck(newContextWithNamespace(), newBatchReq("user:2", items))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, res.GetResults(), 7)
|
||||
for i := range items {
|
||||
assert.True(t, res.GetResults()[fmt.Sprintf("check-%d", i)].GetAllowed(),
|
||||
"check-%d should be allowed via group_resource access", i)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("batch exceeding limit preserves mixed allowed and denied results", func(t *testing.T) {
|
||||
// user:4 has folder-based access to folders 1 and 3 but not folder 2.
|
||||
// Generate items across folders to get a mix of allowed/denied across sub-batches.
|
||||
items := []*authzv1.BatchCheckItem{
|
||||
newItem("f1-a", utils.VerbGet, dashboardGroup, dashboardResource, "", "1", "100"), // allowed
|
||||
newItem("f1-b", utils.VerbGet, dashboardGroup, dashboardResource, "", "1", "101"), // allowed
|
||||
newItem("f2-a", utils.VerbGet, dashboardGroup, dashboardResource, "", "2", "200"), // denied
|
||||
newItem("f3-a", utils.VerbGet, dashboardGroup, dashboardResource, "", "3", "300"), // allowed
|
||||
newItem("f2-b", utils.VerbGet, dashboardGroup, dashboardResource, "", "2", "201"), // denied
|
||||
newItem("f1-c", utils.VerbGet, dashboardGroup, dashboardResource, "", "1", "102"), // allowed
|
||||
newItem("f3-b", utils.VerbGet, dashboardGroup, dashboardResource, "", "3", "301"), // allowed
|
||||
}
|
||||
|
||||
res, err := server.BatchCheck(newContextWithNamespace(), newBatchReq("user:4", items))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, res.GetResults(), 7)
|
||||
|
||||
assert.True(t, res.GetResults()["f1-a"].GetAllowed())
|
||||
assert.True(t, res.GetResults()["f1-b"].GetAllowed())
|
||||
assert.False(t, res.GetResults()["f2-a"].GetAllowed())
|
||||
assert.True(t, res.GetResults()["f3-a"].GetAllowed())
|
||||
assert.False(t, res.GetResults()["f2-b"].GetAllowed())
|
||||
assert.True(t, res.GetResults()["f1-c"].GetAllowed())
|
||||
assert.True(t, res.GetResults()["f3-b"].GetAllowed())
|
||||
})
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue