Remove library code

This commit is contained in:
Eric Lippmann 2024-05-22 11:48:41 +02:00
parent 7c068d4adf
commit be4b450f5c
48 changed files with 5 additions and 4830 deletions

10
go.mod
View file

@ -4,23 +4,18 @@ go 1.22
require (
github.com/creasty/defaults v1.7.0
github.com/go-sql-driver/mysql v1.8.1
github.com/goccy/go-yaml v1.11.3
github.com/google/go-cmp v0.6.0
github.com/google/uuid v1.6.0
github.com/icinga/icinga-go-library v0.1.0
github.com/jessevdk/go-flags v1.5.0
github.com/jmoiron/sqlx v1.4.0
github.com/lib/pq v1.10.9
github.com/mattn/go-sqlite3 v1.14.22
github.com/okzk/sdnotify v0.0.0-20180710141335-d9becc38acbd
github.com/pkg/errors v0.9.1
github.com/redis/go-redis/v9 v9.5.1
github.com/ssgreg/journald v1.0.0
github.com/stretchr/testify v1.9.0
github.com/vbauerster/mpb/v6 v6.0.4
go.uber.org/zap v1.27.0
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
golang.org/x/sync v0.7.0
)
@ -32,12 +27,17 @@ require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/fatih/color v1.16.0 // indirect
github.com/go-sql-driver/mysql v1.8.1 // indirect
github.com/lib/pq v1.10.9 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-runewidth v0.0.12 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/redis/go-redis/v9 v9.5.1 // indirect
github.com/rivo/uniseg v0.2.0 // indirect
github.com/ssgreg/journald v1.0.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect
golang.org/x/sys v0.14.0 // indirect
golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect

View file

@ -1,43 +0,0 @@
package backoff
import (
"math/rand"
"time"
)
// Backoff returns the backoff duration for a specific retry attempt.
type Backoff func(uint64) time.Duration
// NewExponentialWithJitter returns a backoff implementation that
// exponentially increases the backoff duration for each retry from min,
// never exceeding max. Some randomization is added to the backoff duration.
// It panics if min >= max.
func NewExponentialWithJitter(min, max time.Duration) Backoff {
if min <= 0 {
min = 100 * time.Millisecond
}
if max <= 0 {
max = 10 * time.Second
}
if min >= max {
panic("max must be larger than min")
}
return func(attempt uint64) time.Duration {
e := min << attempt
if e <= 0 || e > max {
e = max
}
return time.Duration(jitter(int64(e)))
}
}
// jitter returns a random integer distributed in the range [n/2..n).
func jitter(n int64) int64 {
if n == 0 {
return 0
}
return n/2 + rand.Int63n(n/2)
}

View file

@ -1,38 +0,0 @@
package com
import "sync/atomic"
// Atomic is a type-safe wrapper around atomic.Value.
type Atomic[T any] struct {
v atomic.Value
}
func (a *Atomic[T]) Load() (_ T, ok bool) {
if v, ok := a.v.Load().(box[T]); ok {
return v.v, true
}
return
}
func (a *Atomic[T]) Store(v T) {
a.v.Store(box[T]{v})
}
func (a *Atomic[T]) Swap(new T) (old T, ok bool) {
if old, ok := a.v.Swap(box[T]{new}).(box[T]); ok {
return old.v, true
}
return
}
func (a *Atomic[T]) CompareAndSwap(old, new T) (swapped bool) {
return a.v.CompareAndSwap(box[T]{old}, box[T]{new})
}
// box allows, for the case T is an interface, nil values and values of different specific types implementing T
// to be stored in Atomic[T]#v (bypassing atomic.Value#Store()'s policy) by wrapping it (into a non-interface).
type box[T any] struct {
v T
}

View file

@ -1,166 +0,0 @@
package com
import (
"context"
"golang.org/x/sync/errgroup"
"sync"
"time"
)
// BulkChunkSplitPolicy is a state machine which tracks the items of a chunk a bulker assembles.
// A call takes an item for the current chunk into account.
// Output true indicates that the state machine was reset first and the bulker
// shall finish the current chunk now (not e.g. once $size is reached) without the given item.
type BulkChunkSplitPolicy[T any] func(T) bool
type BulkChunkSplitPolicyFactory[T any] func() BulkChunkSplitPolicy[T]
// NeverSplit returns a pseudo state machine which never demands splitting.
func NeverSplit[T any]() BulkChunkSplitPolicy[T] {
return neverSplit[T]
}
func neverSplit[T any](T) bool {
return false
}
// Bulker reads all values from a channel and streams them in chunks into a Bulk channel.
type Bulker[T any] struct {
ch chan []T
ctx context.Context
mu sync.Mutex
}
// NewBulker returns a new Bulker and starts streaming.
func NewBulker[T any](
ctx context.Context, ch <-chan T, count int, splitPolicyFactory BulkChunkSplitPolicyFactory[T],
) *Bulker[T] {
b := &Bulker[T]{
ch: make(chan []T),
ctx: ctx,
mu: sync.Mutex{},
}
go b.run(ch, count, splitPolicyFactory)
return b
}
// Bulk returns the channel on which the bulks are delivered.
func (b *Bulker[T]) Bulk() <-chan []T {
return b.ch
}
func (b *Bulker[T]) run(ch <-chan T, count int, splitPolicyFactory BulkChunkSplitPolicyFactory[T]) {
defer close(b.ch)
bufCh := make(chan T, count)
splitPolicy := splitPolicyFactory()
g, ctx := errgroup.WithContext(b.ctx)
g.Go(func() error {
defer close(bufCh)
for {
select {
case v, ok := <-ch:
if !ok {
return nil
}
bufCh <- v
case <-ctx.Done():
return ctx.Err()
}
}
})
g.Go(func() error {
for done := false; !done; {
buf := make([]T, 0, count)
timeout := time.After(256 * time.Millisecond)
for drain := true; drain && len(buf) < count; {
select {
case v, ok := <-bufCh:
if !ok {
drain = false
done = true
break
}
if splitPolicy(v) {
if len(buf) > 0 {
b.ch <- buf
buf = make([]T, 0, count)
}
timeout = time.After(256 * time.Millisecond)
}
buf = append(buf, v)
case <-timeout:
drain = false
case <-ctx.Done():
return ctx.Err()
}
}
if len(buf) > 0 {
b.ch <- buf
}
splitPolicy = splitPolicyFactory()
}
return nil
})
// We don't expect an error here.
// We only use errgroup for the encapsulated use of sync.WaitGroup.
_ = g.Wait()
}
// Bulk reads all values from a channel and streams them in chunks into a returned channel.
func Bulk[T any](
ctx context.Context, ch <-chan T, count int, splitPolicyFactory BulkChunkSplitPolicyFactory[T],
) <-chan []T {
if count <= 1 {
return oneBulk(ctx, ch)
}
return NewBulker(ctx, ch, count, splitPolicyFactory).Bulk()
}
// oneBulk operates just as NewBulker(ctx, ch, 1, splitPolicy).Bulk(),
// but without the overhead of the actual bulk creation with a buffer channel, timeout and BulkChunkSplitPolicy.
func oneBulk[T any](ctx context.Context, ch <-chan T) <-chan []T {
out := make(chan []T)
go func() {
defer close(out)
for {
select {
case item, ok := <-ch:
if !ok {
return
}
select {
case out <- []T{item}:
case <-ctx.Done():
return
}
case <-ctx.Done():
return
}
}
}()
return out
}
var (
_ BulkChunkSplitPolicyFactory[struct{}] = NeverSplit[struct{}]
)

View file

@ -1,101 +0,0 @@
package com
import (
"context"
"github.com/pkg/errors"
"golang.org/x/sync/errgroup"
)
// Waiter implements the Wait method,
// which blocks until execution is complete.
type Waiter interface {
Wait() error // Wait waits for execution to complete.
}
// The WaiterFunc type is an adapter to allow the use of ordinary functions as Waiter.
// If f is a function with the appropriate signature, WaiterFunc(f) is a Waiter that calls f.
type WaiterFunc func() error
// Wait implements the Waiter interface.
func (f WaiterFunc) Wait() error {
return f()
}
// WaitAsync calls Wait() on the passed Waiter in a new goroutine and
// sends the first non-nil error (if any) to the returned channel.
// The returned channel is always closed when the Waiter is done.
func WaitAsync(w Waiter) <-chan error {
errs := make(chan error, 1)
go func() {
defer close(errs)
if e := w.Wait(); e != nil {
errs <- e
}
}()
return errs
}
// ErrgroupReceive adds a goroutine to the specified group that
// returns the first non-nil error (if any) from the specified channel.
// If the channel is closed, it will return nil.
func ErrgroupReceive(g *errgroup.Group, err <-chan error) {
g.Go(func() error {
if e := <-err; e != nil {
return e
}
return nil
})
}
// CopyFirst asynchronously forwards all items from input to forward and synchronously returns the first item.
func CopyFirst[T any](
ctx context.Context, input <-chan T,
) (first T, forward <-chan T, err error) {
var ok bool
select {
case <-ctx.Done():
var zero T
return zero, nil, ctx.Err()
case first, ok = <-input:
}
if !ok {
err = errors.New("can't copy from closed channel")
return
}
// Buffer of one because we receive an entity and send it back immediately.
fwd := make(chan T, 1)
fwd <- first
forward = fwd
go func() {
defer close(fwd)
for {
select {
case <-ctx.Done():
return
case e, ok := <-input:
if !ok {
return
}
select {
case <-ctx.Done():
return
case fwd <- e:
}
}
}
}()
return
}

View file

@ -1,90 +0,0 @@
package com
import (
"context"
"github.com/pkg/errors"
)
// Cond implements a channel-based synchronization for goroutines that wait for signals or send them.
// Internally based on a controller loop that handles the synchronization of new listeners and signal propagation,
// which is only started when NewCond is called. Thus the zero value cannot be used.
type Cond struct {
broadcast chan struct{}
done chan struct{}
cancel context.CancelFunc
listeners chan chan struct{}
}
// NewCond returns a new Cond and starts the controller loop.
func NewCond(ctx context.Context) *Cond {
ctx, cancel := context.WithCancel(ctx)
c := &Cond{
broadcast: make(chan struct{}),
cancel: cancel,
done: make(chan struct{}),
listeners: make(chan chan struct{}),
}
go c.controller(ctx)
return c
}
// Broadcast sends a signal to all current listeners by closing the previously returned channel from Wait.
// Panics if the controller loop has already ended.
func (c *Cond) Broadcast() {
select {
case c.broadcast <- struct{}{}:
case <-c.done:
panic(errors.New("condition closed"))
}
}
// Close stops the controller loop, waits for it to finish, and returns an error if any.
// Implements the io.Closer interface.
func (c *Cond) Close() error {
c.cancel()
<-c.done
return nil
}
// Done returns a channel that will be closed when the controller loop has ended.
func (c *Cond) Done() <-chan struct{} {
return c.done
}
// Wait returns a channel that is closed with the next signal.
// Panics if the controller loop has already ended.
func (c *Cond) Wait() <-chan struct{} {
select {
case l := <-c.listeners:
return l
case <-c.done:
panic(errors.New("condition closed"))
}
}
// controller loop.
func (c *Cond) controller(ctx context.Context) {
defer close(c.done)
// Note that the notify channel does not close when the controller loop ends
// in order not to notify pending listeners.
notify := make(chan struct{})
for {
select {
case <-c.broadcast:
// Close channel to notify all current listeners.
close(notify)
// Create a new channel for the next listeners.
notify = make(chan struct{})
case c.listeners <- notify:
// A new listener received the channel.
case <-ctx.Done():
return
}
}
}

View file

@ -1,48 +0,0 @@
package com
import (
"sync"
"sync/atomic"
)
// Counter implements an atomic counter.
type Counter struct {
value uint64
mu sync.Mutex // Protects total.
total uint64
}
// Add adds the given delta to the counter.
func (c *Counter) Add(delta uint64) {
atomic.AddUint64(&c.value, delta)
}
// Inc increments the counter by one.
func (c *Counter) Inc() {
c.Add(1)
}
// Reset resets the counter to 0 and returns its previous value.
// Does not reset the total value returned from Total.
func (c *Counter) Reset() uint64 {
c.mu.Lock()
defer c.mu.Unlock()
v := atomic.SwapUint64(&c.value, 0)
c.total += v
return v
}
// Total returns the total counter value.
func (c *Counter) Total() uint64 {
c.mu.Lock()
defer c.mu.Unlock()
return c.total + c.Val()
}
// Val returns the current counter value.
func (c *Counter) Val() uint64 {
return atomic.LoadUint64(&c.value)
}

View file

@ -1,80 +0,0 @@
package config
import (
stderrors "errors"
"fmt"
"github.com/creasty/defaults"
"github.com/goccy/go-yaml"
"github.com/jessevdk/go-flags"
"github.com/pkg/errors"
"os"
"reflect"
)
// ErrInvalidArgument is the error returned by [ParseFlags] or [FromYAMLFile] if
// its parsing result cannot be stored in the value pointed to by the designated passed argument which
// must be a non-nil pointer.
var ErrInvalidArgument = stderrors.New("invalid argument")
// FromYAMLFile parses the given YAML file and stores the result
// in the value pointed to by v. If v is nil or not a pointer,
// FromYAMLFile returns an [ErrInvalidArgument] error.
func FromYAMLFile(name string, v Validator) error {
rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Pointer || rv.IsNil() {
return errors.Wrapf(ErrInvalidArgument, "non-nil pointer expected, got %T", v)
}
f, err := os.Open(name)
if err != nil {
return errors.Wrap(err, "can't open YAML file "+name)
}
defer func(f *os.File) {
_ = f.Close()
}(f)
if err := defaults.Set(v); err != nil {
return errors.Wrap(err, "can't set config defaults")
}
d := yaml.NewDecoder(f, yaml.DisallowUnknownField())
if err := d.Decode(v); err != nil {
return errors.Wrap(err, "can't parse YAML file "+name)
}
if err := v.Validate(); err != nil {
return errors.Wrap(err, "invalid configuration")
}
return nil
}
// ParseFlags parses CLI flags and stores the result
// in the value pointed to by v. If v is nil or not a pointer,
// ParseFlags returns an [ErrInvalidArgument] error.
// ParseFlags adds a default Help Options group,
// which contains the options -h and --help.
// If either option is specified on the command line,
// ParseFlags prints the help message to [os.Stdout] and exits.
// Note that errors are not printed automatically,
// so error handling is the sole responsibility of the caller.
func ParseFlags(v any) error {
rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Pointer || rv.IsNil() {
return errors.Wrapf(ErrInvalidArgument, "non-nil pointer expected, got %T", v)
}
parser := flags.NewParser(v, flags.Default^flags.PrintErrors)
if _, err := parser.Parse(); err != nil {
var flagErr *flags.Error
if errors.As(err, &flagErr) && flagErr.Type == flags.ErrHelp {
fmt.Fprintln(os.Stdout, flagErr)
os.Exit(0)
}
return errors.Wrap(err, "can't parse CLI flags")
}
return nil
}

View file

@ -1,5 +0,0 @@
package config
type Validator interface {
Validate() error
}

View file

@ -1,58 +0,0 @@
package config
import (
"crypto/tls"
"crypto/x509"
"github.com/pkg/errors"
"os"
)
// TLS provides TLS configuration options.
type TLS struct {
Enable bool `yaml:"tls"`
Cert string `yaml:"cert"`
Key string `yaml:"key"`
Ca string `yaml:"ca"`
Insecure bool `yaml:"insecure"`
}
// MakeConfig assembles a tls.Config from t and serverName.
func (t *TLS) MakeConfig(serverName string) (*tls.Config, error) {
if !t.Enable {
return nil, nil
}
tlsConfig := &tls.Config{}
if t.Cert == "" {
if t.Key != "" {
return nil, errors.New("private key given, but client certificate missing")
}
} else if t.Key == "" {
return nil, errors.New("client certificate given, but private key missing")
} else {
crt, err := tls.LoadX509KeyPair(t.Cert, t.Key)
if err != nil {
return nil, errors.Wrap(err, "can't load X.509 key pair")
}
tlsConfig.Certificates = []tls.Certificate{crt}
}
if t.Insecure {
tlsConfig.InsecureSkipVerify = true
} else if t.Ca != "" {
raw, err := os.ReadFile(t.Ca)
if err != nil {
return nil, errors.Wrap(err, "can't read CA file")
}
tlsConfig.RootCAs = x509.NewCertPool()
if !tlsConfig.RootCAs.AppendCertsFromPEM(raw) {
return nil, errors.New("can't parse CA file")
}
}
tlsConfig.ServerName = serverName
return tlsConfig, nil
}

View file

