diff --git a/src/backend/optimizer/plan/initsplan.c b/src/backend/optimizer/plan/initsplan.c index c20e7e49780..b207b8d913b 100644 --- a/src/backend/optimizer/plan/initsplan.c +++ b/src/backend/optimizer/plan/initsplan.c @@ -913,9 +913,17 @@ create_grouping_expr_infos(PlannerInfo *root) tce->btree_opintype, tce->btree_opintype, BTEQUALIMAGE_PROC); + + /* + * If there is no BTEQUALIMAGE_PROC, eager aggregation is assumed to + * be unsafe. Otherwise, we call the procedure to check. We must be + * careful to pass the expression's actual collation, rather than the + * data type's default collation, to ensure that non-deterministic + * collations are correctly handled. + */ if (!OidIsValid(equalimageproc) || !DatumGetBool(OidFunctionCall1Coll(equalimageproc, - tce->typcollation, + exprCollation((Node *) tle->expr), ObjectIdGetDatum(tce->btree_opintype)))) return; diff --git a/src/backend/optimizer/util/relnode.c b/src/backend/optimizer/util/relnode.c index 91bcda34a37..3fc2c2f71d0 100644 --- a/src/backend/optimizer/util/relnode.c +++ b/src/backend/optimizer/util/relnode.c @@ -3004,9 +3004,17 @@ init_grouping_targets(PlannerInfo *root, RelOptInfo *rel, tce->btree_opintype, tce->btree_opintype, BTEQUALIMAGE_PROC); + + /* + * If there is no BTEQUALIMAGE_PROC, eager aggregation is assumed + * to be unsafe. Otherwise, we call the procedure to check. We + * must be careful to pass the expression's actual collation, + * rather than the data type's default collation, to ensure that + * non-deterministic collations are correctly handled. + */ if (!OidIsValid(equalimageproc) || !DatumGetBool(OidFunctionCall1Coll(equalimageproc, - tce->typcollation, + exprCollation((Node *) expr), ObjectIdGetDatum(tce->btree_opintype)))) return false; diff --git a/src/test/regress/expected/collate.icu.utf8.out b/src/test/regress/expected/collate.icu.utf8.out index d170e7da066..fce726029a2 100644 --- a/src/test/regress/expected/collate.icu.utf8.out +++ b/src/test/regress/expected/collate.icu.utf8.out @@ -2454,11 +2454,11 @@ SELECT c collate "C", count(c) FROM pagg_tab3 GROUP BY c collate "C" ORDER BY 1; SET enable_partitionwise_join TO false; EXPLAIN (COSTS OFF) SELECT t1.c, count(t2.c) FROM pagg_tab3 t1 JOIN pagg_tab3 t2 ON t1.c = t2.c GROUP BY 1 ORDER BY t1.c COLLATE "C"; - QUERY PLAN -------------------------------------------------------------------- + QUERY PLAN +------------------------------------------------------------- Sort Sort Key: t1.c COLLATE "C" - -> Finalize HashAggregate + -> HashAggregate Group Key: t1.c -> Hash Join Hash Cond: (t1.c = t2.c) @@ -2466,12 +2466,10 @@ SELECT t1.c, count(t2.c) FROM pagg_tab3 t1 JOIN pagg_tab3 t2 ON t1.c = t2.c GROU -> Seq Scan on pagg_tab3_p2 t1_1 -> Seq Scan on pagg_tab3_p1 t1_2 -> Hash - -> Partial HashAggregate - Group Key: t2.c - -> Append - -> Seq Scan on pagg_tab3_p2 t2_1 - -> Seq Scan on pagg_tab3_p1 t2_2 -(15 rows) + -> Append + -> Seq Scan on pagg_tab3_p2 t2_1 + -> Seq Scan on pagg_tab3_p1 t2_2 +(13 rows) SELECT t1.c, count(t2.c) FROM pagg_tab3 t1 JOIN pagg_tab3 t2 ON t1.c = t2.c GROUP BY 1 ORDER BY t1.c COLLATE "C"; c | count @@ -2483,11 +2481,11 @@ SELECT t1.c, count(t2.c) FROM pagg_tab3 t1 JOIN pagg_tab3 t2 ON t1.c = t2.c GROU SET enable_partitionwise_join TO true; EXPLAIN (COSTS OFF) SELECT t1.c, count(t2.c) FROM pagg_tab3 t1 JOIN pagg_tab3 t2 ON t1.c = t2.c GROUP BY 1 ORDER BY t1.c COLLATE "C"; - QUERY PLAN -------------------------------------------------------------------- + QUERY PLAN +------------------------------------------------------------- Sort Sort Key: t1.c COLLATE "C" - -> Finalize HashAggregate + -> HashAggregate Group Key: t1.c -> Hash Join Hash Cond: (t1.c = t2.c) @@ -2495,12 +2493,10 @@ SELECT t1.c, count(t2.c) FROM pagg_tab3 t1 JOIN pagg_tab3 t2 ON t1.c = t2.c GROU -> Seq Scan on pagg_tab3_p2 t1_1 -> Seq Scan on pagg_tab3_p1 t1_2 -> Hash - -> Partial HashAggregate - Group Key: t2.c - -> Append - -> Seq Scan on pagg_tab3_p2 t2_1 - -> Seq Scan on pagg_tab3_p1 t2_2 -(15 rows) + -> Append + -> Seq Scan on pagg_tab3_p2 t2_1 + -> Seq Scan on pagg_tab3_p1 t2_2 +(13 rows) SELECT t1.c, count(t2.c) FROM pagg_tab3 t1 JOIN pagg_tab3 t2 ON t1.c = t2.c GROUP BY 1 ORDER BY t1.c COLLATE "C"; c | count @@ -2691,6 +2687,72 @@ DROP TABLE pagg_tab6; RESET enable_partitionwise_aggregate; RESET max_parallel_workers_per_gather; RESET enable_incremental_sort; +-- +-- Test for eager aggregation non-deterministic collation bug +-- +CREATE TABLE eager_agg_t1 (id int, val text COLLATE case_insensitive); +CREATE TABLE eager_agg_t2 (val text COLLATE case_insensitive); +INSERT INTO eager_agg_t1 SELECT 1, 'a' FROM generate_series(1, 50); +INSERT INTO eager_agg_t1 SELECT 1, 'A' FROM generate_series(1, 50); +INSERT INTO eager_agg_t2 VALUES ('A'); +ANALYZE eager_agg_t1; +ANALYZE eager_agg_t2; +-- Ensure that eager aggregation is not used for t1.val due to the +-- non-deterministic collation. +EXPLAIN (COSTS OFF) +SELECT t1.id, count(t1.val) + FROM eager_agg_t1 t1 + JOIN eager_agg_t2 t2 ON t1.val = t2.val COLLATE "C" +GROUP BY t1.id; + QUERY PLAN +-------------------------------------------------------- + HashAggregate + Group Key: t1.id + -> Nested Loop + Join Filter: ((t1.val)::text = (t2.val)::text) + -> Seq Scan on eager_agg_t2 t2 + -> Seq Scan on eager_agg_t1 t1 +(6 rows) + +-- Ensure it returns 1 row with count = 50 +SELECT t1.id, count(t1.val) + FROM eager_agg_t1 t1 + JOIN eager_agg_t2 t2 ON t1.val = t2.val COLLATE "C" +GROUP BY t1.id; + id | count +----+------- + 1 | 50 +(1 row) + +-- Ensure that eager aggregation is not used when grouping by a column with +-- non-deterministic collation. +EXPLAIN (COSTS OFF) +SELECT t1.id, t1.val, count(t1.val) + FROM eager_agg_t1 t1 + JOIN eager_agg_t2 t2 ON t1.val = t2.val COLLATE "C" +GROUP BY t1.id, t1.val; + QUERY PLAN +-------------------------------------------------------- + HashAggregate + Group Key: t1.id, t1.val + -> Nested Loop + Join Filter: ((t1.val)::text = (t2.val)::text) + -> Seq Scan on eager_agg_t2 t2 + -> Seq Scan on eager_agg_t1 t1 +(6 rows) + +-- Ensure it returns 1 row with count = 50 +SELECT t1.id, t1.val, count(t1.val) + FROM eager_agg_t1 t1 + JOIN eager_agg_t2 t2 ON t1.val = t2.val COLLATE "C" +GROUP BY t1.id, t1.val; + id | val | count +----+-----+------- + 1 | A | 50 +(1 row) + +DROP TABLE eager_agg_t1; +DROP TABLE eager_agg_t2; -- virtual generated columns CREATE TABLE t5 ( a int, diff --git a/src/test/regress/sql/collate.icu.utf8.sql b/src/test/regress/sql/collate.icu.utf8.sql index 8f0f973f5fa..0bf65a63535 100644 --- a/src/test/regress/sql/collate.icu.utf8.sql +++ b/src/test/regress/sql/collate.icu.utf8.sql @@ -990,6 +990,51 @@ RESET enable_partitionwise_aggregate; RESET max_parallel_workers_per_gather; RESET enable_incremental_sort; +-- +-- Test for eager aggregation non-deterministic collation bug +-- + +CREATE TABLE eager_agg_t1 (id int, val text COLLATE case_insensitive); +CREATE TABLE eager_agg_t2 (val text COLLATE case_insensitive); + +INSERT INTO eager_agg_t1 SELECT 1, 'a' FROM generate_series(1, 50); +INSERT INTO eager_agg_t1 SELECT 1, 'A' FROM generate_series(1, 50); +INSERT INTO eager_agg_t2 VALUES ('A'); + +ANALYZE eager_agg_t1; +ANALYZE eager_agg_t2; + +-- Ensure that eager aggregation is not used for t1.val due to the +-- non-deterministic collation. +EXPLAIN (COSTS OFF) +SELECT t1.id, count(t1.val) + FROM eager_agg_t1 t1 + JOIN eager_agg_t2 t2 ON t1.val = t2.val COLLATE "C" +GROUP BY t1.id; + +-- Ensure it returns 1 row with count = 50 +SELECT t1.id, count(t1.val) + FROM eager_agg_t1 t1 + JOIN eager_agg_t2 t2 ON t1.val = t2.val COLLATE "C" +GROUP BY t1.id; + +-- Ensure that eager aggregation is not used when grouping by a column with +-- non-deterministic collation. +EXPLAIN (COSTS OFF) +SELECT t1.id, t1.val, count(t1.val) + FROM eager_agg_t1 t1 + JOIN eager_agg_t2 t2 ON t1.val = t2.val COLLATE "C" +GROUP BY t1.id, t1.val; + +-- Ensure it returns 1 row with count = 50 +SELECT t1.id, t1.val, count(t1.val) + FROM eager_agg_t1 t1 + JOIN eager_agg_t2 t2 ON t1.val = t2.val COLLATE "C" +GROUP BY t1.id, t1.val; + +DROP TABLE eager_agg_t1; +DROP TABLE eager_agg_t2; + -- virtual generated columns CREATE TABLE t5 ( a int,