Skip to content

Commit b5b1985

Browse files
c21cloud-fan
authored andcommitted
[SPARK-34620][SQL] Code-gen broadcast nested loop join (inner/cross)
### What changes were proposed in this pull request? `BroadcastNestedLoopJoinExec` does not have code-gen, and we can potentially boost the CPU performance for this operator if we add code-gen for it. https://databricks.com/blog/2017/02/16/processing-trillion-rows-per-second-single-machine-can-nested-loop-joins-fast.html also showed the evidence in one fork. The codegen for `BroadcastNestedLoopJoinExec` shared some code with `HashJoin`, and the interface `JoinCodegenSupport` is created to hold those common logic. This PR is only supporting inner and cross join. Other join types will be added later in followup PRs. Example query and generated code: ``` val df1 = spark.range(4).select($"id".as("k1")) val df2 = spark.range(3).select($"id".as("k2")) df1.join(df2, $"k1" + 1 =!= $"k2").explain("codegen") ``` ``` == Subtree 2 / 2 (maxMethodCodeSize:282; maxConstantPoolSize:203(0.31% used); numInnerClasses:0) == *(2) BroadcastNestedLoopJoin BuildRight, Inner, NOT ((k1#2L + 1) = k2#6L) :- *(2) Project [id#0L AS k1#2L] : +- *(2) Range (0, 4, step=1, splits=2) +- BroadcastExchange IdentityBroadcastMode, [id=#22] +- *(1) Project [id#4L AS k2#6L] +- *(1) Range (0, 3, step=1, splits=2) Generated code: /* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIteratorForCodegenStage2(references); /* 003 */ } /* 004 */ /* 005 */ // codegenStageId=2 /* 006 */ final class GeneratedIteratorForCodegenStage2 extends org.apache.spark.sql.execution.BufferedRowIterator { /* 007 */ private Object[] references; /* 008 */ private scala.collection.Iterator[] inputs; /* 009 */ private boolean range_initRange_0; /* 010 */ private long range_nextIndex_0; /* 011 */ private TaskContext range_taskContext_0; /* 012 */ private InputMetrics range_inputMetrics_0; /* 013 */ private long range_batchEnd_0; /* 014 */ private long range_numElementsTodo_0; /* 015 */ private InternalRow[] bnlj_buildRowArray_0; /* 016 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] range_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[4]; /* 017 */ /* 018 */ public GeneratedIteratorForCodegenStage2(Object[] references) { /* 019 */ this.references = references; /* 020 */ } /* 021 */ /* 022 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 023 */ partitionIndex = index; /* 024 */ this.inputs = inputs; /* 025 */ /* 026 */ range_taskContext_0 = TaskContext.get(); /* 027 */ range_inputMetrics_0 = range_taskContext_0.taskMetrics().inputMetrics(); /* 028 */ range_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0); /* 029 */ range_mutableStateArray_0[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0); /* 030 */ range_mutableStateArray_0[2] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0); /* 031 */ bnlj_buildRowArray_0 = (InternalRow[]) ((org.apache.spark.broadcast.TorrentBroadcast) references[1] /* broadcastTerm */).value(); /* 032 */ range_mutableStateArray_0[3] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(2, 0); /* 033 */ /* 034 */ } /* 035 */ /* 036 */ private void bnlj_doConsume_0(long bnlj_expr_0_0) throws java.io.IOException { /* 037 */ for (int bnlj_arrayIndex_0 = 0; bnlj_arrayIndex_0 < bnlj_buildRowArray_0.length; bnlj_arrayIndex_0++) { /* 038 */ UnsafeRow bnlj_buildRow_0 = (UnsafeRow) bnlj_buildRowArray_0[bnlj_arrayIndex_0]; /* 039 */ /* 040 */ long bnlj_value_1 = bnlj_buildRow_0.getLong(0); /* 041 */ /* 042 */ long bnlj_value_4 = -1L; /* 043 */ /* 044 */ bnlj_value_4 = bnlj_expr_0_0 + 1L; /* 045 */ /* 046 */ boolean bnlj_value_3 = false; /* 047 */ bnlj_value_3 = bnlj_value_4 == bnlj_value_1; /* 048 */ boolean bnlj_value_2 = false; /* 049 */ bnlj_value_2 = !(bnlj_value_3); /* 050 */ if (!(false || !bnlj_value_2)) /* 051 */ { /* 052 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[2] /* numOutputRows */).add(1); /* 053 */ /* 054 */ range_mutableStateArray_0[3].reset(); /* 055 */ /* 056 */ range_mutableStateArray_0[3].write(0, bnlj_expr_0_0); /* 057 */ /* 058 */ range_mutableStateArray_0[3].write(1, bnlj_value_1); /* 059 */ append((range_mutableStateArray_0[3].getRow()).copy()); /* 060 */ /* 061 */ } /* 062 */ } /* 063 */ /* 064 */ } /* 065 */ /* 066 */ private void initRange(int idx) { /* 067 */ java.math.BigInteger index = java.math.BigInteger.valueOf(idx); /* 068 */ java.math.BigInteger numSlice = java.math.BigInteger.valueOf(2L); /* 069 */ java.math.BigInteger numElement = java.math.BigInteger.valueOf(4L); /* 070 */ java.math.BigInteger step = java.math.BigInteger.valueOf(1L); /* 071 */ java.math.BigInteger start = java.math.BigInteger.valueOf(0L); /* 072 */ long partitionEnd; /* 073 */ /* 074 */ java.math.BigInteger st = index.multiply(numElement).divide(numSlice).multiply(step).add(start); /* 075 */ if (st.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) { /* 076 */ range_nextIndex_0 = Long.MAX_VALUE; /* 077 */ } else if (st.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) { /* 078 */ range_nextIndex_0 = Long.MIN_VALUE; /* 079 */ } else { /* 080 */ range_nextIndex_0 = st.longValue(); /* 081 */ } /* 082 */ range_batchEnd_0 = range_nextIndex_0; /* 083 */ /* 084 */ java.math.BigInteger end = index.add(java.math.BigInteger.ONE).multiply(numElement).divide(numSlice) /* 085 */ .multiply(step).add(start); /* 086 */ if (end.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) { /* 087 */ partitionEnd = Long.MAX_VALUE; /* 088 */ } else if (end.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) { /* 089 */ partitionEnd = Long.MIN_VALUE; /* 090 */ } else { /* 091 */ partitionEnd = end.longValue(); /* 092 */ } /* 093 */ /* 094 */ java.math.BigInteger startToEnd = java.math.BigInteger.valueOf(partitionEnd).subtract( /* 095 */ java.math.BigInteger.valueOf(range_nextIndex_0)); /* 096 */ range_numElementsTodo_0 = startToEnd.divide(step).longValue(); /* 097 */ if (range_numElementsTodo_0 < 0) { /* 098 */ range_numElementsTodo_0 = 0; /* 099 */ } else if (startToEnd.remainder(step).compareTo(java.math.BigInteger.valueOf(0L)) != 0) { /* 100 */ range_numElementsTodo_0++; /* 101 */ } /* 102 */ } /* 103 */ /* 104 */ protected void processNext() throws java.io.IOException { /* 105 */ // initialize Range /* 106 */ if (!range_initRange_0) { /* 107 */ range_initRange_0 = true; /* 108 */ initRange(partitionIndex); /* 109 */ } /* 110 */ /* 111 */ while (true) { /* 112 */ if (range_nextIndex_0 == range_batchEnd_0) { /* 113 */ long range_nextBatchTodo_0; /* 114 */ if (range_numElementsTodo_0 > 1000L) { /* 115 */ range_nextBatchTodo_0 = 1000L; /* 116 */ range_numElementsTodo_0 -= 1000L; /* 117 */ } else { /* 118 */ range_nextBatchTodo_0 = range_numElementsTodo_0; /* 119 */ range_numElementsTodo_0 = 0; /* 120 */ if (range_nextBatchTodo_0 == 0) break; /* 121 */ } /* 122 */ range_batchEnd_0 += range_nextBatchTodo_0 * 1L; /* 123 */ } /* 124 */ /* 125 */ int range_localEnd_0 = (int)((range_batchEnd_0 - range_nextIndex_0) / 1L); /* 126 */ for (int range_localIdx_0 = 0; range_localIdx_0 < range_localEnd_0; range_localIdx_0++) { /* 127 */ long range_value_0 = ((long)range_localIdx_0 * 1L) + range_nextIndex_0; /* 128 */ /* 129 */ // common sub-expressions /* 130 */ /* 131 */ bnlj_doConsume_0(range_value_0); /* 132 */ /* 133 */ if (shouldStop()) { /* 134 */ range_nextIndex_0 = range_value_0 + 1L; /* 135 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(range_localIdx_0 + 1); /* 136 */ range_inputMetrics_0.incRecordsRead(range_localIdx_0 + 1); /* 137 */ return; /* 138 */ } /* 139 */ /* 140 */ } /* 141 */ range_nextIndex_0 = range_batchEnd_0; /* 142 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(range_localEnd_0); /* 143 */ range_inputMetrics_0.incRecordsRead(range_localEnd_0); /* 144 */ range_taskContext_0.killTaskIfInterrupted(); /* 145 */ } /* 146 */ } /* 147 */ /* 148 */ } ``` ### Why are the changes needed? Improve query CPU performance. Added a micro benchmark query in `JoinBenchmark.scala`. Saw 1x of run time improvement: ``` OpenJDK 64-Bit Server VM 11.0.9+11-LTS on Linux 4.14.219-161.340.amzn2.x86_64 Intel(R) Xeon(R) CPU E5-2670 v2 2.50GHz broadcast nested loop join: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------- broadcast nested loop join wholestage off 62922 63052 184 0.3 3000.3 1.0X broadcast nested loop join wholestage on 30946 30972 26 0.7 1475.6 2.0X ``` ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? * Added unit test in `WholeStageCodegenSuite.scala`, and existing unit tests for `BroadcastNestedLoopJoinExec`. * Updated golden files for several TCPDS query plans, as whole stage code-gen for `BroadcastNestedLoopJoinExec` is triggered. * Updated `JoinBenchmark-jdk11-results.txt ` and `JoinBenchmark-results.txt` with new benchmark result. Followed previous benchmark PRs - #27078 and #26003 to use same type of machine: ``` Amazon AWS EC2 type: r3.xlarge region: us-west-2 (Oregon) OS: Linux ``` Closes #31736 from c21/nested-join-exec. Authored-by: Cheng Su <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 43b23fd commit b5b1985