@ -1,75 +0,0 @@
package database
import (
"database/sql/driver"
"github.com/jmoiron/sqlx/reflectx"
"reflect"
"sync"
)
// ColumnMap provides a cached mapping of structs exported fields to their database column names.
type ColumnMap interface {
// Columns returns database column names for a struct's exported fields in a cached manner.
// Thus, the returned slice MUST NOT be modified directly.
// By default, all exported struct fields are mapped to database column names using snake case notation.
// The - (hyphen) directive for the db tag can be used to exclude certain fields.
Columns(any) []string
}
// NewColumnMap returns a new ColumnMap.
func NewColumnMap(mapper *reflectx.Mapper) ColumnMap {
return &columnMap{
cache: make(map[reflect.Type][]string),
mapper: mapper,
}
}
type columnMap struct {
mutex sync.Mutex
cache map[reflect.Type][]string
mapper *reflectx.Mapper
}
func (m *columnMap) Columns(subject any) []string {
m.mutex.Lock()
defer m.mutex.Unlock()
t, ok := subject.(reflect.Type)
if !ok {
t = reflect.TypeOf(subject)
}
columns, ok := m.cache[t]
if !ok {
columns = m.getColumns(t)
m.cache[t] = columns
}
return columns
}
func (m *columnMap) getColumns(t reflect.Type) []string {
fields := m.mapper.TypeMap(t).Names
columns := make([]string, 0, len(fields))
FieldLoop:
for _, f := range fields {
// If one of the parent fields implements the driver.Valuer interface, the field can be ignored.
for parent := f.Parent; parent != nil && parent.Zero.IsValid(); parent = parent.Parent {
// Check for pointer types.
if _, ok := reflect.New(parent.Field.Type).Interface().(driver.Valuer); ok {
continue FieldLoop
}
// Check for non-pointer types.
if _, ok := reflect.Zero(parent.Field.Type).Interface().(driver.Valuer); ok {
continue FieldLoop
}
}
columns = append(columns, f.Path)
}
// Shrink/reduce slice length and capacity:
// For a three-index slice (slice[a:b:c]), the length of the returned slice is b-a and the capacity is c-a.
return columns[0:len(columns):len(columns)]
}

View file

@ -1,45 +0,0 @@
package database
import (
"github.com/icinga/icinga-go-library/config"
"github.com/pkg/errors"
)
// Config defines database client configuration.
type Config struct {
Type string `yaml:"type" default:"mysql"`
Host string `yaml:"host"`
Port int `yaml:"port"`
Database string `yaml:"database"`
User string `yaml:"user"`
Password string `yaml:"password"`
TlsOptions config.TLS `yaml:",inline"`
Options Options `yaml:"options"`
}
// Validate checks constraints in the supplied database configuration and returns an error if they are violated.
func (c *Config) Validate() error {
switch c.Type {
case "mysql", "pgsql":
default:
return unknownDbType(c.Type)
}
if c.Host == "" {
return errors.New("database host missing")
}
if c.User == "" {
return errors.New("database user missing")
}
if c.Database == "" {
return errors.New("database name missing")
}
return c.Options.Validate()
}
func unknownDbType(t string) error {
return errors.Errorf(`unknown database type %q, must be one of: "mysql", "pgsql"`, t)
}

View file

@ -1,56 +0,0 @@
package database
// Entity is implemented by each type that works with the database package.
type Entity interface {
Fingerprinter
IDer
}
// Fingerprinter is implemented by every entity that uniquely identifies itself.
type Fingerprinter interface {
// Fingerprint returns the value that uniquely identifies the entity.
Fingerprint() Fingerprinter
}
// ID is a unique identifier of an entity.
type ID interface {
// String returns the string representation form of the ID.
// The String method is used to use the ID in functions
// where it needs to be compared or hashed.
String() string
}
// IDer is implemented by every entity that uniquely identifies itself.
type IDer interface {
ID() ID // ID returns the ID.
SetID(ID) // SetID sets the ID.
}
// EntityFactoryFunc knows how to create an Entity.
type EntityFactoryFunc func() Entity
// Upserter implements the Upsert method,
// which returns a part of the object for ON DUPLICATE KEY UPDATE.
type Upserter interface {
Upsert() any // Upsert partitions the object.
}
// TableNamer implements the TableName method,
// which returns the table of the object.
type TableNamer interface {
TableName() string // TableName tells the table.
}
// Scoper implements the Scope method,
// which returns a struct specifying the WHERE conditions that
// entities must satisfy in order to be SELECTed.
type Scoper interface {
Scope() any
}
// PgsqlOnConflictConstrainter implements the PgsqlOnConflictConstraint method,
// which returns the primary or unique key constraint name of the PostgreSQL table.
type PgsqlOnConflictConstrainter interface {
// PgsqlOnConflictConstraint returns the primary or unique key constraint name of the PostgreSQL table.
PgsqlOnConflictConstraint() string
}

View file

