General-purpose unrolled ASCII tolower() loops

When converting a string to lower case, the compiler is able to
autovectorize nicely, so a nice simple implementation is also very
fast, comparable to memcpy().

Comparisons are more difficult for the compiler, so we convert eight
bytes at a time using "SIMD within a register" tricks. Experiments
indicate it's best to stick to simple loops for shorter strings and
the remainder of long strings.
This commit is contained in:
Tony Finch 2022-06-27 12:57:28 +01:00
parent 27a561273e
commit 21a383a8fd
7 changed files with 177 additions and 223 deletions

View file

@ -1464,9 +1464,11 @@ gcov:
# source files from lib/dns/rdata/*/, using an even nastier trick.
- find lib/dns/rdata/* -name "*.c" -execdir cp -f "{}" ../../ \;
# Help gcovr process inline functions in headers
- cp -f lib/isc/include/isc/*.h lib/dns/
- cp -f lib/dns/include/dns/*.h lib/dns/
- cp -f lib/dns/include/dns/*.h lib/ns/
- cp -f lib/isc/include/isc/*.h lib/isc/
- cp -f lib/isc/include/isc/*.h lib/dns/
- cp -f lib/isc/include/isc/*.h lib/ns/
# Generate XML file in the Cobertura XML format suitable for use by GitLab
# for the purpose of displaying code coverage information in the diff view
# of a given merge request.

View file

@ -237,7 +237,6 @@ dns_compress_find(dns_compress_t *cctx, const dns_name_t *name,
for (node = cctx->table[i]; node != NULL;
node = node->next) {
unsigned int l, count;
unsigned char c;
unsigned char *p1, *p2;
if (node->name.length != length) {
@ -260,39 +259,12 @@ dns_compress_find(dns_compress_t *cctx, const dns_name_t *name,
/* no bitstring support */
INSIST(count <= 63);
/* Loop unrolled for performance */
while (count > 3) {
c = isc_ascii_tolower(p1[0]);
if (c !=
isc_ascii_tolower(p2[0])) {
goto cont1;
}
c = isc_ascii_tolower(p1[1]);
if (c !=
isc_ascii_tolower(p2[1])) {
goto cont1;
}
c = isc_ascii_tolower(p1[2]);
if (c !=
isc_ascii_tolower(p2[2])) {
goto cont1;
}
c = isc_ascii_tolower(p1[3]);
if (c !=
isc_ascii_tolower(p2[3])) {
goto cont1;
}
count -= 4;
p1 += 4;
p2 += 4;
}
while (count-- > 0) {
c = isc_ascii_tolower(*p1++);
if (c !=
isc_ascii_tolower(*p2++)) {
goto cont1;
}
if (!isc_ascii_lowerequal(p1, p2,
count)) {
goto cont1;
}
p1 += count;
p2 += count;
}
break;
cont1:

View file

