Skip to content

Commit 83a2394

Browse files
fix aggregation to not push partial distinct - all or none
Signed-off-by: Harshit Gangal <harshit@planetscale.com>
1 parent 254e401 commit 83a2394

5 files changed

Lines changed: 132 additions & 48 deletions

File tree

go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -535,9 +535,11 @@ func TestDistinctAggregation(t *testing.T) {
535535
mcmp.Exec("insert into t1(t1_id, `name`, `value`, shardkey) values(1,'a1','foo',100), (2,'b1','foo',200), (3,'c1','foo',300), (4,'a1','foo',100), (5,'d1','toto',200), (6,'c1','tata',893), (7,'a1','titi',2380), (8,'b1','tete',12833), (9,'e1','yoyo',783493)")
536536

537537
for _, query := range []string{
538-
`SELECT /*vt+ PLANNER=gen4 */ COUNT(DISTINCT value), SUM(DISTINCT shardkey) FROM t1`,
538+
// `SELECT /*vt+ PLANNER=gen4 */ COUNT(DISTINCT value), SUM(DISTINCT shardkey) FROM t1`, - fails as different distinct expression.
539539
`SELECT /*vt+ PLANNER=gen4 */ a.t1_id, SUM(DISTINCT b.shardkey) FROM t1 a, t1 b group by a.t1_id`,
540540
`SELECT /*vt+ PLANNER=gen4 */ a.value, SUM(DISTINCT b.shardkey) FROM t1 a, t1 b group by a.value`,
541+
// `SELECT /*vt+ PLANNER=gen4 */ count(distinct a.value), SUM(DISTINCT b.t1_id) FROM t1 a, t1 b`, - fails as different distinct expression.
542+
`SELECT /*vt+ PLANNER=gen4 */ a.value, SUM(DISTINCT b.t1_id) FROM t1 a, t1 b group by a.value`,
541543
} {
542544
mcmp.Run(query, func(mcmp *utils.MySQLCompare) {
543545
mcmp.Exec(query)

go/vt/vtgate/engine/scalar_aggregation_test.go

Lines changed: 60 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@ import (
2424
"github.com/stretchr/testify/assert"
2525
"github.com/stretchr/testify/require"
2626

27-
"vitess.io/vitess/go/mysql/collations"
28-
2927
"vitess.io/vitess/go/test/utils"
3028

3129
"vitess.io/vitess/go/sqltypes"
@@ -258,33 +256,76 @@ func TestScalarGroupConcatWithAggrOnEngine(t *testing.T) {
258256
}
259257

260258
// TestScalarDistinctAggr tests distinct aggregation on engine.
261-
func TestScalarDistinctAggr(t *testing.T) {
259+
func TestScalarDistinctAggrOnEngine(t *testing.T) {
260+
fields := sqltypes.MakeTestFields(
261+
"value|value",
262+
"int64|int64",
263+
)
264+
265+
fp := &fakePrimitive{results: []*sqltypes.Result{sqltypes.MakeTestResult(
266+
fields,
267+
"100|100",
268+
"200|200",
269+
"200|200",
270+
"400|400",
271+
"400|400",
272+
"600|600",
273+
)}}
274+
275+
oa := &ScalarAggregate{
276+
Aggregates: []*AggregateParams{
277+
NewAggregateParam(AggregateCountDistinct, 0, "count(distinct value)"),
278+
NewAggregateParam(AggregateSumDistinct, 1, "sum(distinct value)"),
279+
},
280+
Input: fp,
281+
}
282+
qr, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false)
283+
require.NoError(t, err)
284+
require.Equal(t, `[[INT64(4) DECIMAL(1300)]]`, fmt.Sprintf("%v", qr.Rows))
285+
286+
fp.rewind()
287+
results := &sqltypes.Result{}
288+
err = oa.TryStreamExecute(context.Background(), &noopVCursor{}, nil, true, func(qr *sqltypes.Result) error {
289+
if qr.Fields != nil {
290+
results.Fields = qr.Fields
291+
}
292+
results.Rows = append(results.Rows, qr.Rows...)
293+
return nil
294+
})
295+
require.NoError(t, err)
296+
require.Equal(t, `[[INT64(4) DECIMAL(1300)]]`, fmt.Sprintf("%v", results.Rows))
297+
}
298+
299+
func TestScalarDistinctPushedDown(t *testing.T) {
262300
fields := sqltypes.MakeTestFields(
263-
"value|sum(distinct shardkey)",
264-
"varchar|decimal",
301+
"count(distinct value)|sum(distinct value)",
302+
"int64|decimal",
265303
)
266304

267305
fp := &fakePrimitive{results: []*sqltypes.Result{sqltypes.MakeTestResult(
268306
fields,
269-
"foo|600",
270-
"tata|893",
271-
"tete|12833",
272-
"titi|2380",
273-
"toto|200",
274-
"yoyo|783493",
307+
"2|200",
308+
"6|400",
309+
"3|700",
310+
"1|10",
311+
"7|30",
312+
"8|90",
275313
)}}
276-
param := NewAggregateParam(AggregateCountDistinct, 0, "count(distinct value)")
277-
param.CollationID = collations.CollationUtf8mb4ID
278314

279-
param2 := NewAggregateParam(AggregateSum, 1, "sum(distinct sharkey)")
280-
param2.OrigOpcode = AggregateSumDistinct
315+
countAggr := NewAggregateParam(AggregateSum, 0, "count(distinct value)")
316+
countAggr.OrigOpcode = AggregateCountDistinct
317+
sumAggr := NewAggregateParam(AggregateSum, 1, "sum(distinct value)")
318+
sumAggr.OrigOpcode = AggregateSumDistinct
281319
oa := &ScalarAggregate{
282-
Aggregates: []*AggregateParams{param, param2},
283-
Input: fp,
320+
Aggregates: []*AggregateParams{
321+
countAggr,
322+
sumAggr,
323+
},
324+
Input: fp,
284325
}
285326
qr, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, false)
286327
require.NoError(t, err)
287-
require.Equal(t, `[INT64(6) DECIMAL(800199)]`, fmt.Sprintf("%v", qr.Rows))
328+
require.Equal(t, `[[INT64(27) DECIMAL(1430)]]`, fmt.Sprintf("%v", qr.Rows))
288329

289330
fp.rewind()
290331
results := &sqltypes.Result{}
@@ -296,7 +337,7 @@ func TestScalarDistinctAggr(t *testing.T) {
296337
return nil
297338
})
298339
require.NoError(t, err)
299-
require.Equal(t, `[INT64(6) DECIMAL(800199)]`, fmt.Sprintf("%v", results.Rows))
340+
require.Equal(t, `[[INT64(27) DECIMAL(1430)]]`, fmt.Sprintf("%v", results.Rows))
300341
}
301342

302343
// TestScalarGroupConcat tests group_concat with partial aggregation on engine.

go/vt/vtgate/planbuilder/operators/aggregation_pushing.go

Lines changed: 64 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -115,35 +115,68 @@ func pushDownAggregationThroughRoute(
115115

116116
// pushDownAggregations splits aggregations between the original aggregator and the one we are pushing down
117117
func pushDownAggregations(ctx *plancontext.PlanningContext, aggregator *Aggregator, aggrBelowRoute *Aggregator) error {
118-
for i, aggregation := range aggregator.Aggregations {
119-
if !aggregation.Distinct || exprHasUniqueVindex(ctx, aggregation.Func.GetArg()) {
120-
aggrBelowRoute.Aggregations = append(aggrBelowRoute.Aggregations, aggregation)
121-
aggregateTheAggregate(aggregator, i)
118+
canPushDownDistinct, distinctExpr, err := checkIfWeCanPushDown(ctx, aggregator)
119+
if err != nil {
120+
return err
121+
}
122+
123+
if !canPushDownDistinct {
124+
aggregator.DistinctExpr = distinctExpr
125+
}
126+
127+
aeDistinctExpr := aeWrap(aggregator.DistinctExpr)
128+
offset := -1
129+
for i, aggr := range aggregator.Aggregations {
130+
if aggr.Distinct && !canPushDownDistinct {
131+
offset = aggr.ColOffset
132+
aggrBelowRoute.Columns[offset] = aeDistinctExpr
122133
continue
123134
}
124-
innerExpr := aggregation.Func.GetArg()
135+
aggrBelowRoute.Aggregations = append(aggrBelowRoute.Aggregations, aggr)
136+
aggregateTheAggregate(aggregator, i)
137+
}
125138

126-
if aggregator.DistinctExpr != nil {
127-
if ctx.SemTable.EqualsExpr(aggregator.DistinctExpr, innerExpr) {
128-
// we can handle multiple distinct aggregations, as long as they are aggregating on the same expression
129-
aggrBelowRoute.Columns[aggregation.ColOffset] = aeWrap(innerExpr)
130-
continue
131-
}
132-
return vterrors.VT12001(fmt.Sprintf("only one DISTINCT aggregation is allowed in a SELECT: %s", sqlparser.String(aggregation.Original)))
133-
}
139+
// everything is pushed below the route.
140+
if canPushDownDistinct {
141+
return nil
142+
}
134143

135-
// We handle a distinct aggregation by turning it into a group by and
136-
// doing the aggregating on the vtgate level instead
137-
aggregator.DistinctExpr = innerExpr
138-
aeDistinctExpr := aeWrap(aggregator.DistinctExpr)
144+
// We handle a distinct aggregation by turning it into a group by and
145+
// doing the aggregating on the vtgate level instead
146+
// Adding to group by can be done only once even though there are multiple distinct aggregation with same expression.
147+
groupBy := NewGroupBy(aggregator.DistinctExpr, aggregator.DistinctExpr, aeDistinctExpr)
148+
groupBy.ColOffset = offset
149+
aggrBelowRoute.Grouping = append(aggrBelowRoute.Grouping, groupBy)
139150

140-
aggrBelowRoute.Columns[aggregation.ColOffset] = aeDistinctExpr
151+
return nil
152+
}
153+
154+
func checkIfWeCanPushDown(ctx *plancontext.PlanningContext, aggregator *Aggregator) (bool, sqlparser.Expr, error) {
155+
canPushDown := true
156+
var distinctExpr sqlparser.Expr
157+
var differentExpr *sqlparser.AliasedExpr
141158

142-
groupBy := NewGroupBy(aggregator.DistinctExpr, aggregator.DistinctExpr, aeDistinctExpr)
143-
groupBy.ColOffset = aggregation.ColOffset
144-
aggrBelowRoute.Grouping = append(aggrBelowRoute.Grouping, groupBy)
159+
for _, aggr := range aggregator.Aggregations {
160+
if !aggr.Distinct {
161+
continue
162+
}
163+
innerExpr := aggr.Func.GetArg()
164+
if !exprHasUniqueVindex(ctx, innerExpr) {
165+
canPushDown = false
166+
}
167+
if distinctExpr == nil {
168+
distinctExpr = innerExpr
169+
}
170+
if !ctx.SemTable.EqualsExpr(distinctExpr, innerExpr) {
171+
differentExpr = aggr.Original
172+
}
145173
}
146-
return nil
174+
175+
if !canPushDown && differentExpr != nil {
176+
return false, nil, vterrors.VT12001(fmt.Sprintf("only one DISTINCT aggregation is allowed in a SELECT: %s", sqlparser.String(differentExpr)))
177+
}
178+
179+
return canPushDown, distinctExpr, nil
147180
}
148181

149182
func pushDownAggregationThroughFilter(
@@ -411,6 +444,15 @@ func splitAggrColumnsToLeftAndRight(
411444
outerJoin: join.LeftJoin,
412445
}
413446

447+
canPushDownDistinct, distinctExpr, err := checkIfWeCanPushDown(ctx, aggregator)
448+
if err != nil {
449+
return nil, nil, err
450+
}
451+
if !canPushDownDistinct {
452+
aggregator.DistinctExpr = distinctExpr
453+
return nil, nil, errAbortAggrPushing
454+
}
455+
414456
outer:
415457
// we prefer adding the aggregations in the same order as the columns are declared
416458
for colIdx, col := range aggregator.Columns {
@@ -509,9 +551,6 @@ func (ab *aggBuilder) handleAggr(ctx *plancontext.PlanningContext, aggr Aggr) er
509551
// this is only used for SHOW GTID queries that will never contain joins
510552
return vterrors.VT13001("cannot do join with vgtid")
511553
case opcode.AggregateSumDistinct, opcode.AggregateCountDistinct:
512-
if !exprHasUniqueVindex(ctx, aggr.Func.GetArg()) {
513-
return errAbortAggrPushing
514-
}
515554
return ab.handlePushThroughAggregation(ctx, aggr)
516555
default:
517556
return errHorizonNotPlanned()

go/vt/vtgate/planbuilder/operators/aggregator.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ type (
4242
Grouping []GroupBy
4343
Aggregations []Aggr
4444

45-
// We support a single distinct aggregation per aggregator. It is stored here
45+
// We support a single distinct aggregation per aggregator. It is stored here.
46+
// When planning the ordering that the OrderedAggregate will require,
47+
// this needs to be the last ORDER BY expression
4648
DistinctExpr sqlparser.Expr
4749

4850
// Pushed will be set to true once this aggregation has been pushed deeper in the tree

go/vt/vtgate/planbuilder/testdata/aggr_cases.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3854,8 +3854,8 @@
38543854
"Sharded": true
38553855
},
38563856
"FieldQuery": "select u.textcol1, u.val2, weight_string(u.val2) from `user` as u where 1 != 1",
3857-
"OrderBy": "0 ASC COLLATE latin1_swedish_ci",
3858-
"Query": "select u.textcol1, u.val2, weight_string(u.val2) from `user` as u order by u.textcol1 asc",
3857+
"OrderBy": "0 ASC COLLATE latin1_swedish_ci, (1|2) ASC",
3858+
"Query": "select u.textcol1, u.val2, weight_string(u.val2) from `user` as u order by u.textcol1 asc, u.val2 asc",
38593859
"Table": "`user`"
38603860
},
38613861
{

0 commit comments

Comments
 (0)