@ -1,800 +0,0 @@
package database
import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
"github.com/go-sql-driver/mysql"
"github.com/icinga/icinga-go-library/backoff"
"github.com/icinga/icinga-go-library/com"
"github.com/icinga/icinga-go-library/logging"
"github.com/icinga/icinga-go-library/periodic"
"github.com/icinga/icinga-go-library/retry"
"github.com/icinga/icinga-go-library/strcase"
"github.com/icinga/icinga-go-library/utils"
"github.com/jmoiron/sqlx"
"github.com/jmoiron/sqlx/reflectx"
"github.com/lib/pq"
"github.com/pkg/errors"
"go.uber.org/zap"
"golang.org/x/sync/errgroup"
"golang.org/x/sync/semaphore"
"net"
"net/url"
"strconv"
"strings"
"sync"
"time"
)
// DB is a wrapper around sqlx.DB with bulk execution,
// statement building, streaming and logging capabilities.
type DB struct {
*sqlx.DB
Options *Options
addr string
columnMap ColumnMap
logger *logging.Logger
tableSemaphores map[string]*semaphore.Weighted
tableSemaphoresMu sync.Mutex
}
// Options define user configurable database options.
type Options struct {
// Maximum number of open connections to the database.
MaxConnections int `yaml:"max_connections" default:"16"`
// Maximum number of connections per table,
// regardless of what the connection is actually doing,
// e.g. INSERT, UPDATE, DELETE.
MaxConnectionsPerTable int `yaml:"max_connections_per_table" default:"8"`
// MaxPlaceholdersPerStatement defines the maximum number of placeholders in an
// INSERT, UPDATE or DELETE statement. Theoretically, MySQL can handle up to 2^16-1 placeholders,
// but this increases the execution time of queries and thus reduces the number of queries
// that can be executed in parallel in a given time.
// The default is 2^13, which in our tests showed the best performance in terms of execution time and parallelism.
MaxPlaceholdersPerStatement int `yaml:"max_placeholders_per_statement" default:"8192"`
// MaxRowsPerTransaction defines the maximum number of rows per transaction.
// The default is 2^13, which in our tests showed the best performance in terms of execution time and parallelism.
MaxRowsPerTransaction int `yaml:"max_rows_per_transaction" default:"8192"`
// WsrepSyncWait enforces Galera cluster nodes to perform strict cluster-wide causality checks
// before executing specific SQL queries determined by the number you provided.
// Please refer to the below link for a detailed description.
// https://icinga.com/docs/icinga-db/latest/doc/03-Configuration/#galera-cluster
WsrepSyncWait int `yaml:"wsrep_sync_wait" default:"7"`
}
// Validate checks constraints in the supplied database options and returns an error if they are violated.
func (o *Options) Validate() error {
if o.MaxConnections == 0 {
return errors.New("max_connections cannot be 0. Configure a value greater than zero, or use -1 for no connection limit")
}
if o.MaxConnectionsPerTable < 1 {
return errors.New("max_connections_per_table must be at least 1")
}
if o.MaxPlaceholdersPerStatement < 1 {
return errors.New("max_placeholders_per_statement must be at least 1")
}
if o.MaxRowsPerTransaction < 1 {
return errors.New("max_rows_per_transaction must be at least 1")
}
if o.WsrepSyncWait < 0 || o.WsrepSyncWait > 15 {
return errors.New("wsrep_sync_wait can only be set to a number between 0 and 15")
}
return nil
}
// NewDbFromConfig returns a new DB from Config.
func NewDbFromConfig(c *Config, logger *logging.Logger, connectorCallbacks RetryConnectorCallbacks) (*DB, error) {
var addr string
var db *sqlx.DB
switch c.Type {
case "mysql":
config := mysql.NewConfig()
config.User = c.User
config.Passwd = c.Password
config.Logger = MysqlFuncLogger(logger.Debug)
if utils.IsUnixAddr(c.Host) {
config.Net = "unix"
config.Addr = c.Host
} else {
config.Net = "tcp"
port := c.Port
if port == 0 {
port = 3306
}
config.Addr = net.JoinHostPort(c.Host, fmt.Sprint(port))
}
config.DBName = c.Database
config.Timeout = time.Minute
config.Params = map[string]string{"sql_mode": "'TRADITIONAL,ANSI_QUOTES'"}
tlsConfig, err := c.TlsOptions.MakeConfig(c.Host)
if err != nil {
return nil, err
}
config.TLS = tlsConfig
connector, err := mysql.NewConnector(config)
if err != nil {
return nil, errors.Wrap(err, "can't open mysql database")
}
onInitConn := connectorCallbacks.OnInitConn
connectorCallbacks.OnInitConn = func(ctx context.Context, conn driver.Conn) error {
if onInitConn != nil {
if err := onInitConn(ctx, conn); err != nil {
return err
}
}
return setGaleraOpts(ctx, conn, int64(c.Options.WsrepSyncWait))
}
addr = config.Addr
db = sqlx.NewDb(sql.OpenDB(NewConnector(connector, logger, connectorCallbacks)), MySQL)
case "pgsql":
uri := &url.URL{
Scheme: "postgres",
User: url.UserPassword(c.User, c.Password),
Path: "/" + url.PathEscape(c.Database),
}
query := url.Values{
"connect_timeout": {"60"},
"binary_parameters": {"yes"},
// Host and port can alternatively be specified in the query string. lib/pq can't parse the connection URI
// if a Unix domain socket path is specified in the host part of the URI, therefore always use the query
// string. See also https://github.com/lib/pq/issues/796
"host": {c.Host},
}
port := c.Port
if port == 0 {
port = 5432
}
query["port"] = []string{strconv.FormatInt(int64(port), 10)}
if _, err := c.TlsOptions.MakeConfig(c.Host); err != nil {
return nil, err
}
if c.TlsOptions.Enable {
if c.TlsOptions.Insecure {
query["sslmode"] = []string{"require"}
} else {
query["sslmode"] = []string{"verify-full"}
}
if c.TlsOptions.Cert != "" {
query["sslcert"] = []string{c.TlsOptions.Cert}
}
if c.TlsOptions.Key != "" {
query["sslkey"] = []string{c.TlsOptions.Key}
}
if c.TlsOptions.Ca != "" {
query["sslrootcert"] = []string{c.TlsOptions.Ca}
}
} else {
query["sslmode"] = []string{"disable"}
}
uri.RawQuery = query.Encode()
connector, err := pq.NewConnector(uri.String())
if err != nil {
return nil, errors.Wrap(err, "can't open pgsql database")
}
addr = utils.JoinHostPort(c.Host, port)
db = sqlx.NewDb(sql.OpenDB(NewConnector(connector, logger, connectorCallbacks)), PostgreSQL)
default:
return nil, unknownDbType(c.Type)
}
db.SetMaxIdleConns(c.Options.MaxConnections / 3)
db.SetMaxOpenConns(c.Options.MaxConnections)
db.Mapper = reflectx.NewMapperFunc("db", strcase.Snake)
return &DB{
DB: db,
Options: &c.Options,
columnMap: NewColumnMap(db.Mapper),
addr: addr,
logger: logger,
tableSemaphores: make(map[string]*semaphore.Weighted),
}, nil
}
// GetAddr returns the database host:port or Unix socket address.
func (db *DB) GetAddr() string {
return db.addr
}
// BuildDeleteStmt returns a DELETE statement for the given struct.
func (db *DB) BuildDeleteStmt(from interface{}) string {
return fmt.Sprintf(
`DELETE FROM "%s" WHERE id IN (?)`,
TableName(from),
)
}
// BuildInsertStmt returns an INSERT INTO statement for the given struct.
func (db *DB) BuildInsertStmt(into interface{}) (string, int) {
columns := db.columnMap.Columns(into)
return fmt.Sprintf(
`INSERT INTO "%s" ("%s") VALUES (%s)`,
TableName(into),
strings.Join(columns, `", "`),
fmt.Sprintf(":%s", strings.Join(columns, ", :")),
), len(columns)
}
// BuildInsertIgnoreStmt returns an INSERT statement for the specified struct for
// which the database ignores rows that have already been inserted.
func (db *DB) BuildInsertIgnoreStmt(into interface{}) (string, int) {
table := TableName(into)
columns := db.columnMap.Columns(into)
var clause string
switch db.DriverName() {
case MySQL:
// MySQL treats UPDATE id = id as a no-op.
clause = fmt.Sprintf(`ON DUPLICATE KEY UPDATE "%s" = "%s"`, columns[0], columns[0])
case PostgreSQL:
var constraint string
if constrainter, ok := into.(PgsqlOnConflictConstrainter); ok {
constraint = constrainter.PgsqlOnConflictConstraint()
} else {
constraint = "pk_" + table
}
clause = fmt.Sprintf("ON CONFLICT ON CONSTRAINT %s DO NOTHING", constraint)
}
return fmt.Sprintf(
`INSERT INTO "%s" ("%s") VALUES (%s) %s`,
table,
strings.Join(columns, `", "`),
fmt.Sprintf(":%s", strings.Join(columns, ", :")),
clause,
), len(columns)
}
// BuildSelectStmt returns a SELECT query that creates the FROM part from the given table struct
// and the column list from the specified columns struct.
func (db *DB) BuildSelectStmt(table interface{}, columns interface{}) string {
q := fmt.Sprintf(
`SELECT "%s" FROM "%s"`,
strings.Join(db.columnMap.Columns(columns), `", "`),
TableName(table),
)
if scoper, ok := table.(Scoper); ok {
where, _ := db.BuildWhere(scoper.Scope())
q += ` WHERE ` + where
}
return q
}
// BuildUpdateStmt returns an UPDATE statement for the given struct.
func (db *DB) BuildUpdateStmt(update interface{}) (string, int) {
columns := db.columnMap.Columns(update)
set := make([]string, 0, len(columns))
for _, col := range columns {
set = append(set, fmt.Sprintf(`"%s" = :%s`, col, col))
}
return fmt.Sprintf(
`UPDATE "%s" SET %s WHERE id = :id`,
TableName(update),
strings.Join(set, ", "),
), len(columns) + 1 // +1 because of WHERE id = :id
}
// BuildUpsertStmt returns an upsert statement for the given struct.
func (db *DB) BuildUpsertStmt(subject interface{}) (stmt string, placeholders int) {
insertColumns := db.columnMap.Columns(subject)
table := TableName(subject)
var updateColumns []string
if upserter, ok := subject.(Upserter); ok {
updateColumns = db.columnMap.Columns(upserter.Upsert())
} else {
updateColumns = insertColumns
}
var clause, setFormat string
switch db.DriverName() {
case MySQL:
clause = "ON DUPLICATE KEY UPDATE"
setFormat = `"%[1]s" = VALUES("%[1]s")`
case PostgreSQL:
var constraint string
if constrainter, ok := subject.(PgsqlOnConflictConstrainter); ok {
constraint = constrainter.PgsqlOnConflictConstraint()
} else {
constraint = "pk_" + table
}
clause = fmt.Sprintf("ON CONFLICT ON CONSTRAINT %s DO UPDATE SET", constraint)
setFormat = `"%[1]s" = EXCLUDED."%[1]s"`
}
set := make([]string, 0, len(updateColumns))
for _, col := range updateColumns {
set = append(set, fmt.Sprintf(setFormat, col))
}
return fmt.Sprintf(
`INSERT INTO "%s" ("%s") VALUES (%s) %s %s`,
table,
strings.Join(insertColumns, `", "`),
fmt.Sprintf(":%s", strings.Join(insertColumns, ",:")),
clause,
strings.Join(set, ","),
), len(insertColumns)
}
// BuildWhere returns a WHERE clause with named placeholder conditions built from the specified struct
// combined with the AND operator.
func (db *DB) BuildWhere(subject interface{}) (string, int) {
columns := db.columnMap.Columns(subject)
where := make([]string, 0, len(columns))
for _, col := range columns {
where = append(where, fmt.Sprintf(`"%s" = :%s`, col, col))
}
return strings.Join(where, ` AND `), len(columns)
}
// OnSuccess is a callback for successful (bulk) DML operations.
type OnSuccess[T any] func(ctx context.Context, affectedRows []T) (err error)
func OnSuccessIncrement[T any](counter *com.Counter) OnSuccess[T] {
return func(_ context.Context, rows []T) error {
counter.Add(uint64(len(rows)))
return nil
}
}
func OnSuccessSendTo[T any](ch chan<- T) OnSuccess[T] {
return func(ctx context.Context, rows []T) error {
for _, row := range rows {
select {
case ch <- row:
case <-ctx.Done():
return ctx.Err()
}
}
return nil
}
}
// BulkExec bulk executes queries with a single slice placeholder in the form of `IN (?)`.
// Takes in up to the number of arguments specified in count from the arg stream,
// derives and expands a query and executes it with this set of arguments until the arg stream has been processed.
// The derived queries are executed in a separate goroutine with a weighting of 1
// and can be executed concurrently to the extent allowed by the semaphore passed in sem.
// Arguments for which the query ran successfully will be passed to onSuccess.
func (db *DB) BulkExec(
ctx context.Context, query string, count int, sem *semaphore.Weighted, arg <-chan any, onSuccess ...OnSuccess[any],
) error {
var counter com.Counter
defer db.Log(ctx, query, &counter).Stop()
g, ctx := errgroup.WithContext(ctx)
// Use context from group.
bulk := com.Bulk(ctx, arg, count, com.NeverSplit[any])
g.Go(func() error {
g, ctx := errgroup.WithContext(ctx)
for b := range bulk {
if err := sem.Acquire(ctx, 1); err != nil {
return errors.Wrap(err, "can't acquire semaphore")
}
g.Go(func(b []interface{}) func() error {
return func() error {
defer sem.Release(1)
return retry.WithBackoff(
ctx,
func(context.Context) error {
stmt, args, err := sqlx.In(query, b)
if err != nil {
return errors.Wrapf(err, "can't build placeholders for %q", query)
}
stmt = db.Rebind(stmt)
_, err = db.ExecContext(ctx, stmt, args...)
if err != nil {
return CantPerformQuery(err, query)
}
counter.Add(uint64(len(b)))
for _, onSuccess := range onSuccess {
if err := onSuccess(ctx, b); err != nil {
return err
}
}
return nil
},
retry.Retryable,
backoff.NewExponentialWithJitter(1*time.Millisecond, 1*time.Second),
db.GetDefaultRetrySettings(),
)
}
}(b))
}
return g.Wait()
})
return g.Wait()
}
// NamedBulkExec bulk executes queries with named placeholders in a VALUES clause most likely
// in the format INSERT ... VALUES. Takes in up to the number of entities specified in count
// from the arg stream, derives and executes a new query with the VALUES clause expanded to
// this set of arguments, until the arg stream has been processed.
// The queries are executed in a separate goroutine with a weighting of 1
// and can be executed concurrently to the extent allowed by the semaphore passed in sem.
// Entities for which the query ran successfully will be passed to onSuccess.
func (db *DB) NamedBulkExec(
ctx context.Context, query string, count int, sem *semaphore.Weighted, arg <-chan Entity,
splitPolicyFactory com.BulkChunkSplitPolicyFactory[Entity], onSuccess ...OnSuccess[Entity],
) error {
var counter com.Counter
defer db.Log(ctx, query, &counter).Stop()
g, ctx := errgroup.WithContext(ctx)
bulk := com.Bulk(ctx, arg, count, splitPolicyFactory)
g.Go(func() error {
for {
select {
case b, ok := <-bulk:
if !ok {
return nil
}
if err := sem.Acquire(ctx, 1); err != nil {
return errors.Wrap(err, "can't acquire semaphore")
}
g.Go(func(b []Entity) func() error {
return func() error {
defer sem.Release(1)
return retry.WithBackoff(
ctx,
func(ctx context.Context) error {
_, err := db.NamedExecContext(ctx, query, b)
if err != nil {
return CantPerformQuery(err, query)
}
counter.Add(uint64(len(b)))
for _, onSuccess := range onSuccess {
if err := onSuccess(ctx, b); err != nil {
return err
}
}
return nil
},
retry.Retryable,
backoff.NewExponentialWithJitter(1*time.Millisecond, 1*time.Second),
db.GetDefaultRetrySettings(),
)
}
}(b))
case <-ctx.Done():
return ctx.Err()
}
}
})
return g.Wait()
}
// NamedBulkExecTx bulk executes queries with named placeholders in separate transactions.
// Takes in up to the number of entities specified in count from the arg stream and
// executes a new transaction that runs a new query for each entity in this set of arguments,
// until the arg stream has been processed.
// The transactions are executed in a separate goroutine with a weighting of 1
// and can be executed concurrently to the extent allowed by the semaphore passed in sem.
func (db *DB) NamedBulkExecTx(
ctx context.Context, query string, count int, sem *semaphore.Weighted, arg <-chan Entity,
) error {
var counter com.Counter
defer db.Log(ctx, query, &counter).Stop()
g, ctx := errgroup.WithContext(ctx)
bulk := com.Bulk(ctx, arg, count, com.NeverSplit[Entity])
g.Go(func() error {
for {
select {
case b, ok := <-bulk:
if !ok {
return nil
}
if err := sem.Acquire(ctx, 1); err != nil {
return errors.Wrap(err, "can't acquire semaphore")
}
g.Go(func(b []Entity) func() error {
return func() error {
defer sem.Release(1)
return retry.WithBackoff(
ctx,
func(ctx context.Context) error {
tx, err := db.BeginTxx(ctx, nil)
if err != nil {
return errors.Wrap(err, "can't start transaction")
}
stmt, err := tx.PrepareNamedContext(ctx, query)
if err != nil {
return errors.Wrap(err, "can't prepare named statement with context in transaction")
}
for _, arg := range b {
if _, err := stmt.ExecContext(ctx, arg); err != nil {
return errors.Wrap(err, "can't execute statement in transaction")
}
}
if err := tx.Commit(); err != nil {
return errors.Wrap(err, "can't commit transaction")
}
counter.Add(uint64(len(b)))
return nil
},
retry.Retryable,
backoff.NewExponentialWithJitter(1*time.Millisecond, 1*time.Second),
db.GetDefaultRetrySettings(),
)
}
}(b))
case <-ctx.Done():
return ctx.Err()
}
}
})
return g.Wait()
}
// BatchSizeByPlaceholders returns how often the specified number of placeholders fits
// into Options.MaxPlaceholdersPerStatement, but at least 1.
func (db *DB) BatchSizeByPlaceholders(n int) int {
s := db.Options.MaxPlaceholdersPerStatement / n
if s > 0 {
return s
}
return 1
}
// YieldAll executes the query with the supplied scope,
// scans each resulting row into an entity returned by the factory function,
// and streams them into a returned channel.
func (db *DB) YieldAll(ctx context.Context, factoryFunc EntityFactoryFunc, query string, scope interface{}) (<-chan Entity, <-chan error) {
entities := make(chan Entity, 1)
g, ctx := errgroup.WithContext(ctx)
g.Go(func() error {
var counter com.Counter
defer db.Log(ctx, query, &counter).Stop()
defer close(entities)
rows, err := db.NamedQueryContext(ctx, query, scope)
if err != nil {
return CantPerformQuery(err, query)
}
defer rows.Close()
for rows.Next() {
e := factoryFunc()
if err := rows.StructScan(e); err != nil {
return errors.Wrapf(err, "can't store query result into a %T: %s", e, query)
}
select {
case entities <- e:
counter.Inc()
case <-ctx.Done():
return ctx.Err()
}
}
return nil
})
return entities, com.WaitAsync(g)
}
// CreateStreamed bulk creates the specified entities via NamedBulkExec.
// The insert statement is created using BuildInsertStmt with the first entity from the entities stream.
// Bulk size is controlled via Options.MaxPlaceholdersPerStatement and
// concurrency is controlled via Options.MaxConnectionsPerTable.
// Entities for which the query ran successfully will be passed to onSuccess.
func (db *DB) CreateStreamed(
ctx context.Context, entities <-chan Entity, onSuccess ...OnSuccess[Entity],
) error {
first, forward, err := com.CopyFirst(ctx, entities)
if err != nil {
return errors.Wrap(err, "can't copy first entity")
}
sem := db.GetSemaphoreForTable(TableName(first))
stmt, placeholders := db.BuildInsertStmt(first)
return db.NamedBulkExec(
ctx, stmt, db.BatchSizeByPlaceholders(placeholders), sem,
forward, com.NeverSplit[Entity], onSuccess...,
)
}
// CreateIgnoreStreamed bulk creates the specified entities via NamedBulkExec.
// The insert statement is created using BuildInsertIgnoreStmt with the first entity from the entities stream.
// Bulk size is controlled via Options.MaxPlaceholdersPerStatement and
// concurrency is controlled via Options.MaxConnectionsPerTable.
// Entities for which the query ran successfully will be passed to onSuccess.
func (db *DB) CreateIgnoreStreamed(
ctx context.Context, entities <-chan Entity, onSuccess ...OnSuccess[Entity],
) error {
first, forward, err := com.CopyFirst(ctx, entities)
if err != nil {
return errors.Wrap(err, "can't copy first entity")
}
sem := db.GetSemaphoreForTable(TableName(first))
stmt, placeholders := db.BuildInsertIgnoreStmt(first)
return db.NamedBulkExec(
ctx, stmt, db.BatchSizeByPlaceholders(placeholders), sem,
forward, SplitOnDupId[Entity], onSuccess...,
)
}
// UpsertStreamed bulk upserts the specified entities via NamedBulkExec.
// The upsert statement is created using BuildUpsertStmt with the first entity from the entities stream.
// Bulk size is controlled via Options.MaxPlaceholdersPerStatement and
// concurrency is controlled via Options.MaxConnectionsPerTable.
// Entities for which the query ran successfully will be passed to onSuccess.
func (db *DB) UpsertStreamed(
ctx context.Context, entities <-chan Entity, onSuccess ...OnSuccess[Entity],
) error {
first, forward, err := com.CopyFirst(ctx, entities)
if err != nil {
return errors.Wrap(err, "can't copy first entity")
}
sem := db.GetSemaphoreForTable(TableName(first))
stmt, placeholders := db.BuildUpsertStmt(first)
return db.NamedBulkExec(
ctx, stmt, db.BatchSizeByPlaceholders(placeholders), sem,
forward, SplitOnDupId[Entity], onSuccess...,
)
}
// UpdateStreamed bulk updates the specified entities via NamedBulkExecTx.
// The update statement is created using BuildUpdateStmt with the first entity from the entities stream.
// Bulk size is controlled via Options.MaxRowsPerTransaction and
// concurrency is controlled via Options.MaxConnectionsPerTable.
func (db *DB) UpdateStreamed(ctx context.Context, entities <-chan Entity) error {
first, forward, err := com.CopyFirst(ctx, entities)
if err != nil {
return errors.Wrap(err, "can't copy first entity")
}
sem := db.GetSemaphoreForTable(TableName(first))
stmt, _ := db.BuildUpdateStmt(first)
return db.NamedBulkExecTx(ctx, stmt, db.Options.MaxRowsPerTransaction, sem, forward)
}
// DeleteStreamed bulk deletes the specified ids via BulkExec.
// The delete statement is created using BuildDeleteStmt with the passed entityType.
// Bulk size is controlled via Options.MaxPlaceholdersPerStatement and
// concurrency is controlled via Options.MaxConnectionsPerTable.
// IDs for which the query ran successfully will be passed to onSuccess.
func (db *DB) DeleteStreamed(
ctx context.Context, entityType Entity, ids <-chan interface{}, onSuccess ...OnSuccess[any],
) error {
sem := db.GetSemaphoreForTable(TableName(entityType))
return db.BulkExec(
ctx, db.BuildDeleteStmt(entityType), db.Options.MaxPlaceholdersPerStatement, sem, ids, onSuccess...,
)
}
// Delete creates a channel from the specified ids and
// bulk deletes them by passing the channel along with the entityType to DeleteStreamed.
// IDs for which the query ran successfully will be passed to onSuccess.
func (db *DB) Delete(
ctx context.Context, entityType Entity, ids []interface{}, onSuccess ...OnSuccess[any],
) error {
idsCh := make(chan interface{}, len(ids))
for _, id := range ids {
idsCh <- id
}
close(idsCh)
return db.DeleteStreamed(ctx, entityType, idsCh, onSuccess...)
}
func (db *DB) GetSemaphoreForTable(table string) *semaphore.Weighted {
db.tableSemaphoresMu.Lock()
defer db.tableSemaphoresMu.Unlock()
if sem, ok := db.tableSemaphores[table]; ok {
return sem
} else {
sem = semaphore.NewWeighted(int64(db.Options.MaxConnectionsPerTable))
db.tableSemaphores[table] = sem
return sem
}
}
func (db *DB) GetDefaultRetrySettings() retry.Settings {
return retry.Settings{
Timeout: retry.DefaultTimeout,
OnRetryableError: func(_ time.Duration, _ uint64, err, lastErr error) {
if lastErr == nil || err.Error() != lastErr.Error() {
db.logger.Warnw("Can't execute query. Retrying", zap.Error(err))
}
},
OnSuccess: func(elapsed time.Duration, attempt uint64, lastErr error) {
if attempt > 1 {
db.logger.Infow("Query retried successfully after error",
zap.Duration("after", elapsed),
zap.Uint64("attempts", attempt),
zap.NamedError("recovered_error", lastErr))
}
},
}
}
func (db *DB) Log(ctx context.Context, query string, counter *com.Counter) periodic.Stopper {
return periodic.Start(ctx, db.logger.Interval(), func(tick periodic.Tick) {
if count := counter.Reset(); count > 0 {
db.logger.Debugf("Executed %q with %d rows", query, count)
}
}, periodic.OnStop(func(tick periodic.Tick) {
db.logger.Debugf("Finished executing %q with %d rows in %s", query, counter.Total(), tick.Elapsed)
}))
}

View file

@ -1,100 +0,0 @@
package database
import (
"context"
"database/sql/driver"
"github.com/icinga/icinga-go-library/backoff"
"github.com/icinga/icinga-go-library/logging"
"github.com/icinga/icinga-go-library/retry"
"github.com/pkg/errors"
"go.uber.org/zap"
"time"
)
// Driver names as automatically registered in the database/sql package by themselves.
const (
MySQL string = "mysql"
PostgreSQL string = "postgres"
)
// OnInitConnFunc can be used to execute post Connect() arbitrary actions.
// It will be called after successfully initiated a new connection using the connector's Connect method.
type OnInitConnFunc func(context.Context, driver.Conn) error
// RetryConnectorCallbacks specifies callbacks that are executed upon certain events.
type RetryConnectorCallbacks struct {
OnInitConn OnInitConnFunc
OnRetryableError retry.OnRetryableErrorFunc
OnSuccess retry.OnSuccessFunc
}
// RetryConnector wraps driver.Connector with retry logic.
type RetryConnector struct {
driver.Connector
logger *logging.Logger
callbacks RetryConnectorCallbacks
}
// NewConnector creates a fully initialized RetryConnector from the given args.
func NewConnector(c driver.Connector, logger *logging.Logger, callbacks RetryConnectorCallbacks) *RetryConnector {
return &RetryConnector{Connector: c, logger: logger, callbacks: callbacks}
}
// Connect implements part of the driver.Connector interface.
func (c RetryConnector) Connect(ctx context.Context) (driver.Conn, error) {
var conn driver.Conn
err := errors.Wrap(retry.WithBackoff(
ctx,
func(ctx context.Context) (err error) {
conn, err = c.Connector.Connect(ctx)
if err == nil && c.callbacks.OnInitConn != nil {
if err = c.callbacks.OnInitConn(ctx, conn); err != nil {
// We're going to retry this, so just don't bother whether Close() fails!
_ = conn.Close()
}
}
return
},
retry.Retryable,
backoff.NewExponentialWithJitter(128*time.Millisecond, 1*time.Minute),
retry.Settings{
Timeout: retry.DefaultTimeout,
OnRetryableError: func(elapsed time.Duration, attempt uint64, err, lastErr error) {
if c.callbacks.OnRetryableError != nil {
c.callbacks.OnRetryableError(elapsed, attempt, err, lastErr)
}
if lastErr == nil || err.Error() != lastErr.Error() {
c.logger.Warnw("Can't connect to database. Retrying", zap.Error(err))
}
},
OnSuccess: func(elapsed time.Duration, attempt uint64, lastErr error) {
if c.callbacks.OnSuccess != nil {
c.callbacks.OnSuccess(elapsed, attempt, lastErr)
}
if attempt > 1 {
c.logger.Infow("Reconnected to database",
zap.Duration("after", elapsed), zap.Uint64("attempts", attempt))
}
},
},
), "can't connect to database")
return conn, err
}
// Driver implements part of the driver.Connector interface.
func (c RetryConnector) Driver() driver.Driver {
return c.Connector.Driver()
}
// MysqlFuncLogger is an adapter that allows ordinary functions to be used as a logger for mysql.SetLogger.
type MysqlFuncLogger func(v ...interface{})
// Print implements the mysql.Logger interface.
func (log MysqlFuncLogger) Print(v ...interface{}) {
log(v)
}

View file

