Tests: connect to temporary MySQLd

This commit is contained in:
Alexander A. Klimov 2019-09-24 16:21:37 +02:00 committed by Noah Hilverling
parent d9717d5bbb
commit fa9941d0c9
2 changed files with 350 additions and 3 deletions

View file

@ -1,14 +1,16 @@
package connection
import (
"bytes"
"context"
"database/sql"
"errors"
"fmt"
"git.icinga.com/icingadb/icingadb-main/connection/mysqld"
"github.com/go-sql-driver/mysql"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"os"
"sync"
"sync/atomic"
"testing"
@ -161,7 +163,22 @@ func TestDBWrapper_SqlBegin(t *testing.T) {
}
func TestDBWrapper_SqlTransaction(t *testing.T) {
dbw, err := NewDBWrapper(os.Getenv("ICINGADB_TEST_MYSQL"))
var server mysqld.Server
host, errSt := server.Start()
if errSt != nil {
t.Fatal(errSt)
return
}
defer server.Stop()
if errMTD := mkTestDb(host); errMTD != nil {
t.Fatal(errMTD)
return
}
dbw, err := NewDBWrapper(fmt.Sprintf("icingadb:icingadb@%s/icingadb", host))
require.NoError(t, err, "Is the MySQL server running?")
err = dbw.SqlTransaction(false, true, false, func(tx DbTransaction) error {
@ -295,7 +312,22 @@ func TestGetConnectionCheckInterval(t *testing.T) {
}
func TestDBWrapper_SqlFetchAll(t *testing.T) {
dbw, err := NewDBWrapper(os.Getenv("ICINGADB_TEST_MYSQL"))
var server mysqld.Server
host, errSt := server.Start()
if errSt != nil {
t.Fatal(errSt)
return
}
defer server.Stop()
if errMTD := mkTestDb(host); errMTD != nil {
t.Fatal(errMTD)
return
}
dbw, err := NewDBWrapper(fmt.Sprintf("icingadb:icingadb@%s/icingadb", host))
require.NoError(t, err, "Is the MySQL server running?")
_, err = dbw.Db.Exec("CREATE TABLE testing0815 (id INT NOT NULL AUTO_INCREMENT PRIMARY KEY, name varchar(255) NOT NULL)")
@ -324,3 +356,57 @@ func TestDBWrapper_SqlFetchAll(t *testing.T) {
_, err = dbw.Db.Exec("DROP TABLE testing0815")
assert.NoError(t, err)
}
var cComment = regexp.MustCompile(`/\*.*?\*/`)
func mkTestDb(host string) error {
noDb, errNoDb := sql.Open("mysql", fmt.Sprintf("root@%s/", host))
if errNoDb != nil {
return errNoDb
}
defer noDb.Close()
for _, ddl := range []string{
"CREATE DATABASE icingadb",
"GRANT ALL ON icingadb.* TO icingadb@localhost IDENTIFIED BY 'icingadb'",
} {
if _, errEx := noDb.Exec(ddl); errEx != nil {
return errEx
}
}
db, errDb := sql.Open("mysql", fmt.Sprintf("icingadb:icingadb@%s/icingadb", host))
if errDb != nil {
return errDb
}
defer db.Close()
_, thisFile, _, _ := runtime.Caller(0)
schema := path.Join(filepath.Dir(filepath.Dir(thisFile)), "etc/schema/mysql")
entries, errRD := ioutil.ReadDir(schema)
if errRD != nil {
return errRD
}
for _, entry := range entries {
if name := entry.Name(); strings.HasSuffix(name, ".sql") {
ddls, errRF := ioutil.ReadFile(path.Join(schema, name))
if errRF != nil {
return errRF
}
for _, ddl := range bytes.Split(ddls, []byte{';'}) {
if ddl = bytes.TrimSpace(cComment.ReplaceAll(ddl, nil)); len(ddl) > 0 {
if _, errEx := db.Exec(string(ddl)); errEx != nil {
return errEx
}
}
}
}
}
return nil
}

261
connection/mysqld/server.go Normal file
View file

@ -0,0 +1,261 @@
package mysqld
import (
"bufio"
"database/sql"
"fmt"
log "github.com/sirupsen/logrus"
"io"
"io/ioutil"
"os"
"os/exec"
"os/user"
"path"
"regexp"
"strconv"
"syscall"
"time"
)
// logLineRgx matches e.g. "2019-09-23 13:39:55 0 [Note] InnoDB: Using Linux native AIO".
var logLineRgx = regexp.MustCompile(`\A(\S+ \S+) (\d+) \[(\w+)] (.+)\z`)
// logLevels maps the MySQLd log levels (as in log messages) to logrus log levels.
var logLevels = map[string]log.Level{
"Note": log.InfoLevel,
"Warning": log.WarnLevel,
"ERROR": log.ErrorLevel,
}
// Server represents a managed MySQL server.
type Server struct {
// basedir is a directory containing the MySQL server context.
basedir string
// cmd represents the main MySQLd process if any.
cmd *exec.Cmd
// stopped is closed as soon as the main MySQLd process is stopped.
stopped chan struct{}
// errorLogEof is closed on the main MySQLd process' error log EOF.
errorLogEof chan struct{}
// logPipeCloser is the result of opening the log pipe for writing to ensure our reader's termination.
logPipeCloser io.Closer
}
// Start starts *s and returns the host to connect to.
func (s *Server) Start() (string, error) {
me, errUC := user.Current()
if errUC != nil {
return "", errUC
}
{
var errTD error
s.basedir, errTD = ioutil.TempDir("", "")
if errTD != nil {
return "", errTD
}
}
log.WithFields(log.Fields{"basedir": s.basedir}).Info("starting MySQL server")
socket := path.Join(s.basedir, "socket")
host := fmt.Sprintf("unix(%s)", socket)
db, errOpen := sql.Open("mysql", fmt.Sprintf("root@%s/", host))
if errOpen != nil {
os.RemoveAll(s.basedir)
s.basedir = ""
return "", errOpen
}
defer db.Close()
dataDir := path.Join(s.basedir, "data")
if errMkdir := os.Mkdir(dataDir, 0700); errMkdir != nil {
os.RemoveAll(s.basedir)
s.basedir = ""
return "", errMkdir
}
logPipe := path.Join(s.basedir, "log")
if errMkfifo := syscall.Mkfifo(logPipe, 0700); errMkfifo != nil {
os.RemoveAll(s.basedir)
s.basedir = ""
return "", errMkfifo
}
params := []string{
"--no-defaults",
"--user=" + me.Username,
"--pid-file=" + path.Join(s.basedir, "pid"),
"--socket=" + socket,
"--basedir=/usr",
"--datadir=" + dataDir,
"--tmpdir=/tmp",
"--lc-messages-dir=/usr/share/mysql",
"--skip-networking",
"--query_cache_size=16M",
"--expire_logs_days=10",
"--character-set-server=utf8mb4",
"--collation-server=utf8mb4_general_ci",
}
{
cmd := exec.Command("mysql_install_db", append(params, "--log_error=/dev/null")...)
cmd.Dir = s.basedir
if errRun := cmd.Run(); errRun != nil {
os.RemoveAll(s.basedir)
s.basedir = ""
return "", errRun
}
}
s.errorLogEof = make(chan struct{})
go s.file2log(logPipe)
logPipeWriter, errCr := os.Create(logPipe)
if errCr != nil {
os.RemoveAll(s.basedir)
s.basedir = ""
return "", errCr
}
cmd := exec.Command("mysqld", append(params, "--log_error="+logPipe)...)
cmd.Dir = s.basedir
stderr, errStderr := cmd.StderrPipe()
if errStderr != nil {
logPipeWriter.Close()
os.RemoveAll(s.basedir)
s.basedir = ""
return "", errStderr
}
if errStart := cmd.Start(); errStart != nil {
logPipeWriter.Close()
os.RemoveAll(s.basedir)
s.basedir = ""
return "", errStart
}
s.cmd = cmd
s.stopped = make(chan struct{})
s.logPipeCloser = logPipeWriter
go s.stderr2log(stderr)
log.WithFields(log.Fields{"basedir": s.basedir}).Debug("checking the MySQL server for actual serving")
for {
errPing := db.Ping()
if errPing == nil {
log.WithFields(log.Fields{"basedir": s.basedir}).Debug("MySQL server is actually serving now")
return host, nil
}
select {
case <-s.stopped:
return "", errPing
default:
log.WithFields(log.Fields{
"basedir": s.basedir,
"error": errPing,
}).Debug("MySQL server isn't actually serving, yet")
time.Sleep(time.Second)
}
}
}
// Stop stops *s.
func (s *Server) Stop() error {
log.Info("stopping MySQL server")
if errSignal := s.cmd.Process.Signal(syscall.SIGTERM); errSignal != nil {
return errSignal
}
<-s.stopped
return nil
}
// file2log forwards the MySQL server's log from path to logrus.
func (s *Server) file2log(path string) {
defer close(s.errorLogEof)
stream, errOpen := os.Open(path)
if errOpen != nil {
log.WithFields(log.Fields{"source": "log file", "error": errOpen}).Error(
"got unexpected error while forwarding MySQL server logs",
)
return
}
defer stream.Close()
stream2log(stream, "log file")
}
// stderr2log forwards the MySQL server's log from stderr to logrus and cleans up *s.
func (s *Server) stderr2log(stderr io.Reader) {
stream2log(stderr, "stderr")
if errWait := s.cmd.Wait(); errWait != nil {
log.WithFields(log.Fields{"error": errWait}).Error("MySQL server terminated with an error")
}
s.logPipeCloser.Close()
<-s.errorLogEof
os.RemoveAll(s.basedir)
s.basedir = ""
close(s.stopped)
}
// stream2log forwards the MySQL server's log from stream to logrus.
func stream2log(stream io.Reader, source string) {
buffer := bufio.NewReader(stream)
for {
line, errRead := buffer.ReadBytes('\n')
if errRead != nil {
if errRead != io.EOF || len(line) > 0 {
log.WithFields(log.Fields{"source": source, "error": errRead}).Error(
"got unexpected error while forwarding MySQL server logs",
)
}
break
}
line = line[:len(line)-1]
if len(line) > 0 {
if submatch := logLineRgx.FindSubmatch(line); submatch != nil {
timeStamp, errTime := time.ParseInLocation("2006-01-02 15:04:05", string(submatch[1]), time.Local)
if errTime != nil {
timeStamp = time.Now()
}
thread, errPU := strconv.ParseUint(string(submatch[2]), 10, 64)
if errPU != nil {
thread = 0
}
log.WithTime(timeStamp).WithFields(log.Fields{
"component": "mysqld",
"source": source,
"thread": thread,
}).Log(logLevels[string(submatch[3])], string(submatch[4]))
}
}
}
}