36 files changed

+1557
-1378
lines changed

sql/core/benchmarks/JoinBenchmark-jdk11-results.txt

Lines changed: 39 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,74 +2,81 @@
22
Join Benchmark
33
================================================================================================
44

5-
OpenJDK 64-Bit Server VM 11.0.5+10-post-Ubuntu-0ubuntu1.118.04 on Linux 4.15.0-1044-aws
5+
OpenJDK 64-Bit Server VM 11.0.9+11-LTS on Linux 4.14.219-161.340.amzn2.x86_64
66
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
77
Join w long: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
88
------------------------------------------------------------------------------------------------------------------------
9-
Join w long wholestage off 4441 4572 185 4.7 211.8 1.0X
10-
Join w long wholestage on 1409 1500 96 14.9 67.2 3.2X
9+
Join w long wholestage off 3931 3998 95 5.3 187.4 1.0X
10+
Join w long wholestage on 1507 1769 178 13.9 71.9 2.6X
1111

12-
OpenJDK 64-Bit Server VM 11.0.5+10-post-Ubuntu-0ubuntu1.118.04 on Linux 4.15.0-1044-aws
12+
OpenJDK 64-Bit Server VM 11.0.9+11-LTS on Linux 4.14.219-161.340.amzn2.x86_64
1313
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
1414
Join w long duplicated: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
1515
------------------------------------------------------------------------------------------------------------------------
16-
Join w long duplicated wholestage off 5111 5116 7 4.1 243.7 1.0X
17-
Join w long duplicated wholestage on 1493 1518 22 14.0 71.2 3.4X
16+
Join w long duplicated wholestage off 5582 5617 50 3.8 266.2 1.0X
17+
Join w long duplicated wholestage on 1435 1451 19 14.6 68.4 3.9X
1818