@ -1,81 +0,0 @@
package database
import (
"context"
"database/sql/driver"
"github.com/go-sql-driver/mysql"
"github.com/icinga/icinga-go-library/com"
"github.com/icinga/icinga-go-library/strcase"
"github.com/icinga/icinga-go-library/types"
"github.com/pkg/errors"
)
// CantPerformQuery wraps the given error with the specified query that cannot be executed.
func CantPerformQuery(err error, q string) error {
return errors.Wrapf(err, "can't perform %q", q)
}
// TableName returns the table of t.
func TableName(t interface{}) string {
if tn, ok := t.(TableNamer); ok {
return tn.TableName()
} else {
return strcase.Snake(types.Name(t))
}
}
// SplitOnDupId returns a state machine which tracks the inputs' IDs.
// Once an already seen input arrives, it demands splitting.
func SplitOnDupId[T IDer]() com.BulkChunkSplitPolicy[T] {
seenIds := map[string]struct{}{}
return func(ider T) bool {
id := ider.ID().String()
_, ok := seenIds[id]
if ok {
seenIds = map[string]struct{}{id: {}}
} else {
seenIds[id] = struct{}{}
}
return ok
}
}
// setGaleraOpts sets the "wsrep_sync_wait" variable for each session ensures that causality checks are performed
// before execution and that each statement is executed on a fully synchronized node. Doing so prevents foreign key
// violation when inserting into dependent tables on different MariaDB/MySQL nodes. When using MySQL single nodes,
// the "SET SESSION" command will fail with "Unknown system variable (1193)" and will therefore be silently dropped.
//
// https://mariadb.com/kb/en/galera-cluster-system-variables/#wsrep_sync_wait
func setGaleraOpts(ctx context.Context, conn driver.Conn, wsrepSyncWait int64) error {
const galeraOpts = "SET SESSION wsrep_sync_wait=?"
stmt, err := conn.(driver.ConnPrepareContext).PrepareContext(ctx, galeraOpts)
if err != nil {
if errors.Is(err, &mysql.MySQLError{Number: 1193}) { // Unknown system variable
return nil
}
return errors.Wrap(err, "cannot prepare "+galeraOpts)
}
// This is just for an unexpected exit and any returned error can safely be ignored and in case
// of the normal function exit, the stmt is closed manually, and its error is handled gracefully.
defer func() { _ = stmt.Close() }()
_, err = stmt.(driver.StmtExecContext).ExecContext(ctx, []driver.NamedValue{{Value: wsrepSyncWait}})
if err != nil {
return errors.Wrap(err, "cannot execute "+galeraOpts)
}
if err = stmt.Close(); err != nil {
return errors.Wrap(err, "cannot close prepared statement "+galeraOpts)
}
return nil
}
var (
_ com.BulkChunkSplitPolicyFactory[Entity] = SplitOnDupId[Entity]
)

View file

@ -1,46 +0,0 @@
package flatten
import (
"fmt"
"github.com/icinga/icinga-go-library/types"
"strconv"
)
// Flatten creates flat, one-dimensional maps from arbitrarily nested values, e.g. JSON.
func Flatten(value interface{}, prefix string) map[string]types.String {
var flatten func(string, interface{})
flattened := make(map[string]types.String)
flatten = func(key string, value interface{}) {
switch value := value.(type) {
case map[string]interface{}:
if len(value) == 0 {
flattened[key] = types.String{}
break
}
for k, v := range value {
flatten(key+"."+k, v)
}
case []interface{}:
if len(value) == 0 {
flattened[key] = types.String{}
break
}
for i, v := range value {
flatten(key+"["+strconv.Itoa(i)+"]", v)
}
case nil:
flattened[key] = types.MakeString("null")
case float64:
flattened[key] = types.MakeString(strconv.FormatFloat(value, 'f', -1, 64))
default:
flattened[key] = types.MakeString(fmt.Sprintf("%v", value))
}
}
flatten(prefix, value)
return flattened
}

View file

@ -1,45 +0,0 @@
package flatten
import (
"github.com/icinga/icinga-go-library/types"
"github.com/stretchr/testify/assert"
"testing"
)
func TestFlatten(t *testing.T) {
for _, st := range []struct {
name string
prefix string
value any
output map[string]types.String
}{
{"nil", "a", nil, map[string]types.String{"a": types.MakeString("null")}},
{"bool", "b", true, map[string]types.String{"b": types.MakeString("true")}},
{"int", "c", 42, map[string]types.String{"c": types.MakeString("42")}},
{"float", "d", 77.7, map[string]types.String{"d": types.MakeString("77.7")}},
{"large_float", "e", 1e23, map[string]types.String{"e": types.MakeString("100000000000000000000000")}},
{"string", "f", "\x00", map[string]types.String{"f": types.MakeString("\x00")}},
{"nil_slice", "g", []any(nil), map[string]types.String{"g": {}}},
{"empty_slice", "h", []any{}, map[string]types.String{"h": {}}},
{"slice", "i", []any{nil}, map[string]types.String{"i[0]": types.MakeString("null")}},
{"nil_map", "j", map[string]any(nil), map[string]types.String{"j": {}}},
{"empty_map", "k", map[string]any{}, map[string]types.String{"k": {}}},
{"map", "l", map[string]any{" ": nil}, map[string]types.String{"l. ": types.MakeString("null")}},
{"map_with_slice", "m", map[string]any{"\t": []any{"ä", "ö", "ü"}, "ß": "s"}, map[string]types.String{
"m.\t[0]": types.MakeString("ä"),
"m.\t[1]": types.MakeString("ö"),
"m.\t[2]": types.MakeString("ü"),
"m.ß": types.MakeString("s"),
}},
{"slice_with_map", "n", []any{map[string]any{"ä": "a", "ö": "o", "ü": "u"}, "ß"}, map[string]types.String{
"n[0].ä": types.MakeString("a"),
"n[0].ö": types.MakeString("o"),
"n[0].ü": types.MakeString("u"),
"n[1]": types.MakeString("ß"),
}},
} {
t.Run(st.name, func(t *testing.T) {
assert.Equal(t, st.output, Flatten(st.value, st.prefix))
})
}
}

View file

@ -1,60 +0,0 @@
package logging
import (
"fmt"
"github.com/pkg/errors"
"go.uber.org/zap/zapcore"
"os"
"time"
)
// Options define child loggers with their desired log level.
type Options map[string]zapcore.Level
// Config defines Logger configuration.
type Config struct {
// zapcore.Level at 0 is for info level.
Level zapcore.Level `yaml:"level" default:"0"`
Output string `yaml:"output"`
// Interval for periodic logging.
Interval time.Duration `yaml:"interval" default:"20s"`
Options `yaml:"options"`
}
// Validate checks constraints in the configuration and returns an error if they are violated.
// Also configures the log output if it is not configured:
// systemd-journald is used when Icinga DB is running under systemd, otherwise stderr.
func (l *Config) Validate() error {
if l.Interval <= 0 {
return errors.New("periodic logging interval must be positive")
}
if l.Output == "" {
if _, ok := os.LookupEnv("NOTIFY_SOCKET"); ok {
// When started by systemd, NOTIFY_SOCKET is set by systemd for Type=notify supervised services,
// which is the default setting for the Icinga DB service.
// This assumes that Icinga DB is running under systemd, so set output to systemd-journald.
l.Output = JOURNAL
} else {
// Otherwise set it to console, i.e. write log messages to stderr.
l.Output = CONSOLE
}
}
// To be on the safe side, always call AssertOutput.
return AssertOutput(l.Output)
}
// AssertOutput returns an error if output is not a valid logger output.
func AssertOutput(o string) error {
if o == CONSOLE || o == JOURNAL {
return nil
}
return invalidOutput(o)
}
func invalidOutput(o string) error {
return fmt.Errorf("%s is not a valid logger output. Must be either %q or %q", o, CONSOLE, JOURNAL)
}

View file

@ -1,84 +0,0 @@
package logging
import (
"github.com/icinga/icinga-go-library/strcase"
"github.com/pkg/errors"
"github.com/ssgreg/journald"
"go.uber.org/zap/zapcore"
"strings"
)
// priorities maps zapcore.Level to journal.Priority.
var priorities = map[zapcore.Level]journald.Priority{
zapcore.DebugLevel: journald.PriorityDebug,
zapcore.InfoLevel: journald.PriorityInfo,
zapcore.WarnLevel: journald.PriorityWarning,
zapcore.ErrorLevel: journald.PriorityErr,
zapcore.FatalLevel: journald.PriorityCrit,
zapcore.PanicLevel: journald.PriorityCrit,
zapcore.DPanicLevel: journald.PriorityCrit,
}
// NewJournaldCore returns a zapcore.Core that sends log entries to systemd-journald and
// uses the given identifier as a prefix for structured logging context that is sent as journal fields.
func NewJournaldCore(identifier string, enab zapcore.LevelEnabler) zapcore.Core {
return &journaldCore{
LevelEnabler: enab,
identifier: identifier,
identifierU: strings.ToUpper(identifier),
}
}
type journaldCore struct {
zapcore.LevelEnabler
context []zapcore.Field
identifier string
identifierU string
}
func (c *journaldCore) Check(ent zapcore.Entry, ce *zapcore.CheckedEntry) *zapcore.CheckedEntry {
if c.Enabled(ent.Level) {
return ce.AddCore(ent, c)
}
return ce
}
func (c *journaldCore) Sync() error {
return nil
}
func (c *journaldCore) With(fields []zapcore.Field) zapcore.Core {
cc := *c
cc.context = append(cc.context[:len(cc.context):len(cc.context)], fields...)
return &cc
}
func (c *journaldCore) Write(ent zapcore.Entry, fields []zapcore.Field) error {
pri, ok := priorities[ent.Level]
if !ok {
return errors.Errorf("unknown log level %q", ent.Level)
}
enc := zapcore.NewMapObjectEncoder()
c.addFields(enc, fields)
c.addFields(enc, c.context)
enc.Fields["SYSLOG_IDENTIFIER"] = c.identifier
message := ent.Message
if ent.LoggerName != c.identifier {
message = ent.LoggerName + ": " + message
}
return journald.Send(message, pri, enc.Fields)
}
func (c *journaldCore) addFields(enc zapcore.ObjectEncoder, fields []zapcore.Field) {
for _, field := range fields {
field.Key = c.identifierU +
"_" +
strcase.ScreamingSnake(field.Key)
field.AddTo(enc)
}
}

View file

@ -1,26 +0,0 @@
package logging
import (
"go.uber.org/zap"
"time"
)
// Logger wraps zap.SugaredLogger and
// allows to get the interval for periodic logging.
type Logger struct {
*zap.SugaredLogger
interval time.Duration
}
// NewLogger returns a new Logger.
func NewLogger(base *zap.SugaredLogger, interval time.Duration) *Logger {
return &Logger{
SugaredLogger: base,
interval: interval,
}
}
// Interval returns the interval for periodic logging.
func (l *Logger) Interval() time.Duration {
return l.interval
}

View file

@ -1,119 +0,0 @@
package logging
import (
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"os"
"sync"
"time"
)
const (
CONSOLE = "console"
JOURNAL = "systemd-journald"
)
// defaultEncConfig defines the default zapcore.EncoderConfig for the logging package.
var defaultEncConfig = zapcore.EncoderConfig{
TimeKey: "ts",
LevelKey: "level",
NameKey: "logger",
CallerKey: "caller",
MessageKey: "msg",
StacktraceKey: "stacktrace",
LineEnding: zapcore.DefaultLineEnding,
EncodeLevel: zapcore.CapitalLevelEncoder,
EncodeTime: zapcore.ISO8601TimeEncoder,
EncodeDuration: zapcore.StringDurationEncoder,
EncodeCaller: zapcore.ShortCallerEncoder,
}
// Logging implements access to a default logger and named child loggers.
// Log levels can be configured per named child via Options which, if not configured,
// fall back on a default log level.
// Logs either to the console or to systemd-journald.
type Logging struct {
logger *Logger
output string
verbosity zap.AtomicLevel
interval time.Duration
// coreFactory creates zapcore.Core based on the log level and the log output.
coreFactory func(zap.AtomicLevel) zapcore.Core
mu sync.Mutex
loggers map[string]*Logger
options Options
}
// NewLogging takes the name and log level for the default logger,
// output where log messages are written to,
// options having log levels for named child loggers
// and returns a new Logging.
func NewLogging(name string, level zapcore.Level, output string, options Options, interval time.Duration) (*Logging, error) {
verbosity := zap.NewAtomicLevelAt(level)
var coreFactory func(zap.AtomicLevel) zapcore.Core
switch output {
case CONSOLE:
enc := zapcore.NewConsoleEncoder(defaultEncConfig)
ws := zapcore.Lock(os.Stderr)
coreFactory = func(verbosity zap.AtomicLevel) zapcore.Core {
return zapcore.NewCore(enc, ws, verbosity)
}
case JOURNAL:
coreFactory = func(verbosity zap.AtomicLevel) zapcore.Core {
return NewJournaldCore(name, verbosity)
}
default:
return nil, invalidOutput(output)
}
logger := NewLogger(zap.New(coreFactory(verbosity)).Named(name).Sugar(), interval)
return &Logging{
logger: logger,
output: output,
verbosity: verbosity,
interval: interval,
coreFactory: coreFactory,
loggers: make(map[string]*Logger),
options: options,
},
nil
}
// NewLoggingFromConfig returns a new Logging from Config.
func NewLoggingFromConfig(name string, c Config) (*Logging, error) {
return NewLogging(name, c.Level, c.Output, c.Options, c.Interval)
}
// GetChildLogger returns a named child logger.
// Log levels for named child loggers are obtained from the logging options and, if not found,
// set to the default log level.
func (l *Logging) GetChildLogger(name string) *Logger {
l.mu.Lock()
defer l.mu.Unlock()
if logger, ok := l.loggers[name]; ok {
return logger
}
var verbosity zap.AtomicLevel
if level, found := l.options[name]; found {
verbosity = zap.NewAtomicLevelAt(level)
} else {
verbosity = l.verbosity
}
logger := NewLogger(zap.New(l.coreFactory(verbosity)).Named(name).Sugar(), l.interval)
l.loggers[name] = logger
return logger
}
// GetLogger returns the default logger.
func (l *Logging) GetLogger() *Logger {
return l.logger
}

View file

@ -1,213 +0,0 @@
package objectpacker
import (
"bytes"
"encoding/binary"
"fmt"
"github.com/pkg/errors"
"io"
"reflect"
"sort"
)
// MustPackSlice calls PackAny using items and panics if there was an error.
func MustPackSlice(items ...interface{}) []byte {
var buf bytes.Buffer
if err := PackAny(items, &buf); err != nil {
panic(err)
}
return buf.Bytes()
}
// PackAny packs any JSON-encodable value (ex. structs, also ignores interfaces like encoding.TextMarshaler)
// to a BSON-similar format suitable for consistent hashing. Spec:
//
// PackAny(nil) => 0x0
// PackAny(false) => 0x1
// PackAny(true) => 0x2
// PackAny(float64(42)) => 0x3 ieee754_binary64_bigendian(42)
// PackAny("exämple") => 0x4 uint64_bigendian(len([]byte("exämple"))) []byte("exämple")
// PackAny([]uint8{0x42}) => 0x4 uint64_bigendian(len([]uint8{0x42})) []uint8{0x42}
// PackAny([1]uint8{0x42}) => 0x4 uint64_bigendian(len([1]uint8{0x42})) [1]uint8{0x42}
// PackAny([]T{x,y}) => 0x5 uint64_bigendian(len([]T{x,y})) PackAny(x) PackAny(y)
// PackAny(map[K]V{x:y}) => 0x6 uint64_bigendian(len(map[K]V{x:y})) len(map_key(x)) map_key(x) PackAny(y)
// PackAny((*T)(nil)) => 0x0
// PackAny((*T)(0x42)) => PackAny(*(*T)(0x42))
// PackAny(x) => panic()
//
// map_key([1]uint8{0x42}) => [1]uint8{0x42}
// map_key(x) => []byte(fmt.Sprint(x))
func PackAny(in interface{}, out io.Writer) error {
return errors.Wrapf(packValue(reflect.ValueOf(in), out), "can't pack %#v", in)
}
var tByte = reflect.TypeOf(byte(0))
var tBytes = reflect.TypeOf([]uint8(nil))
// packValue does the actual job of packAny and just exists for recursion w/o unnecessary reflect.ValueOf calls.
func packValue(in reflect.Value, out io.Writer) error {
switch kind := in.Kind(); kind {
case reflect.Invalid: // nil
_, err := out.Write([]byte{0})
return err
case reflect.Bool:
if in.Bool() {
_, err := out.Write([]byte{2})
return err
} else {
_, err := out.Write([]byte{1})
return err
}
case reflect.Float64:
if _, err := out.Write([]byte{3}); err != nil {
return err
}
return binary.Write(out, binary.BigEndian, in.Float())
case reflect.Array, reflect.Slice:
if typ := in.Type(); typ.Elem() == tByte {
if kind == reflect.Array {
if !in.CanAddr() {
vNewElem := reflect.New(typ).Elem()
vNewElem.Set(in)
in = vNewElem
}
in = in.Slice(0, in.Len())
}
// Pack []byte as string, not array of numbers.
return packString(in.Convert(tBytes). // Support types.Binary
Interface().([]uint8), out)
}
if _, err := out.Write([]byte{5}); err != nil {
return err
}
l := in.Len()
if err := binary.Write(out, binary.BigEndian, uint64(l)); err != nil {
return err
}
for i := 0; i < l; i++ {
if err := packValue(in.Index(i), out); err != nil {
return err
}
}
// If there aren't any values to pack, ...
if l < 1 {
// ... create one and pack it - panics on disallowed type.
_ = packValue(reflect.Zero(in.Type().Elem()), io.Discard)
}
return nil
case reflect.Interface:
return packValue(in.Elem(), out)
case reflect.Map:
type kv struct {
key []byte
value reflect.Value
}
if _, err := out.Write([]byte{6}); err != nil {
return err
}
l := in.Len()
if err := binary.Write(out, binary.BigEndian, uint64(l)); err != nil {
return err
}
sorted := make([]kv, 0, l)
{
iter := in.MapRange()
for iter.Next() {
var packedKey []byte
if key := iter.Key(); key.Kind() == reflect.Array {
if typ := key.Type(); typ.Elem() == tByte {
if !key.CanAddr() {
vNewElem := reflect.New(typ).Elem()
vNewElem.Set(key)
key = vNewElem
}
packedKey = key.Slice(0, key.Len()).Interface().([]byte)
} else {
// Not just stringify the key (below), but also pack it (here) - panics on disallowed type.
_ = packValue(iter.Key(), io.Discard)
packedKey = []byte(fmt.Sprint(key.Interface()))
}
} else {
// Not just stringify the key (below), but also pack it (here) - panics on disallowed type.
_ = packValue(iter.Key(), io.Discard)
packedKey = []byte(fmt.Sprint(key.Interface()))
}
sorted = append(sorted, kv{packedKey, iter.Value()})
}
}
sort.Slice(sorted, func(i, j int) bool { return bytes.Compare(sorted[i].key, sorted[j].key) < 0 })
for _, kv := range sorted {
if err := binary.Write(out, binary.BigEndian, uint64(len(kv.key))); err != nil {
return err
}
if _, err := out.Write(kv.key); err != nil {
return err
}
if err := packValue(kv.value, out); err != nil {
return err
}
}
// If there aren't any key-value pairs to pack, ...
if l < 1 {
typ := in.Type()
// ... create one and pack it - panics on disallowed type.
_ = packValue(reflect.Zero(typ.Key()), io.Discard)
_ = packValue(reflect.Zero(typ.Elem()), io.Discard)
}
return nil
case reflect.Ptr:
if in.IsNil() {
err := packValue(reflect.Value{}, out)
// Create a fictive referenced value and pack it - panics on disallowed type.
_ = packValue(reflect.Zero(in.Type().Elem()), io.Discard)
return err
} else {
return packValue(in.Elem(), out)
}
case reflect.String:
return packString([]byte(in.String()), out)
default:
panic("bad type: " + in.Kind().String())
}
}
// packString deduplicates string packing of multiple locations in packValue.
func packString(in []byte, out io.Writer) error {
if _, err := out.Write([]byte{4}); err != nil {
return err
}
if err := binary.Write(out, binary.BigEndian, uint64(len(in))); err != nil {
return err
}
_, err := out.Write(in)
return err
}