@ -442,7 +442,7 @@ dns_namereln_t
dns_name_fullcompare(const dns_name_t *name1, const dns_name_t *name2,
int *orderp, unsigned int *nlabelsp) {
unsigned int l1, l2, l, count1, count2, count, nlabels;
int cdiff, ldiff, chdiff;
int cdiff, ldiff, diff;
unsigned char *label1, *label2;
unsigned char *offsets1, *offsets2;
dns_offsets_t odata1, odata2;
@ -492,8 +492,7 @@ dns_name_fullcompare(const dns_name_t *name1, const dns_name_t *name2,
offsets1 += l1;
offsets2 += l2;
while (l > 0) {
l--;
while (l-- > 0) {
offsets1--;
offsets2--;
label1 = &name1->ndata[*offsets1];
@ -501,12 +500,6 @@ dns_name_fullcompare(const dns_name_t *name1, const dns_name_t *name2,
count1 = *label1++;
count2 = *label2++;
/*
* We dropped bitstring labels, and we don't support any
* other extended label types.
*/
INSIST(count1 <= 63 && count2 <= 63);
cdiff = (int)count1 - (int)count2;
if (cdiff < 0) {
count = count1;
@ -514,44 +507,12 @@ dns_name_fullcompare(const dns_name_t *name1, const dns_name_t *name2,
count = count2;
}
/* Loop unrolled for performance */
while (count > 3) {
chdiff = (int)isc_ascii_tolower(label1[0]) -
(int)isc_ascii_tolower(label2[0]);
if (chdiff != 0) {
*orderp = chdiff;
goto done;
}
chdiff = (int)isc_ascii_tolower(label1[1]) -
(int)isc_ascii_tolower(label2[1]);
if (chdiff != 0) {
*orderp = chdiff;
goto done;
}
chdiff = (int)isc_ascii_tolower(label1[2]) -
(int)isc_ascii_tolower(label2[2]);
if (chdiff != 0) {
*orderp = chdiff;
goto done;
}
chdiff = (int)isc_ascii_tolower(label1[3]) -
(int)isc_ascii_tolower(label2[3]);
if (chdiff != 0) {
*orderp = chdiff;
goto done;
}
count -= 4;
label1 += 4;
label2 += 4;
}
while (count-- > 0) {
chdiff = (int)isc_ascii_tolower(*label1++) -
(int)isc_ascii_tolower(*label2++);
if (chdiff != 0) {
*orderp = chdiff;
goto done;
}
diff = isc_ascii_lowercmp(label1, label2, count);
if (diff != 0) {
*orderp = diff;
goto done;
}
if (cdiff != 0) {
*orderp = cdiff;
goto done;
@ -601,9 +562,7 @@ dns_name_compare(const dns_name_t *name1, const dns_name_t *name2) {
bool
dns_name_equal(const dns_name_t *name1, const dns_name_t *name2) {
unsigned int l, count;
unsigned char c;
unsigned char *label1, *label2;
unsigned int length;
/*
* Are 'name1' and 'name2' equal?
@ -626,57 +585,13 @@ dns_name_equal(const dns_name_t *name1, const dns_name_t *name2) {
return (true);
}
if (name1->length != name2->length) {
length = name1->length;
if (length != name2->length) {
return (false);
}
l = name1->labels;
if (l != name2->labels) {
return (false);
}
label1 = name1->ndata;
label2 = name2->ndata;
while (l-- > 0) {
count = *label1++;
if (count != *label2++) {
return (false);
}
INSIST(count <= 63); /* no bitstring support */
/* Loop unrolled for performance */
while (count > 3) {
c = isc_ascii_tolower(label1[0]);
if (c != isc_ascii_tolower(label2[0])) {
return (false);
}
c = isc_ascii_tolower(label1[1]);
if (c != isc_ascii_tolower(label2[1])) {
return (false);
}
c = isc_ascii_tolower(label1[2]);
if (c != isc_ascii_tolower(label2[2])) {
return (false);
}
c = isc_ascii_tolower(label1[3]);
if (c != isc_ascii_tolower(label2[3])) {
return (false);
}
count -= 4;
label1 += 4;
label2 += 4;
}
while (count-- > 0) {
c = isc_ascii_tolower(*label1++);
if (c != isc_ascii_tolower(*label2++)) {
return (false);
}
}
}
return (true);
/* label lengths are < 64 so tolower() does not affect them */
return (isc_ascii_lowerequal(name1->ndata, name2->ndata, length));
}
bool
@ -711,10 +626,6 @@ dns_name_caseequal(const dns_name_t *name1, const dns_name_t *name2) {
int
dns_name_rdatacompare(const dns_name_t *name1, const dns_name_t *name2) {
unsigned int l1, l2, l, count1, count2, count;
unsigned char c1, c2;
unsigned char *label1, *label2;
/*
* Compare two absolute names as rdata.
*/
@ -726,47 +637,9 @@ dns_name_rdatacompare(const dns_name_t *name1, const dns_name_t *name2) {
REQUIRE(name2->labels > 0);
REQUIRE((name2->attributes & DNS_NAMEATTR_ABSOLUTE) != 0);
l1 = name1->labels;
l2 = name2->labels;
l = (l1 < l2) ? l1 : l2;
label1 = name1->ndata;
label2 = name2->ndata;
while (l > 0) {
l--;
count1 = *label1++;
count2 = *label2++;
/* no bitstring support */
INSIST(count1 <= 63 && count2 <= 63);
if (count1 != count2) {
return ((count1 < count2) ? -1 : 1);
}
count = count1;
while (count > 0) {
count--;
c1 = isc_ascii_tolower(*label1++);
c2 = isc_ascii_tolower(*label2++);
if (c1 < c2) {
return (-1);
} else if (c1 > c2) {
return (1);
}
}
}
/*
* If one name had more labels than the other, their common
* prefix must have been different because the shorter name
* ended with the root label and the longer one can't have
* a root label in the middle of it. Therefore, if we get
* to this point, the lengths must be equal.
*/
INSIST(l1 == l2);
return (0);
/* label lengths are < 64 so tolower() does not affect them */
return (isc_ascii_lowercmp(name1->ndata, name2->ndata,
ISC_MIN(name1->length, name2->length)));
}
bool
@ -1572,8 +1445,7 @@ dns_name_tofilenametext(const dns_name_t *name, bool omit_final_dot,
isc_result_t
dns_name_downcase(const dns_name_t *source, dns_name_t *name,
isc_buffer_t *target) {
unsigned char *sndata, *ndata;
unsigned int nlen, count, labels;
unsigned char *ndata;
isc_buffer_t buffer;
/*
@ -1599,33 +1471,13 @@ dns_name_downcase(const dns_name_t *source, dns_name_t *name,
name->ndata = ndata;
}
sndata = source->ndata;
nlen = source->length;
labels = source->labels;
if (nlen > (target->length - target->used)) {
if (source->length > (target->length - target->used)) {
MAKE_EMPTY(name);
return (ISC_R_NOSPACE);
}
while (labels > 0 && nlen > 0) {
labels--;
count = *sndata++;
*ndata++ = count;
nlen--;
if (count < 64) {
INSIST(nlen >= count);
while (count > 0) {
*ndata++ = isc_ascii_tolower(*sndata++);
nlen--;
count--;
}
} else {
FATAL_ERROR(__FILE__, __LINE__,
"Unexpected label type %02x", count);
/* Does not return. */
}
}
/* label lengths are < 64 so tolower() does not affect them */
isc_ascii_lowercopy(ndata, source->ndata, source->length);
if (source != name) {
name->labels = source->labels;

View file

@ -9374,9 +9374,7 @@ rdataset_getownercase(const dns_rdataset_t *rdataset, dns_name_t *name) {
}
if (CASEFULLYLOWER(header)) {
for (size_t i = 0; i < name->length; i++) {
name->ndata[i] = isc_ascii_tolower(name->ndata[i]);
}
isc_ascii_lowercopy(name->ndata, name->ndata, name->length);
} else {
uint8_t *nd = name->ndata;
for (size_t i = 0; i < name->length; i++) {

View file

@ -89,12 +89,9 @@ isc_hash64(const void *data, const size_t length, const bool case_sensitive) {
if (case_sensitive) {
isc_siphash24(isc_hash_key, data, length, (uint8_t *)&hval);
} else {
const uint8_t *byte = data;
uint8_t lower[1024];
REQUIRE(length <= 1024);
for (unsigned i = 0; i < length; i++) {
lower[i] = isc_ascii_tolower(byte[i]);
}
REQUIRE(length <= sizeof(lower));
isc_ascii_lowercopy(lower, data, length);
isc_siphash24(isc_hash_key, lower, length, (uint8_t *)&hval);
}
@ -113,12 +110,9 @@ isc_hash32(const void *data, const size_t length, const bool case_sensitive) {
if (case_sensitive) {
isc_halfsiphash24(isc_hash_key, data, length, (uint8_t *)&hval);
} else {
const uint8_t *byte = data;
uint8_t lower[1024];
REQUIRE(length <= 1024);
for (unsigned i = 0; i < length; i++) {
lower[i] = isc_ascii_tolower(byte[i]);
}
REQUIRE(length <= sizeof(lower));
isc_ascii_lowercopy(lower, data, length);
isc_halfsiphash24(isc_hash_key, lower, length,
(uint8_t *)&hval);
}

View file

@ -13,7 +13,11 @@
#pragma once
#include <stdbool.h>
#include <stdint.h>
#include <string.h>
#include <isc/endian.h>
/*
* ASCII case conversion
@ -27,12 +31,144 @@ extern const uint8_t isc__ascii_toupper[256];
#define isc_ascii_tolower(c) isc__ascii_tolower[(uint8_t)(c)]
#define isc_ascii_toupper(c) isc__ascii_toupper[(uint8_t)(c)]
/*
* A variant tolower() implementation with no memory accesses,
* for use when the compiler is able to autovectorize.
*/
static inline uint8_t
isc__ascii_tolower1(uint8_t c) {
return (c + ('a' - 'A') * ('A' <= c && c <= 'Z'));
}
/*
* Copy `len` bytes from `src` to `dst`, converting to lower case.
*/
static inline void
isc_ascii_lowercopy(uint8_t *dst, const uint8_t *src, unsigned len) {
while (len-- > 0) {
*dst++ = isc__ascii_tolower1(*src++);
}
}
/*
* Convert a string to lower case in place
*/
static inline void
isc_ascii_strtolower(char *str) {
for (size_t len = strlen(str); len > 0; len--, str++) {
*str = isc_ascii_tolower(*str);
}
isc_ascii_lowercopy((uint8_t *)str, (uint8_t *)str,
(unsigned)strlen(str));
}
/*
* Convert 8 bytes to lower case, using SWAR tricks (SIMD within a register).
* Based on "Hacker's Delight" by Henry S. Warren, "searching for a value in a
* given range", p. 95. Eight bytes is wider than many labels in DNS names, so
* it does not seem worth dealing with the portability issues of wide vector
* registers. If there was a vector string load instruction (analogous to
* memove() below) the balance might be different.
*/
static inline uint64_t
isc__ascii_tolower8(uint64_t octets) {
/*
* Multiply a single-byte constant by `all_bytes` to replicate
* it to all eight bytes in a word.
*/
uint64_t all_bytes = 0x0101010101010101;
/*
* Clear the top bit of each byte to make space for a per-byte flag.
*/
uint64_t heptets = octets & (0x7F * all_bytes);
/*
* We will need to avoid going wrong if our flag bits were originally
* set, and clear calculation leftovers in our non-flag bits
*/
uint64_t is_ascii = ~octets & (0x80 * all_bytes);
/*
* To compare a heptet to `N`, we can add `0x7F - N` so that carry
* propagation will set the flag when our heptet is greater than `N`
*/
uint64_t is_gt_Z = heptets + (0x7F - 'Z') * all_bytes;
/*
* Add one for greater-than-or-equal comparison
*/
uint64_t is_ge_A = heptets + (0x80 - 'A') * all_bytes;
/*
* Now we have what we need to identify the ascii uppercase bytes
*/
uint64_t is_upper = (is_ge_A ^ is_gt_Z) & is_ascii;
/*
* Move the is_upper flag bits to bit 0x20 (which is 'a' - 'A')
* and use them to adjust each byte as required
*/
return (octets | (is_upper >> 2));
}
/*
* Helper function to do an unaligned load of 8 bytes in host byte order
*/
static inline uint64_t
isc__ascii_load8(const uint8_t *ptr) {
uint64_t bytes = 0;
memmove(&bytes, ptr, sizeof(bytes));
return (bytes);
}
/*
* Compare `len` bytes at `a` and `b` for case-insensitive equality
*/
static inline bool
isc_ascii_lowerequal(const uint8_t *a, const uint8_t *b, unsigned len) {
uint64_t a8 = 0, b8 = 0;
while (len >= 8) {
a8 = isc__ascii_tolower8(isc__ascii_load8(a));
b8 = isc__ascii_tolower8(isc__ascii_load8(b));
if (a8 != b8) {
return (false);
}
len -= 8;
a += 8;
b += 8;
}
while (len-- > 0) {
if (isc_ascii_tolower(*a++) != isc_ascii_tolower(*b++)) {
return (false);
}
}
return (true);
}
/*
* Compare `len` bytes at `a` and `b` for case-insensitive order.
* Unlike the previous functions (which do not need to care about byte
* order) here we need to ensure the comparisons are lexicographic,
* i.e. they treat the strings as big-endian numbers.
*/
static inline int
isc_ascii_lowercmp(const uint8_t *a, const uint8_t *b, unsigned len) {
uint64_t a8 = 0, b8 = 0;
while (len >= 8) {
a8 = isc__ascii_tolower8(htobe64(isc__ascii_load8(a)));
b8 = isc__ascii_tolower8(htobe64(isc__ascii_load8(b)));
if (a8 != b8) {
goto ret;
}
len -= 8;
a += 8;
b += 8;
}
while (len-- > 0) {
a8 = isc_ascii_tolower(*a++);
b8 = isc_ascii_tolower(*b++);
if (a8 != b8) {
goto ret;
}
}
ret:
if (a8 < b8) {
return (-1);
}
if (a8 > b8) {
return (+1);
}
return (0);
}

View file

@ -61,21 +61,21 @@ ISC_RUN_TEST_IMPL(fullcompare) {
{ "", "", dns_namereln_equal, 0, 0 },
{ "foo", "", dns_namereln_subdomain, 1, 0 },
{ "", "foo", dns_namereln_contains, -1, 0 },
{ "foo", "bar", dns_namereln_none, 4, 0 },
{ "bar", "foo", dns_namereln_none, -4, 0 },
{ "foo", "bar", dns_namereln_none, 1, 0 },
{ "bar", "foo", dns_namereln_none, -1, 0 },
{ "bar.foo", "foo", dns_namereln_subdomain, 1, 1 },
{ "foo", "bar.foo", dns_namereln_contains, -1, 1 },
{ "baz.bar.foo", "bar.foo", dns_namereln_subdomain, 1, 2 },
{ "bar.foo", "baz.bar.foo", dns_namereln_contains, -1, 2 },
{ "foo.example", "bar.example", dns_namereln_commonancestor, 4,
{ "foo.example", "bar.example", dns_namereln_commonancestor, 1,
1 },
/* absolute */
{ ".", ".", dns_namereln_equal, 0, 1 },
{ "foo.", "bar.", dns_namereln_commonancestor, 4, 1 },
{ "bar.", "foo.", dns_namereln_commonancestor, -4, 1 },
{ "foo.", "bar.", dns_namereln_commonancestor, 1, 1 },
{ "bar.", "foo.", dns_namereln_commonancestor, -1, 1 },
{ "foo.example.", "bar.example.", dns_namereln_commonancestor,
4, 2 },
1, 2 },
{ "bar.foo.", "foo.", dns_namereln_subdomain, 1, 2 },
{ "foo.", "bar.foo.", dns_namereln_contains, -1, 2 },
{ "baz.bar.foo.", "bar.foo.", dns_namereln_subdomain, 1, 3 },