19-
OpenJDK 64-Bit Server VM 11.0.5+10-post-Ubuntu-0ubuntu1.118.04 on Linux 4.15.0-1044-aws
19+
OpenJDK 64-Bit Server VM 11.0.9+11-LTS on Linux 4.14.219-161.340.amzn2.x86_64
2020
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
2121
Join w 2 ints: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
2222
------------------------------------------------------------------------------------------------------------------------
23-
Join w 2 ints wholestage off 171821 171906 121 0.1 8193.0 1.0X
24-
Join w 2 ints wholestage on 166559 166975 263 0.1 7942.1 1.0X
23+
Join w 2 ints wholestage off 171470 171478 11 0.1 8176.3 1.0X
24+
Join w 2 ints wholestage on 166612 166762 123 0.1 7944.7 1.0X
2525

26-
OpenJDK 64-Bit Server VM 11.0.5+10-post-Ubuntu-0ubuntu1.118.04 on Linux 4.15.0-1044-aws
26+
OpenJDK 64-Bit Server VM 11.0.9+11-LTS on Linux 4.14.219-161.340.amzn2.x86_64
2727
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
2828
Join w 2 longs: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
2929
------------------------------------------------------------------------------------------------------------------------
30-
Join w 2 longs wholestage off 7511 7555 62 2.8 358.2 1.0X
31-
Join w 2 longs wholestage on 3776 4119 232 5.6 180.1 2.0X
30+
Join w 2 longs wholestage off 6065 6093 40 3.5 289.2 1.0X
31+
Join w 2 longs wholestage on 3285 3375 97 6.4 156.7 1.8X
3232