View file

@ -1,195 +0,0 @@
package objectpacker
import (
"bytes"
"github.com/icinga/icinga-go-library/types"
"github.com/pkg/errors"
"io"
"testing"
)
// limitedWriter allows writing a specific amount of data.
type limitedWriter struct {
// limit specifies how many bytes to allow to write.
limit int
}
var _ io.Writer = (*limitedWriter)(nil)
// Write returns io.EOF once lw.limit is exceeded, nil otherwise.
func (lw *limitedWriter) Write(p []byte) (n int, err error) {
if len(p) <= lw.limit {
lw.limit -= len(p)
return len(p), nil
}
n = lw.limit
err = io.EOF
lw.limit = 0
return
}
func TestLimitedWriter_Write(t *testing.T) {
assertLimitedWriter_Write(t, 3, []byte{1, 2}, 2, nil, 1)
assertLimitedWriter_Write(t, 3, []byte{1, 2, 3}, 3, nil, 0)
assertLimitedWriter_Write(t, 3, []byte{1, 2, 3, 4}, 3, io.EOF, 0)
assertLimitedWriter_Write(t, 0, []byte{1}, 0, io.EOF, 0)
assertLimitedWriter_Write(t, 0, nil, 0, nil, 0)
}
func assertLimitedWriter_Write(t *testing.T, limitBefore int, p []byte, n int, err error, limitAfter int) {
t.Helper()
lw := limitedWriter{limitBefore}
actualN, actualErr := lw.Write(p)
if !errors.Is(actualErr, err) {
t.Errorf("_, err := (&limitedWriter{%d}).Write(%#v); err != %#v", limitBefore, p, err)
}
if actualN != n {
t.Errorf("n, _ := (&limitedWriter{%d}).Write(%#v); n != %d", limitBefore, p, n)
}
if lw.limit != limitAfter {
t.Errorf("lw := limitedWriter{%d}; lw.Write(%#v); lw.limit != %d", limitBefore, p, limitAfter)
}
}
func TestPackAny(t *testing.T) {
assertPackAny(t, nil, []byte{0})
assertPackAny(t, false, []byte{1})
assertPackAny(t, true, []byte{2})
assertPackAnyPanic(t, -42, 0)
assertPackAnyPanic(t, int8(-42), 0)
assertPackAnyPanic(t, int16(-42), 0)
assertPackAnyPanic(t, int32(-42), 0)
assertPackAnyPanic(t, int64(-42), 0)
assertPackAnyPanic(t, uint(42), 0)
assertPackAnyPanic(t, uint8(42), 0)
assertPackAnyPanic(t, uint16(42), 0)
assertPackAnyPanic(t, uint32(42), 0)
assertPackAnyPanic(t, uint64(42), 0)
assertPackAnyPanic(t, uintptr(42), 0)
assertPackAnyPanic(t, float32(-42.5), 0)
assertPackAny(t, -42.5, []byte{3, 0xc0, 0x45, 0x40, 0, 0, 0, 0, 0})
assertPackAnyPanic(t, []struct{}(nil), 9)
assertPackAnyPanic(t, []struct{}{}, 9)
assertPackAny(t, []interface{}{nil, true, -42.5}, []byte{
5, 0, 0, 0, 0, 0, 0, 0, 3,
0,
2,
3, 0xc0, 0x45, 0x40, 0, 0, 0, 0, 0,
})
assertPackAny(t, []string{"", "a"}, []byte{
5, 0, 0, 0, 0, 0, 0, 0, 2,
4, 0, 0, 0, 0, 0, 0, 0, 0,
4, 0, 0, 0, 0, 0, 0, 0, 1, 'a',
})
assertPackAnyPanic(t, []interface{}{0 + 0i}, 9)
assertPackAnyPanic(t, map[struct{}]struct{}(nil), 9)
assertPackAnyPanic(t, map[struct{}]struct{}{}, 9)
assertPackAny(t, map[interface{}]interface{}{true: "", "nil": -42.5}, []byte{
6, 0, 0, 0, 0, 0, 0, 0, 2,
0, 0, 0, 0, 0, 0, 0, 3, 'n', 'i', 'l',
3, 0xc0, 0x45, 0x40, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 4, 't', 'r', 'u', 'e',
4, 0, 0, 0, 0, 0, 0, 0, 0,
})
assertPackAny(t, map[string]float64{"": 42}, []byte{
6, 0, 0, 0, 0, 0, 0, 0, 1,
0, 0, 0, 0, 0, 0, 0, 0,
3, 0x40, 0x45, 0, 0, 0, 0, 0, 0,
})
assertPackAny(t, map[[1]byte]bool{{42}: true}, []byte{
6, 0, 0, 0, 0, 0, 0, 0, 1,
0, 0, 0, 0, 0, 0, 0, 1, 42,
2,
})
assertPackAnyPanic(t, map[struct{}]struct{}{{}: {}}, 9)
assertPackAny(t, (*string)(nil), []byte{0})
assertPackAnyPanic(t, (*int)(nil), 0)
assertPackAny(t, new(float64), []byte{3, 0, 0, 0, 0, 0, 0, 0, 0})
assertPackAny(t, "", []byte{4, 0, 0, 0, 0, 0, 0, 0, 0})
assertPackAny(t, "a", []byte{4, 0, 0, 0, 0, 0, 0, 0, 1, 'a'})
assertPackAny(t, "ä", []byte{4, 0, 0, 0, 0, 0, 0, 0, 2, 0xc3, 0xa4})
{
var binary [256]byte
for i := range binary {
binary[i] = byte(i)
}
assertPackAny(t, binary, append([]byte{4, 0, 0, 0, 0, 0, 0, 1, 0}, binary[:]...))
assertPackAny(t, binary[:], append([]byte{4, 0, 0, 0, 0, 0, 0, 1, 0}, binary[:]...))
assertPackAny(t, types.Binary(binary[:]), append([]byte{4, 0, 0, 0, 0, 0, 0, 1, 0}, binary[:]...))
}
{
type myByte byte
assertPackAnyPanic(t, []myByte(nil), 9)
}
assertPackAnyPanic(t, complex64(0+0i), 0)
assertPackAnyPanic(t, 0+0i, 0)
assertPackAnyPanic(t, make(chan struct{}), 0)
assertPackAnyPanic(t, func() {}, 0)
assertPackAnyPanic(t, struct{}{}, 0)
assertPackAnyPanic(t, uintptr(0), 0)
}
func assertPackAny(t *testing.T, in interface{}, out []byte) {
t.Helper()
{
buf := &bytes.Buffer{}
if err := PackAny(in, buf); err == nil {
if !bytes.Equal(buf.Bytes(), out) {
t.Errorf("buf := &bytes.Buffer{}; packAny(%#v, buf); !bytes.Equal(buf.Bytes(), %#v)", in, out)
}
} else {
t.Errorf("packAny(%#v, &bytes.Buffer{}) != nil", in)
}
}
for i := 0; i < len(out); i++ {
if !errors.Is(PackAny(in, &limitedWriter{i}), io.EOF) {
t.Errorf("packAny(%#v, &limitedWriter{%d}) != io.EOF", in, i)
}
}
}
func assertPackAnyPanic(t *testing.T, in interface{}, allowToWrite int) {
t.Helper()
for i := 0; i < allowToWrite; i++ {
if !errors.Is(PackAny(in, &limitedWriter{i}), io.EOF) {
t.Errorf("packAny(%#v, &limitedWriter{%d}) != io.EOF", in, i)
}
}
defer func() {
t.Helper()
if r := recover(); r == nil {
t.Errorf("packAny(%#v, &limitedWriter{%d}) didn't panic", in, allowToWrite)
}
}()
_ = PackAny(in, &limitedWriter{allowToWrite})
}

View file

@ -1,123 +0,0 @@
package periodic
import (
"context"
"sync"
"time"
)
// Option configures Start.
type Option interface {
apply(*periodic)
}
// Stopper implements the Stop method,
// which stops a periodic task from Start().
type Stopper interface {
Stop() // Stops a periodic task.
}
// Tick is the value for periodic task callbacks that
// contains the time of the tick and
// the time elapsed since the start of the periodic task.
type Tick struct {
Elapsed time.Duration
Time time.Time
}
// Immediate starts the periodic task immediately instead of after the first tick.
func Immediate() Option {
return optionFunc(func(p *periodic) {
p.immediate = true
})
}
// OnStop configures a callback that is executed when a periodic task is stopped or canceled.
func OnStop(f func(Tick)) Option {
return optionFunc(func(p *periodic) {
p.onStop = f
})
}
// Start starts a periodic task with a ticker at the specified interval,
// which executes the given callback after each tick.
// Pending tasks do not overlap, but could start immediately if
// the previous task(s) takes longer than the interval.
// Call Stop() on the return value in order to stop the ticker and to release associated resources.
// The interval must be greater than zero.
func Start(ctx context.Context, interval time.Duration, callback func(Tick), options ...Option) Stopper {
t := &periodic{
interval: interval,
callback: callback,
}
for _, option := range options {
option.apply(t)
}
ctx, cancelCtx := context.WithCancel(ctx)
start := time.Now()
go func() {
done := false
if !t.immediate {
select {
case <-time.After(interval):
case <-ctx.Done():
done = true
}
}
if !done {
ticker := time.NewTicker(t.interval)
defer ticker.Stop()
for tickTime := time.Now(); !done; {
t.callback(Tick{
Elapsed: tickTime.Sub(start),
Time: tickTime,
})
select {
case tickTime = <-ticker.C:
case <-ctx.Done():
done = true
}
}
}
if t.onStop != nil {
now := time.Now()
t.onStop(Tick{
Elapsed: now.Sub(start),
Time: now,
})
}
}()
return stoperFunc(func() {
t.stop.Do(cancelCtx)
})
}
type optionFunc func(*periodic)
func (f optionFunc) apply(p *periodic) {
f(p)
}
type stoperFunc func()
func (f stoperFunc) Stop() {
f()
}
type periodic struct {
interval time.Duration
callback func(Tick)
immediate bool
stop sync.Once
onStop func(Tick)
}

View file

@ -1,14 +0,0 @@
package redis
import "github.com/redis/go-redis/v9"
// Alias definitions of commonly used go-redis exports,
// so that only this redis package needs to be imported and not go-redis additionally.
type IntCmd = redis.IntCmd
type Pipeliner = redis.Pipeliner
type XAddArgs = redis.XAddArgs
type XMessage = redis.XMessage
type XReadArgs = redis.XReadArgs
var NewScript = redis.NewScript

View file

