Move LDAP client and config code to helper (#4532)

This commit is contained in:
Becca Petrin 2018-05-10 14:12:42 -07:00 committed by GitHub
parent 03a48f217e
commit 8ea9efd297
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 616 additions and 407 deletions

View file

@ -1,15 +1,11 @@
package ldap
import (
"bytes"
"context"
"fmt"
"math"
"strings"
"text/template"
"github.com/go-ldap/ldap"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/vault/helper/ldaputil"
"github.com/hashicorp/vault/helper/mfa"
"github.com/hashicorp/vault/helper/strutil"
"github.com/hashicorp/vault/logical"
@ -62,40 +58,6 @@ type backend struct {
*framework.Backend
}
func EscapeLDAPValue(input string) string {
// RFC4514 forbids un-escaped:
// - leading space or hash
// - trailing space
// - special characters '"', '+', ',', ';', '<', '>', '\\'
// - null
for i := 0; i < len(input); i++ {
escaped := false
if input[i] == '\\' {
i++
escaped = true
}
switch input[i] {
case '"', '+', ',', ';', '<', '>', '\\':
if !escaped {
input = input[0:i] + "\\" + input[i:]
i++
}
continue
}
if escaped {
input = input[0:i] + "\\" + input[i:]
i++
}
}
if input[0] == ' ' || input[0] == '#' {
input = "\\" + input
}
if input[len(input)-1] == ' ' {
input = input[0:len(input)-1] + "\\ "
}
return input
}
func (b *backend) Login(ctx context.Context, req *logical.Request, username string, password string) ([]string, *logical.Response, []string, error) {
cfg, err := b.Config(ctx, req)
@ -106,7 +68,12 @@ func (b *backend) Login(ctx context.Context, req *logical.Request, username stri
return nil, logical.ErrorResponse("ldap backend not configured"), nil, nil
}
c, err := cfg.DialLDAP()
ldapClient := ldaputil.Client{
Logger: b.Logger(),
LDAP: ldaputil.NewLDAP(),
}
c, err := ldapClient.DialLDAP(cfg)
if err != nil {
return nil, logical.ErrorResponse(err.Error()), nil, nil
}
@ -117,7 +84,7 @@ func (b *backend) Login(ctx context.Context, req *logical.Request, username stri
// Clean connection
defer c.Close()
userBindDN, err := b.getUserBindDN(cfg, c, username)
userBindDN, err := ldapClient.GetUserBindDN(cfg, c, username)
if err != nil {
return nil, logical.ErrorResponse(err.Error()), nil, nil
}
@ -151,12 +118,12 @@ func (b *backend) Login(ctx context.Context, req *logical.Request, username stri
}
}
userDN, err := b.getUserDN(cfg, c, userBindDN)
userDN, err := ldapClient.GetUserDN(cfg, c, userBindDN)
if err != nil {
return nil, logical.ErrorResponse(err.Error()), nil, nil
}
ldapGroups, err := b.getLdapGroups(cfg, c, userDN, username)
ldapGroups, err := ldapClient.GetLdapGroups(cfg, c, userDN, username)
if err != nil {
return nil, logical.ErrorResponse(err.Error()), nil, nil
}
@ -227,213 +194,6 @@ func (b *backend) Login(ctx context.Context, req *logical.Request, username stri
return policies, ldapResponse, allGroups, nil
}
/*
* Parses a distinguished name and returns the CN portion.
* Given a non-conforming string (such as an already-extracted CN),
* it will be returned as-is.
*/
func (b *backend) getCN(dn string) string {
parsedDN, err := ldap.ParseDN(dn)
if err != nil || len(parsedDN.RDNs) == 0 {
// It was already a CN, return as-is
return dn
}
for _, rdn := range parsedDN.RDNs {
for _, rdnAttr := range rdn.Attributes {
if rdnAttr.Type == "CN" {
return rdnAttr.Value
}
}
}
// Default, return self
return dn
}
/*
* Discover and return the bind string for the user attempting to authenticate.
* This is handled in one of several ways:
*
* 1. If DiscoverDN is set, the user object will be searched for using userdn (base search path)
* and userattr (the attribute that maps to the provided username).
* The bind will either be anonymous or use binddn and bindpassword if they were provided.
* 2. If upndomain is set, the user dn is constructed as 'username@upndomain'. See https://msdn.microsoft.com/en-us/library/cc223499.aspx
*
*/
func (b *backend) getUserBindDN(cfg *ConfigEntry, c *ldap.Conn, username string) (string, error) {
bindDN := ""
if cfg.DiscoverDN || (cfg.BindDN != "" && cfg.BindPassword != "") {
var err error
if cfg.BindPassword != "" {
err = c.Bind(cfg.BindDN, cfg.BindPassword)
} else {
err = c.UnauthenticatedBind(cfg.BindDN)
}
if err != nil {
return bindDN, errwrap.Wrapf("LDAP bind (service) failed: {{err}}", err)
}
filter := fmt.Sprintf("(%s=%s)", cfg.UserAttr, ldap.EscapeFilter(username))
if b.Logger().IsDebug() {
b.Logger().Debug("discovering user", "userdn", cfg.UserDN, "filter", filter)
}
result, err := c.Search(&ldap.SearchRequest{
BaseDN: cfg.UserDN,
Scope: 2, // subtree
Filter: filter,
SizeLimit: math.MaxInt32,
})
if err != nil {
return bindDN, errwrap.Wrapf("LDAP search for binddn failed: {{err}}", err)
}
if len(result.Entries) != 1 {
return bindDN, fmt.Errorf("LDAP search for binddn 0 or not unique")
}
bindDN = result.Entries[0].DN
} else {
if cfg.UPNDomain != "" {
bindDN = fmt.Sprintf("%s@%s", EscapeLDAPValue(username), cfg.UPNDomain)
} else {
bindDN = fmt.Sprintf("%s=%s,%s", cfg.UserAttr, EscapeLDAPValue(username), cfg.UserDN)
}
}
return bindDN, nil
}
/*
* Returns the DN of the object representing the authenticated user.
*/
func (b *backend) getUserDN(cfg *ConfigEntry, c *ldap.Conn, bindDN string) (string, error) {
userDN := ""
if cfg.UPNDomain != "" {
// Find the distinguished name for the user if userPrincipalName used for login
filter := fmt.Sprintf("(userPrincipalName=%s)", ldap.EscapeFilter(bindDN))
if b.Logger().IsDebug() {
b.Logger().Debug("searching upn", "userdn", cfg.UserDN, "filter", filter)
}
result, err := c.Search(&ldap.SearchRequest{
BaseDN: cfg.UserDN,
Scope: 2, // subtree
Filter: filter,
SizeLimit: math.MaxInt32,
})
if err != nil {
return userDN, errwrap.Wrapf("LDAP search failed for detecting user: {{err}}", err)
}
for _, e := range result.Entries {
userDN = e.DN
}
} else {
userDN = bindDN
}
return userDN, nil
}
/*
* getLdapGroups queries LDAP and returns a slice describing the set of groups the authenticated user is a member of.
*
* The search query is constructed according to cfg.GroupFilter, and run in context of cfg.GroupDN.
* Groups will be resolved from the query results by following the attribute defined in cfg.GroupAttr.
*
* cfg.GroupFilter is a go template and is compiled with the following context: [UserDN, Username]
* UserDN - The DN of the authenticated user
* Username - The Username of the authenticated user
*
* Example:
* cfg.GroupFilter = "(&(objectClass=group)(member:1.2.840.113556.1.4.1941:={{.UserDN}}))"
* cfg.GroupDN = "OU=Groups,DC=myorg,DC=com"
* cfg.GroupAttr = "cn"
*
* NOTE - If cfg.GroupFilter is empty, no query is performed and an empty result slice is returned.
*
*/
func (b *backend) getLdapGroups(cfg *ConfigEntry, c *ldap.Conn, userDN string, username string) ([]string, error) {
// retrieve the groups in a string/bool map as a structure to avoid duplicates inside
ldapMap := make(map[string]bool)
if cfg.GroupFilter == "" {
b.Logger().Warn("groupfilter is empty, will not query server")
return make([]string, 0), nil
}
if cfg.GroupDN == "" {
b.Logger().Warn("groupdn is empty, will not query server")
return make([]string, 0), nil
}
// If groupfilter was defined, resolve it as a Go template and use the query for
// returning the user's groups
if b.Logger().IsDebug() {
b.Logger().Debug("compiling group filter", "group_filter", cfg.GroupFilter)
}
// Parse the configuration as a template.
// Example template "(&(objectClass=group)(member:1.2.840.113556.1.4.1941:={{.UserDN}}))"
t, err := template.New("queryTemplate").Parse(cfg.GroupFilter)
if err != nil {
return nil, errwrap.Wrapf("LDAP search failed due to template compilation error: {{err}}", err)
}
// Build context to pass to template - we will be exposing UserDn and Username.
context := struct {
UserDN string
Username string
}{
ldap.EscapeFilter(userDN),
ldap.EscapeFilter(username),
}
var renderedQuery bytes.Buffer
t.Execute(&renderedQuery, context)
if b.Logger().IsDebug() {
b.Logger().Debug("searching", "groupdn", cfg.GroupDN, "rendered_query", renderedQuery.String())
}
result, err := c.Search(&ldap.SearchRequest{
BaseDN: cfg.GroupDN,
Scope: 2, // subtree
Filter: renderedQuery.String(),
Attributes: []string{
cfg.GroupAttr,
},
SizeLimit: math.MaxInt32,
})
if err != nil {
return nil, errwrap.Wrapf("LDAP search failed: {{err}}", err)
}
for _, e := range result.Entries {
dn, err := ldap.ParseDN(e.DN)
if err != nil || len(dn.RDNs) == 0 {
continue
}
// Enumerate attributes of each result, parse out CN and add as group
values := e.GetAttributeValues(cfg.GroupAttr)
if len(values) > 0 {
for _, val := range values {
groupCN := b.getCN(val)
ldapMap[groupCN] = true
}
} else {
// If groupattr didn't resolve, use self (enumerating group objects)
groupCN := b.getCN(e.DN)
ldapMap[groupCN] = true
}
}
ldapGroups := make([]string, 0, len(ldapMap))
for key, _ := range ldapMap {
ldapGroups = append(ldapGroups, key)
}
return ldapGroups, nil
}
const backendHelp = `
The "ldap" credential provider allows authentication querying
a LDAP server, checking username and password, and associating groups

View file

@ -654,23 +654,6 @@ func testAccStepLoginNoGroupDN(t *testing.T, user string, pass string) logicalte
}
}
func TestLDAPEscape(t *testing.T) {
testcases := map[string]string{
"#test": "\\#test",
"test,hello": "test\\,hello",
"test,hel+lo": "test\\,hel\\+lo",
"test\\hello": "test\\\\hello",
" test ": "\\ test \\ ",
}
for test, answer := range testcases {
res := EscapeLDAPValue(test)
if res != answer {
t.Errorf("Failed to escape %s: %s != %s\n", test, res, answer)
}
}
}
func testAccStepGroupList(t *testing.T, groups []string) logicaltest.TestStep {
return logicaltest.TestStep{
Operation: logical.ListOperation,

View file

@ -2,20 +2,15 @@ package ldap
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"fmt"
"net"
"net/url"
"strings"
"text/template"
"github.com/go-ldap/ldap"
"github.com/hashicorp/errwrap"
log "github.com/hashicorp/go-hclog"
multierror "github.com/hashicorp/go-multierror"
"github.com/hashicorp/vault/helper/consts"
"github.com/hashicorp/vault/helper/ldaputil"
"github.com/hashicorp/vault/helper/tlsutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
@ -137,7 +132,7 @@ Default: cn`,
/*
* Construct ConfigEntry struct using stored configuration.
*/
func (b *backend) Config(ctx context.Context, req *logical.Request) (*ConfigEntry, error) {
func (b *backend) Config(ctx context.Context, req *logical.Request) (*ldaputil.ConfigEntry, error) {
// Schema for ConfigEntry
fd, err := b.getConfigFieldData()
if err != nil {
@ -187,8 +182,6 @@ func (b *backend) Config(ctx context.Context, req *logical.Request) (*ConfigEntr
}
}
result.logger = b.Logger()
return result, nil
}
@ -228,10 +221,8 @@ func (b *backend) pathConfigRead(ctx context.Context, req *logical.Request, d *f
* Creates and initializes a ConfigEntry object with its default values,
* as specified by the passed schema.
*/
func (b *backend) newConfigEntry(d *framework.FieldData) (*ConfigEntry, error) {
cfg := new(ConfigEntry)
cfg.logger = b.Logger()
func (b *backend) newConfigEntry(d *framework.FieldData) (*ldaputil.ConfigEntry, error) {
cfg := new(ldaputil.ConfigEntry)
url := d.Get("url").(string)
if url != "" {
@ -367,131 +358,6 @@ func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, d *
return nil, nil
}
type ConfigEntry struct {
logger log.Logger
Url string `json:"url"`
UserDN string `json:"userdn"`
GroupDN string `json:"groupdn"`
GroupFilter string `json:"groupfilter"`
GroupAttr string `json:"groupattr"`
UPNDomain string `json:"upndomain"`
UserAttr string `json:"userattr"`
Certificate string `json:"certificate"`
InsecureTLS bool `json:"insecure_tls"`
StartTLS bool `json:"starttls"`
BindDN string `json:"binddn"`
BindPassword string `json:"bindpass"`
DenyNullBind bool `json:"deny_null_bind"`
DiscoverDN bool `json:"discoverdn"`
TLSMinVersion string `json:"tls_min_version"`
TLSMaxVersion string `json:"tls_max_version"`
// This json tag deviates from snake case because there was a past issue
// where the tag was being ignored, causing it to be jsonified as "CaseSensitiveNames".
// To continue reading in users' previously stored values,
// we chose to carry that forward.
CaseSensitiveNames *bool `json:"CaseSensitiveNames,omitempty"`
}
func (c *ConfigEntry) GetTLSConfig(host string) (*tls.Config, error) {
tlsConfig := &tls.Config{
ServerName: host,
}
if c.TLSMinVersion != "" {
tlsMinVersion, ok := tlsutil.TLSLookup[c.TLSMinVersion]
if !ok {
return nil, fmt.Errorf("invalid 'tls_min_version' in config")
}
tlsConfig.MinVersion = tlsMinVersion
}
if c.TLSMaxVersion != "" {
tlsMaxVersion, ok := tlsutil.TLSLookup[c.TLSMaxVersion]
if !ok {
return nil, fmt.Errorf("invalid 'tls_max_version' in config")
}
tlsConfig.MaxVersion = tlsMaxVersion
}
if c.InsecureTLS {
tlsConfig.InsecureSkipVerify = true
}
if c.Certificate != "" {
caPool := x509.NewCertPool()
ok := caPool.AppendCertsFromPEM([]byte(c.Certificate))
if !ok {
return nil, fmt.Errorf("could not append CA certificate")
}
tlsConfig.RootCAs = caPool
}
return tlsConfig, nil
}
func (c *ConfigEntry) DialLDAP() (*ldap.Conn, error) {
var retErr *multierror.Error
var conn *ldap.Conn
urls := strings.Split(c.Url, ",")
for _, uut := range urls {
u, err := url.Parse(uut)
if err != nil {
retErr = multierror.Append(retErr, errwrap.Wrapf(fmt.Sprintf("error parsing url %q: {{err}}", uut), err))
continue
}
host, port, err := net.SplitHostPort(u.Host)
if err != nil {
host = u.Host
}
var tlsConfig *tls.Config
switch u.Scheme {
case "ldap":
if port == "" {
port = "389"
}
conn, err = ldap.Dial("tcp", net.JoinHostPort(host, port))
if err != nil {
break
}
if conn == nil {
err = fmt.Errorf("empty connection after dialing")
break
}
if c.StartTLS {
tlsConfig, err = c.GetTLSConfig(host)
if err != nil {
break
}
err = conn.StartTLS(tlsConfig)
}
case "ldaps":
if port == "" {
port = "636"
}
tlsConfig, err = c.GetTLSConfig(host)
if err != nil {
break
}
conn, err = ldap.DialTLS("tcp", net.JoinHostPort(host, port), tlsConfig)
default:
retErr = multierror.Append(retErr, fmt.Errorf("invalid LDAP scheme in url %q", net.JoinHostPort(host, port)))
continue
}
if err == nil {
if retErr != nil {
if c.logger.IsDebug() {
c.logger.Debug("errors connecting to some hosts: %s", retErr.Error())
}
}
retErr = nil
break
}
retErr = multierror.Append(retErr, errwrap.Wrapf(fmt.Sprintf("error connecting to host %q: {{err}}", uut), err))
}
return conn, retErr.ErrorOrNil()
}
/*
* Returns FieldData describing our ConfigEntry struct schema
*/

View file

@ -43,7 +43,7 @@ const (
)
// TLSUsage controls whether the intended usage of a *tls.Config
// returned from ParsedCertBundle.GetTLSConfig is for server use,
// returned from ParsedCertBundle.getTLSConfig is for server use,
// client use, or both, which affects which values are set
type TLSUsage int
@ -523,7 +523,7 @@ func (p *ParsedCSRBundle) SetParsedPrivateKey(privateKey crypto.Signer, privateK
p.PrivateKeyBytes = privateKeyBytes
}
// GetTLSConfig returns a TLS config generally suitable for client
// getTLSConfig returns a TLS config generally suitable for client
// authentication. The returned TLS config can be modified slightly
// to be made suitable for a server requiring client authentication;
// specifically, you should set the value of ClientAuth in the returned

366
helper/ldaputil/client.go Normal file
View file

@ -0,0 +1,366 @@
package ldaputil
import (
"bytes"
"crypto/tls"
"crypto/x509"
"fmt"
"math"
"net"
"net/url"
"strings"
"text/template"
"github.com/go-ldap/ldap"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-multierror"
"github.com/hashicorp/vault/helper/tlsutil"
)
type Client struct {
Logger hclog.Logger
LDAP LDAP
}
func (c *Client) DialLDAP(cfg *ConfigEntry) (Connection, error) {
var retErr *multierror.Error
var conn Connection
urls := strings.Split(cfg.Url, ",")
for _, uut := range urls {
u, err := url.Parse(uut)
if err != nil {
retErr = multierror.Append(retErr, errwrap.Wrapf(fmt.Sprintf("error parsing url %q: {{err}}", uut), err))
continue
}
host, port, err := net.SplitHostPort(u.Host)
if err != nil {
host = u.Host
}
var tlsConfig *tls.Config
switch u.Scheme {
case "ldap":
if port == "" {
port = "389"
}
conn, err = c.LDAP.Dial("tcp", net.JoinHostPort(host, port))
if err != nil {
break
}
if conn == nil {
err = fmt.Errorf("empty connection after dialing")
break
}
if cfg.StartTLS {
tlsConfig, err = getTLSConfig(cfg, host)
if err != nil {
break
}
err = conn.StartTLS(tlsConfig)
}
case "ldaps":
if port == "" {
port = "636"
}
tlsConfig, err = getTLSConfig(cfg, host)
if err != nil {
break
}
conn, err = c.LDAP.DialTLS("tcp", net.JoinHostPort(host, port), tlsConfig)
default:
retErr = multierror.Append(retErr, fmt.Errorf("invalid LDAP scheme in url %q", net.JoinHostPort(host, port)))
continue
}
if err == nil {
if retErr != nil {
if c.Logger.IsDebug() {
c.Logger.Debug("errors connecting to some hosts: %s", retErr.Error())
}
}
retErr = nil
break
}
retErr = multierror.Append(retErr, errwrap.Wrapf(fmt.Sprintf("error connecting to host %q: {{err}}", uut), err))
}
return conn, retErr.ErrorOrNil()
}
/*
* Discover and return the bind string for the user attempting to authenticate.
* This is handled in one of several ways:
*
* 1. If DiscoverDN is set, the user object will be searched for using userdn (base search path)
* and userattr (the attribute that maps to the provided username).
* The bind will either be anonymous or use binddn and bindpassword if they were provided.
* 2. If upndomain is set, the user dn is constructed as 'username@upndomain'. See https://msdn.microsoft.com/en-us/library/cc223499.aspx
*
*/
func (c *Client) GetUserBindDN(cfg *ConfigEntry, conn Connection, username string) (string, error) {
bindDN := ""
// Note: The logic below drives the logic in ConfigEntry.Validate().
// If updated, please update there as well.
if cfg.DiscoverDN || (cfg.BindDN != "" && cfg.BindPassword != "") {
var err error
if cfg.BindPassword != "" {
err = conn.Bind(cfg.BindDN, cfg.BindPassword)
} else {
err = conn.UnauthenticatedBind(cfg.BindDN)
}
if err != nil {
return bindDN, errwrap.Wrapf("LDAP bind (service) failed: {{err}}", err)
}
filter := fmt.Sprintf("(%s=%s)", cfg.UserAttr, ldap.EscapeFilter(username))
if c.Logger.IsDebug() {
c.Logger.Debug("discovering user", "userdn", cfg.UserDN, "filter", filter)
}
result, err := conn.Search(&ldap.SearchRequest{
BaseDN: cfg.UserDN,
Scope: 2, // subtree
Filter: filter,
SizeLimit: math.MaxInt32,
})
if err != nil {
return bindDN, errwrap.Wrapf("LDAP search for binddn failed: {{err}}", err)
}
if len(result.Entries) != 1 {
return bindDN, fmt.Errorf("LDAP search for binddn 0 or not unique")
}
bindDN = result.Entries[0].DN
} else {
if cfg.UPNDomain != "" {
bindDN = fmt.Sprintf("%s@%s", escapeLDAPValue(username), cfg.UPNDomain)
} else {
bindDN = fmt.Sprintf("%s=%s,%s", cfg.UserAttr, escapeLDAPValue(username), cfg.UserDN)
}
}
return bindDN, nil
}
/*
* Returns the DN of the object representing the authenticated user.
*/
func (c *Client) GetUserDN(cfg *ConfigEntry, conn Connection, bindDN string) (string, error) {
userDN := ""
if cfg.UPNDomain != "" {
// Find the distinguished name for the user if userPrincipalName used for login
filter := fmt.Sprintf("(userPrincipalName=%s)", ldap.EscapeFilter(bindDN))
if c.Logger.IsDebug() {
c.Logger.Debug("searching upn", "userdn", cfg.UserDN, "filter", filter)
}
result, err := conn.Search(&ldap.SearchRequest{
BaseDN: cfg.UserDN,
Scope: 2, // subtree
Filter: filter,
SizeLimit: math.MaxInt32,
})
if err != nil {
return userDN, errwrap.Wrapf("LDAP search failed for detecting user: {{err}}", err)
}
for _, e := range result.Entries {
userDN = e.DN
}
} else {
userDN = bindDN
}
return userDN, nil
}
/*
* getLdapGroups queries LDAP and returns a slice describing the set of groups the authenticated user is a member of.
*
* The search query is constructed according to cfg.GroupFilter, and run in context of cfg.GroupDN.
* Groups will be resolved from the query results by following the attribute defined in cfg.GroupAttr.
*
* cfg.GroupFilter is a go template and is compiled with the following context: [UserDN, Username]
* UserDN - The DN of the authenticated user
* Username - The Username of the authenticated user
*
* Example:
* cfg.GroupFilter = "(&(objectClass=group)(member:1.2.840.113556.1.4.1941:={{.UserDN}}))"
* cfg.GroupDN = "OU=Groups,DC=myorg,DC=com"
* cfg.GroupAttr = "cn"
*
* NOTE - If cfg.GroupFilter is empty, no query is performed and an empty result slice is returned.
*
*/
func (c *Client) GetLdapGroups(cfg *ConfigEntry, conn Connection, userDN string, username string) ([]string, error) {
// retrieve the groups in a string/bool map as a structure to avoid duplicates inside
ldapMap := make(map[string]bool)
if cfg.GroupFilter == "" {
c.Logger.Warn("groupfilter is empty, will not query server")
return make([]string, 0), nil
}
if cfg.GroupDN == "" {
c.Logger.Warn("groupdn is empty, will not query server")
return make([]string, 0), nil
}
// If groupfilter was defined, resolve it as a Go template and use the query for
// returning the user's groups
if c.Logger.IsDebug() {
c.Logger.Debug("compiling group filter", "group_filter", cfg.GroupFilter)
}
// Parse the configuration as a template.
// Example template "(&(objectClass=group)(member:1.2.840.113556.1.4.1941:={{.UserDN}}))"
t, err := template.New("queryTemplate").Parse(cfg.GroupFilter)
if err != nil {
return nil, errwrap.Wrapf("LDAP search failed due to template compilation error: {{err}}", err)
}
// Build context to pass to template - we will be exposing UserDn and Username.
context := struct {
UserDN string
Username string
}{
ldap.EscapeFilter(userDN),
ldap.EscapeFilter(username),
}
var renderedQuery bytes.Buffer
t.Execute(&renderedQuery, context)
if c.Logger.IsDebug() {
c.Logger.Debug("searching", "groupdn", cfg.GroupDN, "rendered_query", renderedQuery.String())
}
result, err := conn.Search(&ldap.SearchRequest{
BaseDN: cfg.GroupDN,
Scope: 2, // subtree
Filter: renderedQuery.String(),
Attributes: []string{
cfg.GroupAttr,
},
SizeLimit: math.MaxInt32,
})
if err != nil {
return nil, errwrap.Wrapf("LDAP search failed: {{err}}", err)
}
for _, e := range result.Entries {
dn, err := ldap.ParseDN(e.DN)
if err != nil || len(dn.RDNs) == 0 {
continue
}
// Enumerate attributes of each result, parse out CN and add as group
values := e.GetAttributeValues(cfg.GroupAttr)
if len(values) > 0 {
for _, val := range values {
groupCN := getCN(val)
ldapMap[groupCN] = true
}
} else {
// If groupattr didn't resolve, use self (enumerating group objects)
groupCN := getCN(e.DN)
ldapMap[groupCN] = true
}
}
ldapGroups := make([]string, 0, len(ldapMap))
for key, _ := range ldapMap {
ldapGroups = append(ldapGroups, key)
}
return ldapGroups, nil
}
func escapeLDAPValue(input string) string {
// RFC4514 forbids un-escaped:
// - leading space or hash
// - trailing space
// - special characters '"', '+', ',', ';', '<', '>', '\\'
// - null
for i := 0; i < len(input); i++ {
escaped := false
if input[i] == '\\' {
i++
escaped = true
}
switch input[i] {
case '"', '+', ',', ';', '<', '>', '\\':
if !escaped {
input = input[0:i] + "\\" + input[i:]
i++
}
continue
}
if escaped {
input = input[0:i] + "\\" + input[i:]
i++
}
}
if input[0] == ' ' || input[0] == '#' {
input = "\\" + input
}
if input[len(input)-1] == ' ' {
input = input[0:len(input)-1] + "\\ "
}
return input
}
/*
* Parses a distinguished name and returns the CN portion.
* Given a non-conforming string (such as an already-extracted CN),
* it will be returned as-is.
*/
func getCN(dn string) string {
parsedDN, err := ldap.ParseDN(dn)
if err != nil || len(parsedDN.RDNs) == 0 {
// It was already a CN, return as-is
return dn
}
for _, rdn := range parsedDN.RDNs {
for _, rdnAttr := range rdn.Attributes {
if rdnAttr.Type == "CN" {
return rdnAttr.Value
}
}
}
// Default, return self
return dn
}
func getTLSConfig(cfg *ConfigEntry, host string) (*tls.Config, error) {
tlsConfig := &tls.Config{
ServerName: host,
}
if cfg.TLSMinVersion != "" {
tlsMinVersion, ok := tlsutil.TLSLookup[cfg.TLSMinVersion]
if !ok {
return nil, fmt.Errorf("invalid 'tls_min_version' in config")
}
tlsConfig.MinVersion = tlsMinVersion
}
if cfg.TLSMaxVersion != "" {
tlsMaxVersion, ok := tlsutil.TLSLookup[cfg.TLSMaxVersion]
if !ok {
return nil, fmt.Errorf("invalid 'tls_max_version' in config")
}
tlsConfig.MaxVersion = tlsMaxVersion
}
if cfg.InsecureTLS {
tlsConfig.InsecureSkipVerify = true
}
if cfg.Certificate != "" {
caPool := x509.NewCertPool()
ok := caPool.AppendCertsFromPEM([]byte(cfg.Certificate))
if !ok {
return nil, fmt.Errorf("could not append CA certificate")
}
tlsConfig.RootCAs = caPool
}
return tlsConfig, nil
}

View file

@ -0,0 +1,46 @@
package ldaputil
import (
"testing"
)
func TestLDAPEscape(t *testing.T) {
testcases := map[string]string{
"#test": "\\#test",
"test,hello": "test\\,hello",
"test,hel+lo": "test\\,hel\\+lo",
"test\\hello": "test\\\\hello",
" test ": "\\ test \\ ",
}
for test, answer := range testcases {
res := escapeLDAPValue(test)
if res != answer {
t.Errorf("Failed to escape %s: %s != %s\n", test, res, answer)
}
}
}
func TestGetTLSConfigs(t *testing.T) {
config := testConfig()
if err := config.Validate(); err != nil {
t.Fatal(err)
}
tlsConfig, err := getTLSConfig(config, "138.91.247.105")
if err != nil {
t.Fatal(err)
}
if tlsConfig == nil {
t.Fatal("expected 1 TLS config because there's 1 url")
}
if tlsConfig.InsecureSkipVerify {
t.Fatal("InsecureSkipVerify should be false because we should default to the most secure connection")
}
if tlsConfig.ServerName != "138.91.247.105" {
t.Fatalf("expected ServerName of \"138.91.247.105\" but received %q", tlsConfig.ServerName)
}
expected := uint16(771)
if tlsConfig.MinVersion != expected || tlsConfig.MaxVersion != expected {
t.Fatal("expected TLS min and max version of 771 which corresponds with TLS 1.2 since TLS 1.1 and 1.0 have known vulnerabilities")
}
}

68
helper/ldaputil/config.go Normal file
View file

@ -0,0 +1,68 @@
package ldaputil
import (
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"github.com/hashicorp/vault/helper/tlsutil"
)
type ConfigEntry struct {
Url string `json:"url"`
UserDN string `json:"userdn"`
GroupDN string `json:"groupdn"`
GroupFilter string `json:"groupfilter"`
GroupAttr string `json:"groupattr"`
UPNDomain string `json:"upndomain"`
UserAttr string `json:"userattr"`
Certificate string `json:"certificate"`
InsecureTLS bool `json:"insecure_tls"`
StartTLS bool `json:"starttls"`
BindDN string `json:"binddn"`
BindPassword string `json:"bindpass"`
DenyNullBind bool `json:"deny_null_bind"`
DiscoverDN bool `json:"discoverdn"`
TLSMinVersion string `json:"tls_min_version"`
TLSMaxVersion string `json:"tls_max_version"`
// This json tag deviates from snake case because there was a past issue
// where the tag was being ignored, causing it to be jsonified as "CaseSensitiveNames".
// To continue reading in users' previously stored values,
// we chose to carry that forward.
CaseSensitiveNames *bool `json:"CaseSensitiveNames,omitempty"`
}
func (c *ConfigEntry) Validate() error {
if len(c.Url) == 0 {
return errors.New("at least one url must be provided")
}
// Note: This logic is driven by the logic in GetUserBindDN.
// If updating this, please also update the logic there.
if !c.DiscoverDN && (c.BindDN == "" || c.BindPassword == "") && c.UPNDomain == "" && c.UserDN == "" {
return errors.New("cannot derive UserBindDN")
}
tlsMinVersion, ok := tlsutil.TLSLookup[c.TLSMinVersion]
if !ok {
return errors.New("invalid 'tls_min_version' in config")
}
tlsMaxVersion, ok := tlsutil.TLSLookup[c.TLSMaxVersion]
if !ok {
return errors.New("invalid 'tls_max_version' in config")
}
if tlsMaxVersion < tlsMinVersion {
return errors.New("'tls_max_version' must be greater than or equal to 'tls_min_version'")
}
if c.Certificate != "" {
block, _ := pem.Decode([]byte(c.Certificate))
if block == nil || block.Type != "CERTIFICATE" {
return errors.New("failed to decode PEM block in the certificate")
}
_, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return fmt.Errorf("failed to parse certificate %s", err.Error())
}
}
return nil
}

View file

@ -0,0 +1,74 @@
package ldaputil
import "testing"
func TestCertificateValidation(t *testing.T) {
// certificate should default to "" without error if it doesn't exist
config := testConfig()
if err := config.Validate(); err != nil {
t.Fatal(err)
}
if config.Certificate != "" {
t.Fatalf("expected no certificate but received %s", config.Certificate)
}
// certificate should cause an error if a bad one is provided
config.Certificate = "cats"
if err := config.Validate(); err == nil {
t.Fatal("should err due to bad cert")
}
// valid certificates should pass inspection
config.Certificate = validCertificate
if err := config.Validate(); err != nil {
t.Fatal(err)
}
}
func testConfig() *ConfigEntry {
return &ConfigEntry{
Url: "ldap://138.91.247.105",
UserDN: "example,com",
BindDN: "kitty",
BindPassword: "cats",
TLSMaxVersion: "tls12",
TLSMinVersion: "tls12",
}
}
const validCertificate = `
-----BEGIN CERTIFICATE-----
MIIF7zCCA9egAwIBAgIJAOY2qjn64Qq5MA0GCSqGSIb3DQEBCwUAMIGNMQswCQYD
VQQGEwJVUzEQMA4GA1UECAwHTm93aGVyZTERMA8GA1UEBwwIVGltYnVrdHUxEjAQ
BgNVBAoMCVRlc3QgRmFrZTENMAsGA1UECwwETm9uZTEPMA0GA1UEAwwGTm9ib2R5
MSUwIwYJKoZIhvcNAQkBFhZkb25vdHRydXN0QG5vd2hlcmUuY29tMB4XDTE4MDQw
MzIwNDQwOFoXDTE5MDQwMzIwNDQwOFowgY0xCzAJBgNVBAYTAlVTMRAwDgYDVQQI
DAdOb3doZXJlMREwDwYDVQQHDAhUaW1idWt0dTESMBAGA1UECgwJVGVzdCBGYWtl
MQ0wCwYDVQQLDAROb25lMQ8wDQYDVQQDDAZOb2JvZHkxJTAjBgkqhkiG9w0BCQEW
FmRvbm90dHJ1c3RAbm93aGVyZS5jb20wggIiMA0GCSqGSIb3DQEBAQUAA4ICDwAw
ggIKAoICAQDzQPGErqjaoFcuUV6QFpSMU6w8wO8F0othik+rrlKERmrGonUGsoum
WqRe6L4ZnxBvCKB6EWjvf894TXOF2cpUnjDAyBePISyPkRBEJS6VS2SEC4AJzmVu
a+P+fZr4Hf7/bEcUr7Ax37yGVZ5i5ByNHgZkBlPxKiGWSmAqIDRZLp9gbu2EkG9q
NOjNLPU+QI2ov6U/laGS1vbE2LahTYeT5yscu9LpllxzFv4lM1f4wYEaM3HuOxzT
l86cGmEr9Q2N4PZ2T0O/s6D4but7c6Bz2XPXy9nWb5bqu0n5bJEpbRFrkryW1ozh
L9uVVz4dyW10pFBJtE42bqA4PRCDQsUof7UfsQF11D1ThrDfKsQa8PxrYdGUHUG9
GFF1MdTTwaoT90RI582p+6XYV+LNlXcdfyNZO9bMThu9fnCvT7Ey0TKU4MfPrlfT
aIhZmyaHt6mL5p881UPDIvy7paTLgL+C1orLjZAiT//c4Zn+0qG0//Cirxr020UF
3YiEFk2H0bBVwOHoOGw4w5HrvLdyy0ZLDSPQbzkSZ0RusHb5TjiyhtTk/h9vvJv7
u1fKJub4MzgrBRi16ejFdiWoVuMXRC6fu/ERy3+9DH6LURerbPrdroYypUmTe9N6
XPeaF1Tc+WO7O/yW96mV7X/D211qjkOtwboZC5kjogVbaZgGzjHCVwIDAQABo1Aw
TjAdBgNVHQ4EFgQU2zWT3HeiMBzusz7AggVqVEL5g0UwHwYDVR0jBBgwFoAU2zWT
3HeiMBzusz7AggVqVEL5g0UwDAYDVR0TBAUwAwEB/zANBgkqhkiG9w0BAQsFAAOC
AgEAwTGcppY86mNRE43uOimeApTfqHJv+lGDTjEoJCZZmzmtxFe6O9+Vk4bH/8/i
gVQvqzBpaWXRt9OhqlFMK7OkX4ZvqXmnShmxib1dz1XxGhbwSec9ca8bill59Jqa
bIOq2SXVMcFD0GwFxfJRBVzHHuB6AwV9B2QN61zeB1oxNGJrUOo80jVkB7+MWMyD
bQqiFCHWGMa6BG4N91KGOTveZCGdBvvVw5j6lt731KjbvL2hB1UHioucOweKLfa4
QWDImTEjgV68699wKERNL0DCpeD7PcP/L3SY2RJzdyC1CSR7O8yU4lQK7uZGusgB
Mgup+yUaSjxasIqYMebNDDocr5kdwG0+2r2gQdRwc5zLX6YDBn6NLSWjRnY04ZuK
P1cF68rWteWpzJu8bmkJ5r2cqskqrnVK+zz8xMQyEaj548Bnt51ARLHOftR9jkSU
NJWh7zOLZ1r2UUKdDlrMoh3GQO3rvnCJJ16NBM1dB7TUyhMhtF6UOE62BSKdHtQn
d6TqelcRw9WnDsb9IPxRwaXhvGljnYVAgXXlJEI/6nxj2T4wdmL1LWAr6C7DuWGz
8qIvxc4oAau4DsZs2+BwolCFtYc98OjWGcBStBfZz/YYXM+2hKjbONKFxWdEPxGR
Beq3QOqp2+dga36IzQybzPQ8QtotrpSJ3q82zztEvyWiJ7E=
-----END CERTIFICATE-----
`

View file

@ -0,0 +1,18 @@
package ldaputil
import (
"crypto/tls"
"github.com/go-ldap/ldap"
)
// Connection provides the functionality of an LDAP connection,
// but through an interface.
type Connection interface {
Bind(username, password string) error
Close()
Modify(modifyRequest *ldap.ModifyRequest) error
Search(searchRequest *ldap.SearchRequest) (*ldap.SearchResult, error)
StartTLS(config *tls.Config) error
UnauthenticatedBind(username string) error
}

28
helper/ldaputil/ldap.go Normal file
View file

@ -0,0 +1,28 @@
package ldaputil
import (
"crypto/tls"
"github.com/go-ldap/ldap"
)
func NewLDAP() LDAP {
return &ldapIfc{}
}
// LDAP provides ldap functionality, but through an interface
// rather than statically. This allows faking it for tests.
type LDAP interface {
Dial(network, addr string) (Connection, error)
DialTLS(network, addr string, config *tls.Config) (Connection, error)
}
type ldapIfc struct{}
func (l *ldapIfc) Dial(network, addr string) (Connection, error) {
return ldap.Dial(network, addr)
}
func (l *ldapIfc) DialTLS(network, addr string, config *tls.Config) (Connection, error) {
return ldap.DialTLS(network, addr, config)
}