diff --git a/ha/ha_test.go b/ha/ha_test.go index 780e2b9d..71faa3e5 100644 --- a/ha/ha_test.go +++ b/ha/ha_test.go @@ -12,6 +12,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "strconv" "sync" "testing" "time" @@ -21,9 +22,7 @@ func createTestingHA(t *testing.T, redisAddr string) *HA { redisConn := connection.NewRDBWrapper(redisAddr, "", 64) mysqlConn, err := connection.NewDBWrapper(testbackends.MysqlTestDsn, 50) - if err != nil { - assert.Fail(t, "This test needs a working MySQL connection!") - } + require.NoError(t, err, "This test needs a working MySQL connection!") super := supervisor.Supervisor{ ChErr: make(chan error), @@ -49,6 +48,44 @@ func createTestingHA(t *testing.T, redisAddr string) *HA { return ha } +func createTestingMultipleHA(t *testing.T, redisAddr string, numInstances int) ([]*HA, <-chan error) { + redisConn := connection.NewRDBWrapper(redisAddr, "", 64) + + mysqlConn, err := connection.NewDBWrapper(testbackends.MysqlTestDsn, 50) + require.NoError(t, err, "This test needs a working MySQL connection!") + + _, err = mysqlConn.SqlExec(mysqlTestObserver, "TRUNCATE TABLE icingadb_instance") + require.NoError(t, err, "This test needs a working MySQL connection!") + + instances := make([]*HA, numInstances) + chErr := make(chan error) + + for i := 0; i < numInstances; i++ { + + super := supervisor.Supervisor{ + ChErr: chErr, + Rdbw: redisConn, + Dbw: mysqlConn, + } + + ha, _ := NewHA(&super) + + hash := sha1.New() + hash.Write([]byte("derp")) + ha.super.EnvId = hash.Sum(nil) + ha.uid = uuid.NewSHA1(uuid.MustParse("551bc748-94b2-4d27-b6a4-15c52aecfe85"), []byte(strconv.Itoa(i))) + + ha.logger = log.WithFields(log.Fields{ + "context": "HA-Testing", + "UUID": ha.uid, + }) + + instances[i] = ha + } + + return instances, chErr +} + var mysqlTestObserver = connection.DbIoSeconds.WithLabelValues("mysql", "test") func TestHA_InsertInstance(t *testing.T) { @@ -202,3 +239,54 @@ func TestHA_runHA(t *testing.T) { wg.Wait() } + +func TestHA_ConcurrentCheckResponsibility(t *testing.T) { + numAttempts := 10 + numConcurrentTakeovers := 2 + failed := false + + for attempt := 0; !failed && attempt < numAttempts; attempt++ { + wg := sync.WaitGroup{} + wg.Add(numConcurrentTakeovers) + + haInstances, chErr := createTestingMultipleHA(t, testbackends.RedisTestAddr, numConcurrentTakeovers) + for _, ha := range haInstances { + ha := ha + go func() { + defer wg.Done() + ha.checkResponsibility(&Environment{}) + }() + } + + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + loop: + for { + select { + case err := <-chErr: + assert.NoError(t, err, "checkResponsibility() should return no error") + if err != nil { + failed = true + } + case <-done: + break loop + } + } + + numActive := 0 + for _, ha := range haInstances { + if ha.state == StateActive { + numActive++ + } + } + + assert.Equal(t, 1, numActive, "exactly 1 instance must be active after checkResponsibility() but %d are active", numActive) + if numActive != 1 { + failed = true + } + } +}