@ -1,277 +0,0 @@
package redis
import (
"context"
"crypto/tls"
"fmt"
"github.com/icinga/icinga-go-library/backoff"
"github.com/icinga/icinga-go-library/com"
"github.com/icinga/icinga-go-library/logging"
"github.com/icinga/icinga-go-library/periodic"
"github.com/icinga/icinga-go-library/retry"
"github.com/icinga/icinga-go-library/utils"
"github.com/pkg/errors"
"github.com/redis/go-redis/v9"
"go.uber.org/zap"
"golang.org/x/sync/errgroup"
"golang.org/x/sync/semaphore"
"net"
"time"
)
// Client is a wrapper around redis.Client with
// streaming and logging capabilities.
type Client struct {
*redis.Client
Options *Options
logger *logging.Logger
}
// NewClient returns a new Client wrapper for a pre-existing redis.Client.
func NewClient(client *redis.Client, logger *logging.Logger, options *Options) *Client {
return &Client{Client: client, logger: logger, Options: options}
}
// NewClientFromConfig returns a new Client from Config.
func NewClientFromConfig(c *Config, logger *logging.Logger) (*Client, error) {
tlsConfig, err := c.TlsOptions.MakeConfig(c.Host)
if err != nil {
return nil, err
}
var dialer ctxDialerFunc
dl := &net.Dialer{Timeout: 15 * time.Second}
if tlsConfig == nil {
dialer = dl.DialContext
} else {
dialer = (&tls.Dialer{NetDialer: dl, Config: tlsConfig}).DialContext
}
options := &redis.Options{
Dialer: dialWithLogging(dialer, logger),
Password: c.Password,
DB: 0, // Use default DB,
ReadTimeout: c.Options.Timeout,
TLSConfig: tlsConfig,
}
if utils.IsUnixAddr(c.Host) {
options.Network = "unix"
options.Addr = c.Host
} else {
port := c.Port
if port == 0 {
port = 6379
}
options.Network = "tcp"
options.Addr = net.JoinHostPort(c.Host, fmt.Sprint(port))
}
client := redis.NewClient(options)
options = client.Options()
options.PoolSize = utils.MaxInt(32, options.PoolSize)
options.MaxRetries = options.PoolSize + 1 // https://github.com/go-redis/redis/issues/1737
return NewClient(redis.NewClient(options), logger, &c.Options), nil
}
// GetAddr returns the Redis host:port or Unix socket address.
func (c *Client) GetAddr() string {
return c.Client.Options().Addr
}
// HPair defines Redis hashes field-value pairs.
type HPair struct {
Field string
Value string
}
// HYield yields HPair field-value pairs for all fields in the hash stored at key.
func (c *Client) HYield(ctx context.Context, key string) (<-chan HPair, <-chan error) {
pairs := make(chan HPair, c.Options.HScanCount)
return pairs, com.WaitAsync(com.WaiterFunc(func() error {
var counter com.Counter
defer c.log(ctx, key, &counter).Stop()
defer close(pairs)
seen := make(map[string]struct{})
var cursor uint64
var err error
var page []string
for {
cmd := c.HScan(ctx, key, cursor, "", int64(c.Options.HScanCount))
page, cursor, err = cmd.Result()
if err != nil {
return WrapCmdErr(cmd)
}
for i := 0; i < len(page); i += 2 {
if _, ok := seen[page[i]]; ok {
// Ignore duplicate returned by HSCAN.
continue
}
seen[page[i]] = struct{}{}
select {
case pairs <- HPair{
Field: page[i],
Value: page[i+1],
}:
counter.Inc()
case <-ctx.Done():
return ctx.Err()
}
}
if cursor == 0 {
break
}
}
return nil
}))
}
// HMYield yields HPair field-value pairs for the specified fields in the hash stored at key.
func (c *Client) HMYield(ctx context.Context, key string, fields ...string) (<-chan HPair, <-chan error) {
pairs := make(chan HPair)
return pairs, com.WaitAsync(com.WaiterFunc(func() error {
var counter com.Counter
defer c.log(ctx, key, &counter).Stop()
g, ctx := errgroup.WithContext(ctx)
defer func() {
// Wait until the group is done so that we can safely close the pairs channel,
// because on error, sem.Acquire will return before calling g.Wait(),
// which can result in goroutines working on a closed channel.
_ = g.Wait()
close(pairs)
}()
// Use context from group.
batches := utils.BatchSliceOfStrings(ctx, fields, c.Options.HMGetCount)
sem := semaphore.NewWeighted(int64(c.Options.MaxHMGetConnections))
for batch := range batches {
if err := sem.Acquire(ctx, 1); err != nil {
return errors.Wrap(err, "can't acquire semaphore")
}
batch := batch
g.Go(func() error {
defer sem.Release(1)
cmd := c.HMGet(ctx, key, batch...)
vals, err := cmd.Result()
if err != nil {
return WrapCmdErr(cmd)
}
for i, v := range vals {
if v == nil {
c.logger.Warnf("HMGET %s: field %#v missing", key, batch[i])
continue
}
select {
case pairs <- HPair{
Field: batch[i],
Value: v.(string),
}:
counter.Inc()
case <-ctx.Done():
return ctx.Err()
}
}
return nil
})
}
return g.Wait()
}))
}
// XReadUntilResult (repeatedly) calls XREAD with the specified arguments until a result is returned.
// Each call blocks at most for the duration specified in Options.BlockTimeout until data
// is available before it times out and the next call is made.
// This also means that an already set block timeout is overridden.
func (c *Client) XReadUntilResult(ctx context.Context, a *redis.XReadArgs) ([]redis.XStream, error) {
a.Block = c.Options.BlockTimeout
for {
cmd := c.XRead(ctx, a)
streams, err := cmd.Result()
if err != nil {
if errors.Is(err, redis.Nil) {
continue
}
return streams, WrapCmdErr(cmd)
}
return streams, nil
}
}
func (c *Client) log(ctx context.Context, key string, counter *com.Counter) periodic.Stopper {
return periodic.Start(ctx, c.logger.Interval(), func(tick periodic.Tick) {
// We may never get to progress logging here,
// as fetching should be completed before the interval expires,
// but if it does, it is good to have this log message.
if count := counter.Reset(); count > 0 {
c.logger.Debugf("Fetched %d items from %s", count, key)
}
}, periodic.OnStop(func(tick periodic.Tick) {
c.logger.Debugf("Finished fetching from %s with %d items in %s", key, counter.Total(), tick.Elapsed)
}))
}
type ctxDialerFunc = func(ctx context.Context, network, addr string) (net.Conn, error)
// dialWithLogging returns a Redis Dialer with logging capabilities.
func dialWithLogging(dialer ctxDialerFunc, logger *logging.Logger) ctxDialerFunc {
// dial behaves like net.Dialer#DialContext,
// but re-tries on common errors that are considered retryable.
return func(ctx context.Context, network, addr string) (conn net.Conn, err error) {
err = retry.WithBackoff(
ctx,
func(ctx context.Context) (err error) {
conn, err = dialer(ctx, network, addr)
return
},
retry.Retryable,
backoff.NewExponentialWithJitter(1*time.Millisecond, 1*time.Second),
retry.Settings{
Timeout: retry.DefaultTimeout,
OnRetryableError: func(_ time.Duration, _ uint64, err, lastErr error) {
if lastErr == nil || err.Error() != lastErr.Error() {
logger.Warnw("Can't connect to Redis. Retrying", zap.Error(err))
}
},
OnSuccess: func(elapsed time.Duration, attempt uint64, _ error) {
if attempt > 1 {
logger.Infow("Reconnected to Redis",
zap.Duration("after", elapsed), zap.Uint64("attempts", attempt))
}
},
},
)
err = errors.Wrap(err, "can't connect to Redis")
return
}
}

View file

@ -1,59 +0,0 @@
package redis
import (
"github.com/icinga/icinga-go-library/config"
"github.com/pkg/errors"
"time"
)
// Options define user configurable Redis options.
type Options struct {
BlockTimeout time.Duration `yaml:"block_timeout" default:"1s"`
HMGetCount int `yaml:"hmget_count" default:"4096"`
HScanCount int `yaml:"hscan_count" default:"4096"`
MaxHMGetConnections int `yaml:"max_hmget_connections" default:"8"`
Timeout time.Duration `yaml:"timeout" default:"30s"`
XReadCount int `yaml:"xread_count" default:"4096"`
}
// Validate checks constraints in the supplied Redis options and returns an error if they are violated.
func (o *Options) Validate() error {
if o.BlockTimeout <= 0 {
return errors.New("block_timeout must be positive")
}
if o.HMGetCount < 1 {
return errors.New("hmget_count must be at least 1")
}
if o.HScanCount < 1 {
return errors.New("hscan_count must be at least 1")
}
if o.MaxHMGetConnections < 1 {
return errors.New("max_hmget_connections must be at least 1")
}
if o.Timeout == 0 {
return errors.New("timeout cannot be 0. Configure a value greater than zero, or use -1 for no timeout")
}
if o.XReadCount < 1 {
return errors.New("xread_count must be at least 1")
}
return nil
}
// Config defines Config client configuration.
type Config struct {
Host string `yaml:"host"`
Port int `yaml:"port"`
Password string `yaml:"password"`
TlsOptions config.TLS `yaml:",inline"`
Options Options `yaml:"options"`
}
// Validate checks constraints in the supplied Config configuration and returns an error if they are violated.
func (r *Config) Validate() error {
if r.Host == "" {
return errors.New("Redis host missing")
}
return r.Options.Validate()
}

View file

@ -1,20 +0,0 @@
package redis
// Streams represents a Redis stream key to ID mapping.
type Streams map[string]string
// Option returns the Redis stream key to ID mapping
// as a slice of stream keys followed by their IDs
// that is compatible for the Redis STREAMS option.
func (s Streams) Option() []string {
// len*2 because we're appending the IDs later.
streams := make([]string, 0, len(s)*2)
ids := make([]string, 0, len(s))
for key, id := range s {
streams = append(streams, key)
ids = append(ids, id)
}
return append(streams, ids...)
}

View file

@ -1,22 +0,0 @@
package redis
import (
"context"
"github.com/icinga/icinga-go-library/utils"
"github.com/pkg/errors"
"github.com/redis/go-redis/v9"
)
// WrapCmdErr adds the command itself and
// the stack of the current goroutine to the command's error if any.
func WrapCmdErr(cmd redis.Cmder) error {
err := cmd.Err()
if err != nil {
err = errors.Wrapf(err, "can't perform %q", utils.Ellipsize(
redis.NewCmd(context.Background(), cmd.Args()).String(), // Omits error in opposite to cmd.String()
100,
))
}
return err
}

View file

@ -1,201 +0,0 @@
package retry
import (
"context"
"database/sql/driver"
"github.com/go-sql-driver/mysql"
"github.com/icinga/icinga-go-library/backoff"
"github.com/lib/pq"
"github.com/pkg/errors"
"io"
"net"
"syscall"
"time"
)
// DefaultTimeout is our opinionated default timeout for retrying database and Redis operations.
const DefaultTimeout = 5 * time.Minute
// RetryableFunc is a retryable function.
type RetryableFunc func(context.Context) error
// IsRetryable checks whether a new attempt can be started based on the error passed.
type IsRetryable func(error) bool
// OnRetryableErrorFunc is called if a retryable error occurs.
type OnRetryableErrorFunc func(elapsed time.Duration, attempt uint64, err, lastErr error)
// OnSuccessFunc is called once the operation succeeds.
type OnSuccessFunc func(elapsed time.Duration, attempt uint64, lastErr error)
// Settings aggregates optional settings for WithBackoff.
type Settings struct {
// If >0, Timeout lets WithBackoff stop retrying gracefully once elapsed based on the following criteria:
// * If the execution of RetryableFunc has taken longer than Timeout, no further attempts are made.
// * If Timeout elapses during the sleep phase between retries, one final retry is attempted.
// * RetryableFunc is always granted its full execution time and is not canceled if it exceeds Timeout.
// This means that WithBackoff may not stop exactly after Timeout expires,
// or may not retry at all if the first execution of RetryableFunc already takes longer than Timeout.
Timeout time.Duration
OnRetryableError OnRetryableErrorFunc
OnSuccess OnSuccessFunc
}
// WithBackoff retries the passed function if it fails and the error allows it to retry.
// The specified backoff policy is used to determine how long to sleep between attempts.
func WithBackoff(
ctx context.Context, retryableFunc RetryableFunc, retryable IsRetryable, b backoff.Backoff, settings Settings,
) (err error) {
// Channel for retry deadline, which is set to the channel of NewTimer() if a timeout is configured,
// otherwise nil, so that it blocks forever if there is no timeout.
var timeout <-chan time.Time
if settings.Timeout > 0 {
t := time.NewTimer(settings.Timeout)
defer t.Stop()
timeout = t.C
}
start := time.Now()
timedOut := false
for attempt := uint64(1); ; /* true */ attempt++ {
prevErr := err
if err = retryableFunc(ctx); err == nil {
if settings.OnSuccess != nil {
settings.OnSuccess(time.Since(start), attempt, prevErr)
}
return
}
// Retryable function may have exited prematurely due to context errors.
// We explicitly check the context error here, as the error returned by the retryable function can pass the
// error.Is() checks even though it is not a real context error, e.g.
// https://cs.opensource.google/go/go/+/refs/tags/go1.22.2:src/net/net.go;l=422
// https://cs.opensource.google/go/go/+/refs/tags/go1.22.2:src/net/net.go;l=601
if errors.Is(ctx.Err(), context.DeadlineExceeded) || errors.Is(ctx.Err(), context.Canceled) {
if prevErr != nil {
err = errors.Wrap(err, prevErr.Error())
}
return
}
if !retryable(err) {
err = errors.Wrap(err, "can't retry")
return
}
select {
case <-timeout:
// Stop retrying immediately if executing the retryable function took longer than the timeout.
timedOut = true
default:
}
if timedOut {
err = errors.Wrap(err, "retry deadline exceeded")
return
}
if settings.OnRetryableError != nil {
settings.OnRetryableError(time.Since(start), attempt, err, prevErr)
}
select {
case <-time.After(b(attempt)):
case <-timeout:
// Do not stop retrying immediately, but start one last attempt to mitigate timing issues where
// the timeout expires while waiting for the next attempt and
// therefore no retries have happened during this possibly long period.
timedOut = true
case <-ctx.Done():
err = errors.Wrap(ctx.Err(), err.Error())
return
}
}
}
// ResetTimeout changes the possibly expired timer t to expire after duration d.
//
// If the timer has already expired and nothing has been received from its channel,
// it is automatically drained as if the timer had never expired.
func ResetTimeout(t *time.Timer, d time.Duration) {
if !t.Stop() {
<-t.C
}
t.Reset(d)
}
// Retryable returns true for common errors that are considered retryable,
// i.e. temporary, timeout, DNS, connection refused and reset, host down and unreachable and
// network down and unreachable errors. In addition, any database error is considered retryable.
func Retryable(err error) bool {
var temporary interface {
Temporary() bool
}
if errors.As(err, &temporary) && temporary.Temporary() {
return true
}
var timeout interface {
Timeout() bool
}
if errors.As(err, &timeout) && timeout.Timeout() {
return true
}
var dnsError *net.DNSError
if errors.As(err, &dnsError) {
return true
}
var opError *net.OpError
if errors.As(err, &opError) {
// OpError provides Temporary() and Timeout(), but not Unwrap(),
// so we have to extract the underlying error ourselves to also check for ECONNREFUSED,
// which is not considered temporary or timed out by Go.
err = opError.Err
}
if errors.Is(err, syscall.ECONNREFUSED) || errors.Is(err, syscall.ENOENT) {
// syscall errors provide Temporary() and Timeout(),
// which do not include ECONNREFUSED or ENOENT, so we check these ourselves.
return true
}
if errors.Is(err, syscall.ECONNRESET) {
// ECONNRESET is treated as a temporary error by Go only if it comes from calling accept.
return true
}
if errors.Is(err, syscall.EHOSTDOWN) || errors.Is(err, syscall.EHOSTUNREACH) {
return true
}
if errors.Is(err, syscall.ENETDOWN) || errors.Is(err, syscall.ENETUNREACH) {
return true
}
if errors.Is(err, syscall.EPIPE) {
return true
}
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
return true
}
if errors.Is(err, driver.ErrBadConn) {
return true
}
if errors.Is(err, mysql.ErrInvalidConn) {
return true
}
var mye *mysql.MySQLError
var pqe *pq.Error
if errors.As(err, &mye) || errors.As(err, &pqe) {
return true
}
return false
}

View file

@ -1,56 +0,0 @@
// Package strcase implements functions to convert a camelCase UTF-8 string into various cases.
//
// New delimiters will be inserted based on the following transitions:
// - On any change from lowercase to uppercase letter.
// - On any change from number to uppercase letter.
package strcase
import (
"strings"
"unicode"
)
// Delimited converts a string to delimited.lower.case, here using `.` as delimiter.
func Delimited(s string, d rune) string {
return convert(s, unicode.LowerCase, d)
}
// ScreamingDelimited converts a string to DELIMITED.UPPER.CASE, here using `.` as delimiter.
func ScreamingDelimited(s string, d rune) string {
return convert(s, unicode.UpperCase, d)
}
// Snake converts a string to snake_case.
func Snake(s string) string {
return Delimited(s, '_')
}
// ScreamingSnake converts a string to SCREAMING_SNAKE_CASE.
func ScreamingSnake(s string) string {
return ScreamingDelimited(s, '_')
}
// convert converts a camelCase UTF-8 string into various cases.
// _case must be unicode.LowerCase or unicode.UpperCase.
func convert(s string, _case int, d rune) string {
if len(s) == 0 {
return s
}
n := strings.Builder{}
n.Grow(len(s) + 2) // Allow adding at least 2 delimiters without another allocation.
var prevRune rune
for i, r := range s {
if i > 0 && unicode.IsUpper(r) && (unicode.IsNumber(prevRune) || unicode.IsLower(prevRune)) {
n.WriteRune(d)
}
n.WriteRune(unicode.To(_case, r))
prevRune = r
}
return n.String()
}

View file

@ -1,58 +0,0 @@
package strcase
import (
"strings"
"testing"
)
var tests = [][]string{
{"", ""},
{"Test", "test"},
{"test", "test"},
{"testCase", "test_case"},
{"test_case", "test_case"},
{"TestCase", "test_case"},
{"Test_Case", "test_case"},
{"ID", "id"},
{"userID", "user_id"},
{"UserID", "user_id"},
{"ManyManyWords", "many_many_words"},
{"manyManyWords", "many_many_words"},
{"icinga2", "icinga2"},
{"Icinga2Version", "icinga2_version"},
{"k8sVersion", "k8s_version"},
{"1234", "1234"},
{"a1b2c3d4", "a1b2c3d4"},
{"with1234digits", "with1234digits"},
{"with1234Digits", "with1234_digits"},
{"IPv4", "ipv4"},
{"IPv4Address", "ipv4_address"},
{"caféCrème", "café_crème"},
{"0℃", "0℃"},
{"~0", "~0"},
{"icinga💯points", "icinga💯points"},
{"😃🙃😀", "😃🙃😀"},
{"こんにちは", "こんにちは"},
{"\xff\xfe\xfd", "<22><><EFBFBD>"},
{"\xff", "<22>"},
}
func TestSnake(t *testing.T) {
for _, test := range tests {
s, expected := test[0], test[1]
actual := Snake(s)
if actual != expected {
t.Errorf("%q: %q != %q", s, actual, expected)
}
}
}
func TestScreamingSnake(t *testing.T) {
for _, test := range tests {
s, expected := test[0], strings.ToUpper(test[1])
actual := ScreamingSnake(s)
if actual != expected {
t.Errorf("%q: %q != %q", s, actual, expected)
}
}
}

View file

