mirror of
https://github.com/hashicorp/vault.git
synced 2026-06-09 08:55:13 -04:00
Add scram_iterations config option
Allows operators to explicitly set the SCRAM-SHA-256 iteration count, overriding the server's setting. Returns an error if scram_iterations is set without password_authentication=scram-sha-256.
This commit is contained in:
parent
cddd8a631a
commit
ccde9bcdb5
2 changed files with 107 additions and 1 deletions
|
|
@ -183,7 +183,21 @@ func (p *PostgreSQL) Initialize(ctx context.Context, req dbplugin.InitializeRequ
|
|||
p.passwordAuthentication = pwAuthentication
|
||||
}
|
||||
|
||||
if p.passwordAuthentication == passwordAuthenticationSCRAMSHA256 {
|
||||
scramIterationsRaw, err := strutil.GetString(req.Config, "scram_iterations")
|
||||
if err != nil {
|
||||
return dbplugin.InitializeResponse{}, fmt.Errorf("failed to retrieve scram_iterations: %w", err)
|
||||
}
|
||||
|
||||
if scramIterationsRaw != "" {
|
||||
if p.passwordAuthentication != passwordAuthenticationSCRAMSHA256 {
|
||||
return dbplugin.InitializeResponse{}, fmt.Errorf("scram_iterations requires password_authentication to be %q", passwordAuthenticationSCRAMSHA256)
|
||||
}
|
||||
scramIterations, err := strconv.Atoi(scramIterationsRaw)
|
||||
if err != nil || scramIterations < 1 {
|
||||
return dbplugin.InitializeResponse{}, fmt.Errorf("scram_iterations must be a positive integer, got: %s", scramIterationsRaw)
|
||||
}
|
||||
p.scramIterations = scramIterations
|
||||
} else if p.passwordAuthentication == passwordAuthenticationSCRAMSHA256 {
|
||||
if serverIterations, err := p.queryServerSCRAMIterations(ctx); err == nil {
|
||||
p.scramIterations = serverIterations
|
||||
}
|
||||
|
|
|
|||
|
|
@ -880,6 +880,98 @@ func TestPostgreSQL_SCRAMIterations_FallbackOnOlderPG(t *testing.T) {
|
|||
assert.Equal(t, scram.DefaultIterations, db.scramIterations)
|
||||
}
|
||||
|
||||
// TestPostgreSQL_SCRAMIterations_ExplicitConfig tests that an explicit scram_iterations config
|
||||
// value overrides the server's setting.
|
||||
func TestPostgreSQL_SCRAMIterations_ExplicitConfig(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
cleanup, connURL := postgresql.PrepareTestContainerWithSCRAMIterations(t, ctx, 100)
|
||||
defer cleanup()
|
||||
|
||||
dsnConnURL, err := dbutil.ParseURL(connURL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
connectionDetails := map[string]interface{}{
|
||||
"connection_url": dsnConnURL,
|
||||
"password_authentication": string(passwordAuthenticationSCRAMSHA256),
|
||||
"scram_iterations": "200",
|
||||
}
|
||||
|
||||
req := dbplugin.InitializeRequest{
|
||||
Config: connectionDetails,
|
||||
VerifyConnection: true,
|
||||
}
|
||||
|
||||
db := new()
|
||||
_ = dbtesting.AssertInitialize(t, db, req)
|
||||
|
||||
// Explicit config (200) should override server config (100)
|
||||
assert.Equal(t, 200, db.scramIterations)
|
||||
}
|
||||
|
||||
// TestPostgreSQL_SCRAMIterations_RequiresSCRAM tests that setting scram_iterations without
|
||||
// password_authentication=scram-sha-256 returns an error.
|
||||
func TestPostgreSQL_SCRAMIterations_RequiresSCRAM(t *testing.T) {
|
||||
cleanup, connURL := postgresql.PrepareTestContainer(t)
|
||||
defer cleanup()
|
||||
|
||||
dsnConnURL, err := dbutil.ParseURL(connURL)
|
||||
assert.NoError(t, err)
|
||||
|
||||
connectionDetails := map[string]interface{}{
|
||||
"connection_url": dsnConnURL,
|
||||
"scram_iterations": "4096",
|
||||
}
|
||||
|
||||
req := dbplugin.InitializeRequest{
|
||||
Config: connectionDetails,
|
||||
VerifyConnection: true,
|
||||
}
|
||||
|
||||
db := new()
|
||||
_, err = db.Initialize(context.Background(), req)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "scram_iterations requires password_authentication")
|
||||
}
|
||||
|
||||
// TestPostgreSQL_SCRAMIterations_Invalid tests that invalid scram_iterations values are rejected.
|
||||
func TestPostgreSQL_SCRAMIterations_Invalid(t *testing.T) {
|
||||
cleanup, connURL := postgresql.PrepareTestContainer(t)
|
||||
defer cleanup()
|
||||
|
||||
dsnConnURL, err := dbutil.ParseURL(connURL)
|
||||
assert.NoError(t, err)
|
||||
|
||||
tcs := map[string]struct {
|
||||
value string
|
||||
}{
|
||||
"not-a-number": {value: "abc"},
|
||||
"zero": {value: "0"},
|
||||
"negative": {value: "-1"},
|
||||
}
|
||||
|
||||
for name, tc := range tcs {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
connectionDetails := map[string]interface{}{
|
||||
"connection_url": dsnConnURL,
|
||||
"password_authentication": string(passwordAuthenticationSCRAMSHA256),
|
||||
"scram_iterations": tc.value,
|
||||
}
|
||||
|
||||
req := dbplugin.InitializeRequest{
|
||||
Config: connectionDetails,
|
||||
VerifyConnection: true,
|
||||
}
|
||||
|
||||
db := new()
|
||||
_, err := db.Initialize(context.Background(), req)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "scram_iterations must be a positive integer")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgreSQL_NewUser(t *testing.T) {
|
||||
type testCase struct {
|
||||
req dbplugin.NewUserRequest
|
||||
|
|
|
|||
Loading…
Reference in a new issue