Merge remote-tracking branch 'connection/master'

This commit is contained in:
Noah Hilverling 2019-05-13 15:00:22 +02:00
commit 9746eb8aba
15 changed files with 2551 additions and 0 deletions

2
connection/.gitignore vendored Normal file
View file

@ -0,0 +1,2 @@
coverage\.html

30
connection/.gitlab-ci.yml Normal file
View file

@ -0,0 +1,30 @@
image: golang:latest
variables:
REPO_NAME: git.icinga.com/icingadb/icingadb-connection
before_script:
- mkdir -p $GOPATH/src/$(dirname $REPO_NAME)
- ln -svf $CI_PROJECT_DIR $GOPATH/src/$REPO_NAME
- cd $GOPATH/src/$REPO_NAME
- git config --global url."https://gitlab-ci-token:${CI_JOB_TOKEN}@git.icinga.com/".insteadOf "https://git.icinga.com/"
- go get -t ./...
stages:
- test
- coverage
test:
stage: test
script:
- go fmt $(go list ./... | grep -v /vendor/)
- go vet $(go list ./... | grep -v /vendor/)
- go test -race $(go list ./... | grep -v /vendor/) -cover
coverage:
stage: coverage
script:
- ./coverage.sh
artifacts:
paths:
- coverage.html

4
connection/README.md Normal file
View file

