Add scram_iterations config option

Allows operators to explicitly set the SCRAM-SHA-256 iteration count,
overriding the server's setting. Returns an error if scram_iterations
is set without password_authentication=scram-sha-256.
This commit is contained in:
Thom Wright 2026-03-26 12:56:10 +00:00
parent cddd8a631a
commit ccde9bcdb5
No known key found for this signature in database
GPG key ID: 0315D0ABDE80BF78
2 changed files with 107 additions and 1 deletions

View file

@ -183,7 +183,21 @@ func (p *PostgreSQL) Initialize(ctx context.Context, req dbplugin.InitializeRequ
p.passwordAuthentication = pwAuthentication
}
if p.passwordAuthentication == passwordAuthenticationSCRAMSHA256 {
scramIterationsRaw, err := strutil.GetString(req.Config, "scram_iterations")
if err != nil {
return dbplugin.InitializeResponse{}, fmt.Errorf("failed to retrieve scram_iterations: %w", err)
}
if scramIterationsRaw != "" {
if p.passwordAuthentication != passwordAuthenticationSCRAMSHA256 {
return dbplugin.InitializeResponse{}, fmt.Errorf("scram_iterations requires password_authentication to be %q", passwordAuthenticationSCRAMSHA256)
}
scramIterations, err := strconv.Atoi(scramIterationsRaw)
if err != nil || scramIterations < 1 {
return dbplugin.InitializeResponse{}, fmt.Errorf("scram_iterations must be a positive integer, got: %s", scramIterationsRaw)
}
p.scramIterations = scramIterations
} else if p.passwordAuthentication == passwordAuthenticationSCRAMSHA256 {
if serverIterations, err := p.queryServerSCRAMIterations(ctx); err == nil {
p.scramIterations = serverIterations
}

View file

@ -880,6 +880,98 @@ func TestPostgreSQL_SCRAMIterations_FallbackOnOlderPG(t *testing.T) {
assert.Equal(t, scram.DefaultIterations, db.scramIterations)
}
// TestPostgreSQL_SCRAMIterations_ExplicitConfig tests that an explicit scram_iterations config
// value overrides the server's setting.
func TestPostgreSQL_SCRAMIterations_ExplicitConfig(t *testing.T) {
ctx := context.Background()
cleanup, connURL := postgresql.PrepareTestContainerWithSCRAMIterations(t, ctx, 100)
defer cleanup()
dsnConnURL, err := dbutil.ParseURL(connURL)
if err != nil {
t.Fatal(err)
}
connectionDetails := map[string]interface{}{
"connection_url": dsnConnURL,
"password_authentication": string(passwordAuthenticationSCRAMSHA256),
"scram_iterations": "200",
}
req := dbplugin.InitializeRequest{
Config: connectionDetails,
VerifyConnection: true,
}
db := new()
_ = dbtesting.AssertInitialize(t, db, req)
// Explicit config (200) should override server config (100)
assert.Equal(t, 200, db.scramIterations)
}
// TestPostgreSQL_SCRAMIterations_RequiresSCRAM tests that setting scram_iterations without
// password_authentication=scram-sha-256 returns an error.
func TestPostgreSQL_SCRAMIterations_RequiresSCRAM(t *testing.T) {
cleanup, connURL := postgresql.PrepareTestContainer(t)
defer cleanup()
dsnConnURL, err := dbutil.ParseURL(connURL)
assert.NoError(t, err)
connectionDetails := map[string]interface{}{
"connection_url": dsnConnURL,
"scram_iterations": "4096",
}
req := dbplugin.InitializeRequest{
Config: connectionDetails,
VerifyConnection: true,
}
db := new()
_, err = db.Initialize(context.Background(), req)
assert.Error(t, err)
assert.Contains(t, err.Error(), "scram_iterations requires password_authentication")
}
// TestPostgreSQL_SCRAMIterations_Invalid tests that invalid scram_iterations values are rejected.
func TestPostgreSQL_SCRAMIterations_Invalid(t *testing.T) {
cleanup, connURL := postgresql.PrepareTestContainer(t)
defer cleanup()
dsnConnURL, err := dbutil.ParseURL(connURL)
assert.NoError(t, err)
tcs := map[string]struct {
value string
}{
"not-a-number": {value: "abc"},
"zero": {value: "0"},
"negative": {value: "-1"},
}
for name, tc := range tcs {
t.Run(name, func(t *testing.T) {
connectionDetails := map[string]interface{}{
"connection_url": dsnConnURL,
"password_authentication": string(passwordAuthenticationSCRAMSHA256),
"scram_iterations": tc.value,
}
req := dbplugin.InitializeRequest{
Config: connectionDetails,
VerifyConnection: true,
}
db := new()
_, err := db.Initialize(context.Background(), req)
assert.Error(t, err)
assert.Contains(t, err.Error(), "scram_iterations must be a positive integer")
})
}
}
func TestPostgreSQL_NewUser(t *testing.T) {
type testCase struct {
req dbplugin.NewUserRequest