diff --git a/physical/physical_test.go b/physical/physical_test.go index da7e97b122..0fc805e84e 100644 --- a/physical/physical_test.go +++ b/physical/physical_test.go @@ -264,4 +264,6 @@ func testHABackend(t *testing.T, b HABackend, b2 HABackend) { if val != "baz" { t.Fatalf("bad value: %v", err) } + // Cleanup + lock2.Unlock() } diff --git a/physical/zookeeper.go b/physical/zookeeper.go index 8051cb02fc..ad87935c94 100644 --- a/physical/zookeeper.go +++ b/physical/zookeeper.go @@ -4,6 +4,7 @@ import ( "fmt" "sort" "strings" + "sync" "time" "github.com/armon/go-metrics" @@ -194,3 +195,141 @@ func (c *ZookeeperBackend) List(prefix string) ([]string, error) { sort.Strings(children) return children, nil } + +// LockWith is used for mutual exclusion based on the given key. +func (c *ZookeeperBackend) LockWith(key, value string) (Lock, error) { + l := &ZookeeperHALock{ + in: c, + key: key, + value: value, + } + return l, nil +} + +// ZookeeperHALock is a Zookeeper Lock implementation for the HABackend +type ZookeeperHALock struct { + in *ZookeeperBackend + key string + value string + + held bool + localLock sync.Mutex + leaderCh chan struct{} + zkLock *zk.Lock +} + +func (i *ZookeeperHALock) Lock(stopCh <-chan struct{}) (<-chan struct{}, error) { + i.localLock.Lock() + defer i.localLock.Unlock() + if i.held { + return nil, fmt.Errorf("lock already held") + } + + // Attempt an async acquisition + didLock := make(chan struct{}) + failLock := make(chan error, 1) + releaseCh := make(chan bool, 1) + lockpath := i.in.path + i.key + go i.attemptLock(lockpath, didLock, failLock, releaseCh) + + // Wait for lock acquisition, failure, or shutdown + select { + case <-didLock: + releaseCh <- false + case err := <-failLock: + return nil, err + case <-stopCh: + releaseCh <- true + return nil, nil + } + + // Create the leader channel + i.held = true + i.leaderCh = make(chan struct{}) + + // Watch for Events which could result in loss of our zkLock and close(i.leaderCh) + currentVal, _, lockeventCh, err := i.in.client.GetW(lockpath) + if err != nil { + return nil, fmt.Errorf("unable to watch HA lock: %v", err) + } + if i.value != string(currentVal) { + return nil, fmt.Errorf("lost HA lock immediately before watch") + } + go i.monitorLock(lockeventCh, i.leaderCh) + + return i.leaderCh, nil +} + +func (i *ZookeeperHALock) attemptLock(lockpath string, didLock chan struct{}, failLock chan error, releaseCh chan bool) { + // Wait to acquire the lock in ZK + acl := zk.WorldACL(zk.PermAll) + lock := zk.NewLock(i.in.client, lockpath, acl) + err := lock.Lock() + if err != nil { + failLock <- err + return + } + // Set node value + data := []byte(i.value) + err = i.in.ensurePath(lockpath, data) + if err != nil { + failLock <- err + lock.Unlock() + return + } + i.zkLock = lock + + // Signal that lock is held + close(didLock) + + // Handle an early abort + release := <-releaseCh + if release { + lock.Unlock() + } +} + +func (i *ZookeeperHALock) monitorLock(lockeventCh <-chan zk.Event, leaderCh chan struct{}) { + for { + select { + case event := <- lockeventCh: + // Lost connection? + switch event.State { + case zk.StateConnected: + case zk.StateSyncConnected: + case zk.StateHasSession: + default: + close(leaderCh) + return + } + + // Lost lock? + switch event.Type { + case zk.EventNodeChildrenChanged: + case zk.EventSession: + default: + close(leaderCh) + return + } + } + } +} + +func (i *ZookeeperHALock) Unlock() error { + i.localLock.Lock() + defer i.localLock.Unlock() + if !i.held { + return nil + } + + i.held = false + i.zkLock.Unlock() + return nil +} + +func (i *ZookeeperHALock) Value() (bool, string, error) { + lockpath := i.in.path + i.key + value, _, err := i.in.client.Get(lockpath) + return (value != nil), string(value), err +} + diff --git a/physical/zookeeper_test.go b/physical/zookeeper_test.go index 31fd7f3813..9c76277dd4 100644 --- a/physical/zookeeper_test.go +++ b/physical/zookeeper_test.go @@ -30,7 +30,11 @@ func TestZookeeperBackend(t *testing.T) { } defer func() { + client.Delete(randPath + "/foo/bar/baz", -1) + client.Delete(randPath + "/foo/bar", -1) + client.Delete(randPath + "/foo", -1) client.Delete(randPath, -1) + client.Close() }() b, err := NewBackend("zookeeper", map[string]string{ @@ -44,3 +48,50 @@ func TestZookeeperBackend(t *testing.T) { testBackend(t, b) testBackend_ListPrefix(t, b) } + +func TestZookeeperHABackend(t *testing.T) { + addr := os.Getenv("ZOOKEEPER_ADDR") + if addr == "" { + t.SkipNow() + } + + client, _, err := zk.Connect([]string{addr}, time.Second) + + if err != nil { + t.Fatalf("err: %v", err) + } + + randPath := fmt.Sprintf("/vault-ha-%d", time.Now().Unix()) + acl := zk.WorldACL(zk.PermAll) + _, err = client.Create(randPath, []byte("hi"), int32(0), acl) + + if err != nil { + t.Fatalf("err: %v", err) + } + + defer func() { + client.Delete(randPath + "/foo", -1) + client.Delete(randPath, -1) + client.Close() + }() + + b, err := NewBackend("zookeeper", map[string]string{ + "address": addr + "," + addr, + "path": randPath, + }) + if err != nil { + t.Fatalf("err: %s", err) + } + + ha, ok := b.(HABackend) + if !ok { + t.Fatalf("zookeeper does not implement HABackend") + } + testHABackend(t, ha, ha) + + err = client.Delete(randPath + "/foo", -1) + if err != nil { + t.Fatalf("err: failed to cleanup! %s", err) + } + +}