@ -1,176 +0,0 @@
package structify
import (
"encoding"
"fmt"
"github.com/pkg/errors"
"golang.org/x/exp/constraints"
"reflect"
"strconv"
"strings"
"unsafe"
)
// structBranch represents either a leaf or a subTree.
type structBranch struct {
// field specifies the struct field index.
field int
// leaf specifies the map key to parse the struct field from.
leaf string
// subTree specifies the struct field's inner tree.
subTree []structBranch
}
type MapStructifier = func(map[string]interface{}) (interface{}, error)
// MakeMapStructifier builds a function which parses a map's string values into a new struct of type t
// and returns a pointer to it. tag specifies which tag connects struct fields to map keys.
// MakeMapStructifier panics if it detects an unsupported type (suitable for usage in init() or global vars).
func MakeMapStructifier(t reflect.Type, tag string, initer func(any)) MapStructifier {
tree := buildStructTree(t, tag)
return func(kv map[string]interface{}) (interface{}, error) {
vPtr := reflect.New(t)
ptr := vPtr.Interface()
if initer != nil {
initer(ptr)
}
vPtrElem := vPtr.Elem()
err := errors.Wrapf(structifyMapByTree(kv, tree, vPtrElem, vPtrElem, new([]int)), "can't structify map %#v by tree %#v", kv, tree)
return ptr, err
}
}
// buildStructTree assembles a tree which represents the struct t based on tag.
func buildStructTree(t reflect.Type, tag string) []structBranch {
var tree []structBranch
numFields := t.NumField()
for i := 0; i < numFields; i++ {
if field := t.Field(i); field.PkgPath == "" {
switch tagValue := field.Tag.Get(tag); tagValue {
case "", "-":
case ",inline":
if subTree := buildStructTree(field.Type, tag); subTree != nil {
tree = append(tree, structBranch{i, "", subTree})
}
default:
// If parseString doesn't support *T, it'll panic.
_ = parseString("", reflect.New(field.Type).Interface())
tree = append(tree, structBranch{i, tagValue, nil})
}
}
}
return tree
}
// structifyMapByTree parses src's string values into the struct dest according to tree's specification.
func structifyMapByTree(src map[string]interface{}, tree []structBranch, dest, root reflect.Value, stack *[]int) error {
*stack = append(*stack, 0)
defer func() {
*stack = (*stack)[:len(*stack)-1]
}()
for _, branch := range tree {
(*stack)[len(*stack)-1] = branch.field
if branch.subTree == nil {
if v, ok := src[branch.leaf]; ok {
if vs, ok := v.(string); ok {
if err := parseString(vs, dest.Field(branch.field).Addr().Interface()); err != nil {
rt := root.Type()
typ := rt
var path []string
for _, i := range *stack {
f := typ.Field(i)
path = append(path, f.Name)
typ = f.Type
}
return errors.Wrapf(err, "can't parse %s into the %s %s#%s: %s",
branch.leaf, typ.Name(), rt.Name(), strings.Join(path, "."), vs)
}
}
}
} else if err := structifyMapByTree(src, branch.subTree, dest.Field(branch.field), root, stack); err != nil {
return err
}
}
return nil
}
// parseString parses src into *dest.
func parseString(src string, dest interface{}) error {
switch ptr := dest.(type) {
case encoding.TextUnmarshaler:
return ptr.UnmarshalText([]byte(src))
case *string:
*ptr = src
return nil
case **string:
*ptr = &src
return nil
case *uint8:
return parseUint(src, ptr)
case *uint16:
return parseUint(src, ptr)
case *uint32:
return parseUint(src, ptr)
case *uint64:
return parseUint(src, ptr)
case *int8:
return parseInt(src, ptr)
case *int16:
return parseInt(src, ptr)
case *int32:
return parseInt(src, ptr)
case *int64:
return parseInt(src, ptr)
case *float32:
return parseFloat(src, ptr)
case *float64:
return parseFloat(src, ptr)
default:
panic(fmt.Sprintf("unsupported type: %T", dest))
}
}
// parseUint parses src into *dest.
func parseUint[T constraints.Unsigned](src string, dest *T) error {
i, err := strconv.ParseUint(src, 10, bitSizeOf[T]())
if err == nil {
*dest = T(i)
}
return err
}
// parseInt parses src into *dest.
func parseInt[T constraints.Signed](src string, dest *T) error {
i, err := strconv.ParseInt(src, 10, bitSizeOf[T]())
if err == nil {
*dest = T(i)
}
return err
}
// parseFloat parses src into *dest.
func parseFloat[T constraints.Float](src string, dest *T) error {
f, err := strconv.ParseFloat(src, bitSizeOf[T]())
if err == nil {
*dest = T(f)
}
return err
}
func bitSizeOf[T any]() int {
var x T
return int(unsafe.Sizeof(x) * 8)
}

View file

@ -1,125 +0,0 @@
package types
import (
"bytes"
"database/sql"
"database/sql/driver"
"encoding"
"encoding/hex"
"encoding/json"
"fmt"
"github.com/pkg/errors"
)
// Binary nullable byte string. Hex as JSON.
type Binary []byte
// nullBinary for validating whether a Binary is valid.
var nullBinary Binary
// Valid returns whether the Binary is valid.
func (binary Binary) Valid() bool {
return !bytes.Equal(binary, nullBinary)
}
// String returns the hex string representation form of the Binary.
func (binary Binary) String() string {
return hex.EncodeToString(binary)
}
// MarshalText implements a custom marhsal function to encode
// the Binary as hex. MarshalText implements the
// encoding.TextMarshaler interface.
func (binary Binary) MarshalText() ([]byte, error) {
return []byte(binary.String()), nil
}
// UnmarshalText implements a custom unmarshal function to decode
// hex into a Binary. UnmarshalText implements the
// encoding.TextUnmarshaler interface.
func (binary *Binary) UnmarshalText(text []byte) error {
b := make([]byte, hex.DecodedLen(len(text)))
_, err := hex.Decode(b, text)
if err != nil {
return CantDecodeHex(err, string(text))
}
*binary = b
return nil
}
// MarshalJSON implements a custom marshal function to encode the Binary
// as a hex string. MarshalJSON implements the json.Marshaler interface.
// Supports JSON null.
func (binary Binary) MarshalJSON() ([]byte, error) {
if !binary.Valid() {
return []byte("null"), nil
}
return MarshalJSON(binary.String())
}
// UnmarshalJSON implements a custom unmarshal function to decode
// a JSON hex string into a Binary. UnmarshalJSON implements the
// json.Unmarshaler interface. Supports JSON null.
func (binary *Binary) UnmarshalJSON(data []byte) error {
if string(data) == "null" || len(data) == 0 {
return nil
}
var s string
if err := UnmarshalJSON(data, &s); err != nil {
return err
}
b, err := hex.DecodeString(s)
if err != nil {
return CantDecodeHex(err, s)
}
*binary = b
return nil
}
// Scan implements the sql.Scanner interface.
// Supports SQL NULL.
func (binary *Binary) Scan(src interface{}) error {
switch src := src.(type) {
case nil:
return nil
case []byte:
if len(src) == 0 {
return nil
}
b := make([]byte, len(src))
copy(b, src)
*binary = b
default:
return errors.Errorf("unable to scan type %T into Binary", src)
}
return nil
}
// Value implements the driver.Valuer interface.
// Supports SQL NULL.
func (binary Binary) Value() (driver.Value, error) {
if !binary.Valid() {
return nil, nil
}
return []byte(binary), nil
}
// Assert interface compliance.
var (
_ fmt.Stringer = (*Binary)(nil)
_ encoding.TextMarshaler = (*Binary)(nil)
_ encoding.TextUnmarshaler = (*Binary)(nil)
_ json.Marshaler = (*Binary)(nil)
_ json.Unmarshaler = (*Binary)(nil)
_ sql.Scanner = (*Binary)(nil)
_ driver.Valuer = (*Binary)(nil)
)

View file

@ -1,29 +0,0 @@
package types
import (
"github.com/stretchr/testify/require"
"testing"
"unicode/utf8"
)
func TestBinary_MarshalJSON(t *testing.T) {
subtests := []struct {
name string
input Binary
output string
}{
{"nil", nil, `null`},
{"empty", make(Binary, 0, 1), `null`},
{"space", Binary(" "), `"20"`},
}
for _, st := range subtests {
t.Run(st.name, func(t *testing.T) {
actual, err := st.input.MarshalJSON()
require.NoError(t, err)
require.True(t, utf8.Valid(actual))
require.Equal(t, st.output, string(actual))
})
}
}

View file

@ -1,104 +0,0 @@
package types
import (
"database/sql"
"database/sql/driver"
"encoding"
"encoding/json"
"github.com/pkg/errors"
"strconv"
)
var (
enum = map[bool]string{
true: "y",
false: "n",
}
)
// Bool represents a bool for ENUM ('y', 'n'), which can be NULL.
type Bool struct {
Bool bool
Valid bool // Valid is true if Bool is not NULL
}
// MarshalJSON implements the json.Marshaler interface.
func (b Bool) MarshalJSON() ([]byte, error) {
if !b.Valid {
return []byte("null"), nil
}
return MarshalJSON(b.Bool)
}
// UnmarshalText implements the encoding.TextUnmarshaler interface.
func (b *Bool) UnmarshalText(text []byte) error {
parsed, err := strconv.ParseUint(string(text), 10, 64)
if err != nil {
return CantParseUint64(err, string(text))
}
*b = Bool{parsed != 0, true}
return nil
}
// UnmarshalJSON implements the json.Unmarshaler interface.
func (b *Bool) UnmarshalJSON(data []byte) error {
if string(data) == "null" || len(data) == 0 {
return nil
}
if err := UnmarshalJSON(data, &b.Bool); err != nil {
return err
}
b.Valid = true
return nil
}
// Scan implements the sql.Scanner interface.
// Supports SQL NULL.
func (b *Bool) Scan(src interface{}) error {
if src == nil {
b.Bool, b.Valid = false, false
return nil
}
v, ok := src.([]byte)
if !ok {
return errors.Errorf("bad []byte type assertion from %#v", src)
}
switch string(v) {
case "y":
b.Bool = true
case "n":
b.Bool = false
default:
return errors.Errorf("bad bool %#v", v)
}
b.Valid = true
return nil
}
// Value implements the driver.Valuer interface.
// Supports SQL NULL.
func (b Bool) Value() (driver.Value, error) {
if !b.Valid {
return nil, nil
}
return enum[b.Bool], nil
}
// Assert interface compliance.
var (
_ json.Marshaler = (*Bool)(nil)
_ encoding.TextUnmarshaler = (*Bool)(nil)
_ json.Unmarshaler = (*Bool)(nil)
_ sql.Scanner = (*Bool)(nil)
_ driver.Valuer = (*Bool)(nil)
)

View file

@ -1,30 +0,0 @@
package types
import (
"fmt"
"github.com/stretchr/testify/require"
"testing"
"unicode/utf8"
)
func TestBool_MarshalJSON(t *testing.T) {
subtests := []struct {
input Bool
output string
}{
{Bool{Bool: false, Valid: false}, `null`},
{Bool{Bool: false, Valid: true}, `false`},
{Bool{Bool: true, Valid: false}, `null`},
{Bool{Bool: true, Valid: true}, `true`},
}
for _, st := range subtests {
t.Run(fmt.Sprintf("Bool-%#v_Valid-%#v", st.input.Bool, st.input.Valid), func(t *testing.T) {
actual, err := st.input.MarshalJSON()
require.NoError(t, err)
require.True(t, utf8.Valid(actual))
require.Equal(t, st.output, string(actual))
})
}
}

View file

@ -1,67 +0,0 @@
package types
import (
"bytes"
"database/sql"
"database/sql/driver"
"encoding"
"encoding/json"
"strconv"
)
// Float adds JSON support to sql.NullFloat64.
type Float struct {
sql.NullFloat64
}
// MarshalJSON implements the json.Marshaler interface.
// Supports JSON null.
func (f Float) MarshalJSON() ([]byte, error) {
var v interface{}
if f.Valid {
v = f.Float64
}
return MarshalJSON(v)
}
// UnmarshalText implements the encoding.TextUnmarshaler interface.
func (f *Float) UnmarshalText(text []byte) error {
parsed, err := strconv.ParseFloat(string(text), 64)
if err != nil {
return CantParseFloat64(err, string(text))
}
*f = Float{sql.NullFloat64{
Float64: parsed,
Valid: true,
}}
return nil
}
// UnmarshalJSON implements the json.Unmarshaler interface.
// Supports JSON null.
func (f *Float) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if bytes.HasPrefix(data, []byte{'n'}) {
return nil
}
if err := UnmarshalJSON(data, &f.Float64); err != nil {
return err
}
f.Valid = true
return nil
}
// Assert interface compliance.
var (
_ json.Marshaler = Float{}
_ encoding.TextUnmarshaler = (*Float)(nil)
_ json.Unmarshaler = (*Float)(nil)
_ driver.Valuer = Float{}
_ sql.Scanner = (*Float)(nil)
)

View file

@ -1,67 +0,0 @@
package types
import (
"bytes"
"database/sql"
"database/sql/driver"
"encoding"
"encoding/json"
"strconv"
)
// Int adds JSON support to sql.NullInt64.
type Int struct {
sql.NullInt64
}
// MarshalJSON implements the json.Marshaler interface.
// Supports JSON null.
func (i Int) MarshalJSON() ([]byte, error) {
var v interface{}
if i.Valid {
v = i.Int64
}
return MarshalJSON(v)
}
// UnmarshalText implements the encoding.TextUnmarshaler interface.
func (i *Int) UnmarshalText(text []byte) error {
parsed, err := strconv.ParseInt(string(text), 10, 64)
if err != nil {
return CantParseInt64(err, string(text))
}
*i = Int{sql.NullInt64{
Int64: parsed,
Valid: true,
}}
return nil
}
// UnmarshalJSON implements the json.Unmarshaler interface.
// Supports JSON null.
func (i *Int) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if bytes.HasPrefix(data, []byte{'n'}) {
return nil
}
if err := UnmarshalJSON(data, &i.Int64); err != nil {
return err
}
i.Valid = true
return nil
}
// Assert interface compliance.
var (
_ json.Marshaler = Int{}
_ json.Unmarshaler = (*Int)(nil)
_ encoding.TextUnmarshaler = (*Int)(nil)
_ driver.Valuer = Int{}
_ sql.Scanner = (*Int)(nil)
)

View file

@ -1,81 +0,0 @@
package types
import (
"bytes"
"database/sql"
"database/sql/driver"
"encoding"
"encoding/json"
"strings"
)
// String adds JSON support to sql.NullString.
type String struct {
sql.NullString
}
// MakeString constructs a new non-NULL String from s.
func MakeString(s string) String {
return String{sql.NullString{
String: s,
Valid: true,
}}
}
// MarshalJSON implements the json.Marshaler interface.
// Supports JSON null.
func (s String) MarshalJSON() ([]byte, error) {
var v interface{}
if s.Valid {
v = s.String
}
return MarshalJSON(v)
}
// UnmarshalText implements the encoding.TextUnmarshaler interface.
func (s *String) UnmarshalText(text []byte) error {
*s = String{sql.NullString{
String: string(text),
Valid: true,
}}
return nil
}
// UnmarshalJSON implements the json.Unmarshaler interface.
// Supports JSON null.
func (s *String) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if bytes.HasPrefix(data, []byte{'n'}) {
return nil
}
if err := UnmarshalJSON(data, &s.String); err != nil {
return err
}
s.Valid = true
return nil
}
// Value implements the driver.Valuer interface.
// Supports SQL NULL.
func (s String) Value() (driver.Value, error) {
if !s.Valid {
return nil, nil
}
// PostgreSQL does not allow null bytes in varchar, char and text fields.
return strings.ReplaceAll(s.String, "\x00", ""), nil
}
// Assert interface compliance.
var (
_ json.Marshaler = String{}
_ encoding.TextUnmarshaler = (*String)(nil)
_ json.Unmarshaler = (*String)(nil)
_ driver.Valuer = String{}
_ sql.Scanner = (*String)(nil)
)

View file

@ -1,116 +0,0 @@
package types
import (
"bytes"
"database/sql"
"database/sql/driver"
"encoding"
"encoding/json"
"github.com/pkg/errors"
"math"
"strconv"
"time"
)
// UnixMilli is a nullable millisecond UNIX timestamp in databases and JSON.
type UnixMilli time.Time
// Time returns the time.Time conversion of UnixMilli.
func (t UnixMilli) Time() time.Time {
return time.Time(t)
}
// MarshalJSON implements the json.Marshaler interface.
// Marshals to milliseconds. Supports JSON null.
func (t UnixMilli) MarshalJSON() ([]byte, error) {
if time.Time(t).IsZero() {
return []byte("null"), nil
}
return []byte(strconv.FormatInt(t.Time().UnixMilli(), 10)), nil
}
// UnmarshalJSON implements the json.Unmarshaler interface.
// Unmarshals from milliseconds. Supports JSON null.
func (t *UnixMilli) UnmarshalJSON(data []byte) error {
if bytes.Equal(data, []byte("null")) || len(data) == 0 {
return nil
}
return t.fromByteString(data)
}
// MarshalText implements the encoding.TextMarshaler interface.
func (t UnixMilli) MarshalText() ([]byte, error) {
if time.Time(t).IsZero() {
return []byte{}, nil
}
return []byte(strconv.FormatInt(t.Time().UnixMilli(), 10)), nil
}
// UnmarshalText implements the encoding.TextUnmarshaler interface.
func (t *UnixMilli) UnmarshalText(text []byte) error {
if len(text) == 0 {
return nil
}
return t.fromByteString(text)
}
// Scan implements the sql.Scanner interface.
// Scans from milliseconds. Supports SQL NULL.
func (t *UnixMilli) Scan(src interface{}) error {
if src == nil {
return nil
}
switch v := src.(type) {
case []byte:
return t.fromByteString(v)
// https://github.com/go-sql-driver/mysql/pull/1452
case uint64:
if v > math.MaxInt64 {
return errors.Errorf("value %v out of range for int64", v)
}
*t = UnixMilli(time.UnixMilli(int64(v)))
case int64:
*t = UnixMilli(time.UnixMilli(v))
default:
return errors.Errorf("bad (u)int64/[]byte type assertion from %[1]v (%[1]T)", src)
}
return nil
}
// Value implements the driver.Valuer interface.
// Returns milliseconds. Supports SQL NULL.
func (t UnixMilli) Value() (driver.Value, error) {
if t.Time().IsZero() {
return nil, nil
}
return t.Time().UnixMilli(), nil
}
func (t *UnixMilli) fromByteString(data []byte) error {
i, err := strconv.ParseInt(string(data), 10, 64)
if err != nil {
return CantParseInt64(err, string(data))
}
*t = UnixMilli(time.UnixMilli(i))
return nil
}
// Assert interface compliance.
var (
_ encoding.TextMarshaler = (*UnixMilli)(nil)
_ encoding.TextUnmarshaler = (*UnixMilli)(nil)
_ json.Marshaler = (*UnixMilli)(nil)
_ json.Unmarshaler = (*UnixMilli)(nil)
_ driver.Valuer = (*UnixMilli)(nil)
_ sql.Scanner = (*UnixMilli)(nil)
)

