mirror of
https://github.com/Icinga/icingadb.git
synced 2026-06-09 08:56:54 -04:00
Remove library code
This commit is contained in:
parent
7c068d4adf
commit
be4b450f5c
48 changed files with 5 additions and 4830 deletions
10
go.mod
10
go.mod
|
|
@ -4,23 +4,18 @@ go 1.22
|
|||
|
||||
require (
|
||||
github.com/creasty/defaults v1.7.0
|
||||
github.com/go-sql-driver/mysql v1.8.1
|
||||
github.com/goccy/go-yaml v1.11.3
|
||||
github.com/google/go-cmp v0.6.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/icinga/icinga-go-library v0.1.0
|
||||
github.com/jessevdk/go-flags v1.5.0
|
||||
github.com/jmoiron/sqlx v1.4.0
|
||||
github.com/lib/pq v1.10.9
|
||||
github.com/mattn/go-sqlite3 v1.14.22
|
||||
github.com/okzk/sdnotify v0.0.0-20180710141335-d9becc38acbd
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/redis/go-redis/v9 v9.5.1
|
||||
github.com/ssgreg/journald v1.0.0
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/vbauerster/mpb/v6 v6.0.4
|
||||
go.uber.org/zap v1.27.0
|
||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
|
||||
golang.org/x/sync v0.7.0
|
||||
)
|
||||
|
||||
|
|
@ -32,12 +27,17 @@ require (
|
|||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/fatih/color v1.16.0 // indirect
|
||||
github.com/go-sql-driver/mysql v1.8.1 // indirect
|
||||
github.com/lib/pq v1.10.9 // indirect
|
||||
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.12 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/redis/go-redis/v9 v9.5.1 // indirect
|
||||
github.com/rivo/uniseg v0.2.0 // indirect
|
||||
github.com/ssgreg/journald v1.0.0 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect
|
||||
golang.org/x/sys v0.14.0 // indirect
|
||||
golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
|
|
|
|||
|
|
@ -1,43 +0,0 @@
|
|||
package backoff
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Backoff returns the backoff duration for a specific retry attempt.
|
||||
type Backoff func(uint64) time.Duration
|
||||
|
||||
// NewExponentialWithJitter returns a backoff implementation that
|
||||
// exponentially increases the backoff duration for each retry from min,
|
||||
// never exceeding max. Some randomization is added to the backoff duration.
|
||||
// It panics if min >= max.
|
||||
func NewExponentialWithJitter(min, max time.Duration) Backoff {
|
||||
if min <= 0 {
|
||||
min = 100 * time.Millisecond
|
||||
}
|
||||
if max <= 0 {
|
||||
max = 10 * time.Second
|
||||
}
|
||||
if min >= max {
|
||||
panic("max must be larger than min")
|
||||
}
|
||||
|
||||
return func(attempt uint64) time.Duration {
|
||||
e := min << attempt
|
||||
if e <= 0 || e > max {
|
||||
e = max
|
||||
}
|
||||
|
||||
return time.Duration(jitter(int64(e)))
|
||||
}
|
||||
}
|
||||
|
||||
// jitter returns a random integer distributed in the range [n/2..n).
|
||||
func jitter(n int64) int64 {
|
||||
if n == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
return n/2 + rand.Int63n(n/2)
|
||||
}
|
||||
|
|
@ -1,38 +0,0 @@
|
|||
package com
|
||||
|
||||
import "sync/atomic"
|
||||
|
||||
// Atomic is a type-safe wrapper around atomic.Value.
|
||||
type Atomic[T any] struct {
|
||||
v atomic.Value
|
||||
}
|
||||
|
||||
func (a *Atomic[T]) Load() (_ T, ok bool) {
|
||||
if v, ok := a.v.Load().(box[T]); ok {
|
||||
return v.v, true
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Atomic[T]) Store(v T) {
|
||||
a.v.Store(box[T]{v})
|
||||
}
|
||||
|
||||
func (a *Atomic[T]) Swap(new T) (old T, ok bool) {
|
||||
if old, ok := a.v.Swap(box[T]{new}).(box[T]); ok {
|
||||
return old.v, true
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Atomic[T]) CompareAndSwap(old, new T) (swapped bool) {
|
||||
return a.v.CompareAndSwap(box[T]{old}, box[T]{new})
|
||||
}
|
||||
|
||||
// box allows, for the case T is an interface, nil values and values of different specific types implementing T
|
||||
// to be stored in Atomic[T]#v (bypassing atomic.Value#Store()'s policy) by wrapping it (into a non-interface).
|
||||
type box[T any] struct {
|
||||
v T
|
||||
}
|
||||
|
|
@ -1,166 +0,0 @@
|
|||
package com
|
||||
|
||||
import (
|
||||
"context"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// BulkChunkSplitPolicy is a state machine which tracks the items of a chunk a bulker assembles.
|
||||
// A call takes an item for the current chunk into account.
|
||||
// Output true indicates that the state machine was reset first and the bulker
|
||||
// shall finish the current chunk now (not e.g. once $size is reached) without the given item.
|
||||
type BulkChunkSplitPolicy[T any] func(T) bool
|
||||
|
||||
type BulkChunkSplitPolicyFactory[T any] func() BulkChunkSplitPolicy[T]
|
||||
|
||||
// NeverSplit returns a pseudo state machine which never demands splitting.
|
||||
func NeverSplit[T any]() BulkChunkSplitPolicy[T] {
|
||||
return neverSplit[T]
|
||||
}
|
||||
|
||||
func neverSplit[T any](T) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Bulker reads all values from a channel and streams them in chunks into a Bulk channel.
|
||||
type Bulker[T any] struct {
|
||||
ch chan []T
|
||||
ctx context.Context
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewBulker returns a new Bulker and starts streaming.
|
||||
func NewBulker[T any](
|
||||
ctx context.Context, ch <-chan T, count int, splitPolicyFactory BulkChunkSplitPolicyFactory[T],
|
||||
) *Bulker[T] {
|
||||
b := &Bulker[T]{
|
||||
ch: make(chan []T),
|
||||
ctx: ctx,
|
||||
mu: sync.Mutex{},
|
||||
}
|
||||
|
||||
go b.run(ch, count, splitPolicyFactory)
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// Bulk returns the channel on which the bulks are delivered.
|
||||
func (b *Bulker[T]) Bulk() <-chan []T {
|
||||
return b.ch
|
||||
}
|
||||
|
||||
func (b *Bulker[T]) run(ch <-chan T, count int, splitPolicyFactory BulkChunkSplitPolicyFactory[T]) {
|
||||
defer close(b.ch)
|
||||
|
||||
bufCh := make(chan T, count)
|
||||
splitPolicy := splitPolicyFactory()
|
||||
g, ctx := errgroup.WithContext(b.ctx)
|
||||
|
||||
g.Go(func() error {
|
||||
defer close(bufCh)
|
||||
|
||||
for {
|
||||
select {
|
||||
case v, ok := <-ch:
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
bufCh <- v
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
g.Go(func() error {
|
||||
for done := false; !done; {
|
||||
buf := make([]T, 0, count)
|
||||
timeout := time.After(256 * time.Millisecond)
|
||||
|
||||
for drain := true; drain && len(buf) < count; {
|
||||
select {
|
||||
case v, ok := <-bufCh:
|
||||
if !ok {
|
||||
drain = false
|
||||
done = true
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
if splitPolicy(v) {
|
||||
if len(buf) > 0 {
|
||||
b.ch <- buf
|
||||
buf = make([]T, 0, count)
|
||||
}
|
||||
|
||||
timeout = time.After(256 * time.Millisecond)
|
||||
}
|
||||
|
||||
buf = append(buf, v)
|
||||
case <-timeout:
|
||||
drain = false
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
if len(buf) > 0 {
|
||||
b.ch <- buf
|
||||
}
|
||||
|
||||
splitPolicy = splitPolicyFactory()
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
// We don't expect an error here.
|
||||
// We only use errgroup for the encapsulated use of sync.WaitGroup.
|
||||
_ = g.Wait()
|
||||
}
|
||||
|
||||
// Bulk reads all values from a channel and streams them in chunks into a returned channel.
|
||||
func Bulk[T any](
|
||||
ctx context.Context, ch <-chan T, count int, splitPolicyFactory BulkChunkSplitPolicyFactory[T],
|
||||
) <-chan []T {
|
||||
if count <= 1 {
|
||||
return oneBulk(ctx, ch)
|
||||
}
|
||||
|
||||
return NewBulker(ctx, ch, count, splitPolicyFactory).Bulk()
|
||||
}
|
||||
|
||||
// oneBulk operates just as NewBulker(ctx, ch, 1, splitPolicy).Bulk(),
|
||||
// but without the overhead of the actual bulk creation with a buffer channel, timeout and BulkChunkSplitPolicy.
|
||||
func oneBulk[T any](ctx context.Context, ch <-chan T) <-chan []T {
|
||||
out := make(chan []T)
|
||||
go func() {
|
||||
defer close(out)
|
||||
|
||||
for {
|
||||
select {
|
||||
case item, ok := <-ch:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case out <- []T{item}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
var (
|
||||
_ BulkChunkSplitPolicyFactory[struct{}] = NeverSplit[struct{}]
|
||||
)
|
||||
101
pkg/com/com.go
101
pkg/com/com.go
|
|
@ -1,101 +0,0 @@
|
|||
package com
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
// Waiter implements the Wait method,
|
||||
// which blocks until execution is complete.
|
||||
type Waiter interface {
|
||||
Wait() error // Wait waits for execution to complete.
|
||||
}
|
||||
|
||||
// The WaiterFunc type is an adapter to allow the use of ordinary functions as Waiter.
|
||||
// If f is a function with the appropriate signature, WaiterFunc(f) is a Waiter that calls f.
|
||||
type WaiterFunc func() error
|
||||
|
||||
// Wait implements the Waiter interface.
|
||||
func (f WaiterFunc) Wait() error {
|
||||
return f()
|
||||
}
|
||||
|
||||
// WaitAsync calls Wait() on the passed Waiter in a new goroutine and
|
||||
// sends the first non-nil error (if any) to the returned channel.
|
||||
// The returned channel is always closed when the Waiter is done.
|
||||
func WaitAsync(w Waiter) <-chan error {
|
||||
errs := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
defer close(errs)
|
||||
|
||||
if e := w.Wait(); e != nil {
|
||||
errs <- e
|
||||
}
|
||||
}()
|
||||
|
||||
return errs
|
||||
}
|
||||
|
||||
// ErrgroupReceive adds a goroutine to the specified group that
|
||||
// returns the first non-nil error (if any) from the specified channel.
|
||||
// If the channel is closed, it will return nil.
|
||||
func ErrgroupReceive(g *errgroup.Group, err <-chan error) {
|
||||
g.Go(func() error {
|
||||
if e := <-err; e != nil {
|
||||
return e
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// CopyFirst asynchronously forwards all items from input to forward and synchronously returns the first item.
|
||||
func CopyFirst[T any](
|
||||
ctx context.Context, input <-chan T,
|
||||
) (first T, forward <-chan T, err error) {
|
||||
var ok bool
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
var zero T
|
||||
|
||||
return zero, nil, ctx.Err()
|
||||
case first, ok = <-input:
|
||||
}
|
||||
|
||||
if !ok {
|
||||
err = errors.New("can't copy from closed channel")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Buffer of one because we receive an entity and send it back immediately.
|
||||
fwd := make(chan T, 1)
|
||||
fwd <- first
|
||||
|
||||
forward = fwd
|
||||
|
||||
go func() {
|
||||
defer close(fwd)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case e, ok := <-input:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case fwd <- e:
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return
|
||||
}
|
||||
|
|
@ -1,90 +0,0 @@
|
|||
package com
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Cond implements a channel-based synchronization for goroutines that wait for signals or send them.
|
||||
// Internally based on a controller loop that handles the synchronization of new listeners and signal propagation,
|
||||
// which is only started when NewCond is called. Thus the zero value cannot be used.
|
||||
type Cond struct {
|
||||
broadcast chan struct{}
|
||||
done chan struct{}
|
||||
cancel context.CancelFunc
|
||||
listeners chan chan struct{}
|
||||
}
|
||||
|
||||
// NewCond returns a new Cond and starts the controller loop.
|
||||
func NewCond(ctx context.Context) *Cond {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
c := &Cond{
|
||||
broadcast: make(chan struct{}),
|
||||
cancel: cancel,
|
||||
done: make(chan struct{}),
|
||||
listeners: make(chan chan struct{}),
|
||||
}
|
||||
|
||||
go c.controller(ctx)
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// Broadcast sends a signal to all current listeners by closing the previously returned channel from Wait.
|
||||
// Panics if the controller loop has already ended.
|
||||
func (c *Cond) Broadcast() {
|
||||
select {
|
||||
case c.broadcast <- struct{}{}:
|
||||
case <-c.done:
|
||||
panic(errors.New("condition closed"))
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops the controller loop, waits for it to finish, and returns an error if any.
|
||||
// Implements the io.Closer interface.
|
||||
func (c *Cond) Close() error {
|
||||
c.cancel()
|
||||
<-c.done
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Done returns a channel that will be closed when the controller loop has ended.
|
||||
func (c *Cond) Done() <-chan struct{} {
|
||||
return c.done
|
||||
}
|
||||
|
||||
// Wait returns a channel that is closed with the next signal.
|
||||
// Panics if the controller loop has already ended.
|
||||
func (c *Cond) Wait() <-chan struct{} {
|
||||
select {
|
||||
case l := <-c.listeners:
|
||||
return l
|
||||
case <-c.done:
|
||||
panic(errors.New("condition closed"))
|
||||
}
|
||||
}
|
||||
|
||||
// controller loop.
|
||||
func (c *Cond) controller(ctx context.Context) {
|
||||
defer close(c.done)
|
||||
|
||||
// Note that the notify channel does not close when the controller loop ends
|
||||
// in order not to notify pending listeners.
|
||||
notify := make(chan struct{})
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.broadcast:
|
||||
// Close channel to notify all current listeners.
|
||||
close(notify)
|
||||
// Create a new channel for the next listeners.
|
||||
notify = make(chan struct{})
|
||||
case c.listeners <- notify:
|
||||
// A new listener received the channel.
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,48 +0,0 @@
|
|||
package com
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// Counter implements an atomic counter.
|
||||
type Counter struct {
|
||||
value uint64
|
||||
mu sync.Mutex // Protects total.
|
||||
total uint64
|
||||
}
|
||||
|
||||
// Add adds the given delta to the counter.
|
||||
func (c *Counter) Add(delta uint64) {
|
||||
atomic.AddUint64(&c.value, delta)
|
||||
}
|
||||
|
||||
// Inc increments the counter by one.
|
||||
func (c *Counter) Inc() {
|
||||
c.Add(1)
|
||||
}
|
||||
|
||||
// Reset resets the counter to 0 and returns its previous value.
|
||||
// Does not reset the total value returned from Total.
|
||||
func (c *Counter) Reset() uint64 {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
v := atomic.SwapUint64(&c.value, 0)
|
||||
c.total += v
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
// Total returns the total counter value.
|
||||
func (c *Counter) Total() uint64 {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
return c.total + c.Val()
|
||||
}
|
||||
|
||||
// Val returns the current counter value.
|
||||
func (c *Counter) Val() uint64 {
|
||||
return atomic.LoadUint64(&c.value)
|
||||
}
|
||||
|
|
@ -1,80 +0,0 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
stderrors "errors"
|
||||
"fmt"
|
||||
"github.com/creasty/defaults"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/jessevdk/go-flags"
|
||||
"github.com/pkg/errors"
|
||||
"os"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// ErrInvalidArgument is the error returned by [ParseFlags] or [FromYAMLFile] if
|
||||
// its parsing result cannot be stored in the value pointed to by the designated passed argument which
|
||||
// must be a non-nil pointer.
|
||||
var ErrInvalidArgument = stderrors.New("invalid argument")
|
||||
|
||||
// FromYAMLFile parses the given YAML file and stores the result
|
||||
// in the value pointed to by v. If v is nil or not a pointer,
|
||||
// FromYAMLFile returns an [ErrInvalidArgument] error.
|
||||
func FromYAMLFile(name string, v Validator) error {
|
||||
rv := reflect.ValueOf(v)
|
||||
if rv.Kind() != reflect.Pointer || rv.IsNil() {
|
||||
return errors.Wrapf(ErrInvalidArgument, "non-nil pointer expected, got %T", v)
|
||||
}
|
||||
|
||||
f, err := os.Open(name)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "can't open YAML file "+name)
|
||||
}
|
||||
defer func(f *os.File) {
|
||||
_ = f.Close()
|
||||
}(f)
|
||||
|
||||
if err := defaults.Set(v); err != nil {
|
||||
return errors.Wrap(err, "can't set config defaults")
|
||||
}
|
||||
|
||||
d := yaml.NewDecoder(f, yaml.DisallowUnknownField())
|
||||
if err := d.Decode(v); err != nil {
|
||||
return errors.Wrap(err, "can't parse YAML file "+name)
|
||||
}
|
||||
|
||||
if err := v.Validate(); err != nil {
|
||||
return errors.Wrap(err, "invalid configuration")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ParseFlags parses CLI flags and stores the result
|
||||
// in the value pointed to by v. If v is nil or not a pointer,
|
||||
// ParseFlags returns an [ErrInvalidArgument] error.
|
||||
// ParseFlags adds a default Help Options group,
|
||||
// which contains the options -h and --help.
|
||||
// If either option is specified on the command line,
|
||||
// ParseFlags prints the help message to [os.Stdout] and exits.
|
||||
// Note that errors are not printed automatically,
|
||||
// so error handling is the sole responsibility of the caller.
|
||||
func ParseFlags(v any) error {
|
||||
rv := reflect.ValueOf(v)
|
||||
if rv.Kind() != reflect.Pointer || rv.IsNil() {
|
||||
return errors.Wrapf(ErrInvalidArgument, "non-nil pointer expected, got %T", v)
|
||||
}
|
||||
|
||||
parser := flags.NewParser(v, flags.Default^flags.PrintErrors)
|
||||
|
||||
if _, err := parser.Parse(); err != nil {
|
||||
var flagErr *flags.Error
|
||||
if errors.As(err, &flagErr) && flagErr.Type == flags.ErrHelp {
|
||||
fmt.Fprintln(os.Stdout, flagErr)
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
return errors.Wrap(err, "can't parse CLI flags")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
package config
|
||||
|
||||
type Validator interface {
|
||||
Validate() error
|
||||
}
|
||||
|
|
@ -1,58 +0,0 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"github.com/pkg/errors"
|
||||
"os"
|
||||
)
|
||||
|
||||
// TLS provides TLS configuration options.
|
||||
type TLS struct {
|
||||
Enable bool `yaml:"tls"`
|
||||
Cert string `yaml:"cert"`
|
||||
Key string `yaml:"key"`
|
||||
Ca string `yaml:"ca"`
|
||||
Insecure bool `yaml:"insecure"`
|
||||
}
|
||||
|
||||
// MakeConfig assembles a tls.Config from t and serverName.
|
||||
func (t *TLS) MakeConfig(serverName string) (*tls.Config, error) {
|
||||
if !t.Enable {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{}
|
||||
if t.Cert == "" {
|
||||
if t.Key != "" {
|
||||
return nil, errors.New("private key given, but client certificate missing")
|
||||
}
|
||||
} else if t.Key == "" {
|
||||
return nil, errors.New("client certificate given, but private key missing")
|
||||
} else {
|
||||
crt, err := tls.LoadX509KeyPair(t.Cert, t.Key)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "can't load X.509 key pair")
|
||||
}
|
||||
|
||||
tlsConfig.Certificates = []tls.Certificate{crt}
|
||||
}
|
||||
|
||||
if t.Insecure {
|
||||
tlsConfig.InsecureSkipVerify = true
|
||||
} else if t.Ca != "" {
|
||||
raw, err := os.ReadFile(t.Ca)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "can't read CA file")
|
||||
}
|
||||
|
||||
tlsConfig.RootCAs = x509.NewCertPool()
|
||||
if !tlsConfig.RootCAs.AppendCertsFromPEM(raw) {
|
||||
return nil, errors.New("can't parse CA file")
|
||||
}
|
||||
}
|
||||
|
||||
tlsConfig.ServerName = serverName
|
||||
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
|
@ -1,75 +0,0 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"github.com/jmoiron/sqlx/reflectx"
|
||||
"reflect"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// ColumnMap provides a cached mapping of structs exported fields to their database column names.
|
||||
type ColumnMap interface {
|
||||
// Columns returns database column names for a struct's exported fields in a cached manner.
|
||||
// Thus, the returned slice MUST NOT be modified directly.
|
||||
// By default, all exported struct fields are mapped to database column names using snake case notation.
|
||||
// The - (hyphen) directive for the db tag can be used to exclude certain fields.
|
||||
Columns(any) []string
|
||||
}
|
||||
|
||||
// NewColumnMap returns a new ColumnMap.
|
||||
func NewColumnMap(mapper *reflectx.Mapper) ColumnMap {
|
||||
return &columnMap{
|
||||
cache: make(map[reflect.Type][]string),
|
||||
mapper: mapper,
|
||||
}
|
||||
}
|
||||
|
||||
type columnMap struct {
|
||||
mutex sync.Mutex
|
||||
cache map[reflect.Type][]string
|
||||
mapper *reflectx.Mapper
|
||||
}
|
||||
|
||||
func (m *columnMap) Columns(subject any) []string {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
t, ok := subject.(reflect.Type)
|
||||
if !ok {
|
||||
t = reflect.TypeOf(subject)
|
||||
}
|
||||
|
||||
columns, ok := m.cache[t]
|
||||
if !ok {
|
||||
columns = m.getColumns(t)
|
||||
m.cache[t] = columns
|
||||
}
|
||||
|
||||
return columns
|
||||
}
|
||||
|
||||
func (m *columnMap) getColumns(t reflect.Type) []string {
|
||||
fields := m.mapper.TypeMap(t).Names
|
||||
columns := make([]string, 0, len(fields))
|
||||
|
||||
FieldLoop:
|
||||
for _, f := range fields {
|
||||
// If one of the parent fields implements the driver.Valuer interface, the field can be ignored.
|
||||
for parent := f.Parent; parent != nil && parent.Zero.IsValid(); parent = parent.Parent {
|
||||
// Check for pointer types.
|
||||
if _, ok := reflect.New(parent.Field.Type).Interface().(driver.Valuer); ok {
|
||||
continue FieldLoop
|
||||
}
|
||||
// Check for non-pointer types.
|
||||
if _, ok := reflect.Zero(parent.Field.Type).Interface().(driver.Valuer); ok {
|
||||
continue FieldLoop
|
||||
}
|
||||
}
|
||||
|
||||
columns = append(columns, f.Path)
|
||||
}
|
||||
|
||||
// Shrink/reduce slice length and capacity:
|
||||
// For a three-index slice (slice[a:b:c]), the length of the returned slice is b-a and the capacity is c-a.
|
||||
return columns[0:len(columns):len(columns)]
|
||||
}
|
||||
|
|
@ -1,45 +0,0 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"github.com/icinga/icinga-go-library/config"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Config defines database client configuration.
|
||||
type Config struct {
|
||||
Type string `yaml:"type" default:"mysql"`
|
||||
Host string `yaml:"host"`
|
||||
Port int `yaml:"port"`
|
||||
Database string `yaml:"database"`
|
||||
User string `yaml:"user"`
|
||||
Password string `yaml:"password"`
|
||||
TlsOptions config.TLS `yaml:",inline"`
|
||||
Options Options `yaml:"options"`
|
||||
}
|
||||
|
||||
// Validate checks constraints in the supplied database configuration and returns an error if they are violated.
|
||||
func (c *Config) Validate() error {
|
||||
switch c.Type {
|
||||
case "mysql", "pgsql":
|
||||
default:
|
||||
return unknownDbType(c.Type)
|
||||
}
|
||||
|
||||
if c.Host == "" {
|
||||
return errors.New("database host missing")
|
||||
}
|
||||
|
||||
if c.User == "" {
|
||||
return errors.New("database user missing")
|
||||
}
|
||||
|
||||
if c.Database == "" {
|
||||
return errors.New("database name missing")
|
||||
}
|
||||
|
||||
return c.Options.Validate()
|
||||
}
|
||||
|
||||
func unknownDbType(t string) error {
|
||||
return errors.Errorf(`unknown database type %q, must be one of: "mysql", "pgsql"`, t)
|
||||
}
|
||||
|
|
@ -1,56 +0,0 @@
|
|||
package database
|
||||
|
||||
// Entity is implemented by each type that works with the database package.
|
||||
type Entity interface {
|
||||
Fingerprinter
|
||||
IDer
|
||||
}
|
||||
|
||||
// Fingerprinter is implemented by every entity that uniquely identifies itself.
|
||||
type Fingerprinter interface {
|
||||
// Fingerprint returns the value that uniquely identifies the entity.
|
||||
Fingerprint() Fingerprinter
|
||||
}
|
||||
|
||||
// ID is a unique identifier of an entity.
|
||||
type ID interface {
|
||||
// String returns the string representation form of the ID.
|
||||
// The String method is used to use the ID in functions
|
||||
// where it needs to be compared or hashed.
|
||||
String() string
|
||||
}
|
||||
|
||||
// IDer is implemented by every entity that uniquely identifies itself.
|
||||
type IDer interface {
|
||||
ID() ID // ID returns the ID.
|
||||
SetID(ID) // SetID sets the ID.
|
||||
}
|
||||
|
||||
// EntityFactoryFunc knows how to create an Entity.
|
||||
type EntityFactoryFunc func() Entity
|
||||
|
||||
// Upserter implements the Upsert method,
|
||||
// which returns a part of the object for ON DUPLICATE KEY UPDATE.
|
||||
type Upserter interface {
|
||||
Upsert() any // Upsert partitions the object.
|
||||
}
|
||||
|
||||
// TableNamer implements the TableName method,
|
||||
// which returns the table of the object.
|
||||
type TableNamer interface {
|
||||
TableName() string // TableName tells the table.
|
||||
}
|
||||
|
||||
// Scoper implements the Scope method,
|
||||
// which returns a struct specifying the WHERE conditions that
|
||||
// entities must satisfy in order to be SELECTed.
|
||||
type Scoper interface {
|
||||
Scope() any
|
||||
}
|
||||
|
||||
// PgsqlOnConflictConstrainter implements the PgsqlOnConflictConstraint method,
|
||||
// which returns the primary or unique key constraint name of the PostgreSQL table.
|
||||
type PgsqlOnConflictConstrainter interface {
|
||||
// PgsqlOnConflictConstraint returns the primary or unique key constraint name of the PostgreSQL table.
|
||||
PgsqlOnConflictConstraint() string
|
||||
}
|
||||
|
|
@ -1,800 +0,0 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"github.com/go-sql-driver/mysql"
|
||||
"github.com/icinga/icinga-go-library/backoff"
|
||||
"github.com/icinga/icinga-go-library/com"
|
||||
"github.com/icinga/icinga-go-library/logging"
|
||||
"github.com/icinga/icinga-go-library/periodic"
|
||||
"github.com/icinga/icinga-go-library/retry"
|
||||
"github.com/icinga/icinga-go-library/strcase"
|
||||
"github.com/icinga/icinga-go-library/utils"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/jmoiron/sqlx/reflectx"
|
||||
"github.com/lib/pq"
|
||||
"github.com/pkg/errors"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"golang.org/x/sync/semaphore"
|
||||
"net"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DB is a wrapper around sqlx.DB with bulk execution,
|
||||
// statement building, streaming and logging capabilities.
|
||||
type DB struct {
|
||||
*sqlx.DB
|
||||
|
||||
Options *Options
|
||||
|
||||
addr string
|
||||
columnMap ColumnMap
|
||||
logger *logging.Logger
|
||||
tableSemaphores map[string]*semaphore.Weighted
|
||||
tableSemaphoresMu sync.Mutex
|
||||
}
|
||||
|
||||
// Options define user configurable database options.
|
||||
type Options struct {
|
||||
// Maximum number of open connections to the database.
|
||||
MaxConnections int `yaml:"max_connections" default:"16"`
|
||||
|
||||
// Maximum number of connections per table,
|
||||
// regardless of what the connection is actually doing,
|
||||
// e.g. INSERT, UPDATE, DELETE.
|
||||
MaxConnectionsPerTable int `yaml:"max_connections_per_table" default:"8"`
|
||||
|
||||
// MaxPlaceholdersPerStatement defines the maximum number of placeholders in an
|
||||
// INSERT, UPDATE or DELETE statement. Theoretically, MySQL can handle up to 2^16-1 placeholders,
|
||||
// but this increases the execution time of queries and thus reduces the number of queries
|
||||
// that can be executed in parallel in a given time.
|
||||
// The default is 2^13, which in our tests showed the best performance in terms of execution time and parallelism.
|
||||
MaxPlaceholdersPerStatement int `yaml:"max_placeholders_per_statement" default:"8192"`
|
||||
|
||||
// MaxRowsPerTransaction defines the maximum number of rows per transaction.
|
||||
// The default is 2^13, which in our tests showed the best performance in terms of execution time and parallelism.
|
||||
MaxRowsPerTransaction int `yaml:"max_rows_per_transaction" default:"8192"`
|
||||
|
||||
// WsrepSyncWait enforces Galera cluster nodes to perform strict cluster-wide causality checks
|
||||
// before executing specific SQL queries determined by the number you provided.
|
||||
// Please refer to the below link for a detailed description.
|
||||
// https://icinga.com/docs/icinga-db/latest/doc/03-Configuration/#galera-cluster
|
||||
WsrepSyncWait int `yaml:"wsrep_sync_wait" default:"7"`
|
||||
}
|
||||
|
||||
// Validate checks constraints in the supplied database options and returns an error if they are violated.
|
||||
func (o *Options) Validate() error {
|
||||
if o.MaxConnections == 0 {
|
||||
return errors.New("max_connections cannot be 0. Configure a value greater than zero, or use -1 for no connection limit")
|
||||
}
|
||||
if o.MaxConnectionsPerTable < 1 {
|
||||
return errors.New("max_connections_per_table must be at least 1")
|
||||
}
|
||||
if o.MaxPlaceholdersPerStatement < 1 {
|
||||
return errors.New("max_placeholders_per_statement must be at least 1")
|
||||
}
|
||||
if o.MaxRowsPerTransaction < 1 {
|
||||
return errors.New("max_rows_per_transaction must be at least 1")
|
||||
}
|
||||
if o.WsrepSyncWait < 0 || o.WsrepSyncWait > 15 {
|
||||
return errors.New("wsrep_sync_wait can only be set to a number between 0 and 15")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewDbFromConfig returns a new DB from Config.
|
||||
func NewDbFromConfig(c *Config, logger *logging.Logger, connectorCallbacks RetryConnectorCallbacks) (*DB, error) {
|
||||
var addr string
|
||||
var db *sqlx.DB
|
||||
|
||||
switch c.Type {
|
||||
case "mysql":
|
||||
config := mysql.NewConfig()
|
||||
|
||||
config.User = c.User
|
||||
config.Passwd = c.Password
|
||||
config.Logger = MysqlFuncLogger(logger.Debug)
|
||||
|
||||
if utils.IsUnixAddr(c.Host) {
|
||||
config.Net = "unix"
|
||||
config.Addr = c.Host
|
||||
} else {
|
||||
config.Net = "tcp"
|
||||
port := c.Port
|
||||
if port == 0 {
|
||||
port = 3306
|
||||
}
|
||||
config.Addr = net.JoinHostPort(c.Host, fmt.Sprint(port))
|
||||
}
|
||||
|
||||
config.DBName = c.Database
|
||||
config.Timeout = time.Minute
|
||||
config.Params = map[string]string{"sql_mode": "'TRADITIONAL,ANSI_QUOTES'"}
|
||||
|
||||
tlsConfig, err := c.TlsOptions.MakeConfig(c.Host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
config.TLS = tlsConfig
|
||||
|
||||
connector, err := mysql.NewConnector(config)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "can't open mysql database")
|
||||
}
|
||||
|
||||
onInitConn := connectorCallbacks.OnInitConn
|
||||
connectorCallbacks.OnInitConn = func(ctx context.Context, conn driver.Conn) error {
|
||||
if onInitConn != nil {
|
||||
if err := onInitConn(ctx, conn); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return setGaleraOpts(ctx, conn, int64(c.Options.WsrepSyncWait))
|
||||
}
|
||||
|
||||
addr = config.Addr
|
||||
db = sqlx.NewDb(sql.OpenDB(NewConnector(connector, logger, connectorCallbacks)), MySQL)
|
||||
case "pgsql":
|
||||
uri := &url.URL{
|
||||
Scheme: "postgres",
|
||||
User: url.UserPassword(c.User, c.Password),
|
||||
Path: "/" + url.PathEscape(c.Database),
|
||||
}
|
||||
|
||||
query := url.Values{
|
||||
"connect_timeout": {"60"},
|
||||
"binary_parameters": {"yes"},
|
||||
|
||||
// Host and port can alternatively be specified in the query string. lib/pq can't parse the connection URI
|
||||
// if a Unix domain socket path is specified in the host part of the URI, therefore always use the query
|
||||
// string. See also https://github.com/lib/pq/issues/796
|
||||
"host": {c.Host},
|
||||
}
|
||||
|
||||
port := c.Port
|
||||
if port == 0 {
|
||||
port = 5432
|
||||
}
|
||||
query["port"] = []string{strconv.FormatInt(int64(port), 10)}
|
||||
|
||||
if _, err := c.TlsOptions.MakeConfig(c.Host); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if c.TlsOptions.Enable {
|
||||
if c.TlsOptions.Insecure {
|
||||
query["sslmode"] = []string{"require"}
|
||||
} else {
|
||||
query["sslmode"] = []string{"verify-full"}
|
||||
}
|
||||
|
||||
if c.TlsOptions.Cert != "" {
|
||||
query["sslcert"] = []string{c.TlsOptions.Cert}
|
||||
}
|
||||
|
||||
if c.TlsOptions.Key != "" {
|
||||
query["sslkey"] = []string{c.TlsOptions.Key}
|
||||
}
|
||||
|
||||
if c.TlsOptions.Ca != "" {
|
||||
query["sslrootcert"] = []string{c.TlsOptions.Ca}
|
||||
}
|
||||
} else {
|
||||
query["sslmode"] = []string{"disable"}
|
||||
}
|
||||
|
||||
uri.RawQuery = query.Encode()
|
||||
|
||||
connector, err := pq.NewConnector(uri.String())
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "can't open pgsql database")
|
||||
}
|
||||
|
||||
addr = utils.JoinHostPort(c.Host, port)
|
||||
db = sqlx.NewDb(sql.OpenDB(NewConnector(connector, logger, connectorCallbacks)), PostgreSQL)
|
||||
default:
|
||||
return nil, unknownDbType(c.Type)
|
||||
}
|
||||
|
||||
db.SetMaxIdleConns(c.Options.MaxConnections / 3)
|
||||
db.SetMaxOpenConns(c.Options.MaxConnections)
|
||||
|
||||
db.Mapper = reflectx.NewMapperFunc("db", strcase.Snake)
|
||||
|
||||
return &DB{
|
||||
DB: db,
|
||||
Options: &c.Options,
|
||||
columnMap: NewColumnMap(db.Mapper),
|
||||
addr: addr,
|
||||
logger: logger,
|
||||
tableSemaphores: make(map[string]*semaphore.Weighted),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetAddr returns the database host:port or Unix socket address.
|
||||
func (db *DB) GetAddr() string {
|
||||
return db.addr
|
||||
}
|
||||
|
||||
// BuildDeleteStmt returns a DELETE statement for the given struct.
|
||||
func (db *DB) BuildDeleteStmt(from interface{}) string {
|
||||
return fmt.Sprintf(
|
||||
`DELETE FROM "%s" WHERE id IN (?)`,
|
||||
TableName(from),
|
||||
)
|
||||
}
|
||||
|
||||
// BuildInsertStmt returns an INSERT INTO statement for the given struct.
|
||||
func (db *DB) BuildInsertStmt(into interface{}) (string, int) {
|
||||
columns := db.columnMap.Columns(into)
|
||||
|
||||
return fmt.Sprintf(
|
||||
`INSERT INTO "%s" ("%s") VALUES (%s)`,
|
||||
TableName(into),
|
||||
strings.Join(columns, `", "`),
|
||||
fmt.Sprintf(":%s", strings.Join(columns, ", :")),
|
||||
), len(columns)
|
||||
}
|
||||
|
||||
// BuildInsertIgnoreStmt returns an INSERT statement for the specified struct for
|
||||
// which the database ignores rows that have already been inserted.
|
||||
func (db *DB) BuildInsertIgnoreStmt(into interface{}) (string, int) {
|
||||
table := TableName(into)
|
||||
columns := db.columnMap.Columns(into)
|
||||
var clause string
|
||||
|
||||
switch db.DriverName() {
|
||||
case MySQL:
|
||||
// MySQL treats UPDATE id = id as a no-op.
|
||||
clause = fmt.Sprintf(`ON DUPLICATE KEY UPDATE "%s" = "%s"`, columns[0], columns[0])
|
||||
case PostgreSQL:
|
||||
var constraint string
|
||||
if constrainter, ok := into.(PgsqlOnConflictConstrainter); ok {
|
||||
constraint = constrainter.PgsqlOnConflictConstraint()
|
||||
} else {
|
||||
constraint = "pk_" + table
|
||||
}
|
||||
|
||||
clause = fmt.Sprintf("ON CONFLICT ON CONSTRAINT %s DO NOTHING", constraint)
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
`INSERT INTO "%s" ("%s") VALUES (%s) %s`,
|
||||
table,
|
||||
strings.Join(columns, `", "`),
|
||||
fmt.Sprintf(":%s", strings.Join(columns, ", :")),
|
||||
clause,
|
||||
), len(columns)
|
||||
}
|
||||
|
||||
// BuildSelectStmt returns a SELECT query that creates the FROM part from the given table struct
|
||||
// and the column list from the specified columns struct.
|
||||
func (db *DB) BuildSelectStmt(table interface{}, columns interface{}) string {
|
||||
q := fmt.Sprintf(
|
||||
`SELECT "%s" FROM "%s"`,
|
||||
strings.Join(db.columnMap.Columns(columns), `", "`),
|
||||
TableName(table),
|
||||
)
|
||||
|
||||
if scoper, ok := table.(Scoper); ok {
|
||||
where, _ := db.BuildWhere(scoper.Scope())
|
||||
q += ` WHERE ` + where
|
||||
}
|
||||
|
||||
return q
|
||||
}
|
||||
|
||||
// BuildUpdateStmt returns an UPDATE statement for the given struct.
|
||||
func (db *DB) BuildUpdateStmt(update interface{}) (string, int) {
|
||||
columns := db.columnMap.Columns(update)
|
||||
set := make([]string, 0, len(columns))
|
||||
|
||||
for _, col := range columns {
|
||||
set = append(set, fmt.Sprintf(`"%s" = :%s`, col, col))
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
`UPDATE "%s" SET %s WHERE id = :id`,
|
||||
TableName(update),
|
||||
strings.Join(set, ", "),
|
||||
), len(columns) + 1 // +1 because of WHERE id = :id
|
||||
}
|
||||
|
||||
// BuildUpsertStmt returns an upsert statement for the given struct.
|
||||
func (db *DB) BuildUpsertStmt(subject interface{}) (stmt string, placeholders int) {
|
||||
insertColumns := db.columnMap.Columns(subject)
|
||||
table := TableName(subject)
|
||||
var updateColumns []string
|
||||
|
||||
if upserter, ok := subject.(Upserter); ok {
|
||||
updateColumns = db.columnMap.Columns(upserter.Upsert())
|
||||
} else {
|
||||
updateColumns = insertColumns
|
||||
}
|
||||
|
||||
var clause, setFormat string
|
||||
switch db.DriverName() {
|
||||
case MySQL:
|
||||
clause = "ON DUPLICATE KEY UPDATE"
|
||||
setFormat = `"%[1]s" = VALUES("%[1]s")`
|
||||
case PostgreSQL:
|
||||
var constraint string
|
||||
if constrainter, ok := subject.(PgsqlOnConflictConstrainter); ok {
|
||||
constraint = constrainter.PgsqlOnConflictConstraint()
|
||||
} else {
|
||||
constraint = "pk_" + table
|
||||
}
|
||||
|
||||
clause = fmt.Sprintf("ON CONFLICT ON CONSTRAINT %s DO UPDATE SET", constraint)
|
||||
setFormat = `"%[1]s" = EXCLUDED."%[1]s"`
|
||||
}
|
||||
|
||||
set := make([]string, 0, len(updateColumns))
|
||||
|
||||
for _, col := range updateColumns {
|
||||
set = append(set, fmt.Sprintf(setFormat, col))
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
`INSERT INTO "%s" ("%s") VALUES (%s) %s %s`,
|
||||
table,
|
||||
strings.Join(insertColumns, `", "`),
|
||||
fmt.Sprintf(":%s", strings.Join(insertColumns, ",:")),
|
||||
clause,
|
||||
strings.Join(set, ","),
|
||||
), len(insertColumns)
|
||||
}
|
||||
|
||||
// BuildWhere returns a WHERE clause with named placeholder conditions built from the specified struct
|
||||
// combined with the AND operator.
|
||||
func (db *DB) BuildWhere(subject interface{}) (string, int) {
|
||||
columns := db.columnMap.Columns(subject)
|
||||
where := make([]string, 0, len(columns))
|
||||
for _, col := range columns {
|
||||
where = append(where, fmt.Sprintf(`"%s" = :%s`, col, col))
|
||||
}
|
||||
|
||||
return strings.Join(where, ` AND `), len(columns)
|
||||
}
|
||||
|
||||
// OnSuccess is a callback for successful (bulk) DML operations.
|
||||
type OnSuccess[T any] func(ctx context.Context, affectedRows []T) (err error)
|
||||
|
||||
func OnSuccessIncrement[T any](counter *com.Counter) OnSuccess[T] {
|
||||
return func(_ context.Context, rows []T) error {
|
||||
counter.Add(uint64(len(rows)))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func OnSuccessSendTo[T any](ch chan<- T) OnSuccess[T] {
|
||||
return func(ctx context.Context, rows []T) error {
|
||||
for _, row := range rows {
|
||||
select {
|
||||
case ch <- row:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// BulkExec bulk executes queries with a single slice placeholder in the form of `IN (?)`.
|
||||
// Takes in up to the number of arguments specified in count from the arg stream,
|
||||
// derives and expands a query and executes it with this set of arguments until the arg stream has been processed.
|
||||
// The derived queries are executed in a separate goroutine with a weighting of 1
|
||||
// and can be executed concurrently to the extent allowed by the semaphore passed in sem.
|
||||
// Arguments for which the query ran successfully will be passed to onSuccess.
|
||||
func (db *DB) BulkExec(
|
||||
ctx context.Context, query string, count int, sem *semaphore.Weighted, arg <-chan any, onSuccess ...OnSuccess[any],
|
||||
) error {
|
||||
var counter com.Counter
|
||||
defer db.Log(ctx, query, &counter).Stop()
|
||||
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
// Use context from group.
|
||||
bulk := com.Bulk(ctx, arg, count, com.NeverSplit[any])
|
||||
|
||||
g.Go(func() error {
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
|
||||
for b := range bulk {
|
||||
if err := sem.Acquire(ctx, 1); err != nil {
|
||||
return errors.Wrap(err, "can't acquire semaphore")
|
||||
}
|
||||
|
||||
g.Go(func(b []interface{}) func() error {
|
||||
return func() error {
|
||||
defer sem.Release(1)
|
||||
|
||||
return retry.WithBackoff(
|
||||
ctx,
|
||||
func(context.Context) error {
|
||||
stmt, args, err := sqlx.In(query, b)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "can't build placeholders for %q", query)
|
||||
}
|
||||
|
||||
stmt = db.Rebind(stmt)
|
||||
_, err = db.ExecContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return CantPerformQuery(err, query)
|
||||
}
|
||||
|
||||
counter.Add(uint64(len(b)))
|
||||
|
||||
for _, onSuccess := range onSuccess {
|
||||
if err := onSuccess(ctx, b); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
retry.Retryable,
|
||||
backoff.NewExponentialWithJitter(1*time.Millisecond, 1*time.Second),
|
||||
db.GetDefaultRetrySettings(),
|
||||
)
|
||||
}
|
||||
}(b))
|
||||
}
|
||||
|
||||
return g.Wait()
|
||||
})
|
||||
|
||||
return g.Wait()
|
||||
}
|
||||
|
||||
// NamedBulkExec bulk executes queries with named placeholders in a VALUES clause most likely
|
||||
// in the format INSERT ... VALUES. Takes in up to the number of entities specified in count
|
||||
// from the arg stream, derives and executes a new query with the VALUES clause expanded to
|
||||
// this set of arguments, until the arg stream has been processed.
|
||||
// The queries are executed in a separate goroutine with a weighting of 1
|
||||
// and can be executed concurrently to the extent allowed by the semaphore passed in sem.
|
||||
// Entities for which the query ran successfully will be passed to onSuccess.
|
||||
func (db *DB) NamedBulkExec(
|
||||
ctx context.Context, query string, count int, sem *semaphore.Weighted, arg <-chan Entity,
|
||||
splitPolicyFactory com.BulkChunkSplitPolicyFactory[Entity], onSuccess ...OnSuccess[Entity],
|
||||
) error {
|
||||
var counter com.Counter
|
||||
defer db.Log(ctx, query, &counter).Stop()
|
||||
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
bulk := com.Bulk(ctx, arg, count, splitPolicyFactory)
|
||||
|
||||
g.Go(func() error {
|
||||
for {
|
||||
select {
|
||||
case b, ok := <-bulk:
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := sem.Acquire(ctx, 1); err != nil {
|
||||
return errors.Wrap(err, "can't acquire semaphore")
|
||||
}
|
||||
|
||||
g.Go(func(b []Entity) func() error {
|
||||
return func() error {
|
||||
defer sem.Release(1)
|
||||
|
||||
return retry.WithBackoff(
|
||||
ctx,
|
||||
func(ctx context.Context) error {
|
||||
_, err := db.NamedExecContext(ctx, query, b)
|
||||
if err != nil {
|
||||
return CantPerformQuery(err, query)
|
||||
}
|
||||
|
||||
counter.Add(uint64(len(b)))
|
||||
|
||||
for _, onSuccess := range onSuccess {
|
||||
if err := onSuccess(ctx, b); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
retry.Retryable,
|
||||
backoff.NewExponentialWithJitter(1*time.Millisecond, 1*time.Second),
|
||||
db.GetDefaultRetrySettings(),
|
||||
)
|
||||
}
|
||||
}(b))
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return g.Wait()
|
||||
}
|
||||
|
||||
// NamedBulkExecTx bulk executes queries with named placeholders in separate transactions.
|
||||
// Takes in up to the number of entities specified in count from the arg stream and
|
||||
// executes a new transaction that runs a new query for each entity in this set of arguments,
|
||||
// until the arg stream has been processed.
|
||||
// The transactions are executed in a separate goroutine with a weighting of 1
|
||||
// and can be executed concurrently to the extent allowed by the semaphore passed in sem.
|
||||
func (db *DB) NamedBulkExecTx(
|
||||
ctx context.Context, query string, count int, sem *semaphore.Weighted, arg <-chan Entity,
|
||||
) error {
|
||||
var counter com.Counter
|
||||
defer db.Log(ctx, query, &counter).Stop()
|
||||
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
bulk := com.Bulk(ctx, arg, count, com.NeverSplit[Entity])
|
||||
|
||||
g.Go(func() error {
|
||||
for {
|
||||
select {
|
||||
case b, ok := <-bulk:
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := sem.Acquire(ctx, 1); err != nil {
|
||||
return errors.Wrap(err, "can't acquire semaphore")
|
||||
}
|
||||
|
||||
g.Go(func(b []Entity) func() error {
|
||||
return func() error {
|
||||
defer sem.Release(1)
|
||||
|
||||
return retry.WithBackoff(
|
||||
ctx,
|
||||
func(ctx context.Context) error {
|
||||
tx, err := db.BeginTxx(ctx, nil)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "can't start transaction")
|
||||
}
|
||||
|
||||
stmt, err := tx.PrepareNamedContext(ctx, query)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "can't prepare named statement with context in transaction")
|
||||
}
|
||||
|
||||
for _, arg := range b {
|
||||
if _, err := stmt.ExecContext(ctx, arg); err != nil {
|
||||
return errors.Wrap(err, "can't execute statement in transaction")
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return errors.Wrap(err, "can't commit transaction")
|
||||
}
|
||||
|
||||
counter.Add(uint64(len(b)))
|
||||
|
||||
return nil
|
||||
},
|
||||
retry.Retryable,
|
||||
backoff.NewExponentialWithJitter(1*time.Millisecond, 1*time.Second),
|
||||
db.GetDefaultRetrySettings(),
|
||||
)
|
||||
}
|
||||
}(b))
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return g.Wait()
|
||||
}
|
||||
|
||||
// BatchSizeByPlaceholders returns how often the specified number of placeholders fits
|
||||
// into Options.MaxPlaceholdersPerStatement, but at least 1.
|
||||
func (db *DB) BatchSizeByPlaceholders(n int) int {
|
||||
s := db.Options.MaxPlaceholdersPerStatement / n
|
||||
if s > 0 {
|
||||
return s
|
||||
}
|
||||
|
||||
return 1
|
||||
}
|
||||
|
||||
// YieldAll executes the query with the supplied scope,
|
||||
// scans each resulting row into an entity returned by the factory function,
|
||||
// and streams them into a returned channel.
|
||||
func (db *DB) YieldAll(ctx context.Context, factoryFunc EntityFactoryFunc, query string, scope interface{}) (<-chan Entity, <-chan error) {
|
||||
entities := make(chan Entity, 1)
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
|
||||
g.Go(func() error {
|
||||
var counter com.Counter
|
||||
defer db.Log(ctx, query, &counter).Stop()
|
||||
defer close(entities)
|
||||
|
||||
rows, err := db.NamedQueryContext(ctx, query, scope)
|
||||
if err != nil {
|
||||
return CantPerformQuery(err, query)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
e := factoryFunc()
|
||||
|
||||
if err := rows.StructScan(e); err != nil {
|
||||
return errors.Wrapf(err, "can't store query result into a %T: %s", e, query)
|
||||
}
|
||||
|
||||
select {
|
||||
case entities <- e:
|
||||
counter.Inc()
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
return entities, com.WaitAsync(g)
|
||||
}
|
||||
|
||||
// CreateStreamed bulk creates the specified entities via NamedBulkExec.
|
||||
// The insert statement is created using BuildInsertStmt with the first entity from the entities stream.
|
||||
// Bulk size is controlled via Options.MaxPlaceholdersPerStatement and
|
||||
// concurrency is controlled via Options.MaxConnectionsPerTable.
|
||||
// Entities for which the query ran successfully will be passed to onSuccess.
|
||||
func (db *DB) CreateStreamed(
|
||||
ctx context.Context, entities <-chan Entity, onSuccess ...OnSuccess[Entity],
|
||||
) error {
|
||||
first, forward, err := com.CopyFirst(ctx, entities)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "can't copy first entity")
|
||||
}
|
||||
|
||||
sem := db.GetSemaphoreForTable(TableName(first))
|
||||
stmt, placeholders := db.BuildInsertStmt(first)
|
||||
|
||||
return db.NamedBulkExec(
|
||||
ctx, stmt, db.BatchSizeByPlaceholders(placeholders), sem,
|
||||
forward, com.NeverSplit[Entity], onSuccess...,
|
||||
)
|
||||
}
|
||||
|
||||
// CreateIgnoreStreamed bulk creates the specified entities via NamedBulkExec.
|
||||
// The insert statement is created using BuildInsertIgnoreStmt with the first entity from the entities stream.
|
||||
// Bulk size is controlled via Options.MaxPlaceholdersPerStatement and
|
||||
// concurrency is controlled via Options.MaxConnectionsPerTable.
|
||||
// Entities for which the query ran successfully will be passed to onSuccess.
|
||||
func (db *DB) CreateIgnoreStreamed(
|
||||
ctx context.Context, entities <-chan Entity, onSuccess ...OnSuccess[Entity],
|
||||
) error {
|
||||
first, forward, err := com.CopyFirst(ctx, entities)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "can't copy first entity")
|
||||
}
|
||||
|
||||
sem := db.GetSemaphoreForTable(TableName(first))
|
||||
stmt, placeholders := db.BuildInsertIgnoreStmt(first)
|
||||
|
||||
return db.NamedBulkExec(
|
||||
ctx, stmt, db.BatchSizeByPlaceholders(placeholders), sem,
|
||||
forward, SplitOnDupId[Entity], onSuccess...,
|
||||
)
|
||||
}
|
||||
|
||||
// UpsertStreamed bulk upserts the specified entities via NamedBulkExec.
|
||||
// The upsert statement is created using BuildUpsertStmt with the first entity from the entities stream.
|
||||
// Bulk size is controlled via Options.MaxPlaceholdersPerStatement and
|
||||
// concurrency is controlled via Options.MaxConnectionsPerTable.
|
||||
// Entities for which the query ran successfully will be passed to onSuccess.
|
||||
func (db *DB) UpsertStreamed(
|
||||
ctx context.Context, entities <-chan Entity, onSuccess ...OnSuccess[Entity],
|
||||
) error {
|
||||
first, forward, err := com.CopyFirst(ctx, entities)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "can't copy first entity")
|
||||
}
|
||||
|
||||
sem := db.GetSemaphoreForTable(TableName(first))
|
||||
stmt, placeholders := db.BuildUpsertStmt(first)
|
||||
|
||||
return db.NamedBulkExec(
|
||||
ctx, stmt, db.BatchSizeByPlaceholders(placeholders), sem,
|
||||
forward, SplitOnDupId[Entity], onSuccess...,
|
||||
)
|
||||
}
|
||||
|
||||
// UpdateStreamed bulk updates the specified entities via NamedBulkExecTx.
|
||||
// The update statement is created using BuildUpdateStmt with the first entity from the entities stream.
|
||||
// Bulk size is controlled via Options.MaxRowsPerTransaction and
|
||||
// concurrency is controlled via Options.MaxConnectionsPerTable.
|
||||
func (db *DB) UpdateStreamed(ctx context.Context, entities <-chan Entity) error {
|
||||
first, forward, err := com.CopyFirst(ctx, entities)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "can't copy first entity")
|
||||
}
|
||||
sem := db.GetSemaphoreForTable(TableName(first))
|
||||
stmt, _ := db.BuildUpdateStmt(first)
|
||||
|
||||
return db.NamedBulkExecTx(ctx, stmt, db.Options.MaxRowsPerTransaction, sem, forward)
|
||||
}
|
||||
|
||||
// DeleteStreamed bulk deletes the specified ids via BulkExec.
|
||||
// The delete statement is created using BuildDeleteStmt with the passed entityType.
|
||||
// Bulk size is controlled via Options.MaxPlaceholdersPerStatement and
|
||||
// concurrency is controlled via Options.MaxConnectionsPerTable.
|
||||
// IDs for which the query ran successfully will be passed to onSuccess.
|
||||
func (db *DB) DeleteStreamed(
|
||||
ctx context.Context, entityType Entity, ids <-chan interface{}, onSuccess ...OnSuccess[any],
|
||||
) error {
|
||||
sem := db.GetSemaphoreForTable(TableName(entityType))
|
||||
return db.BulkExec(
|
||||
ctx, db.BuildDeleteStmt(entityType), db.Options.MaxPlaceholdersPerStatement, sem, ids, onSuccess...,
|
||||
)
|
||||
}
|
||||
|
||||
// Delete creates a channel from the specified ids and
|
||||
// bulk deletes them by passing the channel along with the entityType to DeleteStreamed.
|
||||
// IDs for which the query ran successfully will be passed to onSuccess.
|
||||
func (db *DB) Delete(
|
||||
ctx context.Context, entityType Entity, ids []interface{}, onSuccess ...OnSuccess[any],
|
||||
) error {
|
||||
idsCh := make(chan interface{}, len(ids))
|
||||
for _, id := range ids {
|
||||
idsCh <- id
|
||||
}
|
||||
close(idsCh)
|
||||
|
||||
return db.DeleteStreamed(ctx, entityType, idsCh, onSuccess...)
|
||||
}
|
||||
|
||||
func (db *DB) GetSemaphoreForTable(table string) *semaphore.Weighted {
|
||||
db.tableSemaphoresMu.Lock()
|
||||
defer db.tableSemaphoresMu.Unlock()
|
||||
|
||||
if sem, ok := db.tableSemaphores[table]; ok {
|
||||
return sem
|
||||
} else {
|
||||
sem = semaphore.NewWeighted(int64(db.Options.MaxConnectionsPerTable))
|
||||
db.tableSemaphores[table] = sem
|
||||
return sem
|
||||
}
|
||||
}
|
||||
|
||||
func (db *DB) GetDefaultRetrySettings() retry.Settings {
|
||||
return retry.Settings{
|
||||
Timeout: retry.DefaultTimeout,
|
||||
OnRetryableError: func(_ time.Duration, _ uint64, err, lastErr error) {
|
||||
if lastErr == nil || err.Error() != lastErr.Error() {
|
||||
db.logger.Warnw("Can't execute query. Retrying", zap.Error(err))
|
||||
}
|
||||
},
|
||||
OnSuccess: func(elapsed time.Duration, attempt uint64, lastErr error) {
|
||||
if attempt > 1 {
|
||||
db.logger.Infow("Query retried successfully after error",
|
||||
zap.Duration("after", elapsed),
|
||||
zap.Uint64("attempts", attempt),
|
||||
zap.NamedError("recovered_error", lastErr))
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (db *DB) Log(ctx context.Context, query string, counter *com.Counter) periodic.Stopper {
|
||||
return periodic.Start(ctx, db.logger.Interval(), func(tick periodic.Tick) {
|
||||
if count := counter.Reset(); count > 0 {
|
||||
db.logger.Debugf("Executed %q with %d rows", query, count)
|
||||
}
|
||||
}, periodic.OnStop(func(tick periodic.Tick) {
|
||||
db.logger.Debugf("Finished executing %q with %d rows in %s", query, counter.Total(), tick.Elapsed)
|
||||
}))
|
||||
}
|
||||
|
|
@ -1,100 +0,0 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql/driver"
|
||||
"github.com/icinga/icinga-go-library/backoff"
|
||||
"github.com/icinga/icinga-go-library/logging"
|
||||
"github.com/icinga/icinga-go-library/retry"
|
||||
"github.com/pkg/errors"
|
||||
"go.uber.org/zap"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Driver names as automatically registered in the database/sql package by themselves.
|
||||
const (
|
||||
MySQL string = "mysql"
|
||||
PostgreSQL string = "postgres"
|
||||
)
|
||||
|
||||
// OnInitConnFunc can be used to execute post Connect() arbitrary actions.
|
||||
// It will be called after successfully initiated a new connection using the connector's Connect method.
|
||||
type OnInitConnFunc func(context.Context, driver.Conn) error
|
||||
|
||||
// RetryConnectorCallbacks specifies callbacks that are executed upon certain events.
|
||||
type RetryConnectorCallbacks struct {
|
||||
OnInitConn OnInitConnFunc
|
||||
OnRetryableError retry.OnRetryableErrorFunc
|
||||
OnSuccess retry.OnSuccessFunc
|
||||
}
|
||||
|
||||
// RetryConnector wraps driver.Connector with retry logic.
|
||||
type RetryConnector struct {
|
||||
driver.Connector
|
||||
|
||||
logger *logging.Logger
|
||||
|
||||
callbacks RetryConnectorCallbacks
|
||||
}
|
||||
|
||||
// NewConnector creates a fully initialized RetryConnector from the given args.
|
||||
func NewConnector(c driver.Connector, logger *logging.Logger, callbacks RetryConnectorCallbacks) *RetryConnector {
|
||||
return &RetryConnector{Connector: c, logger: logger, callbacks: callbacks}
|
||||
}
|
||||
|
||||
// Connect implements part of the driver.Connector interface.
|
||||
func (c RetryConnector) Connect(ctx context.Context) (driver.Conn, error) {
|
||||
var conn driver.Conn
|
||||
err := errors.Wrap(retry.WithBackoff(
|
||||
ctx,
|
||||
func(ctx context.Context) (err error) {
|
||||
conn, err = c.Connector.Connect(ctx)
|
||||
if err == nil && c.callbacks.OnInitConn != nil {
|
||||
if err = c.callbacks.OnInitConn(ctx, conn); err != nil {
|
||||
// We're going to retry this, so just don't bother whether Close() fails!
|
||||
_ = conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
},
|
||||
retry.Retryable,
|
||||
backoff.NewExponentialWithJitter(128*time.Millisecond, 1*time.Minute),
|
||||
retry.Settings{
|
||||
Timeout: retry.DefaultTimeout,
|
||||
OnRetryableError: func(elapsed time.Duration, attempt uint64, err, lastErr error) {
|
||||
if c.callbacks.OnRetryableError != nil {
|
||||
c.callbacks.OnRetryableError(elapsed, attempt, err, lastErr)
|
||||
}
|
||||
|
||||
if lastErr == nil || err.Error() != lastErr.Error() {
|
||||
c.logger.Warnw("Can't connect to database. Retrying", zap.Error(err))
|
||||
}
|
||||
},
|
||||
OnSuccess: func(elapsed time.Duration, attempt uint64, lastErr error) {
|
||||
if c.callbacks.OnSuccess != nil {
|
||||
c.callbacks.OnSuccess(elapsed, attempt, lastErr)
|
||||
}
|
||||
|
||||
if attempt > 1 {
|
||||
c.logger.Infow("Reconnected to database",
|
||||
zap.Duration("after", elapsed), zap.Uint64("attempts", attempt))
|
||||
}
|
||||
},
|
||||
},
|
||||
), "can't connect to database")
|
||||
return conn, err
|
||||
}
|
||||
|
||||
// Driver implements part of the driver.Connector interface.
|
||||
func (c RetryConnector) Driver() driver.Driver {
|
||||
return c.Connector.Driver()
|
||||
}
|
||||
|
||||
// MysqlFuncLogger is an adapter that allows ordinary functions to be used as a logger for mysql.SetLogger.
|
||||
type MysqlFuncLogger func(v ...interface{})
|
||||
|
||||
// Print implements the mysql.Logger interface.
|
||||
func (log MysqlFuncLogger) Print(v ...interface{}) {
|
||||
log(v)
|
||||
}
|
||||
|
|
@ -1,81 +0,0 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql/driver"
|
||||
"github.com/go-sql-driver/mysql"
|
||||
"github.com/icinga/icinga-go-library/com"
|
||||
"github.com/icinga/icinga-go-library/strcase"
|
||||
"github.com/icinga/icinga-go-library/types"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// CantPerformQuery wraps the given error with the specified query that cannot be executed.
|
||||
func CantPerformQuery(err error, q string) error {
|
||||
return errors.Wrapf(err, "can't perform %q", q)
|
||||
}
|
||||
|
||||
// TableName returns the table of t.
|
||||
func TableName(t interface{}) string {
|
||||
if tn, ok := t.(TableNamer); ok {
|
||||
return tn.TableName()
|
||||
} else {
|
||||
return strcase.Snake(types.Name(t))
|
||||
}
|
||||
}
|
||||
|
||||
// SplitOnDupId returns a state machine which tracks the inputs' IDs.
|
||||
// Once an already seen input arrives, it demands splitting.
|
||||
func SplitOnDupId[T IDer]() com.BulkChunkSplitPolicy[T] {
|
||||
seenIds := map[string]struct{}{}
|
||||
|
||||
return func(ider T) bool {
|
||||
id := ider.ID().String()
|
||||
|
||||
_, ok := seenIds[id]
|
||||
if ok {
|
||||
seenIds = map[string]struct{}{id: {}}
|
||||
} else {
|
||||
seenIds[id] = struct{}{}
|
||||
}
|
||||
|
||||
return ok
|
||||
}
|
||||
}
|
||||
|
||||
// setGaleraOpts sets the "wsrep_sync_wait" variable for each session ensures that causality checks are performed
|
||||
// before execution and that each statement is executed on a fully synchronized node. Doing so prevents foreign key
|
||||
// violation when inserting into dependent tables on different MariaDB/MySQL nodes. When using MySQL single nodes,
|
||||
// the "SET SESSION" command will fail with "Unknown system variable (1193)" and will therefore be silently dropped.
|
||||
//
|
||||
// https://mariadb.com/kb/en/galera-cluster-system-variables/#wsrep_sync_wait
|
||||
func setGaleraOpts(ctx context.Context, conn driver.Conn, wsrepSyncWait int64) error {
|
||||
const galeraOpts = "SET SESSION wsrep_sync_wait=?"
|
||||
|
||||
stmt, err := conn.(driver.ConnPrepareContext).PrepareContext(ctx, galeraOpts)
|
||||
if err != nil {
|
||||
if errors.Is(err, &mysql.MySQLError{Number: 1193}) { // Unknown system variable
|
||||
return nil
|
||||
}
|
||||
|
||||
return errors.Wrap(err, "cannot prepare "+galeraOpts)
|
||||
}
|
||||
// This is just for an unexpected exit and any returned error can safely be ignored and in case
|
||||
// of the normal function exit, the stmt is closed manually, and its error is handled gracefully.
|
||||
defer func() { _ = stmt.Close() }()
|
||||
|
||||
_, err = stmt.(driver.StmtExecContext).ExecContext(ctx, []driver.NamedValue{{Value: wsrepSyncWait}})
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "cannot execute "+galeraOpts)
|
||||
}
|
||||
|
||||
if err = stmt.Close(); err != nil {
|
||||
return errors.Wrap(err, "cannot close prepared statement "+galeraOpts)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
_ com.BulkChunkSplitPolicyFactory[Entity] = SplitOnDupId[Entity]
|
||||
)
|
||||
|
|
@ -1,46 +0,0 @@
|
|||
package flatten
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/icinga/icinga-go-library/types"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// Flatten creates flat, one-dimensional maps from arbitrarily nested values, e.g. JSON.
|
||||
func Flatten(value interface{}, prefix string) map[string]types.String {
|
||||
var flatten func(string, interface{})
|
||||
flattened := make(map[string]types.String)
|
||||
|
||||
flatten = func(key string, value interface{}) {
|
||||
switch value := value.(type) {
|
||||
case map[string]interface{}:
|
||||
if len(value) == 0 {
|
||||
flattened[key] = types.String{}
|
||||
break
|
||||
}
|
||||
|
||||
for k, v := range value {
|
||||
flatten(key+"."+k, v)
|
||||
}
|
||||
case []interface{}:
|
||||
if len(value) == 0 {
|
||||
flattened[key] = types.String{}
|
||||
break
|
||||
}
|
||||
|
||||
for i, v := range value {
|
||||
flatten(key+"["+strconv.Itoa(i)+"]", v)
|
||||
}
|
||||
case nil:
|
||||
flattened[key] = types.MakeString("null")
|
||||
case float64:
|
||||
flattened[key] = types.MakeString(strconv.FormatFloat(value, 'f', -1, 64))
|
||||
default:
|
||||
flattened[key] = types.MakeString(fmt.Sprintf("%v", value))
|
||||
}
|
||||
}
|
||||
|
||||
flatten(prefix, value)
|
||||
|
||||
return flattened
|
||||
}
|
||||
|
|
@ -1,45 +0,0 @@
|
|||
package flatten
|
||||
|
||||
import (
|
||||
"github.com/icinga/icinga-go-library/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFlatten(t *testing.T) {
|
||||
for _, st := range []struct {
|
||||
name string
|
||||
prefix string
|
||||
value any
|
||||
output map[string]types.String
|
||||
}{
|
||||
{"nil", "a", nil, map[string]types.String{"a": types.MakeString("null")}},
|
||||
{"bool", "b", true, map[string]types.String{"b": types.MakeString("true")}},
|
||||
{"int", "c", 42, map[string]types.String{"c": types.MakeString("42")}},
|
||||
{"float", "d", 77.7, map[string]types.String{"d": types.MakeString("77.7")}},
|
||||
{"large_float", "e", 1e23, map[string]types.String{"e": types.MakeString("100000000000000000000000")}},
|
||||
{"string", "f", "\x00", map[string]types.String{"f": types.MakeString("\x00")}},
|
||||
{"nil_slice", "g", []any(nil), map[string]types.String{"g": {}}},
|
||||
{"empty_slice", "h", []any{}, map[string]types.String{"h": {}}},
|
||||
{"slice", "i", []any{nil}, map[string]types.String{"i[0]": types.MakeString("null")}},
|
||||
{"nil_map", "j", map[string]any(nil), map[string]types.String{"j": {}}},
|
||||
{"empty_map", "k", map[string]any{}, map[string]types.String{"k": {}}},
|
||||
{"map", "l", map[string]any{" ": nil}, map[string]types.String{"l. ": types.MakeString("null")}},
|
||||
{"map_with_slice", "m", map[string]any{"\t": []any{"ä", "ö", "ü"}, "ß": "s"}, map[string]types.String{
|
||||
"m.\t[0]": types.MakeString("ä"),
|
||||
"m.\t[1]": types.MakeString("ö"),
|
||||
"m.\t[2]": types.MakeString("ü"),
|
||||
"m.ß": types.MakeString("s"),
|
||||
}},
|
||||
{"slice_with_map", "n", []any{map[string]any{"ä": "a", "ö": "o", "ü": "u"}, "ß"}, map[string]types.String{
|
||||
"n[0].ä": types.MakeString("a"),
|
||||
"n[0].ö": types.MakeString("o"),
|
||||
"n[0].ü": types.MakeString("u"),
|
||||
"n[1]": types.MakeString("ß"),
|
||||
}},
|
||||
} {
|
||||
t.Run(st.name, func(t *testing.T) {
|
||||
assert.Equal(t, st.output, Flatten(st.value, st.prefix))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -1,60 +0,0 @@
|
|||
package logging
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/pkg/errors"
|
||||
"go.uber.org/zap/zapcore"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Options define child loggers with their desired log level.
|
||||
type Options map[string]zapcore.Level
|
||||
|
||||
// Config defines Logger configuration.
|
||||
type Config struct {
|
||||
// zapcore.Level at 0 is for info level.
|
||||
Level zapcore.Level `yaml:"level" default:"0"`
|
||||
Output string `yaml:"output"`
|
||||
// Interval for periodic logging.
|
||||
Interval time.Duration `yaml:"interval" default:"20s"`
|
||||
|
||||
Options `yaml:"options"`
|
||||
}
|
||||
|
||||
// Validate checks constraints in the configuration and returns an error if they are violated.
|
||||
// Also configures the log output if it is not configured:
|
||||
// systemd-journald is used when Icinga DB is running under systemd, otherwise stderr.
|
||||
func (l *Config) Validate() error {
|
||||
if l.Interval <= 0 {
|
||||
return errors.New("periodic logging interval must be positive")
|
||||
}
|
||||
|
||||
if l.Output == "" {
|
||||
if _, ok := os.LookupEnv("NOTIFY_SOCKET"); ok {
|
||||
// When started by systemd, NOTIFY_SOCKET is set by systemd for Type=notify supervised services,
|
||||
// which is the default setting for the Icinga DB service.
|
||||
// This assumes that Icinga DB is running under systemd, so set output to systemd-journald.
|
||||
l.Output = JOURNAL
|
||||
} else {
|
||||
// Otherwise set it to console, i.e. write log messages to stderr.
|
||||
l.Output = CONSOLE
|
||||
}
|
||||
}
|
||||
|
||||
// To be on the safe side, always call AssertOutput.
|
||||
return AssertOutput(l.Output)
|
||||
}
|
||||
|
||||
// AssertOutput returns an error if output is not a valid logger output.
|
||||
func AssertOutput(o string) error {
|
||||
if o == CONSOLE || o == JOURNAL {
|
||||
return nil
|
||||
}
|
||||
|
||||
return invalidOutput(o)
|
||||
}
|
||||
|
||||
func invalidOutput(o string) error {
|
||||
return fmt.Errorf("%s is not a valid logger output. Must be either %q or %q", o, CONSOLE, JOURNAL)
|
||||
}
|
||||
|
|
@ -1,84 +0,0 @@
|
|||
package logging
|
||||
|
||||
import (
|
||||
"github.com/icinga/icinga-go-library/strcase"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/ssgreg/journald"
|
||||
"go.uber.org/zap/zapcore"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// priorities maps zapcore.Level to journal.Priority.
|
||||
var priorities = map[zapcore.Level]journald.Priority{
|
||||
zapcore.DebugLevel: journald.PriorityDebug,
|
||||
zapcore.InfoLevel: journald.PriorityInfo,
|
||||
zapcore.WarnLevel: journald.PriorityWarning,
|
||||
zapcore.ErrorLevel: journald.PriorityErr,
|
||||
zapcore.FatalLevel: journald.PriorityCrit,
|
||||
zapcore.PanicLevel: journald.PriorityCrit,
|
||||
zapcore.DPanicLevel: journald.PriorityCrit,
|
||||
}
|
||||
|
||||
// NewJournaldCore returns a zapcore.Core that sends log entries to systemd-journald and
|
||||
// uses the given identifier as a prefix for structured logging context that is sent as journal fields.
|
||||
func NewJournaldCore(identifier string, enab zapcore.LevelEnabler) zapcore.Core {
|
||||
return &journaldCore{
|
||||
LevelEnabler: enab,
|
||||
identifier: identifier,
|
||||
identifierU: strings.ToUpper(identifier),
|
||||
}
|
||||
}
|
||||
|
||||
type journaldCore struct {
|
||||
zapcore.LevelEnabler
|
||||
context []zapcore.Field
|
||||
identifier string
|
||||
identifierU string
|
||||
}
|
||||
|
||||
func (c *journaldCore) Check(ent zapcore.Entry, ce *zapcore.CheckedEntry) *zapcore.CheckedEntry {
|
||||
if c.Enabled(ent.Level) {
|
||||
return ce.AddCore(ent, c)
|
||||
}
|
||||
|
||||
return ce
|
||||
}
|
||||
|
||||
func (c *journaldCore) Sync() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *journaldCore) With(fields []zapcore.Field) zapcore.Core {
|
||||
cc := *c
|
||||
cc.context = append(cc.context[:len(cc.context):len(cc.context)], fields...)
|
||||
|
||||
return &cc
|
||||
}
|
||||
|
||||
func (c *journaldCore) Write(ent zapcore.Entry, fields []zapcore.Field) error {
|
||||
pri, ok := priorities[ent.Level]
|
||||
if !ok {
|
||||
return errors.Errorf("unknown log level %q", ent.Level)
|
||||
}
|
||||
|
||||
enc := zapcore.NewMapObjectEncoder()
|
||||
c.addFields(enc, fields)
|
||||
c.addFields(enc, c.context)
|
||||
enc.Fields["SYSLOG_IDENTIFIER"] = c.identifier
|
||||
|
||||
message := ent.Message
|
||||
if ent.LoggerName != c.identifier {
|
||||
message = ent.LoggerName + ": " + message
|
||||
}
|
||||
|
||||
return journald.Send(message, pri, enc.Fields)
|
||||
}
|
||||
|
||||
func (c *journaldCore) addFields(enc zapcore.ObjectEncoder, fields []zapcore.Field) {
|
||||
for _, field := range fields {
|
||||
field.Key = c.identifierU +
|
||||
"_" +
|
||||
strcase.ScreamingSnake(field.Key)
|
||||
field.AddTo(enc)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,26 +0,0 @@
|
|||
package logging
|
||||
|
||||
import (
|
||||
"go.uber.org/zap"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Logger wraps zap.SugaredLogger and
|
||||
// allows to get the interval for periodic logging.
|
||||
type Logger struct {
|
||||
*zap.SugaredLogger
|
||||
interval time.Duration
|
||||
}
|
||||
|
||||
// NewLogger returns a new Logger.
|
||||
func NewLogger(base *zap.SugaredLogger, interval time.Duration) *Logger {
|
||||
return &Logger{
|
||||
SugaredLogger: base,
|
||||
interval: interval,
|
||||
}
|
||||
}
|
||||
|
||||
// Interval returns the interval for periodic logging.
|
||||
func (l *Logger) Interval() time.Duration {
|
||||
return l.interval
|
||||
}
|
||||
|
|
@ -1,119 +0,0 @@
|
|||
package logging
|
||||
|
||||
import (
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
CONSOLE = "console"
|
||||
JOURNAL = "systemd-journald"
|
||||
)
|
||||
|
||||
// defaultEncConfig defines the default zapcore.EncoderConfig for the logging package.
|
||||
var defaultEncConfig = zapcore.EncoderConfig{
|
||||
TimeKey: "ts",
|
||||
LevelKey: "level",
|
||||
NameKey: "logger",
|
||||
CallerKey: "caller",
|
||||
MessageKey: "msg",
|
||||
StacktraceKey: "stacktrace",
|
||||
LineEnding: zapcore.DefaultLineEnding,
|
||||
EncodeLevel: zapcore.CapitalLevelEncoder,
|
||||
EncodeTime: zapcore.ISO8601TimeEncoder,
|
||||
EncodeDuration: zapcore.StringDurationEncoder,
|
||||
EncodeCaller: zapcore.ShortCallerEncoder,
|
||||
}
|
||||
|
||||
// Logging implements access to a default logger and named child loggers.
|
||||
// Log levels can be configured per named child via Options which, if not configured,
|
||||
// fall back on a default log level.
|
||||
// Logs either to the console or to systemd-journald.
|
||||
type Logging struct {
|
||||
logger *Logger
|
||||
output string
|
||||
verbosity zap.AtomicLevel
|
||||
interval time.Duration
|
||||
|
||||
// coreFactory creates zapcore.Core based on the log level and the log output.
|
||||
coreFactory func(zap.AtomicLevel) zapcore.Core
|
||||
|
||||
mu sync.Mutex
|
||||
loggers map[string]*Logger
|
||||
|
||||
options Options
|
||||
}
|
||||
|
||||
// NewLogging takes the name and log level for the default logger,
|
||||
// output where log messages are written to,
|
||||
// options having log levels for named child loggers
|
||||
// and returns a new Logging.
|
||||
func NewLogging(name string, level zapcore.Level, output string, options Options, interval time.Duration) (*Logging, error) {
|
||||
verbosity := zap.NewAtomicLevelAt(level)
|
||||
|
||||
var coreFactory func(zap.AtomicLevel) zapcore.Core
|
||||
switch output {
|
||||
case CONSOLE:
|
||||
enc := zapcore.NewConsoleEncoder(defaultEncConfig)
|
||||
ws := zapcore.Lock(os.Stderr)
|
||||
coreFactory = func(verbosity zap.AtomicLevel) zapcore.Core {
|
||||
return zapcore.NewCore(enc, ws, verbosity)
|
||||
}
|
||||
case JOURNAL:
|
||||
coreFactory = func(verbosity zap.AtomicLevel) zapcore.Core {
|
||||
return NewJournaldCore(name, verbosity)
|
||||
}
|
||||
default:
|
||||
return nil, invalidOutput(output)
|
||||
}
|
||||
|
||||
logger := NewLogger(zap.New(coreFactory(verbosity)).Named(name).Sugar(), interval)
|
||||
|
||||
return &Logging{
|
||||
logger: logger,
|
||||
output: output,
|
||||
verbosity: verbosity,
|
||||
interval: interval,
|
||||
coreFactory: coreFactory,
|
||||
loggers: make(map[string]*Logger),
|
||||
options: options,
|
||||
},
|
||||
nil
|
||||
}
|
||||
|
||||
// NewLoggingFromConfig returns a new Logging from Config.
|
||||
func NewLoggingFromConfig(name string, c Config) (*Logging, error) {
|
||||
return NewLogging(name, c.Level, c.Output, c.Options, c.Interval)
|
||||
}
|
||||
|
||||
// GetChildLogger returns a named child logger.
|
||||
// Log levels for named child loggers are obtained from the logging options and, if not found,
|
||||
// set to the default log level.
|
||||
func (l *Logging) GetChildLogger(name string) *Logger {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
if logger, ok := l.loggers[name]; ok {
|
||||
return logger
|
||||
}
|
||||
|
||||
var verbosity zap.AtomicLevel
|
||||
if level, found := l.options[name]; found {
|
||||
verbosity = zap.NewAtomicLevelAt(level)
|
||||
} else {
|
||||
verbosity = l.verbosity
|
||||
}
|
||||
|
||||
logger := NewLogger(zap.New(l.coreFactory(verbosity)).Named(name).Sugar(), l.interval)
|
||||
l.loggers[name] = logger
|
||||
|
||||
return logger
|
||||
}
|
||||
|
||||
// GetLogger returns the default logger.
|
||||
func (l *Logging) GetLogger() *Logger {
|
||||
return l.logger
|
||||
}
|
||||
|
|
@ -1,213 +0,0 @@
|
|||
package objectpacker
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"github.com/pkg/errors"
|
||||
"io"
|
||||
"reflect"
|
||||
"sort"
|
||||
)
|
||||
|
||||
// MustPackSlice calls PackAny using items and panics if there was an error.
|
||||
func MustPackSlice(items ...interface{}) []byte {
|
||||
var buf bytes.Buffer
|
||||
|
||||
if err := PackAny(items, &buf); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
// PackAny packs any JSON-encodable value (ex. structs, also ignores interfaces like encoding.TextMarshaler)
|
||||
// to a BSON-similar format suitable for consistent hashing. Spec:
|
||||
//
|
||||
// PackAny(nil) => 0x0
|
||||
// PackAny(false) => 0x1
|
||||
// PackAny(true) => 0x2
|
||||
// PackAny(float64(42)) => 0x3 ieee754_binary64_bigendian(42)
|
||||
// PackAny("exämple") => 0x4 uint64_bigendian(len([]byte("exämple"))) []byte("exämple")
|
||||
// PackAny([]uint8{0x42}) => 0x4 uint64_bigendian(len([]uint8{0x42})) []uint8{0x42}
|
||||
// PackAny([1]uint8{0x42}) => 0x4 uint64_bigendian(len([1]uint8{0x42})) [1]uint8{0x42}
|
||||
// PackAny([]T{x,y}) => 0x5 uint64_bigendian(len([]T{x,y})) PackAny(x) PackAny(y)
|
||||
// PackAny(map[K]V{x:y}) => 0x6 uint64_bigendian(len(map[K]V{x:y})) len(map_key(x)) map_key(x) PackAny(y)
|
||||
// PackAny((*T)(nil)) => 0x0
|
||||
// PackAny((*T)(0x42)) => PackAny(*(*T)(0x42))
|
||||
// PackAny(x) => panic()
|
||||
//
|
||||
// map_key([1]uint8{0x42}) => [1]uint8{0x42}
|
||||
// map_key(x) => []byte(fmt.Sprint(x))
|
||||
func PackAny(in interface{}, out io.Writer) error {
|
||||
return errors.Wrapf(packValue(reflect.ValueOf(in), out), "can't pack %#v", in)
|
||||
}
|
||||
|
||||
var tByte = reflect.TypeOf(byte(0))
|
||||
var tBytes = reflect.TypeOf([]uint8(nil))
|
||||
|
||||
// packValue does the actual job of packAny and just exists for recursion w/o unnecessary reflect.ValueOf calls.
|
||||
func packValue(in reflect.Value, out io.Writer) error {
|
||||
switch kind := in.Kind(); kind {
|
||||
case reflect.Invalid: // nil
|
||||
_, err := out.Write([]byte{0})
|
||||
return err
|
||||
case reflect.Bool:
|
||||
if in.Bool() {
|
||||
_, err := out.Write([]byte{2})
|
||||
return err
|
||||
} else {
|
||||
_, err := out.Write([]byte{1})
|
||||
return err
|
||||
}
|
||||
case reflect.Float64:
|
||||
if _, err := out.Write([]byte{3}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return binary.Write(out, binary.BigEndian, in.Float())
|
||||
case reflect.Array, reflect.Slice:
|
||||
if typ := in.Type(); typ.Elem() == tByte {
|
||||
if kind == reflect.Array {
|
||||
if !in.CanAddr() {
|
||||
vNewElem := reflect.New(typ).Elem()
|
||||
vNewElem.Set(in)
|
||||
in = vNewElem
|
||||
}
|
||||
|
||||
in = in.Slice(0, in.Len())
|
||||
}
|
||||
|
||||
// Pack []byte as string, not array of numbers.
|
||||
return packString(in.Convert(tBytes). // Support types.Binary
|
||||
Interface().([]uint8), out)
|
||||
}
|
||||
|
||||
if _, err := out.Write([]byte{5}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
l := in.Len()
|
||||
if err := binary.Write(out, binary.BigEndian, uint64(l)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i := 0; i < l; i++ {
|
||||
if err := packValue(in.Index(i), out); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// If there aren't any values to pack, ...
|
||||
if l < 1 {
|
||||
// ... create one and pack it - panics on disallowed type.
|
||||
_ = packValue(reflect.Zero(in.Type().Elem()), io.Discard)
|
||||
}
|
||||
|
||||
return nil
|
||||
case reflect.Interface:
|
||||
return packValue(in.Elem(), out)
|
||||
case reflect.Map:
|
||||
type kv struct {
|
||||
key []byte
|
||||
value reflect.Value
|
||||
}
|
||||
|
||||
if _, err := out.Write([]byte{6}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
l := in.Len()
|
||||
if err := binary.Write(out, binary.BigEndian, uint64(l)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sorted := make([]kv, 0, l)
|
||||
|
||||
{
|
||||
iter := in.MapRange()
|
||||
for iter.Next() {
|
||||
var packedKey []byte
|
||||
if key := iter.Key(); key.Kind() == reflect.Array {
|
||||
if typ := key.Type(); typ.Elem() == tByte {
|
||||
if !key.CanAddr() {
|
||||
vNewElem := reflect.New(typ).Elem()
|
||||
vNewElem.Set(key)
|
||||
key = vNewElem
|
||||
}
|
||||
|
||||
packedKey = key.Slice(0, key.Len()).Interface().([]byte)
|
||||
} else {
|
||||
// Not just stringify the key (below), but also pack it (here) - panics on disallowed type.
|
||||
_ = packValue(iter.Key(), io.Discard)
|
||||
|
||||
packedKey = []byte(fmt.Sprint(key.Interface()))
|
||||
}
|
||||
} else {
|
||||
// Not just stringify the key (below), but also pack it (here) - panics on disallowed type.
|
||||
_ = packValue(iter.Key(), io.Discard)
|
||||
|
||||
packedKey = []byte(fmt.Sprint(key.Interface()))
|
||||
}
|
||||
|
||||
sorted = append(sorted, kv{packedKey, iter.Value()})
|
||||
}
|
||||
}
|
||||
|
||||
sort.Slice(sorted, func(i, j int) bool { return bytes.Compare(sorted[i].key, sorted[j].key) < 0 })
|
||||
|
||||
for _, kv := range sorted {
|
||||
if err := binary.Write(out, binary.BigEndian, uint64(len(kv.key))); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := out.Write(kv.key); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := packValue(kv.value, out); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// If there aren't any key-value pairs to pack, ...
|
||||
if l < 1 {
|
||||
typ := in.Type()
|
||||
|
||||
// ... create one and pack it - panics on disallowed type.
|
||||
_ = packValue(reflect.Zero(typ.Key()), io.Discard)
|
||||
_ = packValue(reflect.Zero(typ.Elem()), io.Discard)
|
||||
}
|
||||
|
||||
return nil
|
||||
case reflect.Ptr:
|
||||
if in.IsNil() {
|
||||
err := packValue(reflect.Value{}, out)
|
||||
|
||||
// Create a fictive referenced value and pack it - panics on disallowed type.
|
||||
_ = packValue(reflect.Zero(in.Type().Elem()), io.Discard)
|
||||
|
||||
return err
|
||||
} else {
|
||||
return packValue(in.Elem(), out)
|
||||
}
|
||||
case reflect.String:
|
||||
return packString([]byte(in.String()), out)
|
||||
default:
|
||||
panic("bad type: " + in.Kind().String())
|
||||
}
|
||||
}
|
||||
|
||||
// packString deduplicates string packing of multiple locations in packValue.
|
||||
func packString(in []byte, out io.Writer) error {
|
||||
if _, err := out.Write([]byte{4}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := binary.Write(out, binary.BigEndian, uint64(len(in))); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err := out.Write(in)
|
||||
return err
|
||||
}
|
||||
|
|
@ -1,195 +0,0 @@
|
|||
package objectpacker
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/icinga/icinga-go-library/types"
|
||||
"github.com/pkg/errors"
|
||||
"io"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// limitedWriter allows writing a specific amount of data.
|
||||
type limitedWriter struct {
|
||||
// limit specifies how many bytes to allow to write.
|
||||
limit int
|
||||
}
|
||||
|
||||
var _ io.Writer = (*limitedWriter)(nil)
|
||||
|
||||
// Write returns io.EOF once lw.limit is exceeded, nil otherwise.
|
||||
func (lw *limitedWriter) Write(p []byte) (n int, err error) {
|
||||
if len(p) <= lw.limit {
|
||||
lw.limit -= len(p)
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
n = lw.limit
|
||||
err = io.EOF
|
||||
|
||||
lw.limit = 0
|
||||
return
|
||||
}
|
||||
|
||||
func TestLimitedWriter_Write(t *testing.T) {
|
||||
assertLimitedWriter_Write(t, 3, []byte{1, 2}, 2, nil, 1)
|
||||
assertLimitedWriter_Write(t, 3, []byte{1, 2, 3}, 3, nil, 0)
|
||||
assertLimitedWriter_Write(t, 3, []byte{1, 2, 3, 4}, 3, io.EOF, 0)
|
||||
assertLimitedWriter_Write(t, 0, []byte{1}, 0, io.EOF, 0)
|
||||
assertLimitedWriter_Write(t, 0, nil, 0, nil, 0)
|
||||
}
|
||||
|
||||
func assertLimitedWriter_Write(t *testing.T, limitBefore int, p []byte, n int, err error, limitAfter int) {
|
||||
t.Helper()
|
||||
|
||||
lw := limitedWriter{limitBefore}
|
||||
actualN, actualErr := lw.Write(p)
|
||||
|
||||
if !errors.Is(actualErr, err) {
|
||||
t.Errorf("_, err := (&limitedWriter{%d}).Write(%#v); err != %#v", limitBefore, p, err)
|
||||
}
|
||||
|
||||
if actualN != n {
|
||||
t.Errorf("n, _ := (&limitedWriter{%d}).Write(%#v); n != %d", limitBefore, p, n)
|
||||
}
|
||||
|
||||
if lw.limit != limitAfter {
|
||||
t.Errorf("lw := limitedWriter{%d}; lw.Write(%#v); lw.limit != %d", limitBefore, p, limitAfter)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPackAny(t *testing.T) {
|
||||
assertPackAny(t, nil, []byte{0})
|
||||
assertPackAny(t, false, []byte{1})
|
||||
assertPackAny(t, true, []byte{2})
|
||||
|
||||
assertPackAnyPanic(t, -42, 0)
|
||||
assertPackAnyPanic(t, int8(-42), 0)
|
||||
assertPackAnyPanic(t, int16(-42), 0)
|
||||
assertPackAnyPanic(t, int32(-42), 0)
|
||||
assertPackAnyPanic(t, int64(-42), 0)
|
||||
|
||||
assertPackAnyPanic(t, uint(42), 0)
|
||||
assertPackAnyPanic(t, uint8(42), 0)
|
||||
assertPackAnyPanic(t, uint16(42), 0)
|
||||
assertPackAnyPanic(t, uint32(42), 0)
|
||||
assertPackAnyPanic(t, uint64(42), 0)
|
||||
assertPackAnyPanic(t, uintptr(42), 0)
|
||||
|
||||
assertPackAnyPanic(t, float32(-42.5), 0)
|
||||
assertPackAny(t, -42.5, []byte{3, 0xc0, 0x45, 0x40, 0, 0, 0, 0, 0})
|
||||
|
||||
assertPackAnyPanic(t, []struct{}(nil), 9)
|
||||
assertPackAnyPanic(t, []struct{}{}, 9)
|
||||
|
||||
assertPackAny(t, []interface{}{nil, true, -42.5}, []byte{
|
||||
5, 0, 0, 0, 0, 0, 0, 0, 3,
|
||||
0,
|
||||
2,
|
||||
3, 0xc0, 0x45, 0x40, 0, 0, 0, 0, 0,
|
||||
})
|
||||
|
||||
assertPackAny(t, []string{"", "a"}, []byte{
|
||||
5, 0, 0, 0, 0, 0, 0, 0, 2,
|
||||
4, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
4, 0, 0, 0, 0, 0, 0, 0, 1, 'a',
|
||||
})
|
||||
|
||||
assertPackAnyPanic(t, []interface{}{0 + 0i}, 9)
|
||||
|
||||
assertPackAnyPanic(t, map[struct{}]struct{}(nil), 9)
|
||||
assertPackAnyPanic(t, map[struct{}]struct{}{}, 9)
|
||||
|
||||
assertPackAny(t, map[interface{}]interface{}{true: "", "nil": -42.5}, []byte{
|
||||
6, 0, 0, 0, 0, 0, 0, 0, 2,
|
||||
0, 0, 0, 0, 0, 0, 0, 3, 'n', 'i', 'l',
|
||||
3, 0xc0, 0x45, 0x40, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 4, 't', 'r', 'u', 'e',
|
||||
4, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
})
|
||||
|
||||
assertPackAny(t, map[string]float64{"": 42}, []byte{
|
||||
6, 0, 0, 0, 0, 0, 0, 0, 1,
|
||||
0, 0, 0, 0, 0, 0, 0, 0,
|
||||
3, 0x40, 0x45, 0, 0, 0, 0, 0, 0,
|
||||
})
|
||||
|
||||
assertPackAny(t, map[[1]byte]bool{{42}: true}, []byte{
|
||||
6, 0, 0, 0, 0, 0, 0, 0, 1,
|
||||
0, 0, 0, 0, 0, 0, 0, 1, 42,
|
||||
2,
|
||||
})
|
||||
|
||||
assertPackAnyPanic(t, map[struct{}]struct{}{{}: {}}, 9)
|
||||
|
||||
assertPackAny(t, (*string)(nil), []byte{0})
|
||||
assertPackAnyPanic(t, (*int)(nil), 0)
|
||||
assertPackAny(t, new(float64), []byte{3, 0, 0, 0, 0, 0, 0, 0, 0})
|
||||
|
||||
assertPackAny(t, "", []byte{4, 0, 0, 0, 0, 0, 0, 0, 0})
|
||||
assertPackAny(t, "a", []byte{4, 0, 0, 0, 0, 0, 0, 0, 1, 'a'})
|
||||
assertPackAny(t, "ä", []byte{4, 0, 0, 0, 0, 0, 0, 0, 2, 0xc3, 0xa4})
|
||||
|
||||
{
|
||||
var binary [256]byte
|
||||
for i := range binary {
|
||||
binary[i] = byte(i)
|
||||
}
|
||||
|
||||
assertPackAny(t, binary, append([]byte{4, 0, 0, 0, 0, 0, 0, 1, 0}, binary[:]...))
|
||||
assertPackAny(t, binary[:], append([]byte{4, 0, 0, 0, 0, 0, 0, 1, 0}, binary[:]...))
|
||||
assertPackAny(t, types.Binary(binary[:]), append([]byte{4, 0, 0, 0, 0, 0, 0, 1, 0}, binary[:]...))
|
||||
}
|
||||
|
||||
{
|
||||
type myByte byte
|
||||
assertPackAnyPanic(t, []myByte(nil), 9)
|
||||
}
|
||||
|
||||
assertPackAnyPanic(t, complex64(0+0i), 0)
|
||||
assertPackAnyPanic(t, 0+0i, 0)
|
||||
assertPackAnyPanic(t, make(chan struct{}), 0)
|
||||
assertPackAnyPanic(t, func() {}, 0)
|
||||
assertPackAnyPanic(t, struct{}{}, 0)
|
||||
assertPackAnyPanic(t, uintptr(0), 0)
|
||||
}
|
||||
|
||||
func assertPackAny(t *testing.T, in interface{}, out []byte) {
|
||||
t.Helper()
|
||||
|
||||
{
|
||||
buf := &bytes.Buffer{}
|
||||
if err := PackAny(in, buf); err == nil {
|
||||
if !bytes.Equal(buf.Bytes(), out) {
|
||||
t.Errorf("buf := &bytes.Buffer{}; packAny(%#v, buf); !bytes.Equal(buf.Bytes(), %#v)", in, out)
|
||||
}
|
||||
} else {
|
||||
t.Errorf("packAny(%#v, &bytes.Buffer{}) != nil", in)
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < len(out); i++ {
|
||||
if !errors.Is(PackAny(in, &limitedWriter{i}), io.EOF) {
|
||||
t.Errorf("packAny(%#v, &limitedWriter{%d}) != io.EOF", in, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func assertPackAnyPanic(t *testing.T, in interface{}, allowToWrite int) {
|
||||
t.Helper()
|
||||
|
||||
for i := 0; i < allowToWrite; i++ {
|
||||
if !errors.Is(PackAny(in, &limitedWriter{i}), io.EOF) {
|
||||
t.Errorf("packAny(%#v, &limitedWriter{%d}) != io.EOF", in, i)
|
||||
}
|
||||
}
|
||||
|
||||
defer func() {
|
||||
t.Helper()
|
||||
|
||||
if r := recover(); r == nil {
|
||||
t.Errorf("packAny(%#v, &limitedWriter{%d}) didn't panic", in, allowToWrite)
|
||||
}
|
||||
}()
|
||||
|
||||
_ = PackAny(in, &limitedWriter{allowToWrite})
|
||||
}
|
||||
|
|
@ -1,123 +0,0 @@
|
|||
package periodic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Option configures Start.
|
||||
type Option interface {
|
||||
apply(*periodic)
|
||||
}
|
||||
|
||||
// Stopper implements the Stop method,
|
||||
// which stops a periodic task from Start().
|
||||
type Stopper interface {
|
||||
Stop() // Stops a periodic task.
|
||||
}
|
||||
|
||||
// Tick is the value for periodic task callbacks that
|
||||
// contains the time of the tick and
|
||||
// the time elapsed since the start of the periodic task.
|
||||
type Tick struct {
|
||||
Elapsed time.Duration
|
||||
Time time.Time
|
||||
}
|
||||
|
||||
// Immediate starts the periodic task immediately instead of after the first tick.
|
||||
func Immediate() Option {
|
||||
return optionFunc(func(p *periodic) {
|
||||
p.immediate = true
|
||||
})
|
||||
}
|
||||
|
||||
// OnStop configures a callback that is executed when a periodic task is stopped or canceled.
|
||||
func OnStop(f func(Tick)) Option {
|
||||
return optionFunc(func(p *periodic) {
|
||||
p.onStop = f
|
||||
})
|
||||
}
|
||||
|
||||
// Start starts a periodic task with a ticker at the specified interval,
|
||||
// which executes the given callback after each tick.
|
||||
// Pending tasks do not overlap, but could start immediately if
|
||||
// the previous task(s) takes longer than the interval.
|
||||
// Call Stop() on the return value in order to stop the ticker and to release associated resources.
|
||||
// The interval must be greater than zero.
|
||||
func Start(ctx context.Context, interval time.Duration, callback func(Tick), options ...Option) Stopper {
|
||||
t := &periodic{
|
||||
interval: interval,
|
||||
callback: callback,
|
||||
}
|
||||
|
||||
for _, option := range options {
|
||||
option.apply(t)
|
||||
}
|
||||
|
||||
ctx, cancelCtx := context.WithCancel(ctx)
|
||||
|
||||
start := time.Now()
|
||||
|
||||
go func() {
|
||||
done := false
|
||||
|
||||
if !t.immediate {
|
||||
select {
|
||||
case <-time.After(interval):
|
||||
case <-ctx.Done():
|
||||
done = true
|
||||
}
|
||||
}
|
||||
|
||||
if !done {
|
||||
ticker := time.NewTicker(t.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for tickTime := time.Now(); !done; {
|
||||
t.callback(Tick{
|
||||
Elapsed: tickTime.Sub(start),
|
||||
Time: tickTime,
|
||||
})
|
||||
|
||||
select {
|
||||
case tickTime = <-ticker.C:
|
||||
case <-ctx.Done():
|
||||
done = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if t.onStop != nil {
|
||||
now := time.Now()
|
||||
t.onStop(Tick{
|
||||
Elapsed: now.Sub(start),
|
||||
Time: now,
|
||||
})
|
||||
}
|
||||
}()
|
||||
|
||||
return stoperFunc(func() {
|
||||
t.stop.Do(cancelCtx)
|
||||
})
|
||||
}
|
||||
|
||||
type optionFunc func(*periodic)
|
||||
|
||||
func (f optionFunc) apply(p *periodic) {
|
||||
f(p)
|
||||
}
|
||||
|
||||
type stoperFunc func()
|
||||
|
||||
func (f stoperFunc) Stop() {
|
||||
f()
|
||||
}
|
||||
|
||||
type periodic struct {
|
||||
interval time.Duration
|
||||
callback func(Tick)
|
||||
immediate bool
|
||||
stop sync.Once
|
||||
onStop func(Tick)
|
||||
}
|
||||
|
|
@ -1,14 +0,0 @@
|
|||
package redis
|
||||
|
||||
import "github.com/redis/go-redis/v9"
|
||||
|
||||
// Alias definitions of commonly used go-redis exports,
|
||||
// so that only this redis package needs to be imported and not go-redis additionally.
|
||||
|
||||
type IntCmd = redis.IntCmd
|
||||
type Pipeliner = redis.Pipeliner
|
||||
type XAddArgs = redis.XAddArgs
|
||||
type XMessage = redis.XMessage
|
||||
type XReadArgs = redis.XReadArgs
|
||||
|
||||
var NewScript = redis.NewScript
|
||||
|
|
@ -1,277 +0,0 @@
|
|||
package redis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"github.com/icinga/icinga-go-library/backoff"
|
||||
"github.com/icinga/icinga-go-library/com"
|
||||
"github.com/icinga/icinga-go-library/logging"
|
||||
"github.com/icinga/icinga-go-library/periodic"
|
||||
"github.com/icinga/icinga-go-library/retry"
|
||||
"github.com/icinga/icinga-go-library/utils"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"golang.org/x/sync/semaphore"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Client is a wrapper around redis.Client with
|
||||
// streaming and logging capabilities.
|
||||
type Client struct {
|
||||
*redis.Client
|
||||
|
||||
Options *Options
|
||||
|
||||
logger *logging.Logger
|
||||
}
|
||||
|
||||
// NewClient returns a new Client wrapper for a pre-existing redis.Client.
|
||||
func NewClient(client *redis.Client, logger *logging.Logger, options *Options) *Client {
|
||||
return &Client{Client: client, logger: logger, Options: options}
|
||||
}
|
||||
|
||||
// NewClientFromConfig returns a new Client from Config.
|
||||
func NewClientFromConfig(c *Config, logger *logging.Logger) (*Client, error) {
|
||||
tlsConfig, err := c.TlsOptions.MakeConfig(c.Host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var dialer ctxDialerFunc
|
||||
dl := &net.Dialer{Timeout: 15 * time.Second}
|
||||
|
||||
if tlsConfig == nil {
|
||||
dialer = dl.DialContext
|
||||
} else {
|
||||
dialer = (&tls.Dialer{NetDialer: dl, Config: tlsConfig}).DialContext
|
||||
}
|
||||
|
||||
options := &redis.Options{
|
||||
Dialer: dialWithLogging(dialer, logger),
|
||||
Password: c.Password,
|
||||
DB: 0, // Use default DB,
|
||||
ReadTimeout: c.Options.Timeout,
|
||||
TLSConfig: tlsConfig,
|
||||
}
|
||||
|
||||
if utils.IsUnixAddr(c.Host) {
|
||||
options.Network = "unix"
|
||||
options.Addr = c.Host
|
||||
} else {
|
||||
port := c.Port
|
||||
if port == 0 {
|
||||
port = 6379
|
||||
}
|
||||
options.Network = "tcp"
|
||||
options.Addr = net.JoinHostPort(c.Host, fmt.Sprint(port))
|
||||
}
|
||||
|
||||
client := redis.NewClient(options)
|
||||
options = client.Options()
|
||||
options.PoolSize = utils.MaxInt(32, options.PoolSize)
|
||||
options.MaxRetries = options.PoolSize + 1 // https://github.com/go-redis/redis/issues/1737
|
||||
|
||||
return NewClient(redis.NewClient(options), logger, &c.Options), nil
|
||||
}
|
||||
|
||||
// GetAddr returns the Redis host:port or Unix socket address.
|
||||
func (c *Client) GetAddr() string {
|
||||
return c.Client.Options().Addr
|
||||
}
|
||||
|
||||
// HPair defines Redis hashes field-value pairs.
|
||||
type HPair struct {
|
||||
Field string
|
||||
Value string
|
||||
}
|
||||
|
||||
// HYield yields HPair field-value pairs for all fields in the hash stored at key.
|
||||
func (c *Client) HYield(ctx context.Context, key string) (<-chan HPair, <-chan error) {
|
||||
pairs := make(chan HPair, c.Options.HScanCount)
|
||||
|
||||
return pairs, com.WaitAsync(com.WaiterFunc(func() error {
|
||||
var counter com.Counter
|
||||
defer c.log(ctx, key, &counter).Stop()
|
||||
defer close(pairs)
|
||||
|
||||
seen := make(map[string]struct{})
|
||||
|
||||
var cursor uint64
|
||||
var err error
|
||||
var page []string
|
||||
|
||||
for {
|
||||
cmd := c.HScan(ctx, key, cursor, "", int64(c.Options.HScanCount))
|
||||
page, cursor, err = cmd.Result()
|
||||
|
||||
if err != nil {
|
||||
return WrapCmdErr(cmd)
|
||||
}
|
||||
|
||||
for i := 0; i < len(page); i += 2 {
|
||||
if _, ok := seen[page[i]]; ok {
|
||||
// Ignore duplicate returned by HSCAN.
|
||||
continue
|
||||
}
|
||||
|
||||
seen[page[i]] = struct{}{}
|
||||
|
||||
select {
|
||||
case pairs <- HPair{
|
||||
Field: page[i],
|
||||
Value: page[i+1],
|
||||
}:
|
||||
counter.Inc()
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
||||
// HMYield yields HPair field-value pairs for the specified fields in the hash stored at key.
|
||||
func (c *Client) HMYield(ctx context.Context, key string, fields ...string) (<-chan HPair, <-chan error) {
|
||||
pairs := make(chan HPair)
|
||||
|
||||
return pairs, com.WaitAsync(com.WaiterFunc(func() error {
|
||||
var counter com.Counter
|
||||
defer c.log(ctx, key, &counter).Stop()
|
||||
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
|
||||
defer func() {
|
||||
// Wait until the group is done so that we can safely close the pairs channel,
|
||||
// because on error, sem.Acquire will return before calling g.Wait(),
|
||||
// which can result in goroutines working on a closed channel.
|
||||
_ = g.Wait()
|
||||
close(pairs)
|
||||
}()
|
||||
|
||||
// Use context from group.
|
||||
batches := utils.BatchSliceOfStrings(ctx, fields, c.Options.HMGetCount)
|
||||
|
||||
sem := semaphore.NewWeighted(int64(c.Options.MaxHMGetConnections))
|
||||
|
||||
for batch := range batches {
|
||||
if err := sem.Acquire(ctx, 1); err != nil {
|
||||
return errors.Wrap(err, "can't acquire semaphore")
|
||||
}
|
||||
|
||||
batch := batch
|
||||
g.Go(func() error {
|
||||
defer sem.Release(1)
|
||||
|
||||
cmd := c.HMGet(ctx, key, batch...)
|
||||
vals, err := cmd.Result()
|
||||
|
||||
if err != nil {
|
||||
return WrapCmdErr(cmd)
|
||||
}
|
||||
|
||||
for i, v := range vals {
|
||||
if v == nil {
|
||||
c.logger.Warnf("HMGET %s: field %#v missing", key, batch[i])
|
||||
continue
|
||||
}
|
||||
|
||||
select {
|
||||
case pairs <- HPair{
|
||||
Field: batch[i],
|
||||
Value: v.(string),
|
||||
}:
|
||||
counter.Inc()
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
return g.Wait()
|
||||
}))
|
||||
}
|
||||
|
||||
// XReadUntilResult (repeatedly) calls XREAD with the specified arguments until a result is returned.
|
||||
// Each call blocks at most for the duration specified in Options.BlockTimeout until data
|
||||
// is available before it times out and the next call is made.
|
||||
// This also means that an already set block timeout is overridden.
|
||||
func (c *Client) XReadUntilResult(ctx context.Context, a *redis.XReadArgs) ([]redis.XStream, error) {
|
||||
a.Block = c.Options.BlockTimeout
|
||||
|
||||
for {
|
||||
cmd := c.XRead(ctx, a)
|
||||
streams, err := cmd.Result()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
continue
|
||||
}
|
||||
|
||||
return streams, WrapCmdErr(cmd)
|
||||
}
|
||||
|
||||
return streams, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) log(ctx context.Context, key string, counter *com.Counter) periodic.Stopper {
|
||||
return periodic.Start(ctx, c.logger.Interval(), func(tick periodic.Tick) {
|
||||
// We may never get to progress logging here,
|
||||
// as fetching should be completed before the interval expires,
|
||||
// but if it does, it is good to have this log message.
|
||||
if count := counter.Reset(); count > 0 {
|
||||
c.logger.Debugf("Fetched %d items from %s", count, key)
|
||||
}
|
||||
}, periodic.OnStop(func(tick periodic.Tick) {
|
||||
c.logger.Debugf("Finished fetching from %s with %d items in %s", key, counter.Total(), tick.Elapsed)
|
||||
}))
|
||||
}
|
||||
|
||||
type ctxDialerFunc = func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
|
||||
// dialWithLogging returns a Redis Dialer with logging capabilities.
|
||||
func dialWithLogging(dialer ctxDialerFunc, logger *logging.Logger) ctxDialerFunc {
|
||||
// dial behaves like net.Dialer#DialContext,
|
||||
// but re-tries on common errors that are considered retryable.
|
||||
return func(ctx context.Context, network, addr string) (conn net.Conn, err error) {
|
||||
err = retry.WithBackoff(
|
||||
ctx,
|
||||
func(ctx context.Context) (err error) {
|
||||
conn, err = dialer(ctx, network, addr)
|
||||
return
|
||||
},
|
||||
retry.Retryable,
|
||||
backoff.NewExponentialWithJitter(1*time.Millisecond, 1*time.Second),
|
||||
retry.Settings{
|
||||
Timeout: retry.DefaultTimeout,
|
||||
OnRetryableError: func(_ time.Duration, _ uint64, err, lastErr error) {
|
||||
if lastErr == nil || err.Error() != lastErr.Error() {
|
||||
logger.Warnw("Can't connect to Redis. Retrying", zap.Error(err))
|
||||
}
|
||||
},
|
||||
OnSuccess: func(elapsed time.Duration, attempt uint64, _ error) {
|
||||
if attempt > 1 {
|
||||
logger.Infow("Reconnected to Redis",
|
||||
zap.Duration("after", elapsed), zap.Uint64("attempts", attempt))
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
err = errors.Wrap(err, "can't connect to Redis")
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
@ -1,59 +0,0 @@
|
|||
package redis
|
||||
|
||||
import (
|
||||
"github.com/icinga/icinga-go-library/config"
|
||||
"github.com/pkg/errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Options define user configurable Redis options.
|
||||
type Options struct {
|
||||
BlockTimeout time.Duration `yaml:"block_timeout" default:"1s"`
|
||||
HMGetCount int `yaml:"hmget_count" default:"4096"`
|
||||
HScanCount int `yaml:"hscan_count" default:"4096"`
|
||||
MaxHMGetConnections int `yaml:"max_hmget_connections" default:"8"`
|
||||
Timeout time.Duration `yaml:"timeout" default:"30s"`
|
||||
XReadCount int `yaml:"xread_count" default:"4096"`
|
||||
}
|
||||
|
||||
// Validate checks constraints in the supplied Redis options and returns an error if they are violated.
|
||||
func (o *Options) Validate() error {
|
||||
if o.BlockTimeout <= 0 {
|
||||
return errors.New("block_timeout must be positive")
|
||||
}
|
||||
if o.HMGetCount < 1 {
|
||||
return errors.New("hmget_count must be at least 1")
|
||||
}
|
||||
if o.HScanCount < 1 {
|
||||
return errors.New("hscan_count must be at least 1")
|
||||
}
|
||||
if o.MaxHMGetConnections < 1 {
|
||||
return errors.New("max_hmget_connections must be at least 1")
|
||||
}
|
||||
if o.Timeout == 0 {
|
||||
return errors.New("timeout cannot be 0. Configure a value greater than zero, or use -1 for no timeout")
|
||||
}
|
||||
if o.XReadCount < 1 {
|
||||
return errors.New("xread_count must be at least 1")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Config defines Config client configuration.
|
||||
type Config struct {
|
||||
Host string `yaml:"host"`
|
||||
Port int `yaml:"port"`
|
||||
Password string `yaml:"password"`
|
||||
TlsOptions config.TLS `yaml:",inline"`
|
||||
Options Options `yaml:"options"`
|
||||
}
|
||||
|
||||
// Validate checks constraints in the supplied Config configuration and returns an error if they are violated.
|
||||
func (r *Config) Validate() error {
|
||||
if r.Host == "" {
|
||||
return errors.New("Redis host missing")
|
||||
}
|
||||
|
||||
return r.Options.Validate()
|
||||
}
|
||||
|
|
@ -1,20 +0,0 @@
|
|||
package redis
|
||||
|
||||
// Streams represents a Redis stream key to ID mapping.
|
||||
type Streams map[string]string
|
||||
|
||||
// Option returns the Redis stream key to ID mapping
|
||||
// as a slice of stream keys followed by their IDs
|
||||
// that is compatible for the Redis STREAMS option.
|
||||
func (s Streams) Option() []string {
|
||||
// len*2 because we're appending the IDs later.
|
||||
streams := make([]string, 0, len(s)*2)
|
||||
ids := make([]string, 0, len(s))
|
||||
|
||||
for key, id := range s {
|
||||
streams = append(streams, key)
|
||||
ids = append(ids, id)
|
||||
}
|
||||
|
||||
return append(streams, ids...)
|
||||
}
|
||||
|
|
@ -1,22 +0,0 @@
|
|||
package redis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/icinga/icinga-go-library/utils"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// WrapCmdErr adds the command itself and
|
||||
// the stack of the current goroutine to the command's error if any.
|
||||
func WrapCmdErr(cmd redis.Cmder) error {
|
||||
err := cmd.Err()
|
||||
if err != nil {
|
||||
err = errors.Wrapf(err, "can't perform %q", utils.Ellipsize(
|
||||
redis.NewCmd(context.Background(), cmd.Args()).String(), // Omits error in opposite to cmd.String()
|
||||
100,
|
||||
))
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
|
@ -1,201 +0,0 @@
|
|||
package retry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql/driver"
|
||||
"github.com/go-sql-driver/mysql"
|
||||
"github.com/icinga/icinga-go-library/backoff"
|
||||
"github.com/lib/pq"
|
||||
"github.com/pkg/errors"
|
||||
"io"
|
||||
"net"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DefaultTimeout is our opinionated default timeout for retrying database and Redis operations.
|
||||
const DefaultTimeout = 5 * time.Minute
|
||||
|
||||
// RetryableFunc is a retryable function.
|
||||
type RetryableFunc func(context.Context) error
|
||||
|
||||
// IsRetryable checks whether a new attempt can be started based on the error passed.
|
||||
type IsRetryable func(error) bool
|
||||
|
||||
// OnRetryableErrorFunc is called if a retryable error occurs.
|
||||
type OnRetryableErrorFunc func(elapsed time.Duration, attempt uint64, err, lastErr error)
|
||||
|
||||
// OnSuccessFunc is called once the operation succeeds.
|
||||
type OnSuccessFunc func(elapsed time.Duration, attempt uint64, lastErr error)
|
||||
|
||||
// Settings aggregates optional settings for WithBackoff.
|
||||
type Settings struct {
|
||||
// If >0, Timeout lets WithBackoff stop retrying gracefully once elapsed based on the following criteria:
|
||||
// * If the execution of RetryableFunc has taken longer than Timeout, no further attempts are made.
|
||||
// * If Timeout elapses during the sleep phase between retries, one final retry is attempted.
|
||||
// * RetryableFunc is always granted its full execution time and is not canceled if it exceeds Timeout.
|
||||
// This means that WithBackoff may not stop exactly after Timeout expires,
|
||||
// or may not retry at all if the first execution of RetryableFunc already takes longer than Timeout.
|
||||
Timeout time.Duration
|
||||
OnRetryableError OnRetryableErrorFunc
|
||||
OnSuccess OnSuccessFunc
|
||||
}
|
||||
|
||||
// WithBackoff retries the passed function if it fails and the error allows it to retry.
|
||||
// The specified backoff policy is used to determine how long to sleep between attempts.
|
||||
func WithBackoff(
|
||||
ctx context.Context, retryableFunc RetryableFunc, retryable IsRetryable, b backoff.Backoff, settings Settings,
|
||||
) (err error) {
|
||||
// Channel for retry deadline, which is set to the channel of NewTimer() if a timeout is configured,
|
||||
// otherwise nil, so that it blocks forever if there is no timeout.
|
||||
var timeout <-chan time.Time
|
||||
|
||||
if settings.Timeout > 0 {
|
||||
t := time.NewTimer(settings.Timeout)
|
||||
defer t.Stop()
|
||||
timeout = t.C
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
timedOut := false
|
||||
for attempt := uint64(1); ; /* true */ attempt++ {
|
||||
prevErr := err
|
||||
|
||||
if err = retryableFunc(ctx); err == nil {
|
||||
if settings.OnSuccess != nil {
|
||||
settings.OnSuccess(time.Since(start), attempt, prevErr)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Retryable function may have exited prematurely due to context errors.
|
||||
// We explicitly check the context error here, as the error returned by the retryable function can pass the
|
||||
// error.Is() checks even though it is not a real context error, e.g.
|
||||
// https://cs.opensource.google/go/go/+/refs/tags/go1.22.2:src/net/net.go;l=422
|
||||
// https://cs.opensource.google/go/go/+/refs/tags/go1.22.2:src/net/net.go;l=601
|
||||
if errors.Is(ctx.Err(), context.DeadlineExceeded) || errors.Is(ctx.Err(), context.Canceled) {
|
||||
if prevErr != nil {
|
||||
err = errors.Wrap(err, prevErr.Error())
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if !retryable(err) {
|
||||
err = errors.Wrap(err, "can't retry")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-timeout:
|
||||
// Stop retrying immediately if executing the retryable function took longer than the timeout.
|
||||
timedOut = true
|
||||
default:
|
||||
}
|
||||
|
||||
if timedOut {
|
||||
err = errors.Wrap(err, "retry deadline exceeded")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if settings.OnRetryableError != nil {
|
||||
settings.OnRetryableError(time.Since(start), attempt, err, prevErr)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-time.After(b(attempt)):
|
||||
case <-timeout:
|
||||
// Do not stop retrying immediately, but start one last attempt to mitigate timing issues where
|
||||
// the timeout expires while waiting for the next attempt and
|
||||
// therefore no retries have happened during this possibly long period.
|
||||
timedOut = true
|
||||
case <-ctx.Done():
|
||||
err = errors.Wrap(ctx.Err(), err.Error())
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ResetTimeout changes the possibly expired timer t to expire after duration d.
|
||||
//
|
||||
// If the timer has already expired and nothing has been received from its channel,
|
||||
// it is automatically drained as if the timer had never expired.
|
||||
func ResetTimeout(t *time.Timer, d time.Duration) {
|
||||
if !t.Stop() {
|
||||
<-t.C
|
||||
}
|
||||
|
||||
t.Reset(d)
|
||||
}
|
||||
|
||||
// Retryable returns true for common errors that are considered retryable,
|
||||
// i.e. temporary, timeout, DNS, connection refused and reset, host down and unreachable and
|
||||
// network down and unreachable errors. In addition, any database error is considered retryable.
|
||||
func Retryable(err error) bool {
|
||||
var temporary interface {
|
||||
Temporary() bool
|
||||
}
|
||||
if errors.As(err, &temporary) && temporary.Temporary() {
|
||||
return true
|
||||
}
|
||||
|
||||
var timeout interface {
|
||||
Timeout() bool
|
||||
}
|
||||
if errors.As(err, &timeout) && timeout.Timeout() {
|
||||
return true
|
||||
}
|
||||
|
||||
var dnsError *net.DNSError
|
||||
if errors.As(err, &dnsError) {
|
||||
return true
|
||||
}
|
||||
|
||||
var opError *net.OpError
|
||||
if errors.As(err, &opError) {
|
||||
// OpError provides Temporary() and Timeout(), but not Unwrap(),
|
||||
// so we have to extract the underlying error ourselves to also check for ECONNREFUSED,
|
||||
// which is not considered temporary or timed out by Go.
|
||||
err = opError.Err
|
||||
}
|
||||
if errors.Is(err, syscall.ECONNREFUSED) || errors.Is(err, syscall.ENOENT) {
|
||||
// syscall errors provide Temporary() and Timeout(),
|
||||
// which do not include ECONNREFUSED or ENOENT, so we check these ourselves.
|
||||
return true
|
||||
}
|
||||
if errors.Is(err, syscall.ECONNRESET) {
|
||||
// ECONNRESET is treated as a temporary error by Go only if it comes from calling accept.
|
||||
return true
|
||||
}
|
||||
if errors.Is(err, syscall.EHOSTDOWN) || errors.Is(err, syscall.EHOSTUNREACH) {
|
||||
return true
|
||||
}
|
||||
if errors.Is(err, syscall.ENETDOWN) || errors.Is(err, syscall.ENETUNREACH) {
|
||||
return true
|
||||
}
|
||||
if errors.Is(err, syscall.EPIPE) {
|
||||
return true
|
||||
}
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
return true
|
||||
}
|
||||
|
||||
if errors.Is(err, driver.ErrBadConn) {
|
||||
return true
|
||||
}
|
||||
if errors.Is(err, mysql.ErrInvalidConn) {
|
||||
return true
|
||||
}
|
||||
|
||||
var mye *mysql.MySQLError
|
||||
var pqe *pq.Error
|
||||
if errors.As(err, &mye) || errors.As(err, &pqe) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
|
@ -1,56 +0,0 @@
|
|||
// Package strcase implements functions to convert a camelCase UTF-8 string into various cases.
|
||||
//
|
||||
// New delimiters will be inserted based on the following transitions:
|
||||
// - On any change from lowercase to uppercase letter.
|
||||
// - On any change from number to uppercase letter.
|
||||
package strcase
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
// Delimited converts a string to delimited.lower.case, here using `.` as delimiter.
|
||||
func Delimited(s string, d rune) string {
|
||||
return convert(s, unicode.LowerCase, d)
|
||||
}
|
||||
|
||||
// ScreamingDelimited converts a string to DELIMITED.UPPER.CASE, here using `.` as delimiter.
|
||||
func ScreamingDelimited(s string, d rune) string {
|
||||
return convert(s, unicode.UpperCase, d)
|
||||
}
|
||||
|
||||
// Snake converts a string to snake_case.
|
||||
func Snake(s string) string {
|
||||
return Delimited(s, '_')
|
||||
}
|
||||
|
||||
// ScreamingSnake converts a string to SCREAMING_SNAKE_CASE.
|
||||
func ScreamingSnake(s string) string {
|
||||
return ScreamingDelimited(s, '_')
|
||||
}
|
||||
|
||||
// convert converts a camelCase UTF-8 string into various cases.
|
||||
// _case must be unicode.LowerCase or unicode.UpperCase.
|
||||
func convert(s string, _case int, d rune) string {
|
||||
if len(s) == 0 {
|
||||
return s
|
||||
}
|
||||
|
||||
n := strings.Builder{}
|
||||
n.Grow(len(s) + 2) // Allow adding at least 2 delimiters without another allocation.
|
||||
|
||||
var prevRune rune
|
||||
|
||||
for i, r := range s {
|
||||
if i > 0 && unicode.IsUpper(r) && (unicode.IsNumber(prevRune) || unicode.IsLower(prevRune)) {
|
||||
n.WriteRune(d)
|
||||
}
|
||||
|
||||
n.WriteRune(unicode.To(_case, r))
|
||||
|
||||
prevRune = r
|
||||
}
|
||||
|
||||
return n.String()
|
||||
}
|
||||
|
|
@ -1,58 +0,0 @@
|
|||
package strcase
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var tests = [][]string{
|
||||
{"", ""},
|
||||
{"Test", "test"},
|
||||
{"test", "test"},
|
||||
{"testCase", "test_case"},
|
||||
{"test_case", "test_case"},
|
||||
{"TestCase", "test_case"},
|
||||
{"Test_Case", "test_case"},
|
||||
{"ID", "id"},
|
||||
{"userID", "user_id"},
|
||||
{"UserID", "user_id"},
|
||||
{"ManyManyWords", "many_many_words"},
|
||||
{"manyManyWords", "many_many_words"},
|
||||
{"icinga2", "icinga2"},
|
||||
{"Icinga2Version", "icinga2_version"},
|
||||
{"k8sVersion", "k8s_version"},
|
||||
{"1234", "1234"},
|
||||
{"a1b2c3d4", "a1b2c3d4"},
|
||||
{"with1234digits", "with1234digits"},
|
||||
{"with1234Digits", "with1234_digits"},
|
||||
{"IPv4", "ipv4"},
|
||||
{"IPv4Address", "ipv4_address"},
|
||||
{"caféCrème", "café_crème"},
|
||||
{"0℃", "0℃"},
|
||||
{"~0", "~0"},
|
||||
{"icinga💯points", "icinga💯points"},
|
||||
{"😃🙃😀", "😃🙃😀"},
|
||||
{"こんにちは", "こんにちは"},
|
||||
{"\xff\xfe\xfd", "<22><><EFBFBD>"},
|
||||
{"\xff", "<22>"},
|
||||
}
|
||||
|
||||
func TestSnake(t *testing.T) {
|
||||
for _, test := range tests {
|
||||
s, expected := test[0], test[1]
|
||||
actual := Snake(s)
|
||||
if actual != expected {
|
||||
t.Errorf("%q: %q != %q", s, actual, expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestScreamingSnake(t *testing.T) {
|
||||
for _, test := range tests {
|
||||
s, expected := test[0], strings.ToUpper(test[1])
|
||||
actual := ScreamingSnake(s)
|
||||
if actual != expected {
|
||||
t.Errorf("%q: %q != %q", s, actual, expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,176 +0,0 @@
|
|||
package structify
|
||||
|
||||
import (
|
||||
"encoding"
|
||||
"fmt"
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/exp/constraints"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// structBranch represents either a leaf or a subTree.
|
||||
type structBranch struct {
|
||||
// field specifies the struct field index.
|
||||
field int
|
||||
// leaf specifies the map key to parse the struct field from.
|
||||
leaf string
|
||||
// subTree specifies the struct field's inner tree.
|
||||
subTree []structBranch
|
||||
}
|
||||
|
||||
type MapStructifier = func(map[string]interface{}) (interface{}, error)
|
||||
|
||||
// MakeMapStructifier builds a function which parses a map's string values into a new struct of type t
|
||||
// and returns a pointer to it. tag specifies which tag connects struct fields to map keys.
|
||||
// MakeMapStructifier panics if it detects an unsupported type (suitable for usage in init() or global vars).
|
||||
func MakeMapStructifier(t reflect.Type, tag string, initer func(any)) MapStructifier {
|
||||
tree := buildStructTree(t, tag)
|
||||
|
||||
return func(kv map[string]interface{}) (interface{}, error) {
|
||||
vPtr := reflect.New(t)
|
||||
ptr := vPtr.Interface()
|
||||
if initer != nil {
|
||||
initer(ptr)
|
||||
}
|
||||
vPtrElem := vPtr.Elem()
|
||||
err := errors.Wrapf(structifyMapByTree(kv, tree, vPtrElem, vPtrElem, new([]int)), "can't structify map %#v by tree %#v", kv, tree)
|
||||
|
||||
return ptr, err
|
||||
}
|
||||
}
|
||||
|
||||
// buildStructTree assembles a tree which represents the struct t based on tag.
|
||||
func buildStructTree(t reflect.Type, tag string) []structBranch {
|
||||
var tree []structBranch
|
||||
numFields := t.NumField()
|
||||
|
||||
for i := 0; i < numFields; i++ {
|
||||
if field := t.Field(i); field.PkgPath == "" {
|
||||
switch tagValue := field.Tag.Get(tag); tagValue {
|
||||
case "", "-":
|
||||
case ",inline":
|
||||
if subTree := buildStructTree(field.Type, tag); subTree != nil {
|
||||
tree = append(tree, structBranch{i, "", subTree})
|
||||
}
|
||||
default:
|
||||
// If parseString doesn't support *T, it'll panic.
|
||||
_ = parseString("", reflect.New(field.Type).Interface())
|
||||
|
||||
tree = append(tree, structBranch{i, tagValue, nil})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tree
|
||||
}
|
||||
|
||||
// structifyMapByTree parses src's string values into the struct dest according to tree's specification.
|
||||
func structifyMapByTree(src map[string]interface{}, tree []structBranch, dest, root reflect.Value, stack *[]int) error {
|
||||
*stack = append(*stack, 0)
|
||||
defer func() {
|
||||
*stack = (*stack)[:len(*stack)-1]
|
||||
}()
|
||||
|
||||
for _, branch := range tree {
|
||||
(*stack)[len(*stack)-1] = branch.field
|
||||
|
||||
if branch.subTree == nil {
|
||||
if v, ok := src[branch.leaf]; ok {
|
||||
if vs, ok := v.(string); ok {
|
||||
if err := parseString(vs, dest.Field(branch.field).Addr().Interface()); err != nil {
|
||||
rt := root.Type()
|
||||
typ := rt
|
||||
var path []string
|
||||
|
||||
for _, i := range *stack {
|
||||
f := typ.Field(i)
|
||||
path = append(path, f.Name)
|
||||
typ = f.Type
|
||||
}
|
||||
|
||||
return errors.Wrapf(err, "can't parse %s into the %s %s#%s: %s",
|
||||
branch.leaf, typ.Name(), rt.Name(), strings.Join(path, "."), vs)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if err := structifyMapByTree(src, branch.subTree, dest.Field(branch.field), root, stack); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseString parses src into *dest.
|
||||
func parseString(src string, dest interface{}) error {
|
||||
switch ptr := dest.(type) {
|
||||
case encoding.TextUnmarshaler:
|
||||
return ptr.UnmarshalText([]byte(src))
|
||||
case *string:
|
||||
*ptr = src
|
||||
return nil
|
||||
case **string:
|
||||
*ptr = &src
|
||||
return nil
|
||||
case *uint8:
|
||||
return parseUint(src, ptr)
|
||||
case *uint16:
|
||||
return parseUint(src, ptr)
|
||||
case *uint32:
|
||||
return parseUint(src, ptr)
|
||||
case *uint64:
|
||||
return parseUint(src, ptr)
|
||||
case *int8:
|
||||
return parseInt(src, ptr)
|
||||
case *int16:
|
||||
return parseInt(src, ptr)
|
||||
case *int32:
|
||||
return parseInt(src, ptr)
|
||||
case *int64:
|
||||
return parseInt(src, ptr)
|
||||
case *float32:
|
||||
return parseFloat(src, ptr)
|
||||
case *float64:
|
||||
return parseFloat(src, ptr)
|
||||
default:
|
||||
panic(fmt.Sprintf("unsupported type: %T", dest))
|
||||
}
|
||||
}
|
||||
|
||||
// parseUint parses src into *dest.
|
||||
func parseUint[T constraints.Unsigned](src string, dest *T) error {
|
||||
i, err := strconv.ParseUint(src, 10, bitSizeOf[T]())
|
||||
if err == nil {
|
||||
*dest = T(i)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// parseInt parses src into *dest.
|
||||
func parseInt[T constraints.Signed](src string, dest *T) error {
|
||||
i, err := strconv.ParseInt(src, 10, bitSizeOf[T]())
|
||||
if err == nil {
|
||||
*dest = T(i)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// parseFloat parses src into *dest.
|
||||
func parseFloat[T constraints.Float](src string, dest *T) error {
|
||||
f, err := strconv.ParseFloat(src, bitSizeOf[T]())
|
||||
if err == nil {
|
||||
*dest = T(f)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func bitSizeOf[T any]() int {
|
||||
var x T
|
||||
return int(unsafe.Sizeof(x) * 8)
|
||||
}
|
||||
|
|
@ -1,125 +0,0 @@
|
|||
package types
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Binary nullable byte string. Hex as JSON.
|
||||
type Binary []byte
|
||||
|
||||
// nullBinary for validating whether a Binary is valid.
|
||||
var nullBinary Binary
|
||||
|
||||
// Valid returns whether the Binary is valid.
|
||||
func (binary Binary) Valid() bool {
|
||||
return !bytes.Equal(binary, nullBinary)
|
||||
}
|
||||
|
||||
// String returns the hex string representation form of the Binary.
|
||||
func (binary Binary) String() string {
|
||||
return hex.EncodeToString(binary)
|
||||
}
|
||||
|
||||
// MarshalText implements a custom marhsal function to encode
|
||||
// the Binary as hex. MarshalText implements the
|
||||
// encoding.TextMarshaler interface.
|
||||
func (binary Binary) MarshalText() ([]byte, error) {
|
||||
return []byte(binary.String()), nil
|
||||
}
|
||||
|
||||
// UnmarshalText implements a custom unmarshal function to decode
|
||||
// hex into a Binary. UnmarshalText implements the
|
||||
// encoding.TextUnmarshaler interface.
|
||||
func (binary *Binary) UnmarshalText(text []byte) error {
|
||||
b := make([]byte, hex.DecodedLen(len(text)))
|
||||
_, err := hex.Decode(b, text)
|
||||
if err != nil {
|
||||
return CantDecodeHex(err, string(text))
|
||||
}
|
||||
*binary = b
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarshalJSON implements a custom marshal function to encode the Binary
|
||||
// as a hex string. MarshalJSON implements the json.Marshaler interface.
|
||||
// Supports JSON null.
|
||||
func (binary Binary) MarshalJSON() ([]byte, error) {
|
||||
if !binary.Valid() {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
|
||||
return MarshalJSON(binary.String())
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements a custom unmarshal function to decode
|
||||
// a JSON hex string into a Binary. UnmarshalJSON implements the
|
||||
// json.Unmarshaler interface. Supports JSON null.
|
||||
func (binary *Binary) UnmarshalJSON(data []byte) error {
|
||||
if string(data) == "null" || len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var s string
|
||||
if err := UnmarshalJSON(data, &s); err != nil {
|
||||
return err
|
||||
}
|
||||
b, err := hex.DecodeString(s)
|
||||
if err != nil {
|
||||
return CantDecodeHex(err, s)
|
||||
}
|
||||
*binary = b
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Scan implements the sql.Scanner interface.
|
||||
// Supports SQL NULL.
|
||||
func (binary *Binary) Scan(src interface{}) error {
|
||||
switch src := src.(type) {
|
||||
case nil:
|
||||
return nil
|
||||
|
||||
case []byte:
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
b := make([]byte, len(src))
|
||||
copy(b, src)
|
||||
*binary = b
|
||||
|
||||
default:
|
||||
return errors.Errorf("unable to scan type %T into Binary", src)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the driver.Valuer interface.
|
||||
// Supports SQL NULL.
|
||||
func (binary Binary) Value() (driver.Value, error) {
|
||||
if !binary.Valid() {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return []byte(binary), nil
|
||||
}
|
||||
|
||||
// Assert interface compliance.
|
||||
var (
|
||||
_ fmt.Stringer = (*Binary)(nil)
|
||||
_ encoding.TextMarshaler = (*Binary)(nil)
|
||||
_ encoding.TextUnmarshaler = (*Binary)(nil)
|
||||
_ json.Marshaler = (*Binary)(nil)
|
||||
_ json.Unmarshaler = (*Binary)(nil)
|
||||
_ sql.Scanner = (*Binary)(nil)
|
||||
_ driver.Valuer = (*Binary)(nil)
|
||||
)
|
||||
|
|
@ -1,29 +0,0 @@
|
|||
package types
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"testing"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
func TestBinary_MarshalJSON(t *testing.T) {
|
||||
subtests := []struct {
|
||||
name string
|
||||
input Binary
|
||||
output string
|
||||
}{
|
||||
{"nil", nil, `null`},
|
||||
{"empty", make(Binary, 0, 1), `null`},
|
||||
{"space", Binary(" "), `"20"`},
|
||||
}
|
||||
|
||||
for _, st := range subtests {
|
||||
t.Run(st.name, func(t *testing.T) {
|
||||
actual, err := st.input.MarshalJSON()
|
||||
|
||||
require.NoError(t, err)
|
||||
require.True(t, utf8.Valid(actual))
|
||||
require.Equal(t, st.output, string(actual))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -1,104 +0,0 @@
|
|||
package types
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding"
|
||||
"encoding/json"
|
||||
"github.com/pkg/errors"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
var (
|
||||
enum = map[bool]string{
|
||||
true: "y",
|
||||
false: "n",
|
||||
}
|
||||
)
|
||||
|
||||
// Bool represents a bool for ENUM ('y', 'n'), which can be NULL.
|
||||
type Bool struct {
|
||||
Bool bool
|
||||
Valid bool // Valid is true if Bool is not NULL
|
||||
}
|
||||
|
||||
// MarshalJSON implements the json.Marshaler interface.
|
||||
func (b Bool) MarshalJSON() ([]byte, error) {
|
||||
if !b.Valid {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
|
||||
return MarshalJSON(b.Bool)
|
||||
}
|
||||
|
||||
// UnmarshalText implements the encoding.TextUnmarshaler interface.
|
||||
func (b *Bool) UnmarshalText(text []byte) error {
|
||||
parsed, err := strconv.ParseUint(string(text), 10, 64)
|
||||
if err != nil {
|
||||
return CantParseUint64(err, string(text))
|
||||
}
|
||||
|
||||
*b = Bool{parsed != 0, true}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements the json.Unmarshaler interface.
|
||||
func (b *Bool) UnmarshalJSON(data []byte) error {
|
||||
if string(data) == "null" || len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := UnmarshalJSON(data, &b.Bool); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
b.Valid = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Scan implements the sql.Scanner interface.
|
||||
// Supports SQL NULL.
|
||||
func (b *Bool) Scan(src interface{}) error {
|
||||
if src == nil {
|
||||
b.Bool, b.Valid = false, false
|
||||
return nil
|
||||
}
|
||||
|
||||
v, ok := src.([]byte)
|
||||
if !ok {
|
||||
return errors.Errorf("bad []byte type assertion from %#v", src)
|
||||
}
|
||||
|
||||
switch string(v) {
|
||||
case "y":
|
||||
b.Bool = true
|
||||
case "n":
|
||||
b.Bool = false
|
||||
default:
|
||||
return errors.Errorf("bad bool %#v", v)
|
||||
}
|
||||
|
||||
b.Valid = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the driver.Valuer interface.
|
||||
// Supports SQL NULL.
|
||||
func (b Bool) Value() (driver.Value, error) {
|
||||
if !b.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return enum[b.Bool], nil
|
||||
}
|
||||
|
||||
// Assert interface compliance.
|
||||
var (
|
||||
_ json.Marshaler = (*Bool)(nil)
|
||||
_ encoding.TextUnmarshaler = (*Bool)(nil)
|
||||
_ json.Unmarshaler = (*Bool)(nil)
|
||||
_ sql.Scanner = (*Bool)(nil)
|
||||
_ driver.Valuer = (*Bool)(nil)
|
||||
)
|
||||
|
|
@ -1,30 +0,0 @@
|
|||
package types
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/stretchr/testify/require"
|
||||
"testing"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
func TestBool_MarshalJSON(t *testing.T) {
|
||||
subtests := []struct {
|
||||
input Bool
|
||||
output string
|
||||
}{
|
||||
{Bool{Bool: false, Valid: false}, `null`},
|
||||
{Bool{Bool: false, Valid: true}, `false`},
|
||||
{Bool{Bool: true, Valid: false}, `null`},
|
||||
{Bool{Bool: true, Valid: true}, `true`},
|
||||
}
|
||||
|
||||
for _, st := range subtests {
|
||||
t.Run(fmt.Sprintf("Bool-%#v_Valid-%#v", st.input.Bool, st.input.Valid), func(t *testing.T) {
|
||||
actual, err := st.input.MarshalJSON()
|
||||
|
||||
require.NoError(t, err)
|
||||
require.True(t, utf8.Valid(actual))
|
||||
require.Equal(t, st.output, string(actual))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -1,67 +0,0 @@
|
|||
package types
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding"
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// Float adds JSON support to sql.NullFloat64.
|
||||
type Float struct {
|
||||
sql.NullFloat64
|
||||
}
|
||||
|
||||
// MarshalJSON implements the json.Marshaler interface.
|
||||
// Supports JSON null.
|
||||
func (f Float) MarshalJSON() ([]byte, error) {
|
||||
var v interface{}
|
||||
if f.Valid {
|
||||
v = f.Float64
|
||||
}
|
||||
|
||||
return MarshalJSON(v)
|
||||
}
|
||||
|
||||
// UnmarshalText implements the encoding.TextUnmarshaler interface.
|
||||
func (f *Float) UnmarshalText(text []byte) error {
|
||||
parsed, err := strconv.ParseFloat(string(text), 64)
|
||||
if err != nil {
|
||||
return CantParseFloat64(err, string(text))
|
||||
}
|
||||
|
||||
*f = Float{sql.NullFloat64{
|
||||
Float64: parsed,
|
||||
Valid: true,
|
||||
}}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements the json.Unmarshaler interface.
|
||||
// Supports JSON null.
|
||||
func (f *Float) UnmarshalJSON(data []byte) error {
|
||||
// Ignore null, like in the main JSON package.
|
||||
if bytes.HasPrefix(data, []byte{'n'}) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := UnmarshalJSON(data, &f.Float64); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
f.Valid = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Assert interface compliance.
|
||||
var (
|
||||
_ json.Marshaler = Float{}
|
||||
_ encoding.TextUnmarshaler = (*Float)(nil)
|
||||
_ json.Unmarshaler = (*Float)(nil)
|
||||
_ driver.Valuer = Float{}
|
||||
_ sql.Scanner = (*Float)(nil)
|
||||
)
|
||||
|
|
@ -1,67 +0,0 @@
|
|||
package types
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding"
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// Int adds JSON support to sql.NullInt64.
|
||||
type Int struct {
|
||||
sql.NullInt64
|
||||
}
|
||||
|
||||
// MarshalJSON implements the json.Marshaler interface.
|
||||
// Supports JSON null.
|
||||
func (i Int) MarshalJSON() ([]byte, error) {
|
||||
var v interface{}
|
||||
if i.Valid {
|
||||
v = i.Int64
|
||||
}
|
||||
|
||||
return MarshalJSON(v)
|
||||
}
|
||||
|
||||
// UnmarshalText implements the encoding.TextUnmarshaler interface.
|
||||
func (i *Int) UnmarshalText(text []byte) error {
|
||||
parsed, err := strconv.ParseInt(string(text), 10, 64)
|
||||
if err != nil {
|
||||
return CantParseInt64(err, string(text))
|
||||
}
|
||||
|
||||
*i = Int{sql.NullInt64{
|
||||
Int64: parsed,
|
||||
Valid: true,
|
||||
}}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements the json.Unmarshaler interface.
|
||||
// Supports JSON null.
|
||||
func (i *Int) UnmarshalJSON(data []byte) error {
|
||||
// Ignore null, like in the main JSON package.
|
||||
if bytes.HasPrefix(data, []byte{'n'}) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := UnmarshalJSON(data, &i.Int64); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
i.Valid = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Assert interface compliance.
|
||||
var (
|
||||
_ json.Marshaler = Int{}
|
||||
_ json.Unmarshaler = (*Int)(nil)
|
||||
_ encoding.TextUnmarshaler = (*Int)(nil)
|
||||
_ driver.Valuer = Int{}
|
||||
_ sql.Scanner = (*Int)(nil)
|
||||
)
|
||||
|
|
@ -1,81 +0,0 @@
|
|||
package types
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// String adds JSON support to sql.NullString.
|
||||
type String struct {
|
||||
sql.NullString
|
||||
}
|
||||
|
||||
// MakeString constructs a new non-NULL String from s.
|
||||
func MakeString(s string) String {
|
||||
return String{sql.NullString{
|
||||
String: s,
|
||||
Valid: true,
|
||||
}}
|
||||
}
|
||||
|
||||
// MarshalJSON implements the json.Marshaler interface.
|
||||
// Supports JSON null.
|
||||
func (s String) MarshalJSON() ([]byte, error) {
|
||||
var v interface{}
|
||||
if s.Valid {
|
||||
v = s.String
|
||||
}
|
||||
|
||||
return MarshalJSON(v)
|
||||
}
|
||||
|
||||
// UnmarshalText implements the encoding.TextUnmarshaler interface.
|
||||
func (s *String) UnmarshalText(text []byte) error {
|
||||
*s = String{sql.NullString{
|
||||
String: string(text),
|
||||
Valid: true,
|
||||
}}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements the json.Unmarshaler interface.
|
||||
// Supports JSON null.
|
||||
func (s *String) UnmarshalJSON(data []byte) error {
|
||||
// Ignore null, like in the main JSON package.
|
||||
if bytes.HasPrefix(data, []byte{'n'}) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := UnmarshalJSON(data, &s.String); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.Valid = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the driver.Valuer interface.
|
||||
// Supports SQL NULL.
|
||||
func (s String) Value() (driver.Value, error) {
|
||||
if !s.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// PostgreSQL does not allow null bytes in varchar, char and text fields.
|
||||
return strings.ReplaceAll(s.String, "\x00", ""), nil
|
||||
}
|
||||
|
||||
// Assert interface compliance.
|
||||
var (
|
||||
_ json.Marshaler = String{}
|
||||
_ encoding.TextUnmarshaler = (*String)(nil)
|
||||
_ json.Unmarshaler = (*String)(nil)
|
||||
_ driver.Valuer = String{}
|
||||
_ sql.Scanner = (*String)(nil)
|
||||
)
|
||||
|
|
@ -1,116 +0,0 @@
|
|||
package types
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding"
|
||||
"encoding/json"
|
||||
"github.com/pkg/errors"
|
||||
"math"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// UnixMilli is a nullable millisecond UNIX timestamp in databases and JSON.
|
||||
type UnixMilli time.Time
|
||||
|
||||
// Time returns the time.Time conversion of UnixMilli.
|
||||
func (t UnixMilli) Time() time.Time {
|
||||
return time.Time(t)
|
||||
}
|
||||
|
||||
// MarshalJSON implements the json.Marshaler interface.
|
||||
// Marshals to milliseconds. Supports JSON null.
|
||||
func (t UnixMilli) MarshalJSON() ([]byte, error) {
|
||||
if time.Time(t).IsZero() {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
|
||||
return []byte(strconv.FormatInt(t.Time().UnixMilli(), 10)), nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements the json.Unmarshaler interface.
|
||||
// Unmarshals from milliseconds. Supports JSON null.
|
||||
func (t *UnixMilli) UnmarshalJSON(data []byte) error {
|
||||
if bytes.Equal(data, []byte("null")) || len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return t.fromByteString(data)
|
||||
}
|
||||
|
||||
// MarshalText implements the encoding.TextMarshaler interface.
|
||||
func (t UnixMilli) MarshalText() ([]byte, error) {
|
||||
if time.Time(t).IsZero() {
|
||||
return []byte{}, nil
|
||||
}
|
||||
|
||||
return []byte(strconv.FormatInt(t.Time().UnixMilli(), 10)), nil
|
||||
}
|
||||
|
||||
// UnmarshalText implements the encoding.TextUnmarshaler interface.
|
||||
func (t *UnixMilli) UnmarshalText(text []byte) error {
|
||||
if len(text) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return t.fromByteString(text)
|
||||
}
|
||||
|
||||
// Scan implements the sql.Scanner interface.
|
||||
// Scans from milliseconds. Supports SQL NULL.
|
||||
func (t *UnixMilli) Scan(src interface{}) error {
|
||||
if src == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch v := src.(type) {
|
||||
case []byte:
|
||||
return t.fromByteString(v)
|
||||
// https://github.com/go-sql-driver/mysql/pull/1452
|
||||
case uint64:
|
||||
if v > math.MaxInt64 {
|
||||
return errors.Errorf("value %v out of range for int64", v)
|
||||
}
|
||||
|
||||
*t = UnixMilli(time.UnixMilli(int64(v)))
|
||||
case int64:
|
||||
*t = UnixMilli(time.UnixMilli(v))
|
||||
default:
|
||||
return errors.Errorf("bad (u)int64/[]byte type assertion from %[1]v (%[1]T)", src)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the driver.Valuer interface.
|
||||
// Returns milliseconds. Supports SQL NULL.
|
||||
func (t UnixMilli) Value() (driver.Value, error) {
|
||||
if t.Time().IsZero() {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return t.Time().UnixMilli(), nil
|
||||
}
|
||||
|
||||
func (t *UnixMilli) fromByteString(data []byte) error {
|
||||
i, err := strconv.ParseInt(string(data), 10, 64)
|
||||
if err != nil {
|
||||
return CantParseInt64(err, string(data))
|
||||
}
|
||||
|
||||
*t = UnixMilli(time.UnixMilli(i))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Assert interface compliance.
|
||||
var (
|
||||
_ encoding.TextMarshaler = (*UnixMilli)(nil)
|
||||
_ encoding.TextUnmarshaler = (*UnixMilli)(nil)
|
||||
_ json.Marshaler = (*UnixMilli)(nil)
|
||||
_ json.Unmarshaler = (*UnixMilli)(nil)
|
||||
_ driver.Valuer = (*UnixMilli)(nil)
|
||||
_ sql.Scanner = (*UnixMilli)(nil)
|
||||
)
|
||||
|
|
@ -1,149 +0,0 @@
|
|||
package types
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"math"
|
||||
"testing"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
func TestUnixMilli(t *testing.T) {
|
||||
type testCase struct {
|
||||
v UnixMilli
|
||||
json string
|
||||
text string
|
||||
}
|
||||
|
||||
tests := map[string]testCase{
|
||||
"Zero": {UnixMilli{}, "null", ""},
|
||||
"Non-zero": {UnixMilli(time.Unix(1234567890, 0)), "1234567890000", "1234567890000"},
|
||||
"Epoch": {UnixMilli(time.Unix(0, 0)), "0", "0"},
|
||||
"With milliseconds": {UnixMilli(time.Unix(1234567890, 62000000)), "1234567890062", "1234567890062"},
|
||||
}
|
||||
|
||||
var runTests = func(t *testing.T, f func(*testing.T, testCase)) {
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
f(t, test)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("MarshalJSON", func(t *testing.T) {
|
||||
runTests(t, func(t *testing.T, test testCase) {
|
||||
actual, err := test.v.MarshalJSON()
|
||||
require.NoError(t, err)
|
||||
require.True(t, utf8.Valid(actual))
|
||||
require.Equal(t, test.json, string(actual))
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("UnmarshalJSON", func(t *testing.T) {
|
||||
runTests(t, func(t *testing.T, test testCase) {
|
||||
var actual UnixMilli
|
||||
err := actual.UnmarshalJSON([]byte(test.json))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, test.v, actual)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("MarshalText", func(t *testing.T) {
|
||||
runTests(t, func(t *testing.T, test testCase) {
|
||||
actual, err := test.v.MarshalText()
|
||||
require.NoError(t, err)
|
||||
require.True(t, utf8.Valid(actual))
|
||||
require.Equal(t, test.text, string(actual))
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("UnmarshalText", func(t *testing.T) {
|
||||
runTests(t, func(t *testing.T, test testCase) {
|
||||
var actual UnixMilli
|
||||
err := actual.UnmarshalText([]byte(test.text))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, test.v, actual)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnixMilli_Scan(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
v any
|
||||
expected UnixMilli
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "Nil",
|
||||
v: nil,
|
||||
expected: UnixMilli{},
|
||||
},
|
||||
{
|
||||
name: "Epoch",
|
||||
v: int64(0),
|
||||
expected: UnixMilli(time.Unix(0, 0)),
|
||||
},
|
||||
{
|
||||
name: "bytes",
|
||||
v: []byte("1234567890062"),
|
||||
expected: UnixMilli(time.Unix(1234567890, 62000000)),
|
||||
},
|
||||
{
|
||||
name: "Invalid bytes",
|
||||
v: []byte("invalid"),
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "int64",
|
||||
v: int64(1234567890062),
|
||||
expected: UnixMilli(time.Unix(1234567890, 62000000)),
|
||||
},
|
||||
{
|
||||
name: "uint64",
|
||||
v: uint64(1234567890062),
|
||||
expected: UnixMilli(time.Unix(1234567890, 62000000)),
|
||||
},
|
||||
{
|
||||
name: "uint64 out of range for int64",
|
||||
v: uint64(math.MaxInt64) + 1,
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid type",
|
||||
v: "invalid",
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
var actual UnixMilli
|
||||
err := actual.Scan(test.v)
|
||||
if test.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, test.expected, actual)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnixMilli_Value(t *testing.T) {
|
||||
t.Run("Zero", func(t *testing.T) {
|
||||
var zero UnixMilli
|
||||
actual, err := zero.Value()
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, actual)
|
||||
})
|
||||
|
||||
t.Run("Non-zero", func(t *testing.T) {
|
||||
withMilliseconds := time.Unix(1234567890, 62000000)
|
||||
expected := withMilliseconds.UnixMilli()
|
||||
actual, err := UnixMilli(withMilliseconds).Value()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expected, actual)
|
||||
})
|
||||
}
|
||||
|
|
@ -1,52 +0,0 @@
|
|||
package types
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/pkg/errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Name returns the declared name of type t.
|
||||
func Name(t any) string {
|
||||
s := strings.TrimLeft(fmt.Sprintf("%T", t), "*")
|
||||
|
||||
return s[strings.LastIndex(s, ".")+1:]
|
||||
}
|
||||
|
||||
// CantDecodeHex wraps the given error with the given string that cannot be hex-decoded.
|
||||
func CantDecodeHex(err error, s string) error {
|
||||
return errors.Wrapf(err, "can't decode hex %q", s)
|
||||
}
|
||||
|
||||
// CantParseFloat64 wraps the given error with the specified string that cannot be parsed into float64.
|
||||
func CantParseFloat64(err error, s string) error {
|
||||
return errors.Wrapf(err, "can't parse %q into float64", s)
|
||||
}
|
||||
|
||||
// CantParseInt64 wraps the given error with the specified string that cannot be parsed into int64.
|
||||
func CantParseInt64(err error, s string) error {
|
||||
return errors.Wrapf(err, "can't parse %q into int64", s)
|
||||
}
|
||||
|
||||
// CantParseUint64 wraps the given error with the specified string that cannot be parsed into uint64.
|
||||
func CantParseUint64(err error, s string) error {
|
||||
return errors.Wrapf(err, "can't parse %q into uint64", s)
|
||||
}
|
||||
|
||||
// CantUnmarshalYAML wraps the given error with the designated value, which cannot be unmarshalled into.
|
||||
func CantUnmarshalYAML(err error, v interface{}) error {
|
||||
return errors.Wrapf(err, "can't unmarshal YAML into %T", v)
|
||||
}
|
||||
|
||||
// MarshalJSON calls json.Marshal and wraps any resulting errors.
|
||||
func MarshalJSON(v interface{}) ([]byte, error) {
|
||||
b, err := json.Marshal(v)
|
||||
|
||||
return b, errors.Wrapf(err, "can't marshal JSON from %T", v)
|
||||
}
|
||||
|
||||
// UnmarshalJSON calls json.Unmarshal and wraps any resulting errors.
|
||||
func UnmarshalJSON(data []byte, v interface{}) error {
|
||||
return errors.Wrapf(json.Unmarshal(data, v), "can't unmarshal JSON into %T", v)
|
||||
}
|
||||
|
|
@ -1,24 +0,0 @@
|
|||
package types
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// UUID is like uuid.UUID, but marshals itself binarily (not like xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx) in SQL context.
|
||||
type UUID struct {
|
||||
uuid.UUID
|
||||
}
|
||||
|
||||
// Value implements driver.Valuer.
|
||||
func (uuid UUID) Value() (driver.Value, error) {
|
||||
return uuid.UUID[:], nil
|
||||
}
|
||||
|
||||
// Assert interface compliance.
|
||||
var (
|
||||
_ encoding.TextUnmarshaler = (*UUID)(nil)
|
||||
_ driver.Valuer = UUID{}
|
||||
_ driver.Valuer = (*UUID)(nil)
|
||||
)
|
||||
|
|
@ -1,167 +0,0 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha1"
|
||||
"fmt"
|
||||
"github.com/go-sql-driver/mysql"
|
||||
"github.com/lib/pq"
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/exp/utf8string"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Timed calls the given callback with the time that has elapsed since the start.
|
||||
//
|
||||
// Timed should be installed by defer:
|
||||
//
|
||||
// func TimedExample(logger *zap.SugaredLogger) {
|
||||
// defer utils.Timed(time.Now(), func(elapsed time.Duration) {
|
||||
// logger.Debugf("Executed job in %s", elapsed)
|
||||
// })
|
||||
// job()
|
||||
// }
|
||||
func Timed(start time.Time, callback func(elapsed time.Duration)) {
|
||||
callback(time.Since(start))
|
||||
}
|
||||
|
||||
// BatchSliceOfStrings groups the given keys into chunks of size count and streams them into a returned channel.
|
||||
func BatchSliceOfStrings(ctx context.Context, keys []string, count int) <-chan []string {
|
||||
batches := make(chan []string)
|
||||
|
||||
go func() {
|
||||
defer close(batches)
|
||||
|
||||
for i := 0; i < len(keys); i += count {
|
||||
end := i + count
|
||||
if end > len(keys) {
|
||||
end = len(keys)
|
||||
}
|
||||
|
||||
select {
|
||||
case batches <- keys[i:end]:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return batches
|
||||
}
|
||||
|
||||
// IsContextCanceled returns whether the given error is context.Canceled.
|
||||
func IsContextCanceled(err error) bool {
|
||||
return errors.Is(err, context.Canceled)
|
||||
}
|
||||
|
||||
// Checksum returns the SHA-1 checksum of the data.
|
||||
func Checksum(data interface{}) []byte {
|
||||
var chksm [sha1.Size]byte
|
||||
|
||||
switch data := data.(type) {
|
||||
case string:
|
||||
chksm = sha1.Sum([]byte(data))
|
||||
case []byte:
|
||||
chksm = sha1.Sum(data)
|
||||
default:
|
||||
panic(fmt.Sprintf("Unable to create checksum for type %T", data))
|
||||
}
|
||||
|
||||
return chksm[:]
|
||||
}
|
||||
|
||||
// IsDeadlock returns whether the given error signals serialization failure.
|
||||
func IsDeadlock(err error) bool {
|
||||
var e *mysql.MySQLError
|
||||
if errors.As(err, &e) {
|
||||
switch e.Number {
|
||||
case 1205, 1213:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
var pe *pq.Error
|
||||
if errors.As(err, &pe) {
|
||||
switch pe.Code {
|
||||
case "40001", "40P01":
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
var ellipsis = utf8string.NewString("...")
|
||||
|
||||
// Ellipsize shortens s to <=limit runes and indicates shortening by "...".
|
||||
func Ellipsize(s string, limit int) string {
|
||||
utf8 := utf8string.NewString(s)
|
||||
switch {
|
||||
case utf8.RuneCount() <= limit:
|
||||
return s
|
||||
case utf8.RuneCount() <= ellipsis.RuneCount():
|
||||
return ellipsis.String()
|
||||
default:
|
||||
return utf8.Slice(0, limit-ellipsis.RuneCount()) + ellipsis.String()
|
||||
}
|
||||
}
|
||||
|
||||
// AppName returns the name of the executable that started this program (process).
|
||||
func AppName() string {
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
exe = os.Args[0]
|
||||
}
|
||||
|
||||
return filepath.Base(exe)
|
||||
}
|
||||
|
||||
// MaxInt returns the larger of the given integers.
|
||||
func MaxInt(x, y int) int {
|
||||
if x > y {
|
||||
return x
|
||||
}
|
||||
|
||||
return y
|
||||
}
|
||||
|
||||
// IsUnixAddr indicates whether the given host string represents a Unix socket address.
|
||||
//
|
||||
// A host string that begins with a forward slash ('/') is considered Unix socket address.
|
||||
func IsUnixAddr(host string) bool {
|
||||
return strings.HasPrefix(host, "/")
|
||||
}
|
||||
|
||||
// JoinHostPort is like its equivalent in net., but handles UNIX sockets as well.
|
||||
func JoinHostPort(host string, port int) string {
|
||||
if IsUnixAddr(host) {
|
||||
return host
|
||||
}
|
||||
|
||||
return net.JoinHostPort(host, fmt.Sprint(port))
|
||||
}
|
||||
|
||||
// ChanFromSlice takes a slice of values and returns a channel from which these values can be received.
|
||||
// This channel is closed after the last value was sent.
|
||||
func ChanFromSlice[T any](values []T) <-chan T {
|
||||
ch := make(chan T, len(values))
|
||||
for _, value := range values {
|
||||
ch <- value
|
||||
}
|
||||
|
||||
close(ch)
|
||||
|
||||
return ch
|
||||
}
|
||||
|
||||
// PrintErrorThenExit prints the given error to [os.Stderr] and exits with the specified error code.
|
||||
func PrintErrorThenExit(err error, exitCode int) {
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
os.Exit(exitCode)
|
||||
}
|
||||
|
|
@ -1,54 +0,0 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestChanFromSlice(t *testing.T) {
|
||||
t.Run("Nil", func(t *testing.T) {
|
||||
ch := ChanFromSlice[int](nil)
|
||||
require.NotNil(t, ch)
|
||||
requireClosedEmpty(t, ch)
|
||||
})
|
||||
|
||||
t.Run("Empty", func(t *testing.T) {
|
||||
ch := ChanFromSlice([]int{})
|
||||
require.NotNil(t, ch)
|
||||
requireClosedEmpty(t, ch)
|
||||
})
|
||||
|
||||
t.Run("NonEmpty", func(t *testing.T) {
|
||||
ch := ChanFromSlice([]int{42, 23, 1337})
|
||||
require.NotNil(t, ch)
|
||||
requireReceive(t, ch, 42)
|
||||
requireReceive(t, ch, 23)
|
||||
requireReceive(t, ch, 1337)
|
||||
requireClosedEmpty(t, ch)
|
||||
})
|
||||
}
|
||||
|
||||
// requireReceive is a helper function to check if a value can immediately be received from a channel.
|
||||
func requireReceive(t *testing.T, ch <-chan int, expected int) {
|
||||
t.Helper()
|
||||
|
||||
select {
|
||||
case v, ok := <-ch:
|
||||
require.True(t, ok, "receiving should return a value")
|
||||
require.Equal(t, expected, v)
|
||||
default:
|
||||
require.Fail(t, "receiving should not block")
|
||||
}
|
||||
}
|
||||
|
||||
// requireReceive is a helper function to check if the channel is closed and empty.
|
||||
func requireClosedEmpty(t *testing.T, ch <-chan int) {
|
||||
t.Helper()
|
||||
|
||||
select {
|
||||
case _, ok := <-ch:
|
||||
require.False(t, ok, "receiving from channel should not return anything")
|
||||
default:
|
||||
require.Fail(t, "receiving should not block")
|
||||
}
|
||||
}
|
||||
|
|
@ -1,180 +0,0 @@
|
|||
package version
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type VersionInfo struct {
|
||||
Version string
|
||||
Commit string
|
||||
}
|
||||
|
||||
// Version determines version and commit information based on multiple data sources:
|
||||
// - Version information dynamically added by `git archive` in the remaining to parameters.
|
||||
// - A hardcoded version number passed as first parameter.
|
||||
// - Commit information added to the binary by `go build`.
|
||||
//
|
||||
// It's supposed to be called like this in combination with setting the `export-subst` attribute for the corresponding
|
||||
// file in .gitattributes:
|
||||
//
|
||||
// var Version = version.Version("1.0.0-rc2", "$Format:%(describe)$", "$Format:%H$")
|
||||
//
|
||||
// When exported using `git archive`, the placeholders are replaced in the file and this version information is
|
||||
// preferred. Otherwise the hardcoded version is used and augmented with commit information from the build metadata.
|
||||
func Version(version, gitDescribe, gitHash string) *VersionInfo {
|
||||
const hashLen = 7 // Same truncation length for the commit hash as used by git describe.
|
||||
|
||||
if !strings.HasPrefix(gitDescribe, "$") && !strings.HasPrefix(gitHash, "$") {
|
||||
if strings.HasPrefix(gitDescribe, "%") {
|
||||
// Only Git 2.32+ supports %(describe), older versions don't expand it but keep it as-is.
|
||||
// Fall back to the hardcoded version augmented with the commit hash.
|
||||
gitDescribe = version
|
||||
|
||||
if len(gitHash) >= hashLen {
|
||||
gitDescribe += "-g" + gitHash[:hashLen]
|
||||
}
|
||||
}
|
||||
|
||||
return &VersionInfo{
|
||||
Version: gitDescribe,
|
||||
Commit: gitHash,
|
||||
}
|
||||
} else {
|
||||
commit := ""
|
||||
|
||||
if info, ok := debug.ReadBuildInfo(); ok {
|
||||
modified := false
|
||||
|
||||
for _, setting := range info.Settings {
|
||||
switch setting.Key {
|
||||
case "vcs.revision":
|
||||
commit = setting.Value
|
||||
case "vcs.modified":
|
||||
modified, _ = strconv.ParseBool(setting.Value)
|
||||
}
|
||||
}
|
||||
|
||||
if len(commit) >= hashLen {
|
||||
version += "-g" + commit[:hashLen]
|
||||
|
||||
if modified {
|
||||
version += "-dirty"
|
||||
commit += " (modified)"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &VersionInfo{
|
||||
Version: version,
|
||||
Commit: commit,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Print writes verbose version output to stdout.
|
||||
func (v *VersionInfo) Print() {
|
||||
fmt.Println("Icinga DB version:", v.Version)
|
||||
fmt.Println()
|
||||
|
||||
fmt.Println("Build information:")
|
||||
fmt.Printf(" Go version: %s (%s, %s)\n", runtime.Version(), runtime.GOOS, runtime.GOARCH)
|
||||
if v.Commit != "" {
|
||||
fmt.Println(" Git commit:", v.Commit)
|
||||
}
|
||||
|
||||
if r, err := readOsRelease(); err == nil {
|
||||
fmt.Println()
|
||||
fmt.Println("System information:")
|
||||
fmt.Println(" Platform:", r.Name)
|
||||
fmt.Println(" Platform version:", r.DisplayVersion())
|
||||
}
|
||||
}
|
||||
|
||||
// osRelease contains the information obtained from the os-release file.
|
||||
type osRelease struct {
|
||||
Name string
|
||||
Version string
|
||||
VersionId string
|
||||
BuildId string
|
||||
}
|
||||
|
||||
// DisplayVersion returns the most suitable version information for display purposes.
|
||||
func (o *osRelease) DisplayVersion() string {
|
||||
if o.Version != "" {
|
||||
// Most distributions set VERSION
|
||||
return o.Version
|
||||
} else if o.VersionId != "" {
|
||||
// Some only set VERSION_ID (Alpine Linux for example)
|
||||
return o.VersionId
|
||||
} else if o.BuildId != "" {
|
||||
// Others only set BUILD_ID (Arch Linux for example)
|
||||
return o.BuildId
|
||||
} else {
|
||||
return "(unknown)"
|
||||
}
|
||||
}
|
||||
|
||||
// readOsRelease reads and parses the os-release file.
|
||||
func readOsRelease() (*osRelease, error) {
|
||||
for _, path := range []string{"/etc/os-release", "/usr/lib/os-release"} {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
continue // Try next path.
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
o := &osRelease{
|
||||
Name: "Linux", // Suggested default as per os-release(5) man page.
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(f)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.HasPrefix(line, "#") {
|
||||
continue // Ignore comment.
|
||||
}
|
||||
|
||||
parts := strings.SplitN(line, "=", 2)
|
||||
if len(parts) != 2 {
|
||||
continue // Ignore empty or possibly malformed line.
|
||||
}
|
||||
|
||||
key := parts[0]
|
||||
val := parts[1]
|
||||
|
||||
// Unquote strings. This isn't fully compliant with the specification which allows using some shell escape
|
||||
// sequences. However, typically quotes are only used to allow whitespace within the value.
|
||||
if len(val) >= 2 && (val[0] == '"' || val[0] == '\'') && val[0] == val[len(val)-1] {
|
||||
val = val[1 : len(val)-1]
|
||||
}
|
||||
|
||||
switch key {
|
||||
case "NAME":
|
||||
o.Name = val
|
||||
case "VERSION":
|
||||
o.Version = val
|
||||
case "VERSION_ID":
|
||||
o.VersionId = val
|
||||
case "BUILD_ID":
|
||||
o.BuildId = val
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return o, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("os-release file not found")
|
||||
}
|
||||
Loading…
Reference in a new issue