From ccde9bcdb5907048cddb8f0563008dccdf0b163d Mon Sep 17 00:00:00 2001 From: Thom Wright Date: Thu, 26 Mar 2026 12:56:10 +0000 Subject: [PATCH] Add scram_iterations config option Allows operators to explicitly set the SCRAM-SHA-256 iteration count, overriding the server's setting. Returns an error if scram_iterations is set without password_authentication=scram-sha-256. --- plugins/database/postgresql/postgresql.go | 16 +++- .../database/postgresql/postgresql_test.go | 92 +++++++++++++++++++ 2 files changed, 107 insertions(+), 1 deletion(-) diff --git a/plugins/database/postgresql/postgresql.go b/plugins/database/postgresql/postgresql.go index 3ed9e8d3b3..a102d85d36 100644 --- a/plugins/database/postgresql/postgresql.go +++ b/plugins/database/postgresql/postgresql.go @@ -183,7 +183,21 @@ func (p *PostgreSQL) Initialize(ctx context.Context, req dbplugin.InitializeRequ p.passwordAuthentication = pwAuthentication } - if p.passwordAuthentication == passwordAuthenticationSCRAMSHA256 { + scramIterationsRaw, err := strutil.GetString(req.Config, "scram_iterations") + if err != nil { + return dbplugin.InitializeResponse{}, fmt.Errorf("failed to retrieve scram_iterations: %w", err) + } + + if scramIterationsRaw != "" { + if p.passwordAuthentication != passwordAuthenticationSCRAMSHA256 { + return dbplugin.InitializeResponse{}, fmt.Errorf("scram_iterations requires password_authentication to be %q", passwordAuthenticationSCRAMSHA256) + } + scramIterations, err := strconv.Atoi(scramIterationsRaw) + if err != nil || scramIterations < 1 { + return dbplugin.InitializeResponse{}, fmt.Errorf("scram_iterations must be a positive integer, got: %s", scramIterationsRaw) + } + p.scramIterations = scramIterations + } else if p.passwordAuthentication == passwordAuthenticationSCRAMSHA256 { if serverIterations, err := p.queryServerSCRAMIterations(ctx); err == nil { p.scramIterations = serverIterations } diff --git a/plugins/database/postgresql/postgresql_test.go b/plugins/database/postgresql/postgresql_test.go index fa982471d0..4dfe746dd8 100644 --- a/plugins/database/postgresql/postgresql_test.go +++ b/plugins/database/postgresql/postgresql_test.go @@ -880,6 +880,98 @@ func TestPostgreSQL_SCRAMIterations_FallbackOnOlderPG(t *testing.T) { assert.Equal(t, scram.DefaultIterations, db.scramIterations) } +// TestPostgreSQL_SCRAMIterations_ExplicitConfig tests that an explicit scram_iterations config +// value overrides the server's setting. +func TestPostgreSQL_SCRAMIterations_ExplicitConfig(t *testing.T) { + ctx := context.Background() + cleanup, connURL := postgresql.PrepareTestContainerWithSCRAMIterations(t, ctx, 100) + defer cleanup() + + dsnConnURL, err := dbutil.ParseURL(connURL) + if err != nil { + t.Fatal(err) + } + + connectionDetails := map[string]interface{}{ + "connection_url": dsnConnURL, + "password_authentication": string(passwordAuthenticationSCRAMSHA256), + "scram_iterations": "200", + } + + req := dbplugin.InitializeRequest{ + Config: connectionDetails, + VerifyConnection: true, + } + + db := new() + _ = dbtesting.AssertInitialize(t, db, req) + + // Explicit config (200) should override server config (100) + assert.Equal(t, 200, db.scramIterations) +} + +// TestPostgreSQL_SCRAMIterations_RequiresSCRAM tests that setting scram_iterations without +// password_authentication=scram-sha-256 returns an error. +func TestPostgreSQL_SCRAMIterations_RequiresSCRAM(t *testing.T) { + cleanup, connURL := postgresql.PrepareTestContainer(t) + defer cleanup() + + dsnConnURL, err := dbutil.ParseURL(connURL) + assert.NoError(t, err) + + connectionDetails := map[string]interface{}{ + "connection_url": dsnConnURL, + "scram_iterations": "4096", + } + + req := dbplugin.InitializeRequest{ + Config: connectionDetails, + VerifyConnection: true, + } + + db := new() + _, err = db.Initialize(context.Background(), req) + assert.Error(t, err) + assert.Contains(t, err.Error(), "scram_iterations requires password_authentication") +} + +// TestPostgreSQL_SCRAMIterations_Invalid tests that invalid scram_iterations values are rejected. +func TestPostgreSQL_SCRAMIterations_Invalid(t *testing.T) { + cleanup, connURL := postgresql.PrepareTestContainer(t) + defer cleanup() + + dsnConnURL, err := dbutil.ParseURL(connURL) + assert.NoError(t, err) + + tcs := map[string]struct { + value string + }{ + "not-a-number": {value: "abc"}, + "zero": {value: "0"}, + "negative": {value: "-1"}, + } + + for name, tc := range tcs { + t.Run(name, func(t *testing.T) { + connectionDetails := map[string]interface{}{ + "connection_url": dsnConnURL, + "password_authentication": string(passwordAuthenticationSCRAMSHA256), + "scram_iterations": tc.value, + } + + req := dbplugin.InitializeRequest{ + Config: connectionDetails, + VerifyConnection: true, + } + + db := new() + _, err := db.Initialize(context.Background(), req) + assert.Error(t, err) + assert.Contains(t, err.Error(), "scram_iterations must be a positive integer") + }) + } +} + func TestPostgreSQL_NewUser(t *testing.T) { type testCase struct { req dbplugin.NewUserRequest