View file

@ -1,149 +0,0 @@
package types
import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"math"
"testing"
"time"
"unicode/utf8"
)
func TestUnixMilli(t *testing.T) {
type testCase struct {
v UnixMilli
json string
text string
}
tests := map[string]testCase{
"Zero": {UnixMilli{}, "null", ""},
"Non-zero": {UnixMilli(time.Unix(1234567890, 0)), "1234567890000", "1234567890000"},
"Epoch": {UnixMilli(time.Unix(0, 0)), "0", "0"},
"With milliseconds": {UnixMilli(time.Unix(1234567890, 62000000)), "1234567890062", "1234567890062"},
}
var runTests = func(t *testing.T, f func(*testing.T, testCase)) {
for name, test := range tests {
t.Run(name, func(t *testing.T) {
f(t, test)
})
}
}
t.Run("MarshalJSON", func(t *testing.T) {
runTests(t, func(t *testing.T, test testCase) {
actual, err := test.v.MarshalJSON()
require.NoError(t, err)
require.True(t, utf8.Valid(actual))
require.Equal(t, test.json, string(actual))
})
})
t.Run("UnmarshalJSON", func(t *testing.T) {
runTests(t, func(t *testing.T, test testCase) {
var actual UnixMilli
err := actual.UnmarshalJSON([]byte(test.json))
require.NoError(t, err)
require.Equal(t, test.v, actual)
})
})
t.Run("MarshalText", func(t *testing.T) {
runTests(t, func(t *testing.T, test testCase) {
actual, err := test.v.MarshalText()
require.NoError(t, err)
require.True(t, utf8.Valid(actual))
require.Equal(t, test.text, string(actual))
})
})
t.Run("UnmarshalText", func(t *testing.T) {
runTests(t, func(t *testing.T, test testCase) {
var actual UnixMilli
err := actual.UnmarshalText([]byte(test.text))
require.NoError(t, err)
require.Equal(t, test.v, actual)
})
})
}
func TestUnixMilli_Scan(t *testing.T) {
tests := []struct {
name string
v any
expected UnixMilli
expectErr bool
}{
{
name: "Nil",
v: nil,
expected: UnixMilli{},
},
{
name: "Epoch",
v: int64(0),
expected: UnixMilli(time.Unix(0, 0)),
},
{
name: "bytes",
v: []byte("1234567890062"),
expected: UnixMilli(time.Unix(1234567890, 62000000)),
},
{
name: "Invalid bytes",
v: []byte("invalid"),
expectErr: true,
},
{
name: "int64",
v: int64(1234567890062),
expected: UnixMilli(time.Unix(1234567890, 62000000)),
},
{
name: "uint64",
v: uint64(1234567890062),
expected: UnixMilli(time.Unix(1234567890, 62000000)),
},
{
name: "uint64 out of range for int64",
v: uint64(math.MaxInt64) + 1,
expectErr: true,
},
{
name: "Invalid type",
v: "invalid",
expectErr: true,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var actual UnixMilli
err := actual.Scan(test.v)
if test.expectErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, test.expected, actual)
}
})
}
}
func TestUnixMilli_Value(t *testing.T) {
t.Run("Zero", func(t *testing.T) {
var zero UnixMilli
actual, err := zero.Value()
require.NoError(t, err)
require.Nil(t, actual)
})
t.Run("Non-zero", func(t *testing.T) {
withMilliseconds := time.Unix(1234567890, 62000000)
expected := withMilliseconds.UnixMilli()
actual, err := UnixMilli(withMilliseconds).Value()
assert.NoError(t, err)
assert.Equal(t, expected, actual)
})
}

View file

@ -1,52 +0,0 @@
package types
import (
"encoding/json"
"fmt"
"github.com/pkg/errors"
"strings"
)
// Name returns the declared name of type t.
func Name(t any) string {
s := strings.TrimLeft(fmt.Sprintf("%T", t), "*")
return s[strings.LastIndex(s, ".")+1:]
}
// CantDecodeHex wraps the given error with the given string that cannot be hex-decoded.
func CantDecodeHex(err error, s string) error {
return errors.Wrapf(err, "can't decode hex %q", s)
}
// CantParseFloat64 wraps the given error with the specified string that cannot be parsed into float64.
func CantParseFloat64(err error, s string) error {
return errors.Wrapf(err, "can't parse %q into float64", s)
}
// CantParseInt64 wraps the given error with the specified string that cannot be parsed into int64.
func CantParseInt64(err error, s string) error {
return errors.Wrapf(err, "can't parse %q into int64", s)
}
// CantParseUint64 wraps the given error with the specified string that cannot be parsed into uint64.
func CantParseUint64(err error, s string) error {
return errors.Wrapf(err, "can't parse %q into uint64", s)
}
// CantUnmarshalYAML wraps the given error with the designated value, which cannot be unmarshalled into.
func CantUnmarshalYAML(err error, v interface{}) error {
return errors.Wrapf(err, "can't unmarshal YAML into %T", v)
}
// MarshalJSON calls json.Marshal and wraps any resulting errors.
func MarshalJSON(v interface{}) ([]byte, error) {
b, err := json.Marshal(v)
return b, errors.Wrapf(err, "can't marshal JSON from %T", v)
}
// UnmarshalJSON calls json.Unmarshal and wraps any resulting errors.
func UnmarshalJSON(data []byte, v interface{}) error {
return errors.Wrapf(json.Unmarshal(data, v), "can't unmarshal JSON into %T", v)
}

View file

@ -1,24 +0,0 @@
package types
import (
"database/sql/driver"
"encoding"
"github.com/google/uuid"
)
// UUID is like uuid.UUID, but marshals itself binarily (not like xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx) in SQL context.
type UUID struct {
uuid.UUID
}
// Value implements driver.Valuer.
func (uuid UUID) Value() (driver.Value, error) {
return uuid.UUID[:], nil
}
// Assert interface compliance.
var (
_ encoding.TextUnmarshaler = (*UUID)(nil)
_ driver.Valuer = UUID{}
_ driver.Valuer = (*UUID)(nil)
)

View file

@ -1,167 +0,0 @@
package utils
import (
"context"
"crypto/sha1"
"fmt"
"github.com/go-sql-driver/mysql"
"github.com/lib/pq"
"github.com/pkg/errors"
"golang.org/x/exp/utf8string"
"net"
"os"
"path/filepath"
"strings"
"time"
)
// Timed calls the given callback with the time that has elapsed since the start.
//
// Timed should be installed by defer:
//
// func TimedExample(logger *zap.SugaredLogger) {
// defer utils.Timed(time.Now(), func(elapsed time.Duration) {
// logger.Debugf("Executed job in %s", elapsed)
// })
// job()
// }
func Timed(start time.Time, callback func(elapsed time.Duration)) {
callback(time.Since(start))
}
// BatchSliceOfStrings groups the given keys into chunks of size count and streams them into a returned channel.
func BatchSliceOfStrings(ctx context.Context, keys []string, count int) <-chan []string {
batches := make(chan []string)
go func() {
defer close(batches)
for i := 0; i < len(keys); i += count {
end := i + count
if end > len(keys) {
end = len(keys)
}
select {
case batches <- keys[i:end]:
case <-ctx.Done():
return
}
}
}()
return batches
}
// IsContextCanceled returns whether the given error is context.Canceled.
func IsContextCanceled(err error) bool {
return errors.Is(err, context.Canceled)
}
// Checksum returns the SHA-1 checksum of the data.
func Checksum(data interface{}) []byte {
var chksm [sha1.Size]byte
switch data := data.(type) {
case string:
chksm = sha1.Sum([]byte(data))
case []byte:
chksm = sha1.Sum(data)
default:
panic(fmt.Sprintf("Unable to create checksum for type %T", data))
}
return chksm[:]
}
// IsDeadlock returns whether the given error signals serialization failure.
func IsDeadlock(err error) bool {
var e *mysql.MySQLError
if errors.As(err, &e) {
switch e.Number {
case 1205, 1213:
return true
default:
return false
}
}
var pe *pq.Error
if errors.As(err, &pe) {
switch pe.Code {
case "40001", "40P01":
return true
}
}
return false
}
var ellipsis = utf8string.NewString("...")
// Ellipsize shortens s to <=limit runes and indicates shortening by "...".
func Ellipsize(s string, limit int) string {
utf8 := utf8string.NewString(s)
switch {
case utf8.RuneCount() <= limit:
return s
case utf8.RuneCount() <= ellipsis.RuneCount():
return ellipsis.String()
default:
return utf8.Slice(0, limit-ellipsis.RuneCount()) + ellipsis.String()
}
}
// AppName returns the name of the executable that started this program (process).
func AppName() string {
exe, err := os.Executable()
if err != nil {
exe = os.Args[0]
}
return filepath.Base(exe)
}
// MaxInt returns the larger of the given integers.
func MaxInt(x, y int) int {
if x > y {
return x
}
return y
}
// IsUnixAddr indicates whether the given host string represents a Unix socket address.
//
// A host string that begins with a forward slash ('/') is considered Unix socket address.
func IsUnixAddr(host string) bool {
return strings.HasPrefix(host, "/")
}
// JoinHostPort is like its equivalent in net., but handles UNIX sockets as well.
func JoinHostPort(host string, port int) string {
if IsUnixAddr(host) {
return host
}
return net.JoinHostPort(host, fmt.Sprint(port))
}
// ChanFromSlice takes a slice of values and returns a channel from which these values can be received.
// This channel is closed after the last value was sent.
func ChanFromSlice[T any](values []T) <-chan T {
ch := make(chan T, len(values))
for _, value := range values {
ch <- value
}
close(ch)
return ch
}
// PrintErrorThenExit prints the given error to [os.Stderr] and exits with the specified error code.
func PrintErrorThenExit(err error, exitCode int) {
fmt.Fprintln(os.Stderr, err)
os.Exit(exitCode)
}

View file

@ -1,54 +0,0 @@
package utils
import (
"github.com/stretchr/testify/require"
"testing"
)
func TestChanFromSlice(t *testing.T) {
t.Run("Nil", func(t *testing.T) {
ch := ChanFromSlice[int](nil)
require.NotNil(t, ch)
requireClosedEmpty(t, ch)
})
t.Run("Empty", func(t *testing.T) {
ch := ChanFromSlice([]int{})
require.NotNil(t, ch)
requireClosedEmpty(t, ch)
})
t.Run("NonEmpty", func(t *testing.T) {
ch := ChanFromSlice([]int{42, 23, 1337})
require.NotNil(t, ch)
requireReceive(t, ch, 42)
requireReceive(t, ch, 23)
requireReceive(t, ch, 1337)
requireClosedEmpty(t, ch)
})
}
// requireReceive is a helper function to check if a value can immediately be received from a channel.
func requireReceive(t *testing.T, ch <-chan int, expected int) {
t.Helper()
select {
case v, ok := <-ch:
require.True(t, ok, "receiving should return a value")
require.Equal(t, expected, v)
default:
require.Fail(t, "receiving should not block")
}
}
// requireReceive is a helper function to check if the channel is closed and empty.
func requireClosedEmpty(t *testing.T, ch <-chan int) {
t.Helper()
select {
case _, ok := <-ch:
require.False(t, ok, "receiving from channel should not return anything")
default:
require.Fail(t, "receiving should not block")
}
}

View file

@ -1,180 +0,0 @@
package version
import (
"bufio"
"errors"
"fmt"
"os"
"runtime"
"runtime/debug"
"strconv"
"strings"
)
type VersionInfo struct {
Version string
Commit string
}
// Version determines version and commit information based on multiple data sources:
// - Version information dynamically added by `git archive` in the remaining to parameters.
// - A hardcoded version number passed as first parameter.
// - Commit information added to the binary by `go build`.
//
// It's supposed to be called like this in combination with setting the `export-subst` attribute for the corresponding
// file in .gitattributes:
//
// var Version = version.Version("1.0.0-rc2", "$Format:%(describe)$", "$Format:%H$")
//
// When exported using `git archive`, the placeholders are replaced in the file and this version information is
// preferred. Otherwise the hardcoded version is used and augmented with commit information from the build metadata.
func Version(version, gitDescribe, gitHash string) *VersionInfo {
const hashLen = 7 // Same truncation length for the commit hash as used by git describe.
if !strings.HasPrefix(gitDescribe, "$") && !strings.HasPrefix(gitHash, "$") {
if strings.HasPrefix(gitDescribe, "%") {
// Only Git 2.32+ supports %(describe), older versions don't expand it but keep it as-is.
// Fall back to the hardcoded version augmented with the commit hash.
gitDescribe = version
if len(gitHash) >= hashLen {
gitDescribe += "-g" + gitHash[:hashLen]
}
}
return &VersionInfo{
Version: gitDescribe,
Commit: gitHash,
}
} else {
commit := ""
if info, ok := debug.ReadBuildInfo(); ok {
modified := false
for _, setting := range info.Settings {
switch setting.Key {
case "vcs.revision":
commit = setting.Value
case "vcs.modified":
modified, _ = strconv.ParseBool(setting.Value)
}
}
if len(commit) >= hashLen {
version += "-g" + commit[:hashLen]
if modified {
version += "-dirty"
commit += " (modified)"
}
}
}
return &VersionInfo{
Version: version,
Commit: commit,
}
}
}
// Print writes verbose version output to stdout.
func (v *VersionInfo) Print() {
fmt.Println("Icinga DB version:", v.Version)
fmt.Println()
fmt.Println("Build information:")
fmt.Printf(" Go version: %s (%s, %s)\n", runtime.Version(), runtime.GOOS, runtime.GOARCH)
if v.Commit != "" {
fmt.Println(" Git commit:", v.Commit)
}
if r, err := readOsRelease(); err == nil {
fmt.Println()
fmt.Println("System information:")
fmt.Println(" Platform:", r.Name)
fmt.Println(" Platform version:", r.DisplayVersion())
}
}
// osRelease contains the information obtained from the os-release file.
type osRelease struct {
Name string
Version string
VersionId string
BuildId string
}
// DisplayVersion returns the most suitable version information for display purposes.
func (o *osRelease) DisplayVersion() string {
if o.Version != "" {
// Most distributions set VERSION
return o.Version
} else if o.VersionId != "" {
// Some only set VERSION_ID (Alpine Linux for example)
return o.VersionId
} else if o.BuildId != "" {
// Others only set BUILD_ID (Arch Linux for example)
return o.BuildId
} else {
return "(unknown)"
}
}
// readOsRelease reads and parses the os-release file.
func readOsRelease() (*osRelease, error) {
for _, path := range []string{"/etc/os-release", "/usr/lib/os-release"} {
f, err := os.Open(path)
if err != nil {
if os.IsNotExist(err) {
continue // Try next path.
} else {
return nil, err
}
}
o := &osRelease{
Name: "Linux", // Suggested default as per os-release(5) man page.
}
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, "#") {
continue // Ignore comment.
}
parts := strings.SplitN(line, "=", 2)
if len(parts) != 2 {
continue // Ignore empty or possibly malformed line.
}
key := parts[0]
val := parts[1]
// Unquote strings. This isn't fully compliant with the specification which allows using some shell escape
// sequences. However, typically quotes are only used to allow whitespace within the value.
if len(val) >= 2 && (val[0] == '"' || val[0] == '\'') && val[0] == val[len(val)-1] {
val = val[1 : len(val)-1]
}
switch key {
case "NAME":
o.Name = val
case "VERSION":
o.Version = val
case "VERSION_ID":
o.VersionId = val
case "BUILD_ID":
o.BuildId = val
}
}
if err := scanner.Err(); err != nil {
return nil, err
}
return o, nil
}
return nil, errors.New("os-release file not found")
}