diff --git a/sys/netlink/netlink_message_parser.c b/sys/netlink/netlink_message_parser.c index 451d9d49749..dc0c3871261 100644 --- a/sys/netlink/netlink_message_parser.c +++ b/sys/netlink/netlink_message_parser.c @@ -152,17 +152,27 @@ nl_get_attrs_bmask_raw(struct nlattr *nla_head, int len, struct nlattr_bmask *bm { struct nlattr *nla = NULL; - bzero(bm->mask, sizeof(bm->mask)); + BIT_ZERO(NL_ATTR_BMASK_SIZE, bm); NLA_FOREACH(nla, nla_head, len) { if (nla->nla_len < sizeof(struct nlattr)) return; int nla_type = nla->nla_type & NLA_TYPE_MASK; - if (nla_type <= sizeof(bm->mask) * 8) - bm->mask[nla_type / 8] |= 1 << (nla_type % 8); + if (nla_type < NL_ATTR_BMASK_SIZE) + BIT_SET(NL_ATTR_BMASK_SIZE, nla_type, bm); + else + NL_LOG(LOG_DEBUG2, "Skipping type %d in the mask: too short", + nla_type); } } +bool +nl_has_attr(const struct nlattr_bmask *bm, unsigned int nla_type) +{ + MPASS(nla_type < NL_ATTR_BMASK_SIZE); + + return (BIT_ISSET(NL_ATTR_BMASK_SIZE, nla_type, bm)); +} int nlattr_get_flag(struct nlattr *nla, struct nl_pstate *npt, const void *arg, void *target) diff --git a/sys/netlink/netlink_message_parser.h b/sys/netlink/netlink_message_parser.h index 3f64c1967f0..94f0ca5260d 100644 --- a/sys/netlink/netlink_message_parser.h +++ b/sys/netlink/netlink_message_parser.h @@ -29,6 +29,9 @@ #define _NETLINK_NETLINK_MESSAGE_PARSER_H_ #ifdef _KERNEL + +#include + /* * It is not meant to be included directly */ @@ -152,18 +155,11 @@ static const struct nlhdr_parser _name = { \ .np_size = NL_ARRAY_LEN(_np), \ } -struct nlattr_bmask { - uint64_t mask[2]; -}; +#define NL_ATTR_BMASK_SIZE 128 +BITSET_DEFINE(nlattr_bmask, NL_ATTR_BMASK_SIZE); -static inline bool -nl_has_attr(const struct nlattr_bmask *bm, unsigned int attr_type) -{ - MPASS(attr_type < sizeof(bm->mask) * 8); - - return ((bm->mask[attr_type / 8] & (1 << (attr_type % 8)))); -} void nl_get_attrs_bmask_raw(struct nlattr *nla_head, int len, struct nlattr_bmask *bm); +bool nl_has_attr(const struct nlattr_bmask *bm, unsigned int nla_type); int nl_parse_attrs_raw(struct nlattr *nla_head, int len, const struct nlattr_parser *ps, int pslen, struct nl_pstate *npt, void *target);