diff --git a/connection/.gitignore b/connection/.gitignore new file mode 100644 index 00000000..3d800336 --- /dev/null +++ b/connection/.gitignore @@ -0,0 +1,2 @@ + +coverage\.html diff --git a/connection/.gitlab-ci.yml b/connection/.gitlab-ci.yml new file mode 100644 index 00000000..56b155c4 --- /dev/null +++ b/connection/.gitlab-ci.yml @@ -0,0 +1,30 @@ +image: golang:latest + +variables: + REPO_NAME: git.icinga.com/icingadb/icingadb-connection + +before_script: + - mkdir -p $GOPATH/src/$(dirname $REPO_NAME) + - ln -svf $CI_PROJECT_DIR $GOPATH/src/$REPO_NAME + - cd $GOPATH/src/$REPO_NAME + - git config --global url."https://gitlab-ci-token:${CI_JOB_TOKEN}@git.icinga.com/".insteadOf "https://git.icinga.com/" + - go get -t ./... + +stages: + - test + - coverage + +test: + stage: test + script: + - go fmt $(go list ./... | grep -v /vendor/) + - go vet $(go list ./... | grep -v /vendor/) + - go test -race $(go list ./... | grep -v /vendor/) -cover + +coverage: + stage: coverage + script: + - ./coverage.sh + artifacts: + paths: + - coverage.html diff --git a/connection/README.md b/connection/README.md new file mode 100644 index 00000000..57eedb83 --- /dev/null +++ b/connection/README.md @@ -0,0 +1,4 @@ +IcingaDB Connection Library + +[![pipeline status](https://git.icinga.com/icingadb/icingadb-connection-lib/badges/master/pipeline.svg)](https://git.icinga.com/icingadb/icingadb-connection-lib/commits/master) +[![coverage report](https://git.icinga.com/icingadb/icingadb-connection-lib/badges/master/coverage.svg)](https://git.icinga.com/icingadb/icingadb-connection-lib/-/jobs/artifacts/master/raw/coverage.html?job=coverage) \ No newline at end of file diff --git a/connection/coverage.sh b/connection/coverage.sh new file mode 100755 index 00000000..4b35f56d --- /dev/null +++ b/connection/coverage.sh @@ -0,0 +1,3 @@ +go test -race -cover -coverprofile=c.out +go tool cover -html=c.out -o coverage.html +rm c.out diff --git a/connection/mysql.go b/connection/mysql.go new file mode 100644 index 00000000..06cbef43 --- /dev/null +++ b/connection/mysql.go @@ -0,0 +1,704 @@ +package connection + +import ( + "container/list" + "context" + "database/sql" + "fmt" + "git.icinga.com/icingadb/icingadb-main/configobject" + "git.icinga.com/icingadb/icingadb-utils" + log "github.com/sirupsen/logrus" + "strings" + "sync" + "sync/atomic" + "time" +) + +// This is used in SqlFetchAll and SqlFetchAllQuiet +type DbClientOrTransaction interface { + Query(query string, args ...interface{}) (*sql.Rows, error) + Exec(query string, args ...interface{}) (sql.Result, error) +} + +type DbClient interface { + DbClientOrTransaction + Ping() error + BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) +} + +type DbTransaction interface { + DbClientOrTransaction + Commit() error + Rollback() error +} + +func NewDBWrapper(dbDsn string) (*DBWrapper, error) { + db, err := mkMysql("mysql", dbDsn) + + if err != nil { + return nil, err + } + + dbw := DBWrapper{Db: db, ConnectedAtomic: new(uint32), ConnectionLostCounterAtomic: new(uint32)} + dbw.ConnectionUpCondition = sync.NewCond(&sync.Mutex{}) + + err = dbw.Db.Ping() + if err != nil { + return nil, err + } + + go func() { + for { + dbw.checkConnection(true) + time.Sleep(dbw.getConnectionCheckInterval()) + } + }() + + return &dbw, nil +} + +// Database wrapper including helper functions +type DBWrapper struct { + Db DbClient + ConnectedAtomic *uint32 //uint32 to be able to use atomic operations + ConnectionUpCondition *sync.Cond + ConnectionLostCounterAtomic *uint32 //uint32 to be able to use atomic operations +} + +func (dbw *DBWrapper) IsConnected() bool { + return atomic.LoadUint32(dbw.ConnectedAtomic) != 0 +} + +func (dbw *DBWrapper) CompareAndSetConnected(connected bool) (swapped bool) { + if connected { + return atomic.CompareAndSwapUint32(dbw.ConnectedAtomic, 0, 1) + } else { + return atomic.CompareAndSwapUint32(dbw.ConnectedAtomic, 1, 0) + } +} + +func (dbw *DBWrapper) getConnectionCheckInterval() time.Duration { + if !dbw.IsConnected() { + v := atomic.LoadUint32(dbw.ConnectionLostCounterAtomic) + if v < 4 { + return 5 * time.Second + } else if v < 8 { + return 10 * time.Second + } else if v < 11 { + return 30 * time.Second + } else if v < 14 { + return 60 * time.Second + } else { + log.Fatal("Could not connect to SQL for over 5 minutes. Shutting down...") + } + } + + return 15 * time.Second +} + +func (dbw *DBWrapper) checkConnection(isTicker bool) bool { + err := dbw.Db.Ping() + if err != nil { + if dbw.CompareAndSetConnected(false) { + log.WithFields(log.Fields{ + "context": "sql", + "error": err, + }).Error("SQL connection lost. Trying to reconnect") + } else if isTicker { + atomic.AddUint32(dbw.ConnectionLostCounterAtomic, 1) + + log.WithFields(log.Fields{ + "context": "sql", + "error": err, + }).Debugf("SQL connection lost. Trying again in %s", dbw.getConnectionCheckInterval()) + } + + return false + } else { + if dbw.CompareAndSetConnected(true) { + log.Info("SQL connection established") + atomic.StoreUint32(dbw.ConnectionLostCounterAtomic, 0) + dbw.ConnectionUpCondition.Broadcast() + } + + return true + } +} + +func (dbw *DBWrapper) WaitForConnection() { + dbw.ConnectionUpCondition.L.Lock() + dbw.ConnectionUpCondition.Wait() + dbw.ConnectionUpCondition.L.Unlock() +} + +func (dbw *DBWrapper) WithRetry(f func() (sql.Result, error)) (sql.Result, error) { + for { + res, err := f() + + if err != nil { + if isRetryableError(err) { + continue + } else { + return nil, err + } + } + + return res, err + } +} + +func (dbw *DBWrapper) SqlQuery(query string, args ...interface{}) (*sql.Rows, error) { + for { + if !dbw.IsConnected() { + dbw.WaitForConnection() + continue + } + + res, err := dbw.Db.Query(query, args...) + DbOperationsQuery.Inc() + + if err != nil { + if !dbw.checkConnection(false) { + continue + } + } + + return res, err + } +} + +// Wrapper around Db.BeginTx() for auto-logging +func (dbw *DBWrapper) SqlBegin(concurrencySafety bool, quiet bool) (DbTransaction, error) { + var isoLvl sql.IsolationLevel + if concurrencySafety { + isoLvl = sql.LevelSerializable + } else { + isoLvl = sql.LevelReadCommitted + } + + for { + if !dbw.IsConnected() { + dbw.WaitForConnection() + continue + } + + var err error + var tx DbTransaction + if quiet { + tx, err = dbw.Db.BeginTx(context.Background(), &sql.TxOptions{Isolation: isoLvl}) + } else { + benchmarc := icingadb_utils.NewBenchmark() + tx, err = dbw.Db.BeginTx(context.Background(), &sql.TxOptions{Isolation: isoLvl}) + benchmarc.Stop() + + DbIoSeconds.WithLabelValues("mysql", "begin").Observe(benchmarc.Seconds()) + + log.WithFields(log.Fields{ + "context": "sql", + "benchmark": benchmarc, + }).Debug("BEGIN transaction") + } + + if err != nil { + if !dbw.checkConnection(false) { + continue + } + } + + return tx, err + } +} + +// Wrapper around tx.Commit() for auto-logging +func (dbw *DBWrapper) SqlCommit(tx DbTransaction, quiet bool) error { + for { + if !dbw.IsConnected() { + dbw.WaitForConnection() + continue + } + + var err error + if quiet { + err = tx.Commit() + } else { + benchmarc := icingadb_utils.NewBenchmark() + err = tx.Commit() + benchmarc.Stop() + + DbIoSeconds.WithLabelValues("mysql", "commit").Observe(benchmarc.Seconds()) + + log.WithFields(log.Fields{ + "context": "sql", + "benchmark": benchmarc, + }).Debug("COMMIT transaction") + } + + if err != nil { + if !dbw.checkConnection(false) { + continue + } + } + + return err + } +} + +// Wrapper around tx.Rollback() for auto-logging +func (dbw *DBWrapper) SqlRollback(tx DbTransaction, quiet bool) error { + for { + if !dbw.IsConnected() { + dbw.WaitForConnection() + continue + } + + var err error + if !quiet { + benchmarc := icingadb_utils.NewBenchmark() + err = tx.Rollback() + benchmarc.Stop() + + DbIoSeconds.WithLabelValues("mysql", "rollback").Observe(benchmarc.Seconds()) + + log.WithFields(log.Fields{ + "context": "sql", + "benchmark": benchmarc, + }).Debug("ROLLBACK transaction") + } else { + err = tx.Rollback() + } + + if err != nil { + if !dbw.checkConnection(false) { + continue + } + } + + return err + } +} + +// Wrapper around sql.Exec() for auto-logging +func (dbw *DBWrapper) SqlExec(opDescription string, sql string, args ...interface{}) (sql.Result, error) { + return dbw.sqlExecInternal(dbw.Db, opDescription, sql, false, args...) +} + +// No logging, no benchmarking +func (dbw *DBWrapper) SqlExecQuiet(opDescription string, sql string, args ...interface{}) (sql.Result, error) { + return dbw.sqlExecInternal(dbw.Db, opDescription, sql, true, args...) +} + +// Wrapper around tx.Exec() for auto-logging +func (dbw *DBWrapper) SqlExecTx(tx DbTransaction, opDescription string, sql string, args ...interface{}) (sql.Result, error) { + return dbw.sqlExecInternal(tx, opDescription, sql, false, args...) +} + +// No logging, no benchmarking +func (dbw *DBWrapper) SqlExecTxQuiet(tx DbTransaction, opDescription string, sql string, args ...interface{}) (sql.Result, error) { + return dbw.sqlExecInternal(tx, opDescription, sql, true, args...) +} + +func (dbw *DBWrapper) SqlFetchAll(queryDescription string, query string, args ...interface{}) ([][]interface{}, error) { + return dbw.sqlFetchAllInternal(dbw.Db, queryDescription, query, false, args...) +} + +func (dbw *DBWrapper) SqlFetchAllQuiet(queryDescription string, query string, args ...interface{}) ([][]interface{}, error) { + return dbw.sqlFetchAllInternal(dbw.Db, queryDescription, query, true, args...) +} + +func (dbw *DBWrapper) SqlFetchAllTx(tx DbTransaction, queryDescription string, query string, args ...interface{}) ([][]interface{}, error) { + return dbw.sqlFetchAllInternal(tx, queryDescription, query, false, args...) +} + +func (dbw *DBWrapper) SqlFetchAllTxQuiet(tx DbTransaction, queryDescription string, query string, args ...interface{}) ([][]interface{}, error) { + return dbw.sqlFetchAllInternal(tx, queryDescription, query, true, args...) +} + +// Wrapper around sql.Exec() for auto-logging +func (dbw *DBWrapper) sqlExecInternal(db DbClientOrTransaction, opDescription string, sql string, quiet bool, args ...interface{}) (sql.Result, error) { + for { + if !dbw.IsConnected() { + dbw.WaitForConnection() + continue + } + + var benchmarc *icingadb_utils.Benchmark + if !quiet { + benchmarc = icingadb_utils.NewBenchmark() + } + res, err := db.Exec(sql, args...) + DbOperationsExec.Inc() + if !quiet { + benchmarc.Stop() + } + + if !quiet { + DbIoSeconds.WithLabelValues("mysql", opDescription).Observe(benchmarc.Seconds()) + log.WithFields(log.Fields{ + "context": "sql", + "benchmark": benchmarc, + "affected_rows": prettyPrintedRowsAffected{res}, + "args": prettyPrintedArgs{args}, + "query": prettyPrintedSql{sql}, + }).Debug("Finished Exec") + } + + + if err != nil { + if !dbw.checkConnection(false) { + continue + } + } + + return res, err + } +} + +// Wrapper around Db.SqlQuery() for auto-logging +func (dbw *DBWrapper) sqlFetchAllInternal(db DbClientOrTransaction, queryDescription string, query string, quiet bool, args ...interface{}) ([][]interface{}, error) { + for { + if !dbw.IsConnected() { + dbw.WaitForConnection() + continue + } + + res, err := sqlTryFetchAll(db, queryDescription, query, quiet, args...) + + if err != nil { + if _, isDb := db.(*sql.DB); isDb { + if !dbw.checkConnection(false) { + continue + } + } + } + + return res, err + } +} + +func sqlTryFetchAll(db DbClientOrTransaction, queryDescription string, query string, quiet bool, args ...interface{}) ([][]interface{}, error) { + var benchmarc *icingadb_utils.Benchmark + if !quiet { + benchmarc = icingadb_utils.NewBenchmark() + } + rows, errQuery := db.Query(query, args...) + DbOperationsQuery.Inc() + if !quiet { + benchmarc.Stop() + } + + rowsCount := 0 + + defer func() { + if !quiet { + DbIoSeconds.WithLabelValues("mysql", queryDescription).Observe(benchmarc.Seconds()) + log.WithFields(log.Fields{ + "context": "sql", + "benchmark": benchmarc, + "query": prettyPrintedSql{query}, + "args": prettyPrintedArgs{args}, + "affected_Rows": rowsCount, + }).Debug("Finished FetchAll") + } + }() + + if errQuery != nil { + return [][]interface{}{}, errQuery + } + + defer rows.Close() + + columnTypes, errCT := rows.ColumnTypes() + if errCT != nil { + return [][]interface{}{}, errCT + } + + colsPerRow := len(columnTypes) + buf := list.New() + bridges := make([]dbTypeBridge, colsPerRow) + scanDest := make([]interface{}, colsPerRow) + + for i, columnType := range columnTypes { + typ := columnType.DatabaseTypeName() + factory, hasFactory := dbTypeBridgeFactories[typ] + if hasFactory { + bridges[i] = factory() + } else { + bridges[i] = &dbBrokenBridge{typ: typ} + } + + scanDest[i] = bridges[i] + } + + for { + if rows.Next() { + if errScan := rows.Scan(scanDest...); errScan != nil { + return [][]interface{}{}, errScan + } + + row := make([]interface{}, colsPerRow) + + for i, bridge := range bridges { + row[i] = bridge.Result() + } + + buf.PushBack(row) + } else if errNx := rows.Err(); errNx == nil { + break + } else { + return nil, errNx + } + } + + res := make([][]interface{}, buf.Len()) + + for current, i := buf.Front(), 0; current != nil; current = current.Next() { + res[i] = current.Value.([]interface{}) + i++ + } + + rowsCount = len(res) + + return res, nil +} + +// sqlTransaction executes the given function inside a transaction. +func (dbw DBWrapper) SqlTransaction(concurrencySafety bool, retryOnConnectionFailure bool, quiet bool, f func(DbTransaction) error) error { + for { + if !dbw.IsConnected() { + dbw.WaitForConnection() + continue + } + + var benchmarc *icingadb_utils.Benchmark + if !quiet { + benchmarc = icingadb_utils.NewBenchmark() + } + errTx := dbw.sqlTryTransaction(f, concurrencySafety, false) + if !quiet { + benchmarc.Stop() + DbIoSeconds.WithLabelValues("mysql", "transaction").Observe(benchmarc.Seconds()) + + log.WithFields(log.Fields{ + "context": "sql", + "benchmark": benchmarc, + }).Debug("Executed transaction") + } + + if errTx != nil { + //TODO: Do this only for concurrencySafety = true, once we figure out the serialization errors. + if isSerializationFailure(errTx) { + if !quiet { + log.WithFields(log.Fields{ + "context": "sql", + "error": errTx, + }).Debug("Repeating transaction") + } + continue + } + + if !dbw.checkConnection(false) { + if retryOnConnectionFailure { + continue + } else { + return MysqlConnectionError{"Transaction failed duo to a connection error"} + } + } + + log.WithFields(log.Fields{ + "context": "sql", + "error": errTx, + }).Warn("SQL error occurred") + } + + return errTx + } +} + +// Executes the given function inside a transaction +func (dbw *DBWrapper) sqlTryTransaction(f func(transaction DbTransaction) error, concurrencySafety bool, quiet bool) error { + tx, errBegin := dbw.SqlBegin(concurrencySafety, quiet) + if errBegin != nil { + return errBegin + } + + errTx := f(tx) + if errTx != nil { + dbw.SqlRollback(tx, quiet) + return errTx + } + + return dbw.SqlCommit(tx, quiet) +} + +func (dbw *DBWrapper) SqlFetchIds(envId []byte, table string) ([]string, error) { + var keys []string + for { + if !dbw.IsConnected() { + dbw.WaitForConnection() + continue + } + + rows, err := dbw.SqlQuery(fmt.Sprintf("SELECT id FROM %s WHERE env_id=(X'%s')", table, icingadb_utils.DecodeChecksum(envId))) + + if err != nil { + if !dbw.checkConnection(false) { + continue + } + + return nil, err + } + + defer rows.Close() + + for rows.Next() { + var id []byte + + err = rows.Scan(&id) + if err != nil { + return nil, err + } + + keys = append(keys, icingadb_utils.DecodeChecksum(id)) + } + + err = rows.Err() + if err != nil { + return nil, err + } + + return keys, nil + } +} + +func (dbw *DBWrapper) SqlFetchChecksums(table string, ids []string) (map[string]map[string]string, error) { + var checksums = map[string]map[string]string{} + + done := make(chan struct{}) + //TODO: Don't do this hardcoded - Chunksize + for bulk := range icingadb_utils.ChunkKeys(done, ids, 1000) { + //TODO: This should be done in parallel + query := fmt.Sprintf("SELECT id, properties_checksum FROM %s WHERE id IN (X'%s')", table, strings.Join(bulk, "', X'")) + rows, err := dbw.SqlQuery(query) + + if err != nil { + if !dbw.checkConnection(false) { + continue + } + + return nil, err + } + + defer rows.Close() + + for rows.Next() { + var id []byte + var propertiesChecksum []byte + + err = rows.Scan(&id, &propertiesChecksum) + if err != nil { + return nil, err + } + + checksums[icingadb_utils.DecodeChecksum(id)] = map[string]string{ + "properties_checksum": icingadb_utils.DecodeChecksum(propertiesChecksum), + } + } + + err = rows.Err() + if err != nil { + return nil, err + } + } + + return checksums, nil +} + +func (dbw *DBWrapper) SqlBulkInsert(rows []configobject.Row, stmt *BulkInsertStmt) error { + if len(rows) == 0 { + return nil + } + + placeholders := make([]string, len(rows)) + values := make([]interface{}, len(rows)*stmt.NumField) + j := 0 + + for i, r := range rows { + placeholders[i] = stmt.Placeholder + + for _, v := range r.InsertValues() { + values[j] = v + j++ + } + } + + query := fmt.Sprintf(stmt.Format, strings.Join(placeholders, ", ")) + + _, err := dbw.WithRetry(func() (result sql.Result, e error) { + return dbw.SqlExec("Bulk insert", query, values...) + }) + + if err != nil { + return err + } + + return nil +} + +func (dbw *DBWrapper) SqlBulkDelete(keys []string, stmt *BulkDeleteStmt) error { + if len(keys) == 0 { + return nil + } + + done := make(chan struct{}) + defer close(done) + + //TODO: Don't do this hardcoded - Chunksize + for bulk := range icingadb_utils.ChunkKeys(done, keys, 1000) { + placeholders := strings.TrimSuffix(strings.Repeat("?, ", len(bulk)), ", ") + values := make([]interface{}, len(bulk)) + + for i, key := range bulk { + values[i] = icingadb_utils.Checksum(key) + } + query := fmt.Sprintf(stmt.Format, placeholders) + + _, err := dbw.WithRetry(func() (result sql.Result, e error) { + return dbw.SqlExec("Bulk delete", query, values...) + }) + if err != nil { + return err + } + } + + return nil +} + +func (dbw *DBWrapper) SqlBulkUpdate(rows []configobject.Row, stmt *BulkUpdateStmt) error { + if len(rows) == 0 { + return nil + } + + placeholders := make([]string, len(rows)) + values := make([]interface{}, len(rows)*stmt.NumField) + j := 0 + + for i, r := range rows { + placeholders[i] = stmt.Placeholder + + for _, v := range r.InsertValues() { + values[j] = v + j++ + } + } + + query := fmt.Sprintf(stmt.Format, strings.Join(placeholders, ", ")) + + _, err := dbw.WithRetry(func() (result sql.Result, e error) { + return dbw.SqlExec("Bulk update", query, values...) + }) + if err != nil { + return err + } + + return nil +} \ No newline at end of file diff --git a/connection/mysql_test.go b/connection/mysql_test.go new file mode 100644 index 00000000..fafd95e9 --- /dev/null +++ b/connection/mysql_test.go @@ -0,0 +1,318 @@ +package connection + +import ( + "context" + "database/sql" + "errors" + "github.com/go-sql-driver/mysql" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "sync" + "sync/atomic" + "testing" + "time" +) + +type SqlResultMock struct { + sql.Result +} + +type TransactionMock struct { + mock.Mock +} + +func (m *TransactionMock) Query(query string, args ...interface{}) (*sql.Rows, error) { + args2 := m.Called(query, args) + return args2.Get(0).(*sql.Rows), args2.Error(1) +} + +func (m *TransactionMock) Exec(query string, args ...interface{}) (sql.Result, error) { + args2 := m.Called(query, args) + return args2.Get(0).(sql.Result), args2.Error(1) +} + +func (m *TransactionMock) Commit() error { + args := m.Called() + return args.Error(0) +} + +func (m *TransactionMock) Rollback() error { + args := m.Called() + return args.Error(0) +} + +type DbMock struct { + mock.Mock +} + +func (m *DbMock) Ping() error { + args := m.Called() + return args.Error(0) +} + +func (m *DbMock) Query(query string, args ...interface{}) (*sql.Rows, error) { + args2 := m.Called(query, args) + return args2.Get(0).(*sql.Rows), args2.Error(1) +} + +func (m *DbMock) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { + args := m.Called(ctx, opts) + return args.Get(0).(*sql.Tx), args.Error(1) +} + +func (m *DbMock) Exec(query string, args ...interface{}) (sql.Result, error) { + args2 := m.Called(query, args) + return args2.Get(0).(sql.Result), args2.Error(1) +} + +func NewTestDBW(db DbClient) DBWrapper { + dbw := DBWrapper{Db: db, ConnectedAtomic: new(uint32), ConnectionLostCounterAtomic: new(uint32)} + dbw.ConnectionUpCondition = sync.NewCond(&sync.Mutex{}) + return dbw +} + +func TestNewDBWrapper(t *testing.T) { + _, err := NewDBWrapper("asdasd") + assert.Error(t, err) + //TODO: Add more tests here +} + +func TestDBWrapper_CheckConnection(t *testing.T) { + mockDb := new(DbMock) + dbw := NewTestDBW(mockDb) + + atomic.StoreUint32(dbw.ConnectionLostCounterAtomic, 512312312) + mockDb.On("Ping").Return(nil).Once() + assert.True(t, dbw.checkConnection(false), "DBWrapper should be connected") + assert.Equal(t, uint32(0), atomic.LoadUint32(dbw.ConnectionLostCounterAtomic)) + + atomic.StoreUint32(dbw.ConnectionLostCounterAtomic, 0) + mockDb.On("Ping").Return(mysql.ErrInvalidConn).Once() + assert.False(t, dbw.checkConnection(false), "DBWrapper should not be connected") + assert.Equal(t, uint32(0), atomic.LoadUint32(dbw.ConnectionLostCounterAtomic)) + + atomic.StoreUint32(dbw.ConnectionLostCounterAtomic, 10) + mockDb.On("Ping").Return(mysql.ErrInvalidConn).Once() + assert.False(t, dbw.checkConnection(true), "DBWrapper should not be connected") + assert.Equal(t, uint32(11), atomic.LoadUint32(dbw.ConnectionLostCounterAtomic)) +} + +func TestDBWrapper_SqlCommit(t *testing.T) { + mockDb := new(DbMock) + dbw := NewTestDBW(mockDb) + mockTx := new(TransactionMock) + + mockTx.On("Commit").Return(errors.New("whoops")).Once() + mockTx.On("Commit").Return( nil).Once() + mockDb.On("Ping").Return(errors.New("whoops")).Once() + + var err error + done := make(chan bool) + + dbw.CompareAndSetConnected(true) + go func() { + err = dbw.SqlCommit(mockTx, false) + done <- true + }() + + time.Sleep(time.Millisecond * 50) + + dbw.CompareAndSetConnected(true) + dbw.ConnectionUpCondition.Broadcast() + + <- done + + assert.NoError(t, err) + mockTx.AssertExpectations(t) + mockDb.AssertExpectations(t) +} + +func TestDBWrapper_SqlBegin(t *testing.T) { + mockDb := new(DbMock) + dbw := NewTestDBW(mockDb) + + mockDb.On("BeginTx", context.Background(), &sql.TxOptions{Isolation: sql.LevelReadCommitted}).Return(&sql.Tx{}, errors.New("whoops")).Once() + mockDb.On("BeginTx", context.Background(), &sql.TxOptions{Isolation: sql.LevelReadCommitted}).Return(&sql.Tx{}, nil).Once() + mockDb.On("Ping").Return(errors.New("whoops")).Once() + + var err error + done := make(chan bool) + + dbw.CompareAndSetConnected(true) + go func() { + _, err = dbw.SqlBegin(false, false) + done <- true + }() + + time.Sleep(time.Millisecond * 50) + + dbw.CompareAndSetConnected(true) + dbw.ConnectionUpCondition.Broadcast() + + <- done + + assert.NoError(t, err) + mockDb.AssertExpectations(t) +} + +func TestDBWrapper_SqlTransaction(t *testing.T) { + dbw, err := NewDBWrapper( "module-dev:icinga0815!@tcp(127.0.0.1:3306)/icingadb") + assert.NoError(t, err, "Is the MySQL server running?") + + err = dbw.SqlTransaction(false, true, false, func(tx DbTransaction) error { + return nil + }) + + assert.NoError(t, err) + + err = dbw.SqlTransaction(false, true, false, func(tx DbTransaction) error { + return errors.New("whoops") + }) + + assert.Error(t, err) +} + +func TestDBWrapper_WithRetry(t *testing.T) { + mockDb := new(DbMock) + dbw := NewTestDBW(mockDb) + + tries := 0 + + _, err := dbw.WithRetry(func() (result sql.Result, e error) { + if tries > 0 { + tries++ + return nil, nil + } else { + tries++ + return nil, errors.New("Deadlock found when trying to get lock") + } + }) + + assert.NoError(t, err) + assert.Equal(t, 2, tries) + + _, err = dbw.WithRetry(func() (result sql.Result, e error) { + return nil, errors.New("something went wrong") + }) + + assert.Error(t, err) +} + +func TestDBWrapper_SqlQuery(t *testing.T) { + mockDb := new(DbMock) + dbw := NewTestDBW(mockDb) + + mockDb.On("Query", "test", []interface{}(nil)).Return(&sql.Rows{}, errors.New("whoops")).Once() + mockDb.On("Query", "test", []interface{}(nil)).Return(&sql.Rows{}, nil).Once() + mockDb.On("Ping").Return(errors.New("whoops")).Once() + + var err error + done := make(chan bool) + + dbw.CompareAndSetConnected(true) + go func() { + _, err = dbw.SqlQuery("test") + done <- true + }() + + time.Sleep(time.Millisecond * 50) + + dbw.CompareAndSetConnected(true) + dbw.ConnectionUpCondition.Broadcast() + + <- done + + assert.NoError(t, err) + mockDb.AssertExpectations(t) +} + +func TestDBWrapper_SqlExec(t *testing.T) { + mockDb := new(DbMock) + dbw := NewTestDBW(mockDb) + + mockDb.On("Exec", "test", []interface{}(nil)).Return(SqlResultMock{}, errors.New("whoops")).Once() + mockDb.On("Exec", "test", []interface{}(nil)).Return(SqlResultMock{}, nil).Once() + mockDb.On("Ping").Return(errors.New("whoops")).Once() + + var err error + done := make(chan bool) + + dbw.CompareAndSetConnected(true) + go func() { + _, err = dbw.SqlExec("test", "test") + done <- true + }() + + time.Sleep(time.Millisecond * 50) + + dbw.CompareAndSetConnected(true) + dbw.ConnectionUpCondition.Broadcast() + + <- done + + assert.NoError(t, err) + mockDb.AssertExpectations(t) +} + +func TestGetConnectionCheckInterval(t *testing.T) { + dbw := NewTestDBW(nil) + + //Should return 15s, if connected - counter doesn't madder + dbw.CompareAndSetConnected(true) + assert.Equal(t, 15*time.Second, dbw.getConnectionCheckInterval()) + + //Should return 5s, if not connected and counter < 4 + dbw.CompareAndSetConnected(false) + atomic.StoreUint32(dbw.ConnectionLostCounterAtomic, 0) + assert.Equal(t, 5*time.Second, dbw.getConnectionCheckInterval()) + + //Should return 10s, if not connected and 4 <= counter < 8 + dbw.CompareAndSetConnected(false) + atomic.StoreUint32(dbw.ConnectionLostCounterAtomic, 4) + assert.Equal(t, 10*time.Second, dbw.getConnectionCheckInterval()) + + //Should return 30s, if not connected and 8 <= counter < 11 + dbw.CompareAndSetConnected(false) + atomic.StoreUint32(dbw.ConnectionLostCounterAtomic, 8) + assert.Equal(t, 30*time.Second, dbw.getConnectionCheckInterval()) + + //Should return 60s, if not connected and 11 <= counter < 14 + dbw.CompareAndSetConnected(false) + atomic.StoreUint32(dbw.ConnectionLostCounterAtomic, 11) + assert.Equal(t, 60*time.Second, dbw.getConnectionCheckInterval()) + + //dbw.ConnectionLostCounter = 14 + //interval = dbw.getConnectionCheckInterval() + //TODO: Check for Fatal +} + +func TestDBWrapper_SqlFetchAll(t *testing.T) { + dbw, err := NewDBWrapper("module-dev:icinga0815!@tcp(127.0.0.1:3306)/icingadb") + assert.NoError(t, err, "Is the MySQL server running?") + + _, err = dbw.Db.Exec("CREATE TABLE testing0815 (id INT NOT NULL AUTO_INCREMENT PRIMARY KEY, name varchar(255) NOT NULL)") + assert.NoError(t, err) + + _, err = dbw.Db.Exec("INSERT INTO testing0815 (name) VALUES ('horst'), ('test')") + assert.NoError(t, err) + + var res [][]interface{} + done := make(chan bool) + dbw.CompareAndSetConnected(false) + go func() { + res, err = dbw.SqlFetchAll("test", "SELECT * FROM testing0815") + done <- true + }() + + time.Sleep(time.Millisecond * 50) + + dbw.checkConnection(true) + + <- done + + assert.NoError(t, err) + assert.Equal(t, [][]interface {}([][]interface {}{{int64(1), "horst"}, {int64(2), "test"}}), res) + + _, err = dbw.Db.Exec("DROP TABLE testing0815") + assert.NoError(t, err) +} diff --git a/connection/mysql_utils.go b/connection/mysql_utils.go new file mode 100644 index 00000000..8220d5e9 --- /dev/null +++ b/connection/mysql_utils.go @@ -0,0 +1,425 @@ +package connection + +import ( + "database/sql" + "encoding/hex" + "errors" + "fmt" + "github.com/go-sql-driver/mysql" + "io/ioutil" + "reflect" + "sort" + "strconv" + "strings" + log "github.com/sirupsen/logrus" + oldlog "log" +) + +// mkMysql creates a new MySQL client. +func mkMysql(dbType string, dbDsn string) (*sql.DB, error) { + log.Info("Connecting to MySQL") + + sep := "?" + + if dbDsn == "" { + dbDsn = "/" + } else { + dsnParts := strings.Split(dbDsn, "/") + if strings.Contains(dsnParts[len(dsnParts)-1], "?") { + sep = "&" + } + } + + dbDsn = dbDsn + sep + + "innodb_strict_mode=1&sql_mode='STRICT_ALL_TABLES,NO_ZERO_IN_DATE,NO_ZERO_DATE,NO_ENGINE_SUBSTITUTION,PIPES_AS_CONCAT,ANSI_QUOTES,ERROR_FOR_DIVISION_BY_ZERO,NO_AUTO_CREATE_USER'" + + db, errConn := sql.Open(dbType, dbDsn) + if errConn != nil { + return nil, errConn + } + + mysql.SetLogger(oldlog.New(ioutil.Discard, "", 0)) + + db.SetMaxOpenConns(50) + db.SetMaxIdleConns(0) + + return db, nil +} + +type dbTypeBridge interface { + sql.Scanner + Result() interface{} +} + +type dbIntBridge struct { + result interface{} +} + +func (d *dbIntBridge) Scan(src interface{}) (err error) { + baseScanner := sql.NullInt64{} + err = baseScanner.Scan(src) + + if err == nil { + if baseScanner.Valid { + d.result = baseScanner.Int64 + } else { + d.result = nil + } + } + + return +} + +func (d *dbIntBridge) Result() interface{} { + return d.result +} + +type dbFloatBridge struct { + result interface{} +} + +func (d *dbFloatBridge) Scan(src interface{}) (err error) { + baseScanner := sql.NullFloat64{} + err = baseScanner.Scan(src) + + if err == nil { + if baseScanner.Valid { + d.result = baseScanner.Float64 + } else { + d.result = nil + } + } + + return +} + +func (d *dbFloatBridge) Result() interface{} { + return d.result +} + +type dbStringBridge struct { + result interface{} +} + +func (d *dbStringBridge) Scan(src interface{}) (err error) { + baseScanner := sql.NullString{} + err = baseScanner.Scan(src) + + if err == nil { + if baseScanner.Valid { + d.result = baseScanner.String + } else { + d.result = nil + } + } + + return +} + +func (d *dbStringBridge) Result() interface{} { + return d.result +} + +type dbBytesBridge struct { + result interface{} +} + +func (d *dbBytesBridge) Scan(src interface{}) (err error) { + baseScanner := sql.NullString{} + err = baseScanner.Scan(src) + + if err == nil { + if baseScanner.Valid { + d.result = []byte(baseScanner.String) + } else { + d.result = nil + } + } + + return +} + +func (d *dbBytesBridge) Result() interface{} { + return d.result +} + +var dbTypeBridgeFactories = map[string]func() dbTypeBridge{ + // MySQL + "TINYINT": func() dbTypeBridge { + return &dbIntBridge{} + }, + "SMALLINT": func() dbTypeBridge { + return &dbIntBridge{} + }, + "INT": func() dbTypeBridge { + return &dbIntBridge{} + }, + "BIGINT": func() dbTypeBridge { + return &dbIntBridge{} + }, + "FLOAT": func() dbTypeBridge { + return &dbFloatBridge{} + }, + "CHAR": func() dbTypeBridge { + return &dbStringBridge{} + }, + "VARCHAR": func() dbTypeBridge { + return &dbStringBridge{} + }, + "ENUM": func() dbTypeBridge { + return &dbStringBridge{} + }, + "BINARY": func() dbTypeBridge { + return &dbBytesBridge{} + }, + + // SQLite + "INTEGER": func() dbTypeBridge { + return &dbIntBridge{} + }, + "REAL": func() dbTypeBridge { + return &dbFloatBridge{} + }, + "TEXT": func() dbTypeBridge { + return &dbStringBridge{} + }, + "BLOB": func() dbTypeBridge { + return &dbBytesBridge{} + }, + // SELECT 1 FROM ... + "": func() dbTypeBridge { + return &dbIntBridge{} + }, +} + +type dbBrokenBridge struct { + typ string +} + +func (d *dbBrokenBridge) Scan(src interface{}) error { + types := make([]string, len(dbTypeBridgeFactories)) + typeIdx := 0 + + for typ := range dbTypeBridgeFactories { + types[typeIdx] = typ + typeIdx++ + } + + sort.Strings(types) + + return errors.New(fmt.Sprintf("bad column type %s, expected one of %s", d.typ, strings.Join(types, ", "))) +} + +func (d *dbBrokenBridge) Result() interface{} { + return nil +} + +var prettyPrintedSqlReplacer = strings.NewReplacer("\n", " ", "\t", "") + +type prettyPrintedSql struct { + sql string +} + +// String implements and interface from Stringer +func (p prettyPrintedSql) String() string { + return strings.TrimSpace(prettyPrintedSqlReplacer.Replace(p.sql)) +} + +// MarshalText implements an interface from TextMarshaler +func (p prettyPrintedSql) MarshalText() (text []byte, err error) { + return []byte(p.String()), nil +} + +type prettyPrintedArgs struct { + args []interface{} +} + +func (p *prettyPrintedArgs) String() string { + res := "[" + + for _, v := range p.args { + if byteArray, isByteArray := v.([]byte); isByteArray { + res = fmt.Sprintf("%s hex.DecodeString(\"%s\"),", res, hex.EncodeToString(byteArray)) + } else { + res = fmt.Sprintf("%s %#v,", res, v) + } + } + + return res + " ]" +} + +// MarshalText implements an interface from TextMarshaler +func (p prettyPrintedArgs) MarshalText() (text []byte, err error) { + return []byte(p.String()), nil +} + +type prettyPrintedRowsAffected struct { + result sql.Result +} + +// String implements and interface from Stringer +func (d prettyPrintedRowsAffected) String() string { + if d.result != nil { + rows, errRA := d.result.RowsAffected() + if errRA == nil { + return strconv.FormatInt(rows, 10) + } + } + + return "N/A" +} + +// MarshalText implements an interface from TextMarshaler +func (d prettyPrintedRowsAffected) MarshalText() (text []byte, err error) { + return []byte(d.String()), nil +} + +type MysqlConnectionError struct { + err string +} + +func (e MysqlConnectionError) Error() string { + return e.err +} + +// Returns whether the given error signals serialization failure +// https://dev.mysql.com/doc/refman/5.5/en/error-messages-server.html#error_er_lock_deadlock +func isSerializationFailure(e error) bool { + switch err := e.(type) { + case *mysql.MySQLError: + switch err.Number { + // Those are the error numbers for serialization failures, upon which we retry + case 1205, 1213: + return true + } + } + + return false +} + +func formatLogQuery(query string) string { + r := strings.NewReplacer("\n", " ", "\t", "") + return strings.TrimSpace(r.Replace(query)) +} + +// Go bool -> DB bool +var yesNo = map[bool]string{ + true: "y", + false: "n", +} + +func ConvertValueForDb(in interface{}) (interface{}, error) { + switch value := in.(type) { + case []byte: + case string: + case float64: + case int64: + case nil: + case float32: + return float64(value), nil + case uint: + return int64(value), nil + case uint8: + return int64(value), nil + case uint16: + return int64(value), nil + case uint32: + return int64(value), nil + case uint64: + return int64(value), nil + case int: + return int64(value), nil + case int8: + return int64(value), nil + case int16: + return int64(value), nil + case int32: + return int64(value), nil + case bool: + return yesNo[value], nil + default: + return nil, errors.New(fmt.Sprintf( + "bad type %s, expected one of []byte, string, float{32,64}, {,u}int{,8,16,32,64}, bool, nil", + reflect.TypeOf(in).Name(), + )) + } + + return in, nil +} + +func MakePlaceholderList(x int) string { + runes := make([]rune, 1+x*2) + + i := 1 + for j := 0; j < x; j++ { + runes[i] = '?' + i++ + + runes[i] = ',' + i++ + } + + runes[0] = '(' + runes[len(runes)-1] = ')' + + return string(runes) +} + +func isRetryableError(err error) bool { + if strings.Contains(err.Error(), "Deadlock found when trying to get lock") { + return true + } + return false +} + +type BulkInsertStmt struct { + Format string + Fields []string + Placeholder string + NumField int +} + +func NewBulkInsertStmt(table string, fields []string) *BulkInsertStmt { + numField := len(fields) + placeholder := fmt.Sprintf("(%s)", strings.TrimSuffix(strings.Repeat("?, ", numField), ", ")) + stmt := BulkInsertStmt{ + Format: fmt.Sprintf("INSERT INTO %s (%s) VALUES %s ON DUPLICATE KEY UPDATE id = id", table, strings.Join(fields, ", "), "%s"), + Fields: fields, + Placeholder: placeholder, + NumField: numField, + } + + return &stmt +} + +type BulkDeleteStmt struct { + Format string +} + +func NewBulkDeleteStmt(table string) *BulkDeleteStmt { + stmt := BulkDeleteStmt{ + Format: fmt.Sprintf("DELETE FROM %s WHERE id IN (%s)", table, "%s"), + } + + return &stmt +} + +type BulkUpdateStmt struct { + Format string + Fields []string + Placeholder string + NumField int +} + +func NewBulkUpdateStmt(table string, fields []string) *BulkUpdateStmt { + numField := len(fields) + placeholder := fmt.Sprintf("(%s)", strings.TrimSuffix(strings.Repeat("?, ", numField), ", ")) + stmt := BulkUpdateStmt{ + Format: fmt.Sprintf("REPLACE INTO %s (%s) VALUES %s", table, strings.Join(fields, ", "), "%s"), + Fields: fields, + Placeholder: placeholder, + NumField: numField, + } + + return &stmt +} \ No newline at end of file diff --git a/connection/mysql_utils_test.go b/connection/mysql_utils_test.go new file mode 100644 index 00000000..8b5346ae --- /dev/null +++ b/connection/mysql_utils_test.go @@ -0,0 +1,104 @@ +package connection + +import ( + "errors" + "github.com/go-sql-driver/mysql" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestMakePlaceholderList(t *testing.T) { + assert.Equal(t, "(?)", MakePlaceholderList(1)) + assert.Equal(t, "(?,?,?,?,?)", MakePlaceholderList(5)) + assert.Equal(t, "(?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)", MakePlaceholderList(20)) +} + +func TestConvertValueForDb(t *testing.T) { + var v interface{} + var err error + + v, err = ConvertValueForDb(nil) + assert.IsType(t, nil, v) + assert.Nil(t, err) + + v, err = ConvertValueForDb([]byte{100}) + assert.IsType(t, []byte{100}, v) + assert.Nil(t, err) + + v, err = ConvertValueForDb("this-is-a-string") + assert.IsType(t, "this-is-a-string", v) + assert.Nil(t, err) + + v, err = ConvertValueForDb(float32(123.456)) + assert.IsType(t, float64(123.456), v) + assert.Nil(t, err) + + v, err = ConvertValueForDb(float64(123.456)) + assert.IsType(t, float64(123.456), v) + assert.Nil(t, err) + + v, err = ConvertValueForDb(uint(20)) + assert.IsType(t, int64(10), v) + assert.Nil(t, err) + + v, err = ConvertValueForDb(uint8(30)) + assert.IsType(t, int64(10), v) + assert.Nil(t, err) + + v, err = ConvertValueForDb(uint16(40)) + assert.IsType(t, int64(10), v) + assert.Nil(t, err) + + v, err = ConvertValueForDb(uint32(50)) + assert.IsType(t, int64(10), v) + assert.Nil(t, err) + + v, err = ConvertValueForDb(uint64(60)) + assert.IsType(t, int64(10), v) + assert.Nil(t, err) + + v, err = ConvertValueForDb(int(70)) + assert.IsType(t, int64(10), v) + assert.Nil(t, err) + + v, err = ConvertValueForDb(int8(80)) + assert.IsType(t, int64(10), v) + assert.Nil(t, err) + + v, err = ConvertValueForDb(int16(90)) + assert.IsType(t, int64(10), v) + assert.Nil(t, err) + + v, err = ConvertValueForDb(int32(100)) + assert.IsType(t, int64(10), v) + assert.Nil(t, err) + + v, err = ConvertValueForDb(int64(10)) + assert.IsType(t, int64(10), v) + assert.Nil(t, err) + + v, err = ConvertValueForDb(true) + assert.IsType(t, "y/n-string", v) + assert.Nil(t, err) + + //Should not be possible + v, err = ConvertValueForDb(errors.New("test")) + assert.NotNil(t, err) +} + +func TestIsSerializationFailure(t *testing.T) { + assert.True(t, isSerializationFailure(&mysql.MySQLError{Number: 1205})) + assert.True(t, isSerializationFailure(&mysql.MySQLError{Number: 1213})) + + assert.False(t, isSerializationFailure(&mysql.MySQLError{Number: 6342})) + assert.False(t, isSerializationFailure(errors.New("random error"))) +} + +func TestMysqlConnectionError_Error(t *testing.T) { + err := MysqlConnectionError{"The chicken has left the database!"} + assert.Equal(t, "The chicken has left the database!", err.Error()) +} + +func TestFormatLogQuery(t *testing.T) { + assert.Equal(t, "This is my string", formatLogQuery("\tThis is\nmy string\n")) +} \ No newline at end of file diff --git a/connection/prometheus.go b/connection/prometheus.go new file mode 100644 index 00000000..00094c8d --- /dev/null +++ b/connection/prometheus.go @@ -0,0 +1,29 @@ +package connection + +import ( + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +var DbIoSeconds = promauto.NewSummaryVec( + prometheus.SummaryOpts{ + Name: "db_io_seconds", + Help: "Database I/O (s)", + }, + []string{"backend_type", "operation"}, +) + +var DbOperationsTotal = promauto.NewCounter(prometheus.CounterOpts{ + Name: "db_operations_total", + Help: "Database operations since startup", +}) + +var DbOperationsQuery = promauto.NewCounter(prometheus.CounterOpts{ + Name: "db_operations_query", + Help: "Database query operations since startup", +}) + +var DbOperationsExec = promauto.NewCounter(prometheus.CounterOpts{ + Name: "db_operations_exec", + Help: "Database exec operations since startup", +}) diff --git a/connection/redis.go b/connection/redis.go new file mode 100644 index 00000000..9e47b651 --- /dev/null +++ b/connection/redis.go @@ -0,0 +1,466 @@ +package connection + +import ( + "fmt" + "git.icinga.com/icingadb/icingadb-utils" + "github.com/go-redis/redis" + log "github.com/sirupsen/logrus" + "sync" + "sync/atomic" + "time" +) + +type Icinga2RedisWriterEventsConfig struct { + Update, Delete, Dump string +} + +type Icinga2RedisWriterKeyPrefixesConfig struct { + Checksum, Object, Customvar string +} + +type Icinga2RedisWriterKeyPrefixesStatus struct { + Object string +} + +type Icinga2RedisWriterEvents struct { + Config Icinga2RedisWriterEventsConfig + Stats string +} + +type Icinga2RedisWriterKeyPrefixes struct { + Config Icinga2RedisWriterKeyPrefixesConfig + Status Icinga2RedisWriterKeyPrefixesStatus +} + +type Icinga2RedisWriter struct { + Events Icinga2RedisWriterEvents + KeyPrefixes Icinga2RedisWriterKeyPrefixes +} + +var RedisWriter = Icinga2RedisWriter{ + Events: Icinga2RedisWriterEvents{ + Config: Icinga2RedisWriterEventsConfig{ + Dump: "icinga:config:dump", + Delete: "icinga:config:delete", + Update: "icinga:config:update", + }, + Stats: "icinga:stats", + }, + KeyPrefixes: Icinga2RedisWriterKeyPrefixes{ + Config: Icinga2RedisWriterKeyPrefixesConfig{ + Checksum: "icinga:config:checksum:", + Object: "icinga:config:object:", + Customvar: "icinga:config:customvar:", + }, + Status: Icinga2RedisWriterKeyPrefixesStatus{ + Object: "icinga:state:object:", + }, + }, +} + +type RedisClient interface { + Ping() *redis.StatusCmd + Publish(channel string, message interface{}) *redis.IntCmd + XRead(a *redis.XReadArgs) *redis.XStreamSliceCmd + XDel(stream string, ids ...string) *redis.IntCmd + HKeys(key string) *redis.StringSliceCmd + HMGet(key string, fields ...string) *redis.SliceCmd + HGetAll(key string) *redis.StringStringMapCmd + TxPipelined(fn func(redis.Pipeliner) error) ([]redis.Cmder, error) + Pipeline() redis.Pipeliner + Subscribe(channels ...string) *redis.PubSub +} + +type StatusCmd interface { +} + +// Redis wrapper including helper functions +type RDBWrapper struct { + Rdb RedisClient + ConnectedAtomic *uint32 //uint32 to be able to use atomic operations + ConnectionUpCondition *sync.Cond + ConnectionLostCounterAtomic *uint32 //uint32 to be able to use atomic operations +} + +func (rdbw *RDBWrapper) IsConnected() bool { + return atomic.LoadUint32(rdbw.ConnectedAtomic) != 0 +} + +func (rdbw *RDBWrapper) CompareAndSetConnected(connected bool) (swapped bool) { + if connected { + return atomic.CompareAndSwapUint32(rdbw.ConnectedAtomic, 0, 1) + } else { + return atomic.CompareAndSwapUint32(rdbw.ConnectedAtomic, 1, 0) + } +} + +func NewRDBWrapper(address string) (*RDBWrapper, error) { + rdb := redis.NewClient(&redis.Options{ + Addr: address, + DialTimeout: time.Minute / 2, + ReadTimeout: time.Minute, + WriteTimeout: time.Minute, + }) + + rdbw := RDBWrapper{ + Rdb: rdb, ConnectedAtomic: new(uint32), + ConnectionLostCounterAtomic: new(uint32), + ConnectionUpCondition: sync.NewCond(&sync.Mutex{}), + } + + _, err := rdbw.Rdb.Ping().Result() + if err != nil { + return nil, err + } + + go func() { + for { + rdbw.CheckConnection(true) + time.Sleep(rdbw.getConnectionCheckInterval()) + } + }() + + return &rdbw, nil +} + +func (rdbw *RDBWrapper) getConnectionCheckInterval() time.Duration { + if !rdbw.IsConnected() { + v := atomic.LoadUint32(rdbw.ConnectionLostCounterAtomic) + if v < 4 { + return 5 * time.Second + } else if v < 8 { + return 10 * time.Second + } else if v < 11 { + return 30 * time.Second + } else if v < 14 { + return 60 * time.Second + } else { + log.Fatal("Could not connect to Redis for over 5 minutes. Shutting down...") + } + } + + return 15 * time.Second +} + +func (rdbw *RDBWrapper) CheckConnection(isTicker bool) bool { + _, err := rdbw.Rdb.Ping().Result() + if err != nil { + if rdbw.CompareAndSetConnected(false) { + log.WithFields(log.Fields{ + "context": "redis", + "error": err, + }).Error("Redis connection lost. Trying to reconnect") + } else if isTicker { + atomic.AddUint32(rdbw.ConnectionLostCounterAtomic, 1) + + log.WithFields(log.Fields{ + "context": "redis", + "error": err, + }).Debugf("Redis connection lost. Trying again in %s", rdbw.getConnectionCheckInterval()) + } + + return false + } else { + if rdbw.CompareAndSetConnected(true) { + log.Info("Redis connection established") + atomic.StoreUint32(rdbw.ConnectionLostCounterAtomic, 0) + rdbw.ConnectionUpCondition.Broadcast() + } + + return true + } +} + +func (rdbw *RDBWrapper) WaitForConnection() { + rdbw.ConnectionUpCondition.L.Lock() + rdbw.ConnectionUpCondition.Wait() + rdbw.ConnectionUpCondition.L.Unlock() +} + +// Wrapper for connection handling +func (rdbw *RDBWrapper) Publish(channel string, message interface{}) *redis.IntCmd { + for { + if !rdbw.IsConnected() { + rdbw.WaitForConnection() + continue + } + + cmd := rdbw.Rdb.Publish(channel, message) + _, err := cmd.Result() + + if err != nil { + if !rdbw.CheckConnection(false) { + continue + } + } + + return cmd + } +} + +// Wrapper for connection handling +func (rdbw *RDBWrapper) XRead(args *redis.XReadArgs) *redis.XStreamSliceCmd { + for { + if !rdbw.IsConnected() { + rdbw.WaitForConnection() + continue + } + + cmd := rdbw.Rdb.XRead(args) + _, err := cmd.Result() + + if err != nil { + if !rdbw.CheckConnection(false) { + continue + } + } + + return cmd + } +} + +// Wrapper for connection handling +func (rdbw *RDBWrapper) XDel(stream string, ids ...string) *redis.IntCmd { + for { + if !rdbw.IsConnected() { + rdbw.WaitForConnection() + continue + } + + cmd := rdbw.Rdb.XDel(stream, ids...) + _, err := cmd.Result() + + if err != nil { + if !rdbw.CheckConnection(false) { + continue + } + } + + return cmd + } +} + +// Wrapper for connection handling +func (rdbw *RDBWrapper) HKeys(key string) *redis.StringSliceCmd { + for { + if !rdbw.IsConnected() { + rdbw.WaitForConnection() + continue + } + + cmd := rdbw.Rdb.HKeys(key) + _, err := cmd.Result() + + if err != nil { + if !rdbw.CheckConnection(false) { + continue + } + } + + return cmd + } +} + +func (rdbw * RDBWrapper) HMGet(key string, fields ...string) *redis.SliceCmd { + for { + if !rdbw.IsConnected() { + rdbw.WaitForConnection() + continue + } + + cmd := rdbw.Rdb.HMGet(key, fields...) + _, err := cmd.Result() + + if err != nil { + if !rdbw.CheckConnection(false) { + continue + } + } + + return cmd + } +} + +// Wrapper for auto-logging and connection handling +func (rdbw *RDBWrapper) HGetAll(key string) *redis.StringStringMapCmd { + for { + if !rdbw.IsConnected() { + rdbw.WaitForConnection() + continue + } + + benchmarc := icingadb_utils.NewBenchmark() + res := rdbw.Rdb.HGetAll(key) + + if _, err := res.Result(); err != nil { + if !rdbw.CheckConnection(false) { + continue + } + } + + benchmarc.Stop() + + DbIoSeconds.WithLabelValues("redis", "hgetall").Observe(benchmarc.Seconds()) + + log.WithFields(log.Fields{ + "context": "redis", + "benchmark": benchmarc, + "query": "HGETALL " + key, + "result": res.Val(), + }).Debug("Ran Query") + + return res + } +} + +// Wrapper for auto-logging and connection handling +func (rdbw *RDBWrapper) TxPipelined(fn func(pipeliner redis.Pipeliner) error) ([]redis.Cmder, error) { + for { + if !rdbw.IsConnected() { + rdbw.WaitForConnection() + continue + } + + benchmarc := icingadb_utils.NewBenchmark() + cmd, err := rdbw.Rdb.TxPipelined(fn) + + if err != nil { + if !rdbw.CheckConnection(false) { + continue + } + } + + benchmarc.Stop() + + DbIoSeconds.WithLabelValues("redis", "multi").Observe(benchmarc.Seconds()) + + log.WithFields(log.Fields{ + "context": "redis", + "benchmark": benchmarc, + "query": "MULTI/EXEC", + }).Debug("Ran pipelined transaction") + + return cmd, err + } +} + +func (rdbw *RDBWrapper) Pipeline() PipelinerWrapper { + pipeliner := rdbw.Rdb.Pipeline() + plw := PipelinerWrapper{pipeliner: pipeliner, rdbw: rdbw} + return plw +} + +func (rdbw *RDBWrapper) Subscribe() PubSubWrapper { + ps := rdbw.Rdb.Subscribe() + psw := PubSubWrapper{ps: ps, rdbw: rdbw} + return psw +} + +type ConfigChunk struct { + Keys []string + Configs []interface{} + Checksums []interface{} +} + +type ChecksumChunk struct { + Keys []string + Checksums []interface{} +} + +func (rdbw *RDBWrapper) PipeConfigChunks(done <-chan struct{}, keys []string, objectType string) <-chan *ConfigChunk { + out := make(chan *ConfigChunk) + + worker := func(chunk <-chan []string) { + for k := range chunk { + pipe := rdbw.Pipeline() + cmds := make([]*redis.SliceCmd, 2) + + cmds[0] = pipe.HMGet(fmt.Sprintf("icinga:config:object:%s", objectType), k...) + cmds[1] = pipe.HMGet(fmt.Sprintf("icinga:config:checksum:%s", objectType), k...) + + _, err := pipe.Exec() // TODO(el): What to do with the Cmder slice? + if err != nil { + panic(err) + } + + configs, err := cmds[0].Result() + if err != nil { + panic(err) + } + checksums, err := cmds[1].Result() + if err != nil { + panic(err) + } + + select { + case out <- &ConfigChunk{Keys: k, Configs: configs, Checksums: checksums}: + case <-done: + return + } + } + } + + //TODO: Replace fixed chunkSize + work := icingadb_utils.ChunkKeys(done, keys, 500) + + go func() { + defer close(out) + + wg := &sync.WaitGroup{} + + for i := 0; i < 32; i++ { + wg.Add(1) + go func() { + defer wg.Done() + worker(work) + }() + } + + wg.Wait() + }() + + return out +} + +func (rdbw *RDBWrapper) PipeChecksumChunks(done <-chan struct{}, keys []string, objectType string) <-chan *ChecksumChunk { + out := make(chan *ChecksumChunk) + + worker := func(chunk <-chan []string) { + for k := range chunk { + cmd := rdbw.HMGet(fmt.Sprintf("icinga:config:checksum:%s", objectType), k...) + + checksums, err := cmd.Result() + if err != nil { + panic(err) + } + + select { + case out <- &ChecksumChunk{Keys: k, Checksums: checksums}: + case <-done: + return + } + } + } + + //TODO: Replace fixed chunkSize + work := icingadb_utils.ChunkKeys(done, keys, 500) + + go func() { + defer close(out) + + wg := &sync.WaitGroup{} + + for i := 0; i < 32; i++ { + wg.Add(1) + go func() { + defer wg.Done() + worker(work) + }() + } + + wg.Wait() + }() + + return out +} + diff --git a/connection/redis_pipeliner.go b/connection/redis_pipeliner.go new file mode 100644 index 00000000..e58e39a3 --- /dev/null +++ b/connection/redis_pipeliner.go @@ -0,0 +1,47 @@ +package connection + +import "github.com/go-redis/redis" + +type PipelinerWrapper struct { + pipeliner redis.Pipeliner + rdbw *RDBWrapper +} + +func (plw *PipelinerWrapper) Exec() ([]redis.Cmder, error) { + for { + if !plw.rdbw.IsConnected() { + plw.rdbw.WaitForConnection() + continue + } + + cmder, err := plw.pipeliner.Exec() + + if err != nil { + if !plw.rdbw.CheckConnection(false) { + continue + } + } + + return cmder, err + } +} + +func (plw *PipelinerWrapper) HMGet(key string, fields ...string) *redis.SliceCmd { + for { + if !plw.rdbw.IsConnected() { + plw.rdbw.WaitForConnection() + continue + } + + cmd := plw.pipeliner.HMGet(key, fields...) + _, err := cmd.Result() + + if err != nil { + if !plw.rdbw.CheckConnection(false) { + continue + } + } + + return cmd + } +} \ No newline at end of file diff --git a/connection/redis_pubsub.go b/connection/redis_pubsub.go new file mode 100644 index 00000000..a373ba27 --- /dev/null +++ b/connection/redis_pubsub.go @@ -0,0 +1,67 @@ +package connection + +import ( + "github.com/go-redis/redis" +) + +type PubSubWrapper struct { + ps *redis.PubSub + rdbw *RDBWrapper +} + +func (psw *PubSubWrapper) Subscribe(channels ...string) error { + for { + if !psw.rdbw.IsConnected() { + psw.rdbw.WaitForConnection() + continue + } + + err := psw.ps.Subscribe(channels...) + + if err != nil { + if !psw.rdbw.CheckConnection(false) { + continue + } + } + + return err + } +} + +func (psw *PubSubWrapper) ReceiveMessage() (*redis.Message, error) { + for { + if !psw.rdbw.IsConnected() { + psw.rdbw.WaitForConnection() + continue + } + + msg, err := psw.ps.ReceiveMessage() + + if err != nil { + if !psw.rdbw.CheckConnection(false) { + continue + } + } + + return msg, err + } +} + +func (psw *PubSubWrapper) Close() error { + for { + if !psw.rdbw.IsConnected() { + psw.rdbw.WaitForConnection() + continue + } + + err := psw.ps.Close() + + if err != nil { + if !psw.rdbw.CheckConnection(false) { + continue + } + } + + return err + } +} \ No newline at end of file diff --git a/connection/redis_pubsub_test.go b/connection/redis_pubsub_test.go new file mode 100644 index 00000000..bee59d05 --- /dev/null +++ b/connection/redis_pubsub_test.go @@ -0,0 +1,74 @@ +package connection + +import ( + "github.com/go-redis/redis" + "github.com/stretchr/testify/assert" + "testing" + "time" +) + +func TestPubSubWrapper(t *testing.T) { + rdb := redis.NewClient(&redis.Options{ + Addr: "127.0.0.1:6379", + DialTimeout: time.Minute / 2, + ReadTimeout: time.Minute, + WriteTimeout: time.Minute, + }) + rdbw := NewTestRDBW(rdb) + + if !rdbw.CheckConnection(true) { + t.Fatal("This test needs a working Redis connection") + } + + ps := rdbw.Subscribe() + + rdbw.CompareAndSetConnected(false) + + var errSubscribe error + done1:= make(chan bool) + go func () { + errSubscribe = ps.Subscribe("testchannel") + done1 <- true + }() + + time.Sleep(50 * time.Millisecond) + rdbw.CheckConnection(true) + + <- done1 + + rdbw.CompareAndSetConnected(false) + + var msg *redis.Message + var errReceive error + done2 := make(chan bool) + go func() { + msg, errReceive = ps.ReceiveMessage() + done2 <- true + }() + + time.Sleep(50 * time.Millisecond) + rdbw.CheckConnection(true) + + rdbw.Publish("testchannel", "Hello there") + + <- done2 + + rdbw.CompareAndSetConnected(false) + + var errClose error + done3:= make(chan bool) + go func () { + errClose = ps.Close() + done3 <- true + }() + + time.Sleep(50 * time.Millisecond) + rdbw.CheckConnection(true) + + <- done3 + + assert.NoError(t, errSubscribe) + assert.NoError(t, errReceive) + assert.NoError(t, errClose) + assert.Equal(t, "Hello there", msg.Payload) +} \ No newline at end of file diff --git a/connection/redis_test.go b/connection/redis_test.go new file mode 100644 index 00000000..edc62339 --- /dev/null +++ b/connection/redis_test.go @@ -0,0 +1,275 @@ +package connection + +import ( + "github.com/go-redis/redis" + "github.com/stretchr/testify/assert" + "sync" + "sync/atomic" + "testing" + "time" +) + +func NewTestRDBW(rdb RedisClient) RDBWrapper { + dbw := RDBWrapper{Rdb: rdb, ConnectedAtomic: new(uint32), ConnectionLostCounterAtomic: new(uint32)} + dbw.ConnectionUpCondition = sync.NewCond(&sync.Mutex{}) + return dbw +} + +func TestNewRDBWrapper(t *testing.T) { + _, err := NewRDBWrapper("127.0.0.1:6379") + assert.NoError(t, err, "Redis should be connected") + + _, err = NewRDBWrapper("asdasdasdasdasd:5123") + assert.Error(t, err, "Redis should not be connected") + //TODO: Add more tests here +} + +func TestRDBWrapper_GetConnectionCheckInterval(t *testing.T) { + rdbw := NewTestRDBW(nil) + + //Should return 15s, if connected - counter doesn't madder + rdbw.CompareAndSetConnected(true) + assert.Equal(t, 15*time.Second, rdbw.getConnectionCheckInterval()) + + //Should return 5s, if not connected and counter < 4 + rdbw.CompareAndSetConnected(false) + atomic.StoreUint32(rdbw.ConnectionLostCounterAtomic, 0) + assert.Equal(t, 5*time.Second, rdbw.getConnectionCheckInterval()) + + //Should return 10s, if not connected and 4 <= counter < 8 + rdbw.CompareAndSetConnected(false) + atomic.StoreUint32(rdbw.ConnectionLostCounterAtomic, 4) + assert.Equal(t, 10*time.Second, rdbw.getConnectionCheckInterval()) + + //Should return 30s, if not connected and 8 <= counter < 11 + rdbw.CompareAndSetConnected(false) + atomic.StoreUint32(rdbw.ConnectionLostCounterAtomic, 8) + assert.Equal(t, 30*time.Second, rdbw.getConnectionCheckInterval()) + + //Should return 60s, if not connected and 11 <= counter < 14 + rdbw.CompareAndSetConnected(false) + atomic.StoreUint32(rdbw.ConnectionLostCounterAtomic, 11) + assert.Equal(t, 60*time.Second, rdbw.getConnectionCheckInterval()) + + //dbw.ConnectionLostCounter = 14 + //interval = dbw.getConnectionCheckInterval() + //TODO: Check for Fatal +} + +func TestRDBWrapper_CheckConnection(t *testing.T) { + rdbw := NewTestRDBW(nil) + + rdbw.Rdb = redis.NewClient(&redis.Options{ + Addr: "127.0.0.1:6379", + DialTimeout: time.Minute / 2, + ReadTimeout: time.Minute, + WriteTimeout: time.Minute, + }) + atomic.StoreUint32(rdbw.ConnectionLostCounterAtomic, 512312312) + assert.True(t, rdbw.CheckConnection(false), "DBWrapper should be connected") + assert.Equal(t, uint32(0), atomic.LoadUint32(rdbw.ConnectionLostCounterAtomic)) + + rdbw.Rdb = redis.NewClient(&redis.Options{ + Addr: "dasdasdasdasdasd:5123", + DialTimeout: time.Minute / 2, + ReadTimeout: time.Minute, + WriteTimeout: time.Minute, + }) + atomic.StoreUint32(rdbw.ConnectionLostCounterAtomic, 0) + assert.False(t, rdbw.CheckConnection(false), "DBWrapper should not be connected") + assert.Equal(t, uint32(0), atomic.LoadUint32(rdbw.ConnectionLostCounterAtomic)) + + atomic.StoreUint32(rdbw.ConnectionLostCounterAtomic, 10) + assert.False(t, rdbw.CheckConnection(true), "DBWrapper should not be connected") + assert.Equal(t, uint32(11), atomic.LoadUint32(rdbw.ConnectionLostCounterAtomic)) +} + +func TestRDBWrapper_HGetAll(t *testing.T) { + rdb := redis.NewClient(&redis.Options{ + Addr: "127.0.0.1:6379", + DialTimeout: time.Minute / 2, + ReadTimeout: time.Minute, + WriteTimeout: time.Minute, + }) + rdbw := NewTestRDBW(rdb) + + if !rdbw.CheckConnection(true) { + t.Fatal("This test needs a working Redis connection") + } + + rdb.Del("herpdaderp") + rdb.HSet("herpdaderp", "one", 5) + rdb.HSet("herpdaderp", "two", 11) + + rdbw.CompareAndSetConnected(false) + + var data map[string]string + var err error + done := make(chan bool) + go func() { + data, err = rdbw.HGetAll("herpdaderp") + done <- true + }() + + time.Sleep(50 * time.Millisecond) + rdbw.CheckConnection(true) + + <- done + + assert.NoError(t, err) + assert.Contains(t, data, "one") + assert.Contains(t, data, "two") +} + +func TestRDBWrapper_XRead(t *testing.T) { + rdb := redis.NewClient(&redis.Options{ + Addr: "127.0.0.1:6379", + DialTimeout: time.Minute / 2, + ReadTimeout: time.Minute, + WriteTimeout: time.Minute, + }) + rdbw := NewTestRDBW(rdb) + + if !rdbw.CheckConnection(true) { + t.Fatal("This test needs a working Redis connection") + } + + rdb.XTrim("teststream", 0) + rdb.XAdd(&redis.XAddArgs{Stream: "teststream", Values: map[string]interface{}{"one": "5", "two": "11", "herp": "11"}}) + + rdbw.CompareAndSetConnected(false) + + var data *redis.XStreamSliceCmd + done := make(chan bool) + go func() { + data = rdbw.XRead(&redis.XReadArgs{Streams: []string{"teststream", "0"}}) + done <- true + }() + + time.Sleep(50 * time.Millisecond) + rdbw.CheckConnection(true) + + <- done + + streams, err := data.Result() + assert.NoError(t, err) + value := streams[0].Messages[0].Values + + assert.Contains(t, value, "one") + assert.Contains(t, value, "two") +} + +func TestRDBWrapper_XDel(t *testing.T) { + rdb := redis.NewClient(&redis.Options{ + Addr: "127.0.0.1:6379", + DialTimeout: time.Minute / 2, + ReadTimeout: time.Minute, + WriteTimeout: time.Minute, + }) + rdbw := NewTestRDBW(rdb) + + if !rdbw.CheckConnection(true) { + t.Fatal("This test needs a working Redis connection") + } + + rdb.XTrim("teststream", 0) + adds := rdb.XAdd(&redis.XAddArgs{Stream: "teststream", Values: map[string]interface{}{"one": "5", "two": "11", "herp": "11"}}) + + rdbw.CompareAndSetConnected(false) + + done := make(chan bool) + go func() { + rdbw.XDel("teststream", adds.Val()) + done <- true + }() + + time.Sleep(50 * time.Millisecond) + rdbw.CheckConnection(true) + + <- done + + data := rdbw.XRead(&redis.XReadArgs{Streams: []string{"teststream", "0"}, Block: -1}) + streams, err := data.Result() + assert.Error(t, err) + assert.Len(t, streams, 0) +} + +func TestRDBWrapper_Publish(t *testing.T) { + rdb := redis.NewClient(&redis.Options{ + Addr: "127.0.0.1:6379", + DialTimeout: time.Minute / 2, + ReadTimeout: time.Minute, + WriteTimeout: time.Minute, + }) + rdbw := NewTestRDBW(rdb) + + if !rdbw.CheckConnection(true) { + t.Fatal("This test needs a working Redis connection") + } + + var msg *redis.Message + var err error + done := make(chan bool) + go func() { + msg, err = rdb.Subscribe("testchannel").ReceiveMessage() + done <- true + }() + + rdbw.CompareAndSetConnected(false) + + go func () { + rdbw.Publish("testchannel", "Hello there") + }() + + time.Sleep(50 * time.Millisecond) + rdbw.CheckConnection(true) + + <- done + + assert.NoError(t, err) + assert.Equal(t, "Hello there", msg.Payload) +} + +func TestRDBWrapper_TxPipelined(t *testing.T) { + rdb := redis.NewClient(&redis.Options{ + Addr: "127.0.0.1:6379", + DialTimeout: time.Minute / 2, + ReadTimeout: time.Minute, + WriteTimeout: time.Minute, + }) + rdbw := NewTestRDBW(rdb) + + if !rdbw.CheckConnection(true) { + t.Fatal("This test needs a working Redis connection") + } + + rdb.Del("firstKey") + rdb.Del("secondKey") + rdb.HSet("firstKey", "foo", 5) + rdb.HSet("secondKey", "bar", 11) + + rdbw.CompareAndSetConnected(false) + + var firstMap *redis.StringStringMapCmd + var secondMap *redis.StringStringMapCmd + var err error + done := make(chan bool) + go func() { + _, err = rdbw.TxPipelined(func(pipe redis.Pipeliner) error { + firstMap = pipe.HGetAll("firstKey") + secondMap = pipe.HGetAll("secondKey") + return nil + }) + done <- true + }() + + time.Sleep(50 * time.Millisecond) + rdbw.CheckConnection(true) + + <- done + + assert.NoError(t, err) + assert.Contains(t, firstMap.Val(), "foo") + assert.Contains(t, secondMap.Val(), "bar") + +} diff --git a/connection/test_db.sql b/connection/test_db.sql new file mode 100644 index 00000000..69cbd78d --- /dev/null +++ b/connection/test_db.sql @@ -0,0 +1,3 @@ +CREATE database icingadb; +CREATE USER 'module-dev'@'127.0.0.1' IDENTIFIED BY 'icinga0815!'; +GRANT ALL PRIVILEGES ON icingadb.* TO 'module-dev'@'127.0.0.1'; \ No newline at end of file