@ -0,0 +1,4 @@
IcingaDB Connection Library
[![pipeline status](https://git.icinga.com/icingadb/icingadb-connection-lib/badges/master/pipeline.svg)](https://git.icinga.com/icingadb/icingadb-connection-lib/commits/master)
[![coverage report](https://git.icinga.com/icingadb/icingadb-connection-lib/badges/master/coverage.svg)](https://git.icinga.com/icingadb/icingadb-connection-lib/-/jobs/artifacts/master/raw/coverage.html?job=coverage)

3
connection/coverage.sh Executable file
View file

@ -0,0 +1,3 @@
go test -race -cover -coverprofile=c.out
go tool cover -html=c.out -o coverage.html
rm c.out

704
connection/mysql.go Normal file
View file

@ -0,0 +1,704 @@
package connection
import (
"container/list"
"context"
"database/sql"
"fmt"
"git.icinga.com/icingadb/icingadb-main/configobject"
"git.icinga.com/icingadb/icingadb-utils"
log "github.com/sirupsen/logrus"
"strings"
"sync"
"sync/atomic"
"time"
)
// This is used in SqlFetchAll and SqlFetchAllQuiet
type DbClientOrTransaction interface {
Query(query string, args ...interface{}) (*sql.Rows, error)
Exec(query string, args ...interface{}) (sql.Result, error)
}
type DbClient interface {
DbClientOrTransaction
Ping() error
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
}
type DbTransaction interface {
DbClientOrTransaction
Commit() error
Rollback() error
}
func NewDBWrapper(dbDsn string) (*DBWrapper, error) {
db, err := mkMysql("mysql", dbDsn)
if err != nil {
return nil, err
}
dbw := DBWrapper{Db: db, ConnectedAtomic: new(uint32), ConnectionLostCounterAtomic: new(uint32)}
dbw.ConnectionUpCondition = sync.NewCond(&sync.Mutex{})
err = dbw.Db.Ping()
if err != nil {
return nil, err
}
go func() {
for {
dbw.checkConnection(true)
time.Sleep(dbw.getConnectionCheckInterval())
}
}()
return &dbw, nil
}
// Database wrapper including helper functions
type DBWrapper struct {
Db DbClient
ConnectedAtomic *uint32 //uint32 to be able to use atomic operations
ConnectionUpCondition *sync.Cond
ConnectionLostCounterAtomic *uint32 //uint32 to be able to use atomic operations
}
func (dbw *DBWrapper) IsConnected() bool {
return atomic.LoadUint32(dbw.ConnectedAtomic) != 0
}
func (dbw *DBWrapper) CompareAndSetConnected(connected bool) (swapped bool) {
if connected {
return atomic.CompareAndSwapUint32(dbw.ConnectedAtomic, 0, 1)
} else {
return atomic.CompareAndSwapUint32(dbw.ConnectedAtomic, 1, 0)
}
}
func (dbw *DBWrapper) getConnectionCheckInterval() time.Duration {
if !dbw.IsConnected() {
v := atomic.LoadUint32(dbw.ConnectionLostCounterAtomic)
if v < 4 {
return 5 * time.Second
} else if v < 8 {
return 10 * time.Second
} else if v < 11 {
return 30 * time.Second
} else if v < 14 {
return 60 * time.Second
} else {
log.Fatal("Could not connect to SQL for over 5 minutes. Shutting down...")
}
}
return 15 * time.Second
}
func (dbw *DBWrapper) checkConnection(isTicker bool) bool {
err := dbw.Db.Ping()
if err != nil {
if dbw.CompareAndSetConnected(false) {
log.WithFields(log.Fields{
"context": "sql",
"error": err,
}).Error("SQL connection lost. Trying to reconnect")
} else if isTicker {
atomic.AddUint32(dbw.ConnectionLostCounterAtomic, 1)
log.WithFields(log.Fields{
"context": "sql",
"error": err,
}).Debugf("SQL connection lost. Trying again in %s", dbw.getConnectionCheckInterval())
}
return false
} else {
if dbw.CompareAndSetConnected(true) {
log.Info("SQL connection established")
atomic.StoreUint32(dbw.ConnectionLostCounterAtomic, 0)
dbw.ConnectionUpCondition.Broadcast()
}
return true
}
}
func (dbw *DBWrapper) WaitForConnection() {
dbw.ConnectionUpCondition.L.Lock()
dbw.ConnectionUpCondition.Wait()
dbw.ConnectionUpCondition.L.Unlock()
}
func (dbw *DBWrapper) WithRetry(f func() (sql.Result, error)) (sql.Result, error) {
for {
res, err := f()
if err != nil {
if isRetryableError(err) {
continue
} else {
return nil, err
}
}
return res, err
}
}
func (dbw *DBWrapper) SqlQuery(query string, args ...interface{}) (*sql.Rows, error) {
for {
if !dbw.IsConnected() {
dbw.WaitForConnection()
continue
}
res, err := dbw.Db.Query(query, args...)
DbOperationsQuery.Inc()
if err != nil {
if !dbw.checkConnection(false) {
continue
}
}
return res, err
}
}
// Wrapper around Db.BeginTx() for auto-logging
func (dbw *DBWrapper) SqlBegin(concurrencySafety bool, quiet bool) (DbTransaction, error) {
var isoLvl sql.IsolationLevel
if concurrencySafety {
isoLvl = sql.LevelSerializable
} else {
isoLvl = sql.LevelReadCommitted
}
for {
if !dbw.IsConnected() {
dbw.WaitForConnection()
continue
}
var err error
var tx DbTransaction
if quiet {
tx, err = dbw.Db.BeginTx(context.Background(), &sql.TxOptions{Isolation: isoLvl})
} else {
benchmarc := icingadb_utils.NewBenchmark()
tx, err = dbw.Db.BeginTx(context.Background(), &sql.TxOptions{Isolation: isoLvl})
benchmarc.Stop()
DbIoSeconds.WithLabelValues("mysql", "begin").Observe(benchmarc.Seconds())
log.WithFields(log.Fields{
"context": "sql",
"benchmark": benchmarc,
}).Debug("BEGIN transaction")
}
if err != nil {
if !dbw.checkConnection(false) {
continue
}
}
return tx, err
}
}
// Wrapper around tx.Commit() for auto-logging
func (dbw *DBWrapper) SqlCommit(tx DbTransaction, quiet bool) error {
for {
if !dbw.IsConnected() {
dbw.WaitForConnection()
continue
}
var err error
if quiet {
err = tx.Commit()
} else {
benchmarc := icingadb_utils.NewBenchmark()
err = tx.Commit()
benchmarc.Stop()
DbIoSeconds.WithLabelValues("mysql", "commit").Observe(benchmarc.Seconds())
log.WithFields(log.Fields{
"context": "sql",
"benchmark": benchmarc,
}).Debug("COMMIT transaction")
}
if err != nil {
if !dbw.checkConnection(false) {
continue
}
}
return err
}
}
// Wrapper around tx.Rollback() for auto-logging
func (dbw *DBWrapper) SqlRollback(tx DbTransaction, quiet bool) error {
for {
if !dbw.IsConnected() {
dbw.WaitForConnection()
continue
}
var err error
if !quiet {
benchmarc := icingadb_utils.NewBenchmark()
err = tx.Rollback()
benchmarc.Stop()
DbIoSeconds.WithLabelValues("mysql", "rollback").Observe(benchmarc.Seconds())
log.WithFields(log.Fields{
"context": "sql",
"benchmark": benchmarc,
}).Debug("ROLLBACK transaction")
} else {
err = tx.Rollback()
}
if err != nil {
if !dbw.checkConnection(false) {
continue
}
}
return err
}
}
// Wrapper around sql.Exec() for auto-logging
func (dbw *DBWrapper) SqlExec(opDescription string, sql string, args ...interface{}) (sql.Result, error) {
return dbw.sqlExecInternal(dbw.Db, opDescription, sql, false, args...)
}
// No logging, no benchmarking
func (dbw *DBWrapper) SqlExecQuiet(opDescription string, sql string, args ...interface{}) (sql.Result, error) {
return dbw.sqlExecInternal(dbw.Db, opDescription, sql, true, args...)
}
// Wrapper around tx.Exec() for auto-logging
func (dbw *DBWrapper) SqlExecTx(tx DbTransaction, opDescription string, sql string, args ...interface{}) (sql.Result, error) {
return dbw.sqlExecInternal(tx, opDescription, sql, false, args...)
}
// No logging, no benchmarking
func (dbw *DBWrapper) SqlExecTxQuiet(tx DbTransaction, opDescription string, sql string, args ...interface{}) (sql.Result, error) {
return dbw.sqlExecInternal(tx, opDescription, sql, true, args...)
}
func (dbw *DBWrapper) SqlFetchAll(queryDescription string, query string, args ...interface{}) ([][]interface{}, error) {
return dbw.sqlFetchAllInternal(dbw.Db, queryDescription, query, false, args...)
}
func (dbw *DBWrapper) SqlFetchAllQuiet(queryDescription string, query string, args ...interface{}) ([][]interface{}, error) {
return dbw.sqlFetchAllInternal(dbw.Db, queryDescription, query, true, args...)
}
func (dbw *DBWrapper) SqlFetchAllTx(tx DbTransaction, queryDescription string, query string, args ...interface{}) ([][]interface{}, error) {
return dbw.sqlFetchAllInternal(tx, queryDescription, query, false, args...)
}
func (dbw *DBWrapper) SqlFetchAllTxQuiet(tx DbTransaction, queryDescription string, query string, args ...interface{}) ([][]interface{}, error) {
return dbw.sqlFetchAllInternal(tx, queryDescription, query, true, args...)
}
// Wrapper around sql.Exec() for auto-logging
func (dbw *DBWrapper) sqlExecInternal(db DbClientOrTransaction, opDescription string, sql string, quiet bool, args ...interface{}) (sql.Result, error) {
for {
if !dbw.IsConnected() {
dbw.WaitForConnection()
continue
}
var benchmarc *icingadb_utils.Benchmark
if !quiet {
benchmarc = icingadb_utils.NewBenchmark()
}
res, err := db.Exec(sql, args...)
DbOperationsExec.Inc()
if !quiet {
benchmarc.Stop()
}
if !quiet {
DbIoSeconds.WithLabelValues("mysql", opDescription).Observe(benchmarc.Seconds())
log.WithFields(log.Fields{
"context": "sql",
"benchmark": benchmarc,
"affected_rows": prettyPrintedRowsAffected{res},
"args": prettyPrintedArgs{args},
"query": prettyPrintedSql{sql},
}).Debug("Finished Exec")
}
if err != nil {
if !dbw.checkConnection(false) {
continue
}
}
return res, err
}
}
// Wrapper around Db.SqlQuery() for auto-logging
func (dbw *DBWrapper) sqlFetchAllInternal(db DbClientOrTransaction, queryDescription string, query string, quiet bool, args ...interface{}) ([][]interface{}, error) {
for {
if !dbw.IsConnected() {
dbw.WaitForConnection()
continue
}
res, err := sqlTryFetchAll(db, queryDescription, query, quiet, args...)
if err != nil {
if _, isDb := db.(*sql.DB); isDb {
if !dbw.checkConnection(false) {
continue
}
}
}
return res, err
}
}
func sqlTryFetchAll(db DbClientOrTransaction, queryDescription string, query string, quiet bool, args ...interface{}) ([][]interface{}, error) {
var benchmarc *icingadb_utils.Benchmark
if !quiet {
benchmarc = icingadb_utils.NewBenchmark()
}
rows, errQuery := db.Query(query, args...)
DbOperationsQuery.Inc()
if !quiet {
benchmarc.Stop()
}
rowsCount := 0
defer func() {
if !quiet {
DbIoSeconds.WithLabelValues("mysql", queryDescription).Observe(benchmarc.Seconds())
log.WithFields(log.Fields{
"context": "sql",
"benchmark": benchmarc,
"query": prettyPrintedSql{query},
"args": prettyPrintedArgs{args},
"affected_Rows": rowsCount,
}).Debug("Finished FetchAll")
}
}()
if errQuery != nil {
return [][]interface{}{}, errQuery
}
defer rows.Close()
columnTypes, errCT := rows.ColumnTypes()
if errCT != nil {
return [][]interface{}{}, errCT
}
colsPerRow := len(columnTypes)
buf := list.New()
bridges := make([]dbTypeBridge, colsPerRow)
scanDest := make([]interface{}, colsPerRow)
for i, columnType := range columnTypes {
typ := columnType.DatabaseTypeName()
factory, hasFactory := dbTypeBridgeFactories[typ]
if hasFactory {
bridges[i] = factory()
} else {
bridges[i] = &dbBrokenBridge{typ: typ}
}
scanDest[i] = bridges[i]
}
for {
if rows.Next() {
if errScan := rows.Scan(scanDest...); errScan != nil {
return [][]interface{}{}, errScan
}
row := make([]interface{}, colsPerRow)
for i, bridge := range bridges {
row[i] = bridge.Result()
}
buf.PushBack(row)
} else if errNx := rows.Err(); errNx == nil {
break
} else {
return nil, errNx
}
}
res := make([][]interface{}, buf.Len())
for current, i := buf.Front(), 0; current != nil; current = current.Next() {
res[i] = current.Value.([]interface{})
i++
}
rowsCount = len(res)
return res, nil
}
// sqlTransaction executes the given function inside a transaction.
func (dbw DBWrapper) SqlTransaction(concurrencySafety bool, retryOnConnectionFailure bool, quiet bool, f func(DbTransaction) error) error {
for {
if !dbw.IsConnected() {
dbw.WaitForConnection()
continue
}
var benchmarc *icingadb_utils.Benchmark
if !quiet {
benchmarc = icingadb_utils.NewBenchmark()
}
errTx := dbw.sqlTryTransaction(f, concurrencySafety, false)
if !quiet {
benchmarc.Stop()
DbIoSeconds.WithLabelValues("mysql", "transaction").Observe(benchmarc.Seconds())
log.WithFields(log.Fields{
"context": "sql",
"benchmark": benchmarc,
}).Debug("Executed transaction")
}
if errTx != nil {
//TODO: Do this only for concurrencySafety = true, once we figure out the serialization errors.
if isSerializationFailure(errTx) {
if !quiet {
log.WithFields(log.Fields{
"context": "sql",
"error": errTx,
}).Debug("Repeating transaction")
}
continue
}
if !dbw.checkConnection(false) {
if retryOnConnectionFailure {
continue
} else {
return MysqlConnectionError{"Transaction failed duo to a connection error"}
}
}
log.WithFields(log.Fields{
"context": "sql",
"error": errTx,
}).Warn("SQL error occurred")
}
return errTx
}
}
// Executes the given function inside a transaction
func (dbw *DBWrapper) sqlTryTransaction(f func(transaction DbTransaction) error, concurrencySafety bool, quiet bool) error {
tx, errBegin := dbw.SqlBegin(concurrencySafety, quiet)
if errBegin != nil {
return errBegin
}
errTx := f(tx)
if errTx != nil {
dbw.SqlRollback(tx, quiet)
return errTx
}
return dbw.SqlCommit(tx, quiet)
}
func (dbw *DBWrapper) SqlFetchIds(envId []byte, table string) ([]string, error) {
var keys []string
for {
if !dbw.IsConnected() {
dbw.WaitForConnection()
continue
}
rows, err := dbw.SqlQuery(fmt.Sprintf("SELECT id FROM %s WHERE env_id=(X'%s')", table, icingadb_utils.DecodeChecksum(envId)))
if err != nil {
if !dbw.checkConnection(false) {
continue
}
return nil, err
}
defer rows.Close()
for rows.Next() {
var id []byte
err = rows.Scan(&id)
if err != nil {
return nil, err
}
keys = append(keys, icingadb_utils.DecodeChecksum(id))
}
err = rows.Err()
if err != nil {
return nil, err
}
return keys, nil
}
}
func (dbw *DBWrapper) SqlFetchChecksums(table string, ids []string) (map[string]map[string]string, error) {
var checksums = map[string]map[string]string{}
done := make(chan struct{})
//TODO: Don't do this hardcoded - Chunksize
for bulk := range icingadb_utils.ChunkKeys(done, ids, 1000) {
//TODO: This should be done in parallel
query := fmt.Sprintf("SELECT id, properties_checksum FROM %s WHERE id IN (X'%s')", table, strings.Join(bulk, "', X'"))
rows, err := dbw.SqlQuery(query)
if err != nil {
if !dbw.checkConnection(false) {
continue
}
return nil, err
}
defer rows.Close()
for rows.Next() {
var id []byte
var propertiesChecksum []byte
err = rows.Scan(&id, &propertiesChecksum)
if err != nil {
return nil, err
}
checksums[icingadb_utils.DecodeChecksum(id)] = map[string]string{
"properties_checksum": icingadb_utils.DecodeChecksum(propertiesChecksum),
}
}
err = rows.Err()
if err != nil {
return nil, err
}
}
return checksums, nil
}
func (dbw *DBWrapper) SqlBulkInsert(rows []configobject.Row, stmt *BulkInsertStmt) error {
if len(rows) == 0 {
return nil
}
placeholders := make([]string, len(rows))
values := make([]interface{}, len(rows)*stmt.NumField)
j := 0
for i, r := range rows {
placeholders[i] = stmt.Placeholder
for _, v := range r.InsertValues() {
values[j] = v
j++
}
}
query := fmt.Sprintf(stmt.Format, strings.Join(placeholders, ", "))
_, err := dbw.WithRetry(func() (result sql.Result, e error) {
return dbw.SqlExec("Bulk insert", query, values...)
})
if err != nil {
return err
}
return nil
}
func (dbw *DBWrapper) SqlBulkDelete(keys []string, stmt *BulkDeleteStmt) error {
if len(keys) == 0 {
return nil
}
done := make(chan struct{})
defer close(done)
//TODO: Don't do this hardcoded - Chunksize
for bulk := range icingadb_utils.ChunkKeys(done, keys, 1000) {
placeholders := strings.TrimSuffix(strings.Repeat("?, ", len(bulk)), ", ")
values := make([]interface{}, len(bulk))
for i, key := range bulk {
values[i] = icingadb_utils.Checksum(key)
}
query := fmt.Sprintf(stmt.Format, placeholders)
_, err := dbw.WithRetry(func() (result sql.Result, e error) {
return dbw.SqlExec("Bulk delete", query, values...)
})
if err != nil {
return err
}
}
return nil
}
func (dbw *DBWrapper) SqlBulkUpdate(rows []configobject.Row, stmt *BulkUpdateStmt) error {
if len(rows) == 0 {
return nil
}
placeholders := make([]string, len(rows))
values := make([]interface{}, len(rows)*stmt.NumField)
j := 0
for i, r := range rows {
placeholders[i] = stmt.Placeholder
for _, v := range r.InsertValues() {
values[j] = v
j++
}
}
query := fmt.Sprintf(stmt.Format, strings.Join(placeholders, ", "))
_, err := dbw.WithRetry(func() (result sql.Result, e error) {
return dbw.SqlExec("Bulk update", query, values...)
})
if err != nil {
return err
}
return nil
}

318
connection/mysql_test.go Normal file
View file

@ -0,0 +1,318 @@
package connection
import (
"context"
"database/sql"
"errors"
"github.com/go-sql-driver/mysql"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"sync"
"sync/atomic"
"testing"
"time"
)
type SqlResultMock struct {
sql.Result
}
type TransactionMock struct {
mock.Mock
}
func (m *TransactionMock) Query(query string, args ...interface{}) (*sql.Rows, error) {
args2 := m.Called(query, args)
return args2.Get(0).(*sql.Rows), args2.Error(1)
}
func (m *TransactionMock) Exec(query string, args ...interface{}) (sql.Result, error) {
args2 := m.Called(query, args)
return args2.Get(0).(sql.Result), args2.Error(1)
}
func (m *TransactionMock) Commit() error {
args := m.Called()
return args.Error(0)
}
func (m *TransactionMock) Rollback() error {
args := m.Called()
return args.Error(0)
}
type DbMock struct {
mock.Mock
}
func (m *DbMock) Ping() error {
args := m.Called()
return args.Error(0)
}
func (m *DbMock) Query(query string, args ...interface{}) (*sql.Rows, error) {
args2 := m.Called(query, args)
return args2.Get(0).(*sql.Rows), args2.Error(1)
}
func (m *DbMock) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
args := m.Called(ctx, opts)
return args.Get(0).(*sql.Tx), args.Error(1)
}
func (m *DbMock) Exec(query string, args ...interface{}) (sql.Result, error) {
args2 := m.Called(query, args)
return args2.Get(0).(sql.Result), args2.Error(1)
}
func NewTestDBW(db DbClient) DBWrapper {
dbw := DBWrapper{Db: db, ConnectedAtomic: new(uint32), ConnectionLostCounterAtomic: new(uint32)}
dbw.ConnectionUpCondition = sync.NewCond(&sync.Mutex{})
return dbw
}
func TestNewDBWrapper(t *testing.T) {
_, err := NewDBWrapper("asdasd")
assert.Error(t, err)
//TODO: Add more tests here
}
func TestDBWrapper_CheckConnection(t *testing.T) {
mockDb := new(DbMock)
dbw := NewTestDBW(mockDb)
atomic.StoreUint32(dbw.ConnectionLostCounterAtomic, 512312312)
mockDb.On("Ping").Return(nil).Once()
assert.True(t, dbw.checkConnection(false), "DBWrapper should be connected")
assert.Equal(t, uint32(0), atomic.LoadUint32(dbw.ConnectionLostCounterAtomic))
atomic.StoreUint32(dbw.ConnectionLostCounterAtomic, 0)
mockDb.On("Ping").Return(mysql.ErrInvalidConn).Once()
assert.False(t, dbw.checkConnection(false), "DBWrapper should not be connected")
assert.Equal(t, uint32(0), atomic.LoadUint32(dbw.ConnectionLostCounterAtomic))
atomic.StoreUint32(dbw.ConnectionLostCounterAtomic, 10)
mockDb.On("Ping").Return(mysql.ErrInvalidConn).Once()
assert.False(t, dbw.checkConnection(true), "DBWrapper should not be connected")
assert.Equal(t, uint32(11), atomic.LoadUint32(dbw.ConnectionLostCounterAtomic))
}
func TestDBWrapper_SqlCommit(t *testing.T) {
mockDb := new(DbMock)
dbw := NewTestDBW(mockDb)
mockTx := new(TransactionMock)
mockTx.On("Commit").Return(errors.New("whoops")).Once()
mockTx.On("Commit").Return( nil).Once()
mockDb.On("Ping").Return(errors.New("whoops")).Once()
var err error
done := make(chan bool)
dbw.CompareAndSetConnected(true)
go func() {
err = dbw.SqlCommit(mockTx, false)
done <- true
}()
time.Sleep(time.Millisecond * 50)
dbw.CompareAndSetConnected(true)
dbw.ConnectionUpCondition.Broadcast()
<- done
assert.NoError(t, err)
mockTx.AssertExpectations(t)
mockDb.AssertExpectations(t)
}
func TestDBWrapper_SqlBegin(t *testing.T) {
mockDb := new(DbMock)
dbw := NewTestDBW(mockDb)
mockDb.On("BeginTx", context.Background(), &sql.TxOptions{Isolation: sql.LevelReadCommitted}).Return(&sql.Tx{}, errors.New("whoops")).Once()
mockDb.On("BeginTx", context.Background(), &sql.TxOptions{Isolation: sql.LevelReadCommitted}).Return(&sql.Tx{}, nil).Once()
mockDb.On("Ping").Return(errors.New("whoops")).Once()
var err error
done := make(chan bool)
dbw.CompareAndSetConnected(true)
go func() {
_, err = dbw.SqlBegin(false, false)
done <- true
}()
time.Sleep(time.Millisecond * 50)
dbw.CompareAndSetConnected(true)
dbw.ConnectionUpCondition.Broadcast()
<- done
assert.NoError(t, err)
mockDb.AssertExpectations(t)
}
func TestDBWrapper_SqlTransaction(t *testing.T) {
dbw, err := NewDBWrapper( "module-dev:icinga0815!@tcp(127.0.0.1:3306)/icingadb")
assert.NoError(t, err, "Is the MySQL server running?")
err = dbw.SqlTransaction(false, true, false, func(tx DbTransaction) error {
return nil
})
assert.NoError(t, err)
err = dbw.SqlTransaction(false, true, false, func(tx DbTransaction) error {
return errors.New("whoops")
})
assert.Error(t, err)
}
func TestDBWrapper_WithRetry(t *testing.T) {
mockDb := new(DbMock)
dbw := NewTestDBW(mockDb)
tries := 0
_, err := dbw.WithRetry(func() (result sql.Result, e error) {
if tries > 0 {
tries++
return nil, nil
} else {
tries++
return nil, errors.New("Deadlock found when trying to get lock")
}
})
assert.NoError(t, err)
assert.Equal(t, 2, tries)
_, err = dbw.WithRetry(func() (result sql.Result, e error) {
return nil, errors.New("something went wrong")
})
assert.Error(t, err)
}
func TestDBWrapper_SqlQuery(t *testing.T) {
mockDb := new(DbMock)
dbw := NewTestDBW(mockDb)
mockDb.On("Query", "test", []interface{}(nil)).Return(&sql.Rows{}, errors.New("whoops")).Once()
mockDb.On("Query", "test", []interface{}(nil)).Return(&sql.Rows{}, nil).Once()
mockDb.On("Ping").Return(errors.New("whoops")).Once()
var err error
done := make(chan bool)
dbw.CompareAndSetConnected(true)
go func() {
_, err = dbw.SqlQuery("test")
done <- true
}()
time.Sleep(time.Millisecond * 50)
dbw.CompareAndSetConnected(true)
dbw.ConnectionUpCondition.Broadcast()
<- done
assert.NoError(t, err)
mockDb.AssertExpectations(t)
}
func TestDBWrapper_SqlExec(t *testing.T) {
mockDb := new(DbMock)
dbw := NewTestDBW(mockDb)
mockDb.On("Exec", "test", []interface{}(nil)).Return(SqlResultMock{}, errors.New("whoops")).Once()
mockDb.On("Exec", "test", []interface{}(nil)).Return(SqlResultMock{}, nil).Once()
mockDb.On("Ping").Return(errors.New("whoops")).Once()
var err error
done := make(chan bool)
dbw.CompareAndSetConnected(true)
go func() {
_, err = dbw.SqlExec("test", "test")
done <- true
}()
time.Sleep(time.Millisecond * 50)
dbw.CompareAndSetConnected(true)
dbw.ConnectionUpCondition.Broadcast()
<- done
assert.NoError(t, err)
mockDb.AssertExpectations(t)
}
func TestGetConnectionCheckInterval(t *testing.T) {
dbw := NewTestDBW(nil)
//Should return 15s, if connected - counter doesn't madder
dbw.CompareAndSetConnected(true)
assert.Equal(t, 15*time.Second, dbw.getConnectionCheckInterval())
//Should return 5s, if not connected and counter < 4
dbw.CompareAndSetConnected(false)
atomic.StoreUint32(dbw.ConnectionLostCounterAtomic, 0)
assert.Equal(t, 5*time.Second, dbw.getConnectionCheckInterval())
//Should return 10s, if not connected and 4 <= counter < 8
dbw.CompareAndSetConnected(false)
atomic.StoreUint32(dbw.ConnectionLostCounterAtomic, 4)
assert.Equal(t, 10*time.Second, dbw.getConnectionCheckInterval())
//Should return 30s, if not connected and 8 <= counter < 11
dbw.CompareAndSetConnected(false)
atomic.StoreUint32(dbw.ConnectionLostCounterAtomic, 8)
assert.Equal(t, 30*time.Second, dbw.getConnectionCheckInterval())
//Should return 60s, if not connected and 11 <= counter < 14
dbw.CompareAndSetConnected(false)
atomic.StoreUint32(dbw.ConnectionLostCounterAtomic, 11)
assert.Equal(t, 60*time.Second, dbw.getConnectionCheckInterval())
//dbw.ConnectionLostCounter = 14
//interval = dbw.getConnectionCheckInterval()
//TODO: Check for Fatal
}
func TestDBWrapper_SqlFetchAll(t *testing.T) {
dbw, err := NewDBWrapper("module-dev:icinga0815!@tcp(127.0.0.1:3306)/icingadb")
assert.NoError(t, err, "Is the MySQL server running?")
_, err = dbw.Db.Exec("CREATE TABLE testing0815 (id INT NOT NULL AUTO_INCREMENT PRIMARY KEY, name varchar(255) NOT NULL)")
assert.NoError(t, err)
_, err = dbw.Db.Exec("INSERT INTO testing0815 (name) VALUES ('horst'), ('test')")
assert.NoError(t, err)
var res [][]interface{}
done := make(chan bool)
dbw.CompareAndSetConnected(false)
go func() {
res, err = dbw.SqlFetchAll("test", "SELECT * FROM testing0815")
done <- true
}()
time.Sleep(time.Millisecond * 50)
dbw.checkConnection(true)
<- done
assert.NoError(t, err)
assert.Equal(t, [][]interface {}([][]interface {}{{int64(1), "horst"}, {int64(2), "test"}}), res)
_, err = dbw.Db.Exec("DROP TABLE testing0815")
assert.NoError(t, err)
}

425
connection/mysql_utils.go Normal file
View file

@ -0,0 +1,425 @@
package connection
import (
"database/sql"
"encoding/hex"
"errors"
"fmt"
"github.com/go-sql-driver/mysql"
"io/ioutil"
"reflect"
"sort"
"strconv"
"strings"
log "github.com/sirupsen/logrus"
oldlog "log"
)
// mkMysql creates a new MySQL client.
func mkMysql(dbType string, dbDsn string) (*sql.DB, error) {
log.Info("Connecting to MySQL")
sep := "?"
if dbDsn == "" {
dbDsn = "/"
} else {
dsnParts := strings.Split(dbDsn, "/")
if strings.Contains(dsnParts[len(dsnParts)-1], "?") {
sep = "&"
}
}
dbDsn = dbDsn + sep +
"innodb_strict_mode=1&sql_mode='STRICT_ALL_TABLES,NO_ZERO_IN_DATE,NO_ZERO_DATE,NO_ENGINE_SUBSTITUTION,PIPES_AS_CONCAT,ANSI_QUOTES,ERROR_FOR_DIVISION_BY_ZERO,NO_AUTO_CREATE_USER'"
db, errConn := sql.Open(dbType, dbDsn)
if errConn != nil {
return nil, errConn
}
mysql.SetLogger(oldlog.New(ioutil.Discard, "", 0))
db.SetMaxOpenConns(50)
db.SetMaxIdleConns(0)
return db, nil
}
type dbTypeBridge interface {
sql.Scanner
Result() interface{}
}
type dbIntBridge struct {
result interface{}
}
func (d *dbIntBridge) Scan(src interface{}) (err error) {
baseScanner := sql.NullInt64{}
err = baseScanner.Scan(src)
if err == nil {
if baseScanner.Valid {
d.result = baseScanner.Int64
} else {
d.result = nil
}
}
return
}
func (d *dbIntBridge) Result() interface{} {
return d.result
}
type dbFloatBridge struct {
result interface{}
}
func (d *dbFloatBridge) Scan(src interface{}) (err error) {
baseScanner := sql.NullFloat64{}
err = baseScanner.Scan(src)
if err == nil {
if baseScanner.Valid {
d.result = baseScanner.Float64
} else {
d.result = nil
}
}
return
}
func (d *dbFloatBridge) Result() interface{} {
return d.result
}
type dbStringBridge struct {
result interface{}
}
func (d *dbStringBridge) Scan(src interface{}) (err error) {
baseScanner := sql.NullString{}
err = baseScanner.Scan(src)
if err == nil {
if baseScanner.Valid {
d.result = baseScanner.String
} else {
d.result = nil
}
}
return
}
func (d *dbStringBridge) Result() interface{} {
return d.result
}
type dbBytesBridge struct {
result interface{}
}
func (d *dbBytesBridge) Scan(src interface{}) (err error) {
baseScanner := sql.NullString{}
err = baseScanner.Scan(src)
if err == nil {
if baseScanner.Valid {
d.result = []byte(baseScanner.String)
} else {
d.result = nil
}
}
return
}
func (d *dbBytesBridge) Result() interface{} {
return d.result
}
var dbTypeBridgeFactories = map[string]func() dbTypeBridge{
// MySQL
"TINYINT": func() dbTypeBridge {
return &dbIntBridge{}
},
"SMALLINT": func() dbTypeBridge {
return &dbIntBridge{}
},
"INT": func() dbTypeBridge {
return &dbIntBridge{}
},
"BIGINT": func() dbTypeBridge {
return &dbIntBridge{}
},
"FLOAT": func() dbTypeBridge {
return &dbFloatBridge{}
},
"CHAR": func() dbTypeBridge {
return &dbStringBridge{}
},
"VARCHAR": func() dbTypeBridge {
return &dbStringBridge{}
},
"ENUM": func() dbTypeBridge {
return &dbStringBridge{}
},
"BINARY": func() dbTypeBridge {
return &dbBytesBridge{}
},
// SQLite
"INTEGER": func() dbTypeBridge {
return &dbIntBridge{}
},
"REAL": func() dbTypeBridge {
return &dbFloatBridge{}
},
"TEXT": func() dbTypeBridge {
return &dbStringBridge{}
},
"BLOB": func() dbTypeBridge {
return &dbBytesBridge{}
},
// SELECT 1 FROM ...
"": func() dbTypeBridge {
return &dbIntBridge{}
},
}
type dbBrokenBridge struct {
typ string
}
func (d *dbBrokenBridge) Scan(src interface{}) error {
types := make([]string, len(dbTypeBridgeFactories))
typeIdx := 0
for typ := range dbTypeBridgeFactories {
types[typeIdx] = typ
typeIdx++
}
sort.Strings(types)
return errors.New(fmt.Sprintf("bad column type %s, expected one of %s", d.typ, strings.Join(types, ", ")))
}
func (d *dbBrokenBridge) Result() interface{} {
return nil
}
var prettyPrintedSqlReplacer = strings.NewReplacer("\n", " ", "\t", "")
type prettyPrintedSql struct {
sql string
}
// String implements and interface from Stringer
func (p prettyPrintedSql) String() string {
return strings.TrimSpace(prettyPrintedSqlReplacer.Replace(p.sql))
}
// MarshalText implements an interface from TextMarshaler
func (p prettyPrintedSql) MarshalText() (text []byte, err error) {
return []byte(p.String()), nil
}
type prettyPrintedArgs struct {
args []interface{}
}
func (p *prettyPrintedArgs) String() string {
res := "["
for _, v := range p.args {
if byteArray, isByteArray := v.([]byte); isByteArray {
res = fmt.Sprintf("%s hex.DecodeString(\"%s\"),", res, hex.EncodeToString(byteArray))
} else {
res = fmt.Sprintf("%s %#v,", res, v)
}
}
return res + " ]"
}
// MarshalText implements an interface from TextMarshaler
func (p prettyPrintedArgs) MarshalText() (text []byte, err error) {
return []byte(p.String()), nil
}
type prettyPrintedRowsAffected struct {
result sql.Result
}
// String implements and interface from Stringer
func (d prettyPrintedRowsAffected) String() string {
if d.result != nil {
rows, errRA := d.result.RowsAffected()
if errRA == nil {
return strconv.FormatInt(rows, 10)
}
}
return "N/A"
}
// MarshalText implements an interface from TextMarshaler
func (d prettyPrintedRowsAffected) MarshalText() (text []byte, err error) {
return []byte(d.String()), nil
}
type MysqlConnectionError struct {
err string
}
func (e MysqlConnectionError) Error() string {
return e.err
}
// Returns whether the given error signals serialization failure
// https://dev.mysql.com/doc/refman/5.5/en/error-messages-server.html#error_er_lock_deadlock
func isSerializationFailure(e error) bool {
switch err := e.(type) {
case *mysql.MySQLError:
switch err.Number {
// Those are the error numbers for serialization failures, upon which we retry
case 1205, 1213:
return true
}
}
return false
}
func formatLogQuery(query string) string {
r := strings.NewReplacer("\n", " ", "\t", "")
return strings.TrimSpace(r.Replace(query))
}
// Go bool -> DB bool
var yesNo = map[bool]string{
true: "y",
false: "n",
}
func ConvertValueForDb(in interface{}) (interface{}, error) {
switch value := in.(type) {
case []byte:
case string:
case float64:
case int64:
case nil:
case float32:
return float64(value), nil
case uint:
return int64(value), nil
case uint8:
return int64(value), nil
case uint16:
return int64(value), nil
case uint32:
return int64(value), nil
case uint64:
return int64(value), nil
case int:
return int64(value), nil
case int8:
return int64(value), nil
case int16:
return int64(value), nil
case int32:
return int64(value), nil
case bool:
return yesNo[value], nil
default:
return nil, errors.New(fmt.Sprintf(
"bad type %s, expected one of []byte, string, float{32,64}, {,u}int{,8,16,32,64}, bool, nil",
reflect.TypeOf(in).Name(),
))
}
return in, nil
}
func MakePlaceholderList(x int) string {
runes := make([]rune, 1+x*2)
i := 1
for j := 0; j < x; j++ {
runes[i] = '?'
i++
runes[i] = ','
i++
}
runes[0] = '('
runes[len(runes)-1] = ')'
return string(runes)
}
func isRetryableError(err error) bool {
if strings.Contains(err.Error(), "Deadlock found when trying to get lock") {
return true
}
return false
}
type BulkInsertStmt struct {
Format string
Fields []string
Placeholder string
NumField int
}
func NewBulkInsertStmt(table string, fields []string) *BulkInsertStmt {
numField := len(fields)
placeholder := fmt.Sprintf("(%s)", strings.TrimSuffix(strings.Repeat("?, ", numField), ", "))
stmt := BulkInsertStmt{
Format: fmt.Sprintf("INSERT INTO %s (%s) VALUES %s ON DUPLICATE KEY UPDATE id = id", table, strings.Join(fields, ", "), "%s"),
Fields: fields,
Placeholder: placeholder,
NumField: numField,
}
return &stmt
}
type BulkDeleteStmt struct {
Format string
}
func NewBulkDeleteStmt(table string) *BulkDeleteStmt {
stmt := BulkDeleteStmt{
Format: fmt.Sprintf("DELETE FROM %s WHERE id IN (%s)", table, "%s"),
}
return &stmt
}
type BulkUpdateStmt struct {
Format string
Fields []string
Placeholder string
NumField int
}
func NewBulkUpdateStmt(table string, fields []string) *BulkUpdateStmt {
numField := len(fields)
placeholder := fmt.Sprintf("(%s)", strings.TrimSuffix(strings.Repeat("?, ", numField), ", "))
stmt := BulkUpdateStmt{
Format: fmt.Sprintf("REPLACE INTO %s (%s) VALUES %s", table, strings.Join(fields, ", "), "%s"),
Fields: fields,
Placeholder: placeholder,
NumField: numField,
}
return &stmt
}

View file

@ -0,0 +1,104 @@
package connection
import (
"errors"
"github.com/go-sql-driver/mysql"
"github.com/stretchr/testify/assert"
"testing"
)
func TestMakePlaceholderList(t *testing.T) {
assert.Equal(t, "(?)", MakePlaceholderList(1))
assert.Equal(t, "(?,?,?,?,?)", MakePlaceholderList(5))
assert.Equal(t, "(?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)", MakePlaceholderList(20))
}
func TestConvertValueForDb(t *testing.T) {
var v interface{}
var err error
v, err = ConvertValueForDb(nil)
assert.IsType(t, nil, v)
assert.Nil(t, err)
v, err = ConvertValueForDb([]byte{100})
assert.IsType(t, []byte{100}, v)
assert.Nil(t, err)
v, err = ConvertValueForDb("this-is-a-string")
assert.IsType(t, "this-is-a-string", v)
assert.Nil(t, err)
v, err = ConvertValueForDb(float32(123.456))
assert.IsType(t, float64(123.456), v)
assert.Nil(t, err)
v, err = ConvertValueForDb(float64(123.456))
assert.IsType(t, float64(123.456), v)
assert.Nil(t, err)
v, err = ConvertValueForDb(uint(20))
assert.IsType(t, int64(10), v)
assert.Nil(t, err)
v, err = ConvertValueForDb(uint8(30))
assert.IsType(t, int64(10), v)
assert.Nil(t, err)
v, err = ConvertValueForDb(uint16(40))
assert.IsType(t, int64(10), v)
assert.Nil(t, err)
v, err = ConvertValueForDb(uint32(50))
assert.IsType(t, int64(10), v)
assert.Nil(t, err)
v, err = ConvertValueForDb(uint64(60))
assert.IsType(t, int64(10), v)
assert.Nil(t, err)
v, err = ConvertValueForDb(int(70))
assert.IsType(t, int64(10), v)
assert.Nil(t, err)
v, err = ConvertValueForDb(int8(80))
assert.IsType(t, int64(10), v)
assert.Nil(t, err)
v, err = ConvertValueForDb(int16(90))
assert.IsType(t, int64(10), v)
assert.Nil(t, err)
v, err = ConvertValueForDb(int32(100))
assert.IsType(t, int64(10), v)
assert.Nil(t, err)
v, err = ConvertValueForDb(int64(10))
assert.IsType(t, int64(10), v)
assert.Nil(t, err)
v, err = ConvertValueForDb(true)
assert.IsType(t, "y/n-string", v)
assert.Nil(t, err)
//Should not be possible
v, err = ConvertValueForDb(errors.New("test"))
assert.NotNil(t, err)
}
func TestIsSerializationFailure(t *testing.T) {
assert.True(t, isSerializationFailure(&mysql.MySQLError{Number: 1205}))
assert.True(t, isSerializationFailure(&mysql.MySQLError{Number: 1213}))
assert.False(t, isSerializationFailure(&mysql.MySQLError{Number: 6342}))
assert.False(t, isSerializationFailure(errors.New("random error")))
}
func TestMysqlConnectionError_Error(t *testing.T) {
err := MysqlConnectionError{"The chicken has left the database!"}
assert.Equal(t, "The chicken has left the database!", err.Error())
}
func TestFormatLogQuery(t *testing.T) {
assert.Equal(t, "This is my string", formatLogQuery("\tThis is\nmy string\n"))
}

29
connection/prometheus.go Normal file
View file

@ -0,0 +1,29 @@
package connection
import (
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)
var DbIoSeconds = promauto.NewSummaryVec(
prometheus.SummaryOpts{
Name: "db_io_seconds",
Help: "Database I/O (s)",
},
[]string{"backend_type", "operation"},
)
var DbOperationsTotal = promauto.NewCounter(prometheus.CounterOpts{
Name: "db_operations_total",
Help: "Database operations since startup",
})
var DbOperationsQuery = promauto.NewCounter(prometheus.CounterOpts{
Name: "db_operations_query",
Help: "Database query operations since startup",
})
var DbOperationsExec = promauto.NewCounter(prometheus.CounterOpts{
Name: "db_operations_exec",
Help: "Database exec operations since startup",
})

466
connection/redis.go Normal file
View file

@ -0,0 +1,466 @@
package connection
import (
"fmt"
"git.icinga.com/icingadb/icingadb-utils"
"github.com/go-redis/redis"
log "github.com/sirupsen/logrus"
"sync"
"sync/atomic"
"time"
)
type Icinga2RedisWriterEventsConfig struct {
Update, Delete, Dump string
}
type Icinga2RedisWriterKeyPrefixesConfig struct {
Checksum, Object, Customvar string
}
type Icinga2RedisWriterKeyPrefixesStatus struct {
Object string
}
type Icinga2RedisWriterEvents struct {
Config Icinga2RedisWriterEventsConfig
Stats string
}
type Icinga2RedisWriterKeyPrefixes struct {
Config Icinga2RedisWriterKeyPrefixesConfig
Status Icinga2RedisWriterKeyPrefixesStatus
}
type Icinga2RedisWriter struct {
Events Icinga2RedisWriterEvents
KeyPrefixes Icinga2RedisWriterKeyPrefixes
}
var RedisWriter = Icinga2RedisWriter{
Events: Icinga2RedisWriterEvents{
Config: Icinga2RedisWriterEventsConfig{
Dump: "icinga:config:dump",
Delete: "icinga:config:delete",
Update: "icinga:config:update",
},
Stats: "icinga:stats",
},
KeyPrefixes: Icinga2RedisWriterKeyPrefixes{
Config: Icinga2RedisWriterKeyPrefixesConfig{
Checksum: "icinga:config:checksum:",
Object: "icinga:config:object:",
Customvar: "icinga:config:customvar:",
},
Status: Icinga2RedisWriterKeyPrefixesStatus{
Object: "icinga:state:object:",
},
},
}
type RedisClient interface {
Ping() *redis.StatusCmd
Publish(channel string, message interface{}) *redis.IntCmd
XRead(a *redis.XReadArgs) *redis.XStreamSliceCmd
XDel(stream string, ids ...string) *redis.IntCmd
HKeys(key string) *redis.StringSliceCmd
HMGet(key string, fields ...string) *redis.SliceCmd
HGetAll(key string) *redis.StringStringMapCmd
TxPipelined(fn func(redis.Pipeliner) error) ([]redis.Cmder, error)
Pipeline() redis.Pipeliner
Subscribe(channels ...string) *redis.PubSub
}
type StatusCmd interface {
}
// Redis wrapper including helper functions
type RDBWrapper struct {
Rdb RedisClient
ConnectedAtomic *uint32 //uint32 to be able to use atomic operations
ConnectionUpCondition *sync.Cond
ConnectionLostCounterAtomic *uint32 //uint32 to be able to use atomic operations
}
func (rdbw *RDBWrapper) IsConnected() bool {
return atomic.LoadUint32(rdbw.ConnectedAtomic) != 0
}
func (rdbw *RDBWrapper) CompareAndSetConnected(connected bool) (swapped bool) {
if connected {
return atomic.CompareAndSwapUint32(rdbw.ConnectedAtomic, 0, 1)
} else {
return atomic.CompareAndSwapUint32(rdbw.ConnectedAtomic, 1, 0)
}
}
func NewRDBWrapper(address string) (*RDBWrapper, error) {
rdb := redis.NewClient(&redis.Options{
Addr: address,
DialTimeout: time.Minute / 2,
ReadTimeout: time.Minute,
WriteTimeout: time.Minute,
})
rdbw := RDBWrapper{
Rdb: rdb, ConnectedAtomic: new(uint32),
ConnectionLostCounterAtomic: new(uint32),
ConnectionUpCondition: sync.NewCond(&sync.Mutex{}),
}
_, err := rdbw.Rdb.Ping().Result()
if err != nil {
return nil, err
}
go func() {
for {
rdbw.CheckConnection(true)
time.Sleep(rdbw.getConnectionCheckInterval())
}
}()
return &rdbw, nil
}
func (rdbw *RDBWrapper) getConnectionCheckInterval() time.Duration {
if !rdbw.IsConnected() {
v := atomic.LoadUint32(rdbw.ConnectionLostCounterAtomic)
if v < 4 {
return 5 * time.Second
} else if v < 8 {
return 10 * time.Second
} else if v < 11 {
return 30 * time.Second
} else if v < 14 {
return 60 * time.Second
} else {
log.Fatal("Could not connect to Redis for over 5 minutes. Shutting down...")
}
}
return 15 * time.Second
}
func (rdbw *RDBWrapper) CheckConnection(isTicker bool) bool {
_, err := rdbw.Rdb.Ping().Result()
if err != nil {
if rdbw.CompareAndSetConnected(false) {
log.WithFields(log.Fields{
"context": "redis",
"error": err,
}).Error("Redis connection lost. Trying to reconnect")
} else if isTicker {
atomic.AddUint32(rdbw.ConnectionLostCounterAtomic, 1)
log.WithFields(log.Fields{
"context": "redis",
"error": err,
}).Debugf("Redis connection lost. Trying again in %s", rdbw.getConnectionCheckInterval())
}
return false
} else {
if rdbw.CompareAndSetConnected(true) {
log.Info("Redis connection established")
atomic.StoreUint32(rdbw.ConnectionLostCounterAtomic, 0)
rdbw.ConnectionUpCondition.Broadcast()
}
return true
}
}
func (rdbw *RDBWrapper) WaitForConnection() {
rdbw.ConnectionUpCondition.L.Lock()
rdbw.ConnectionUpCondition.Wait()
rdbw.ConnectionUpCondition.L.Unlock()
}
// Wrapper for connection handling
func (rdbw *RDBWrapper) Publish(channel string, message interface{}) *redis.IntCmd {
for {
if !rdbw.IsConnected() {
rdbw.WaitForConnection()
continue
}
cmd := rdbw.Rdb.Publish(channel, message)
_, err := cmd.Result()
if err != nil {
if !rdbw.CheckConnection(false) {
continue
}
}
return cmd
}
}
// Wrapper for connection handling
func (rdbw *RDBWrapper) XRead(args *redis.XReadArgs) *redis.XStreamSliceCmd {
for {
if !rdbw.IsConnected() {
rdbw.WaitForConnection()
continue
}
cmd := rdbw.Rdb.XRead(args)
_, err := cmd.Result()
if err != nil {
if !rdbw.CheckConnection(false) {
continue
}
}
return cmd
}
}
// Wrapper for connection handling
func (rdbw *RDBWrapper) XDel(stream string, ids ...string) *redis.IntCmd {
for {
if !rdbw.IsConnected() {
rdbw.WaitForConnection()
continue
}
cmd := rdbw.Rdb.XDel(stream, ids...)
_, err := cmd.Result()
if err != nil {
if !rdbw.CheckConnection(false) {
continue
}
}
return cmd
}
}
// Wrapper for connection handling
func (rdbw *RDBWrapper) HKeys(key string) *redis.StringSliceCmd {
for {
if !rdbw.IsConnected() {
rdbw.WaitForConnection()
continue
}
cmd := rdbw.Rdb.HKeys(key)
_, err := cmd.Result()
if err != nil {
if !rdbw.CheckConnection(false) {
continue
}
}
return cmd
}
}
func (rdbw * RDBWrapper) HMGet(key string, fields ...string) *redis.SliceCmd {
for {
if !rdbw.IsConnected() {
rdbw.WaitForConnection()
continue
}
cmd := rdbw.Rdb.HMGet(key, fields...)
_, err := cmd.Result()
if err != nil {
if !rdbw.CheckConnection(false) {
continue
}
}
return cmd
}
}
// Wrapper for auto-logging and connection handling
func (rdbw *RDBWrapper) HGetAll(key string) *redis.StringStringMapCmd {
for {
if !rdbw.IsConnected() {
rdbw.WaitForConnection()
continue
}
benchmarc := icingadb_utils.NewBenchmark()
res := rdbw.Rdb.HGetAll(key)
if _, err := res.Result(); err != nil {
if !rdbw.CheckConnection(false) {
continue
}
}
benchmarc.Stop()
DbIoSeconds.WithLabelValues("redis", "hgetall").Observe(benchmarc.Seconds())
log.WithFields(log.Fields{
"context": "redis",
"benchmark": benchmarc,
"query": "HGETALL " + key,
"result": res.Val(),
}).Debug("Ran Query")
return res
}
}
// Wrapper for auto-logging and connection handling
func (rdbw *RDBWrapper) TxPipelined(fn func(pipeliner redis.Pipeliner) error) ([]redis.Cmder, error) {
for {
if !rdbw.IsConnected() {
rdbw.WaitForConnection()
continue
}
benchmarc := icingadb_utils.NewBenchmark()
cmd, err := rdbw.Rdb.TxPipelined(fn)
if err != nil {
if !rdbw.CheckConnection(false) {
continue
}
}
benchmarc.Stop()
DbIoSeconds.WithLabelValues("redis", "multi").Observe(benchmarc.Seconds())
log.WithFields(log.Fields{
"context": "redis",
"benchmark": benchmarc,
"query": "MULTI/EXEC",
}).Debug("Ran pipelined transaction")
return cmd, err
}
}
func (rdbw *RDBWrapper) Pipeline() PipelinerWrapper {
pipeliner := rdbw.Rdb.Pipeline()
plw := PipelinerWrapper{pipeliner: pipeliner, rdbw: rdbw}
return plw
}
func (rdbw *RDBWrapper) Subscribe() PubSubWrapper {
ps := rdbw.Rdb.Subscribe()
psw := PubSubWrapper{ps: ps, rdbw: rdbw}
return psw
}
type ConfigChunk struct {
Keys []string
Configs []interface{}
Checksums []interface{}
}
type ChecksumChunk struct {
Keys []string
Checksums []interface{}
}
func (rdbw *RDBWrapper) PipeConfigChunks(done <-chan struct{}, keys []string, objectType string) <-chan *ConfigChunk {
out := make(chan *ConfigChunk)
worker := func(chunk <-chan []string) {
for k := range chunk {
pipe := rdbw.Pipeline()
cmds := make([]*redis.SliceCmd, 2)
cmds[0] = pipe.HMGet(fmt.Sprintf("icinga:config:object:%s", objectType), k...)
cmds[1] = pipe.HMGet(fmt.Sprintf("icinga:config:checksum:%s", objectType), k...)
_, err := pipe.Exec() // TODO(el): What to do with the Cmder slice?
if err != nil {
panic(err)
}
configs, err := cmds[0].Result()
if err != nil {
panic(err)
}
checksums, err := cmds[1].Result()
if err != nil {
panic(err)
}
select {
case out <- &ConfigChunk{Keys: k, Configs: configs, Checksums: checksums}:
case <-done:
return
}
}
}
//TODO: Replace fixed chunkSize
work := icingadb_utils.ChunkKeys(done, keys, 500)
go func() {
defer close(out)
wg := &sync.WaitGroup{}
for i := 0; i < 32; i++ {
wg.Add(1)
go func() {
defer wg.Done()
worker(work)
}()
}
wg.Wait()
}()
return out
}
func (rdbw *RDBWrapper) PipeChecksumChunks(done <-chan struct{}, keys []string, objectType string) <-chan *ChecksumChunk {
out := make(chan *ChecksumChunk)
worker := func(chunk <-chan []string) {
for k := range chunk {
cmd := rdbw.HMGet(fmt.Sprintf("icinga:config:checksum:%s", objectType), k...)
checksums, err := cmd.Result()
if err != nil {
panic(err)
}
select {
case out <- &ChecksumChunk{Keys: k, Checksums: checksums}:
case <-done:
return
}
}
}
//TODO: Replace fixed chunkSize
work := icingadb_utils.ChunkKeys(done, keys, 500)
go func() {
defer close(out)
wg := &sync.WaitGroup{}
for i := 0; i < 32; i++ {
wg.Add(1)
go func() {
defer wg.Done()
worker(work)
}()
}
wg.Wait()
}()
return out
}

View file

@ -0,0 +1,47 @@
package connection
import "github.com/go-redis/redis"
type PipelinerWrapper struct {
pipeliner redis.Pipeliner
rdbw *RDBWrapper
}
func (plw *PipelinerWrapper) Exec() ([]redis.Cmder, error) {
for {
if !plw.rdbw.IsConnected() {
plw.rdbw.WaitForConnection()
continue
}
cmder, err := plw.pipeliner.Exec()
if err != nil {
if !plw.rdbw.CheckConnection(false) {
continue
}
}
return cmder, err
}
}
func (plw *PipelinerWrapper) HMGet(key string, fields ...string) *redis.SliceCmd {
for {
if !plw.rdbw.IsConnected() {
plw.rdbw.WaitForConnection()
continue
}
cmd := plw.pipeliner.HMGet(key, fields...)
_, err := cmd.Result()
if err != nil {
if !plw.rdbw.CheckConnection(false) {
continue
}
}
return cmd
}
}

View file

@ -0,0 +1,67 @@
package connection
import (
"github.com/go-redis/redis"
)
type PubSubWrapper struct {
ps *redis.PubSub
rdbw *RDBWrapper
}
func (psw *PubSubWrapper) Subscribe(channels ...string) error {
for {
if !psw.rdbw.IsConnected() {
psw.rdbw.WaitForConnection()
continue
}
err := psw.ps.Subscribe(channels...)
if err != nil {
if !psw.rdbw.CheckConnection(false) {
continue
}
}
return err
}
}
func (psw *PubSubWrapper) ReceiveMessage() (*redis.Message, error) {
for {
if !psw.rdbw.IsConnected() {
psw.rdbw.WaitForConnection()
continue
}
msg, err := psw.ps.ReceiveMessage()
if err != nil {
if !psw.rdbw.CheckConnection(false) {
continue
}
}
return msg, err
}
}
func (psw *PubSubWrapper) Close() error {
for {
if !psw.rdbw.IsConnected() {
psw.rdbw.WaitForConnection()
continue
}
err := psw.ps.Close()
if err != nil {
if !psw.rdbw.CheckConnection(false) {
continue
}
}
return err
}
}

View file

@ -0,0 +1,74 @@
package connection
import (
"github.com/go-redis/redis"
"github.com/stretchr/testify/assert"
"testing"
"time"
)
func TestPubSubWrapper(t *testing.T) {
rdb := redis.NewClient(&redis.Options{
Addr: "127.0.0.1:6379",
DialTimeout: time.Minute / 2,
ReadTimeout: time.Minute,
WriteTimeout: time.Minute,
})
rdbw := NewTestRDBW(rdb)
if !rdbw.CheckConnection(true) {
t.Fatal("This test needs a working Redis connection")
}
ps := rdbw.Subscribe()
rdbw.CompareAndSetConnected(false)
var errSubscribe error
done1:= make(chan bool)
go func () {
errSubscribe = ps.Subscribe("testchannel")
done1 <- true
}()
time.Sleep(50 * time.Millisecond)
rdbw.CheckConnection(true)
<- done1
rdbw.CompareAndSetConnected(false)
var msg *redis.Message
var errReceive error
done2 := make(chan bool)
go func() {
msg, errReceive = ps.ReceiveMessage()
done2 <- true
}()
time.Sleep(50 * time.Millisecond)
rdbw.CheckConnection(true)
rdbw.Publish("testchannel", "Hello there")
<- done2
rdbw.CompareAndSetConnected(false)
var errClose error
done3:= make(chan bool)
go func () {
errClose = ps.Close()
done3 <- true
}()
time.Sleep(50 * time.Millisecond)
rdbw.CheckConnection(true)
<- done3
assert.NoError(t, errSubscribe)
assert.NoError(t, errReceive)
assert.NoError(t, errClose)
assert.Equal(t, "Hello there", msg.Payload)
}

275
connection/redis_test.go Normal file
View file

@ -0,0 +1,275 @@
package connection
import (
"github.com/go-redis/redis"
"github.com/stretchr/testify/assert"
"sync"
"sync/atomic"
"testing"
"time"
)
func NewTestRDBW(rdb RedisClient) RDBWrapper {
dbw := RDBWrapper{Rdb: rdb, ConnectedAtomic: new(uint32), ConnectionLostCounterAtomic: new(uint32)}
dbw.ConnectionUpCondition = sync.NewCond(&sync.Mutex{})
return dbw
}
func TestNewRDBWrapper(t *testing.T) {
_, err := NewRDBWrapper("127.0.0.1:6379")
assert.NoError(t, err, "Redis should be connected")
_, err = NewRDBWrapper("asdasdasdasdasd:5123")
assert.Error(t, err, "Redis should not be connected")
//TODO: Add more tests here
}
func TestRDBWrapper_GetConnectionCheckInterval(t *testing.T) {
rdbw := NewTestRDBW(nil)
//Should return 15s, if connected - counter doesn't madder
rdbw.CompareAndSetConnected(true)
assert.Equal(t, 15*time.Second, rdbw.getConnectionCheckInterval())
//Should return 5s, if not connected and counter < 4
rdbw.CompareAndSetConnected(false)
atomic.StoreUint32(rdbw.ConnectionLostCounterAtomic, 0)
assert.Equal(t, 5*time.Second, rdbw.getConnectionCheckInterval())
//Should return 10s, if not connected and 4 <= counter < 8
rdbw.CompareAndSetConnected(false)
atomic.StoreUint32(rdbw.ConnectionLostCounterAtomic, 4)
assert.Equal(t, 10*time.Second, rdbw.getConnectionCheckInterval())
//Should return 30s, if not connected and 8 <= counter < 11
rdbw.CompareAndSetConnected(false)
atomic.StoreUint32(rdbw.ConnectionLostCounterAtomic, 8)
assert.Equal(t, 30*time.Second, rdbw.getConnectionCheckInterval())
//Should return 60s, if not connected and 11 <= counter < 14
rdbw.CompareAndSetConnected(false)
atomic.StoreUint32(rdbw.ConnectionLostCounterAtomic, 11)
assert.Equal(t, 60*time.Second, rdbw.getConnectionCheckInterval())
//dbw.ConnectionLostCounter = 14
//interval = dbw.getConnectionCheckInterval()
//TODO: Check for Fatal
}
func TestRDBWrapper_CheckConnection(t *testing.T) {
rdbw := NewTestRDBW(nil)
rdbw.Rdb = redis.NewClient(&redis.Options{
Addr: "127.0.0.1:6379",
DialTimeout: time.Minute / 2,
ReadTimeout: time.Minute,
WriteTimeout: time.Minute,
})
atomic.StoreUint32(rdbw.ConnectionLostCounterAtomic, 512312312)
assert.True(t, rdbw.CheckConnection(false), "DBWrapper should be connected")
assert.Equal(t, uint32(0), atomic.LoadUint32(rdbw.ConnectionLostCounterAtomic))
rdbw.Rdb = redis.NewClient(&redis.Options{
Addr: "dasdasdasdasdasd:5123",
DialTimeout: time.Minute / 2,
ReadTimeout: time.Minute,
WriteTimeout: time.Minute,
})
atomic.StoreUint32(rdbw.ConnectionLostCounterAtomic, 0)
assert.False(t, rdbw.CheckConnection(false), "DBWrapper should not be connected")
assert.Equal(t, uint32(0), atomic.LoadUint32(rdbw.ConnectionLostCounterAtomic))
atomic.StoreUint32(rdbw.ConnectionLostCounterAtomic, 10)
assert.False(t, rdbw.CheckConnection(true), "DBWrapper should not be connected")
assert.Equal(t, uint32(11), atomic.LoadUint32(rdbw.ConnectionLostCounterAtomic))
}
func TestRDBWrapper_HGetAll(t *testing.T) {
rdb := redis.NewClient(&redis.Options{
Addr: "127.0.0.1:6379",
DialTimeout: time.Minute / 2,
ReadTimeout: time.Minute,
WriteTimeout: time.Minute,
})
rdbw := NewTestRDBW(rdb)
if !rdbw.CheckConnection(true) {
t.Fatal("This test needs a working Redis connection")
}
rdb.Del("herpdaderp")
rdb.HSet("herpdaderp", "one", 5)
rdb.HSet("herpdaderp", "two", 11)
rdbw.CompareAndSetConnected(false)
var data map[string]string
var err error
done := make(chan bool)
go func() {
data, err = rdbw.HGetAll("herpdaderp")
done <- true
}()
time.Sleep(50 * time.Millisecond)
rdbw.CheckConnection(true)
<- done
assert.NoError(t, err)
assert.Contains(t, data, "one")
assert.Contains(t, data, "two")
}
func TestRDBWrapper_XRead(t *testing.T) {
rdb := redis.NewClient(&redis.Options{
Addr: "127.0.0.1:6379",
DialTimeout: time.Minute / 2,
ReadTimeout: time.Minute,
WriteTimeout: time.Minute,
})
rdbw := NewTestRDBW(rdb)
if !rdbw.CheckConnection(true) {
t.Fatal("This test needs a working Redis connection")
}
rdb.XTrim("teststream", 0)
rdb.XAdd(&redis.XAddArgs{Stream: "teststream", Values: map[string]interface{}{"one": "5", "two": "11", "herp": "11"}})
rdbw.CompareAndSetConnected(false)
var data *redis.XStreamSliceCmd
done := make(chan bool)
go func() {
data = rdbw.XRead(&redis.XReadArgs{Streams: []string{"teststream", "0"}})
done <- true
}()
time.Sleep(50 * time.Millisecond)
rdbw.CheckConnection(true)
<- done
streams, err := data.Result()
assert.NoError(t, err)
value := streams[0].Messages[0].Values
assert.Contains(t, value, "one")
assert.Contains(t, value, "two")
}
func TestRDBWrapper_XDel(t *testing.T) {
rdb := redis.NewClient(&redis.Options{
Addr: "127.0.0.1:6379",
DialTimeout: time.Minute / 2,
ReadTimeout: time.Minute,
WriteTimeout: time.Minute,
})
rdbw := NewTestRDBW(rdb)
if !rdbw.CheckConnection(true) {
t.Fatal("This test needs a working Redis connection")
}
rdb.XTrim("teststream", 0)
adds := rdb.XAdd(&redis.XAddArgs{Stream: "teststream", Values: map[string]interface{}{"one": "5", "two": "11", "herp": "11"}})
rdbw.CompareAndSetConnected(false)
done := make(chan bool)
go func() {
rdbw.XDel("teststream", adds.Val())
done <- true
}()
time.Sleep(50 * time.Millisecond)
rdbw.CheckConnection(true)
<- done
data := rdbw.XRead(&redis.XReadArgs{Streams: []string{"teststream", "0"}, Block: -1})
streams, err := data.Result()
assert.Error(t, err)
assert.Len(t, streams, 0)
}
func TestRDBWrapper_Publish(t *testing.T) {
rdb := redis.NewClient(&redis.Options{
Addr: "127.0.0.1:6379",
DialTimeout: time.Minute / 2,
ReadTimeout: time.Minute,
WriteTimeout: time.Minute,
})
rdbw := NewTestRDBW(rdb)
if !rdbw.CheckConnection(true) {
t.Fatal("This test needs a working Redis connection")
}
var msg *redis.Message
var err error
done := make(chan bool)
go func() {
msg, err = rdb.Subscribe("testchannel").ReceiveMessage()
done <- true
}()
rdbw.CompareAndSetConnected(false)
go func () {
rdbw.Publish("testchannel", "Hello there")
}()
time.Sleep(50 * time.Millisecond)
rdbw.CheckConnection(true)
<- done
assert.NoError(t, err)
assert.Equal(t, "Hello there", msg.Payload)
}
func TestRDBWrapper_TxPipelined(t *testing.T) {
rdb := redis.NewClient(&redis.Options{
Addr: "127.0.0.1:6379",
DialTimeout: time.Minute / 2,
ReadTimeout: time.Minute,
WriteTimeout: time.Minute,
})
rdbw := NewTestRDBW(rdb)
if !rdbw.CheckConnection(true) {
t.Fatal("This test needs a working Redis connection")
}
rdb.Del("firstKey")
rdb.Del("secondKey")
rdb.HSet("firstKey", "foo", 5)
rdb.HSet("secondKey", "bar", 11)
rdbw.CompareAndSetConnected(false)
var firstMap *redis.StringStringMapCmd
var secondMap *redis.StringStringMapCmd
var err error
done := make(chan bool)
go func() {
_, err = rdbw.TxPipelined(func(pipe redis.Pipeliner) error {
firstMap = pipe.HGetAll("firstKey")
secondMap = pipe.HGetAll("secondKey")
return nil
})
done <- true
}()
time.Sleep(50 * time.Millisecond)
rdbw.CheckConnection(true)
<- done
assert.NoError(t, err)
assert.Contains(t, firstMap.Val(), "foo")
assert.Contains(t, secondMap.Val(), "bar")
}

3
connection/test_db.sql Normal file
View file

@ -0,0 +1,3 @@
CREATE database icingadb;
CREATE USER 'module-dev'@'127.0.0.1' IDENTIFIED BY 'icinga0815!';
GRANT ALL PRIVILEGES ON icingadb.* TO 'module-dev'@'127.0.0.1';