33-
OpenJDK 64-Bit Server VM 11.0.5+10-post-Ubuntu-0ubuntu1.118.04 on Linux 4.15.0-1044-aws
33+
OpenJDK 64-Bit Server VM 11.0.9+11-LTS on Linux 4.14.219-161.340.amzn2.x86_64
3434
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
3535
Join w 2 longs duplicated: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
3636
------------------------------------------------------------------------------------------------------------------------
37-
Join w 2 longs duplicated wholestage off 13563 13617 77 1.5 646.7 1.0X
38-
Join w 2 longs duplicated wholestage on 7947 8053 71 2.6 378.9 1.7X
37+
Join w 2 longs duplicated wholestage off 14969 15027 82 1.4 713.8 1.0X
38+
Join w 2 longs duplicated wholestage on 7902 8151 406 2.7 376.8 1.9X
3939

40-
OpenJDK 64-Bit Server VM 11.0.5+10-post-Ubuntu-0ubuntu1.118.04 on Linux 4.15.0-1044-aws
40+
OpenJDK 64-Bit Server VM 11.0.9+11-LTS on Linux 4.14.219-161.340.amzn2.x86_64
4141
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
4242
outer join w long: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
4343
------------------------------------------------------------------------------------------------------------------------
44-
outer join w long wholestage off 3915 3923 12 5.4 186.7 1.0X
45-
outer join w long wholestage on 1421 1461 30 14.8 67.8 2.8X
44+
outer join w long wholestage off 2822 2823 1 7.4 134.6 1.0X
45+
outer join w long wholestage on 1419 1436 19 14.8 67.7 2.0X
4646

