diff --git a/mysql_test.go b/mysql_test.go index 4b355d43..d1998edf 100644 --- a/mysql_test.go +++ b/mysql_test.go @@ -15,6 +15,7 @@ import ( type SqlResultMock struct { sql.Result } + type TransactionMock struct { mock.Mock } @@ -95,6 +96,64 @@ func TestRDBWrapper_CheckConnection(t *testing.T) { assert.Equal(t, 11, dbw.ConnectionLostCounter) } +func TestDBWrapper_SqlCommit(t *testing.T) { + mockDb := new(DbMock) + dbw := NewTestDBW(mockDb) + mockTx := new(TransactionMock) + + mockTx.On("Commit").Return(errors.New("whoops")).Once() + mockTx.On("Commit").Return( nil).Once() + mockDb.On("Ping").Return(errors.New("whoops")).Once() + + var err error + done := make(chan bool) + + dbw.CompareAndSetConnected(true) + go func() { + err = dbw.SqlCommit(mockTx, false) + done <- true + }() + + time.Sleep(time.Millisecond * 100) + + dbw.CompareAndSetConnected(true) + dbw.ConnectionUpCondition.Broadcast() + + <- done + + assert.Nil(t, err) + mockTx.AssertExpectations(t) + mockDb.AssertExpectations(t) +} + +func TestDBWrapper_SqlBegin(t *testing.T) { + mockDb := new(DbMock) + dbw := NewTestDBW(mockDb) + + mockDb.On("BeginTx", context.Background(), &sql.TxOptions{Isolation: sql.LevelReadCommitted}).Return(&sql.Tx{}, errors.New("whoops")).Once() + mockDb.On("BeginTx", context.Background(), &sql.TxOptions{Isolation: sql.LevelReadCommitted}).Return(&sql.Tx{}, nil).Once() + mockDb.On("Ping").Return(errors.New("whoops")).Once() + + var err error + done := make(chan bool) + + dbw.CompareAndSetConnected(true) + go func() { + _, err = dbw.SqlBegin(false, false) + done <- true + }() + + time.Sleep(time.Millisecond * 100) + + dbw.CompareAndSetConnected(true) + dbw.ConnectionUpCondition.Broadcast() + + <- done + + assert.Nil(t, err) + mockDb.AssertExpectations(t) +} + func TestDBWrapper_WithRetry(t *testing.T) { mockDb := new(DbMock) dbw := NewTestDBW(mockDb) @@ -113,6 +172,12 @@ func TestDBWrapper_WithRetry(t *testing.T) { assert.Nil(t, err) assert.Equal(t, 2, tries) + + _, err = dbw.WithRetry(func() (result sql.Result, e error) { + return nil, errors.New("something went wrong") + }) + + assert.Error(t, err) } func TestDBWrapper_SqlQuery(t *testing.T) {