Skip common prefixes during radix sort

During the counting step, keep track of the bits that are the same
for the entire input.  If we counted only a single distinct byte,
the next recursion will start at the next byte position that has
more than one distinct byte in the input. This allows us to skip over
multiple passes where the byte is the same for the entire input.

This provides a significant speedup for integers that have some upper
bytes with all-zeros or all-ones, which is common.

Reviewed-by: Chengpeng Yan <chengpeng_yan@outlook.com>
Reviewed-by: ChangAo Chen <cca5507@qq.com>
Discussion: https://postgr.es/m/CANWCAZYpGMDSSwAa18fOxJGXaPzVdyPsWpOkfCX32DWh3Qznzw@mail.gmail.com
This commit is contained in:
John Naylor 2026-04-01 14:18:57 +07:00
parent 21b018e7ea
commit f6bd9f0fe2

View file

@ -104,6 +104,7 @@
#include "commands/tablespace.h"
#include "miscadmin.h"
#include "pg_trace.h"
#include "port/pg_bitutils.h"
#include "storage/shmem.h"
#include "utils/guc.h"
#include "utils/memutils.h"
@ -2659,17 +2660,25 @@ radix_sort_recursive(SortTuple *begin, size_t n_elems, int level, Tuplesortstate
int num_partitions = 0;
int num_remaining;
SortSupport ssup = &state->base.sortKeys[0];
Datum ref_datum;
Datum common_upper_bits = 0;
size_t start_offset = 0;
SortTuple *partition_begin = begin;
int next_level;
/* count number of occurrences of each byte */
ref_datum = normalize_datum(begin[0].datum1, ssup);
for (SortTuple *st = begin; st < begin + n_elems; st++)
{
Datum this_datum;
uint8 this_partition;
this_datum = normalize_datum(st->datum1, ssup);
/* accumulate bits different from the reference datum */
common_upper_bits |= ref_datum ^ this_datum;
/* extract the byte for this level from the normalized datum */
this_partition = current_byte(normalize_datum(st->datum1, ssup),
level);
this_partition = current_byte(this_datum, level);
/* save it for the permutation step */
st->curbyte = this_partition;
@ -2747,6 +2756,33 @@ radix_sort_recursive(SortTuple *begin, size_t n_elems, int level, Tuplesortstate
}
/* recurse */
if (num_partitions == 1)
{
/*
* There is only one distinct byte at the current level. It can happen
* that some subsequent bytes are also the same for all input values,
* such as the upper bytes of small integers. To skip unproductive
* passes for that case, we compute the level where the input has more
* than one distinct byte, so that the next recursion can start there.
*/
if (common_upper_bits == 0)
next_level = sizeof(Datum);
else
{
int diffpos;
/*
* The upper bits of common_upper_bits are zero where all datums
* have the same bits.
*/
diffpos = pg_leftmost_one_pos64(DatumGetUInt64(common_upper_bits));
next_level = sizeof(Datum) - 1 - (diffpos / BITS_PER_BYTE);
}
}
else
next_level = level + 1;
for (uint8 *rp = remaining_partitions;
rp < remaining_partitions + num_partitions;
rp++)
@ -2757,7 +2793,7 @@ radix_sort_recursive(SortTuple *begin, size_t n_elems, int level, Tuplesortstate
if (num_elements > 1)
{
if (level < sizeof(Datum) - 1)
if (next_level < sizeof(Datum))
{
if (num_elements < QSORT_THRESHOLD)
{
@ -2770,7 +2806,7 @@ radix_sort_recursive(SortTuple *begin, size_t n_elems, int level, Tuplesortstate
{
radix_sort_recursive(partition_begin,
num_elements,
level + 1,
next_level,
state);
}
}