47-
OpenJDK 64-Bit Server VM 11.0.5+10-post-Ubuntu-0ubuntu1.118.04 on Linux 4.15.0-1044-aws
47+
OpenJDK 64-Bit Server VM 11.0.9+11-LTS on Linux 4.14.219-161.340.amzn2.x86_64
4848
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
4949
semi join w long: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
5050
------------------------------------------------------------------------------------------------------------------------
51-
semi join w long wholestage off 2310 2332 30 9.1 110.2 1.0X
52-
semi join w long wholestage on 835 860 34 25.1 39.8 2.8X
51+
semi join w long wholestage off 1821 1832 15 11.5 86.8 1.0X
52+
semi join w long wholestage on 828 853 36 25.3 39.5 2.2X
5353

54-
OpenJDK 64-Bit Server VM 11.0.5+10-post-Ubuntu-0ubuntu1.118.04 on Linux 4.15.0-1044-aws
54+
OpenJDK 64-Bit Server VM 11.0.9+11-LTS on Linux 4.14.219-161.340.amzn2.x86_64
5555
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
5656
sort merge join: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
5757
------------------------------------------------------------------------------------------------------------------------
58-
sort merge join wholestage off 1846 1886 56 1.1 880.5 1.0X
59-
sort merge join wholestage on 1402 1654 234 1.5 668.3 1.3X
58+
sort merge join wholestage off 1371 1380 13 1.5 653.7 1.0X
59+
sort merge join wholestage on 1197 1244 37 1.8 570.9 1.1X
6060

61-
OpenJDK 64-Bit Server VM 11.0.5+10-post-Ubuntu-0ubuntu1.118.04 on Linux 4.15.0-1044-aws
61+
OpenJDK 64-Bit Server VM 11.0.9+11-LTS on Linux 4.14.219-161.340.amzn2.x86_64
6262
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
63-
sort merge join with duplicates: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
64-
------------------------------------------------------------------------------------------------------------------------
65-
sort merge join with duplicates wholestage off 2852 2879 38 0.7 1360.0 1.0X
66-
sort merge join with duplicates wholestage on 2645 2742 156 0.8 1261.0 1.1X
63+
sort merge join with duplicates: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
64+
------------------------------------------------------------------------------------------------------------------------------
65+
sort merge join with duplicates wholestage off 1920 1933 20 1.1 915.3 1.0X
66+
sort merge join with duplicates wholestage on 1871 1912 27 1.1 892.0 1.0X
6767

68-
OpenJDK 64-Bit Server VM 11.0.5+10-post-Ubuntu-0ubuntu1.118.04 on Linux 4.15.0-1044-aws
68+
OpenJDK 64-Bit Server VM 11.0.9+11-LTS on Linux 4.14.219-161.340.amzn2.x86_64
6969
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
7070
shuffle hash join: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
7171
------------------------------------------------------------------------------------------------------------------------
72-
shuffle hash join wholestage off 1506 1564 82 2.8 359.1 1.0X
73-
shuffle hash join wholestage on 1303 1330 23 3.2 310.6 1.2X
72+
shuffle hash join wholestage off 1102 1122 28 3.8 262.8 1.0X
73+
shuffle hash join wholestage on 657 674 13 6.4 156.6 1.7X
74+
75+
OpenJDK 64-Bit Server VM 11.0.9+11-LTS on Linux 4.14.219-161.340.amzn2.x86_64
76+
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
77+
broadcast nested loop join: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
78+
-------------------------------------------------------------------------------------------------------------------------
79+
broadcast nested loop join wholestage off 62922 63052 184 0.3 3000.3 1.0X
80+
broadcast nested loop join wholestage on 30946 30972 26 0.7 1475.6 2.0X
7481

7582

0 commit comments

Comments
 (0)