Skip to content

Commit 4220f1e

Browse files
committed
Better config, docs, etc.
1 parent ca6cc6b commit 4220f1e

File tree

8 files changed

+88
-38
lines changed

8 files changed

+88
-38
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ package org.apache.spark.sql.catalyst.expressions
2222
* new row. If the schema of the input row is specified, then the given expression will be bound to
2323
* that schema.
2424
*/
25-
class InterpretedProjection(expressions: Seq[Expression]) extends (Row => Row) {
25+
class InterpretedProjection(expressions: Seq[Expression]) extends Projection {
2626
def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
2727
this(expressions.map(BindReferences.bindReference(_, inputSchema)))
2828

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ import org.apache.spark.sql.catalyst.types.BooleanType
2323

2424

2525
object InterpretedPredicate {
26+
def apply(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) =
27+
apply(BindReferences.bindReference(expression, inputSchema))
28+
2629
def apply(expression: Expression): (Row => Boolean) = {
2730
(r: Row) => expression.eval(r).asInstanceOf[Boolean]
2831
}

sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,18 @@ trait SQLConf {
3535
/** Number of partitions to use for shuffle operators. */
3636
private[spark] def numShufflePartitions: Int = get("spark.sql.shuffle.partitions", "200").toInt
3737

38+
/**
39+
* When set to true, Spark SQL will use the scala compiler at runtime to generate custom bytecode
40+
* that evaluates expressions found in queries. In general this custom code runs much faster
41+
* than interpreted evaluation, but there are significant start-up costs due to compilation.
42+
* As a result codegen is only benificial when queries run for a long time, or when the same
43+
* expressions are used multiple times.
44+
*
45+
* Defaults to false as this feature is currently experimental.
46+
*/
47+
private[spark] def codegenEnabled: Boolean =
48+
if (get("spark.sql.codegen", "true") == "true") true else false
49+
3850
/**
3951
* Upper bound on the sizes (in bytes) of the tables qualified for the auto conversion to
4052
* a broadcast value during the physical executions of join operations. Setting this to 0

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,8 @@ class SQLContext(@transient val sparkContext: SparkContext)
234234

235235
val sqlContext: SQLContext = self
236236

237+
def codegenEnabled = self.codegenEnabled
238+
237239
def numPartitions = self.numShufflePartitions
238240

239241
val strategies: Seq[Strategy] =

sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ case class Generate(
5050
override def output =
5151
if (join) child.output ++ generatorOutput else generatorOutput
5252

53+
/** Codegenned rows are not serializable... */
54+
override val codegenEnabled = false
55+
5356
override def execute() = {
5457
if (join) {
5558
child.execute().mapPartitions { iter =>

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution
1919

2020
import org.apache.spark.annotation.DeveloperApi
2121
import org.apache.spark.rdd.RDD
22-
import org.apache.spark.sql.{Logging, Row}
22+
import org.apache.spark.sql.{SQLContext, Logging, Row}
2323
import org.apache.spark.sql.catalyst.trees
2424
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
2525
import org.apache.spark.sql.catalyst.expressions._
@@ -35,6 +35,8 @@ import org.apache.spark.sql.catalyst.plans.physical._
3535
abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging {
3636
self: Product =>
3737

38+
val codegenEnabled = true
39+
3840
// TODO: Move to `DistributedPlan`
3941
/** Specifies how data is partitioned across different nodes in the cluster. */
4042
def outputPartitioning: Partitioning = UnknownPartitioning(0) // TODO: WRONG WIDTH!
@@ -53,17 +55,29 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging {
5355
def executeCollect(): Array[Row] = execute().map(_.copy()).collect()
5456

5557
def newProjection(expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection =
56-
GenerateProjection(expressions, inputSchema)
58+
if (codegenEnabled) {
59+
GenerateProjection(expressions, inputSchema)
60+
} else {
61+
new InterpretedProjection(expressions, inputSchema)
62+
}
5763

5864
def newMutableProjection(
5965
expressions: Seq[Expression],
6066
inputSchema: Seq[Attribute]): () => MutableProjection = {
61-
GenerateMutableProjection(expressions, inputSchema)
67+
if(codegenEnabled) {
68+
GenerateMutableProjection(expressions, inputSchema)
69+
} else {
70+
() => new InterpretedMutableProjection(expressions, inputSchema)
71+
}
6272
}
6373

6474

6575
def newPredicate(expression: Expression, inputSchema: Seq[Attribute]): (Row) => Boolean = {
66-
GeneratePredicate(expression, inputSchema)
76+
if (codegenEnabled) {
77+
GeneratePredicate(expression, inputSchema)
78+
} else {
79+
InterpretedPredicate(expression, inputSchema)
80+
}
6781
}
6882
}
6983

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.execution
1919

20-
import org.apache.spark.sql.{SQLContext, execution}
20+
import org.apache.spark.sql.{SQLConf, SQLContext, execution}
2121
import org.apache.spark.sql.catalyst.expressions._
2222
import org.apache.spark.sql.catalyst.planning._
2323
import org.apache.spark.sql.catalyst.plans._
@@ -108,7 +108,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
108108
child)
109109
if canBeCodeGened(
110110
allAggregates(partialComputation) ++
111-
allAggregates(rewrittenAggregateExpressions))=>
111+
allAggregates(rewrittenAggregateExpressions)) &&
112+
codegenEnabled =>
112113
execution.GeneratedAggregate(
113114
partial = false,
114115
namedGroupingAttributes,
@@ -119,7 +120,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
119120
partialComputation,
120121
planLater(child))(sqlContext))(sqlContext) :: Nil
121122

122-
123123
// Cases where some aggregate can not be codegened
124124
case PartialAggregation(
125125
namedGroupingAttributes,

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala

Lines changed: 46 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,8 @@
1717

1818
package org.apache.spark.sql.execution
1919

20-
import org.apache.spark.SparkContext
2120
import org.apache.spark.annotation.DeveloperApi
2221
import org.apache.spark.sql.SQLContext
23-
import org.apache.spark.sql.catalyst.errors._
2422
import org.apache.spark.sql.catalyst.expressions._
2523
import org.apache.spark.sql.catalyst.plans.physical._
2624
import org.apache.spark.sql.catalyst.types._
@@ -51,8 +49,6 @@ case class GeneratedAggregate(
5149
child: SparkPlan)(@transient sqlContext: SQLContext)
5250
extends UnaryNode with NoBind {
5351

54-
private def sc = sqlContext.sparkContext
55-
5652
override def requiredChildDistribution =
5753
if (partial) {
5854
UnspecifiedDistribution :: Nil
@@ -66,24 +62,24 @@ case class GeneratedAggregate(
6662

6763
override def otherCopyArgs = sqlContext :: Nil
6864

69-
def output = aggregateExpressions.map(_.toAttribute)
65+
override def output = aggregateExpressions.map(_.toAttribute)
7066

71-
def execute() = {
67+
override def execute() = {
7268
val aggregatesToCompute = aggregateExpressions.flatMap { a =>
7369
a.collect { case agg: AggregateExpression => agg}
7470
}
7571

7672
val computeFunctions = aggregatesToCompute.map {
77-
case c@Count(expr) =>
78-
val currentCount = AttributeReference("currentCount", LongType, true)()
73+
case c @ Count(expr) =>
74+
val currentCount = AttributeReference("currentCount", LongType, nullable = false)()
7975
val initialValue = Literal(0L)
8076
val updateFunction = If(IsNotNull(expr), Add(currentCount, Literal(1L)), currentCount)
8177
val result = currentCount
8278

8379
AggregateEvaluation(currentCount :: Nil, initialValue :: Nil, updateFunction :: Nil, result)
8480

8581
case Sum(expr) =>
86-
val currentSum = AttributeReference("currentSum", expr.dataType, true)()
82+
val currentSum = AttributeReference("currentSum", expr.dataType, nullable = false)()
8783
val initialValue = Cast(Literal(0L), expr.dataType)
8884

8985
// Coalasce avoids double calculation...
@@ -93,9 +89,9 @@ case class GeneratedAggregate(
9389

9490
AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result)
9591

96-
case a@Average(expr) =>
97-
val currentCount = AttributeReference("currentCount", LongType, true)()
98-
val currentSum = AttributeReference("currentSum", expr.dataType, true)()
92+
case a @ Average(expr) =>
93+
val currentCount = AttributeReference("currentCount", LongType, nullable = false)()
94+
val currentSum = AttributeReference("currentSum", expr.dataType, nullable = false)()
9995
val initialCount = Literal(0L)
10096
val initialSum = Cast(Literal(0L), expr.dataType)
10197
val updateCount = If(IsNotNull(expr), Add(currentCount, Literal(1L)), currentCount)
@@ -131,50 +127,70 @@ case class GeneratedAggregate(
131127

132128
child.execute().mapPartitions { iter =>
133129
// Builds a new custom class for holding the results of aggregation for a group.
130+
@transient
134131
val newAggregationBuffer =
135132
newProjection(computeFunctions.flatMap(_.initialValues), child.output)
136133

137134
// A projection that is used to update the aggregate values for a group given a new tuple.
138135
// This projection should be targeted at the current values for the group and then applied
139136
// to a joined row of the current values with the new input row.
137+
@transient
140138
val updateProjection =
141139
newMutableProjection(
142140
computeFunctions.flatMap(_.update),
143141
computeFunctions.flatMap(_.schema) ++ child.output)()
144142

145143
// A projection that computes the group given an input tuple.
144+
@transient
146145
val groupProjection = newProjection(groupingExpressions, child.output)
147146

148147
// A projection that produces the final result, given a computation.
148+
@transient
149149
val resultProjectionBuilder =
150150
newMutableProjection(
151151
resultExpressions,
152152
(namedGroups.map(_._2.toAttribute) ++ computationSchema).toSeq)
153153

154-
val buffers = new java.util.HashMap[Row, MutableRow]()
155154
val joinedRow = new JoinedRow
156155

157-
var currentRow: Row = null
158-
while (iter.hasNext) {
159-
currentRow = iter.next()
160-
val currentGroup = groupProjection(currentRow)
161-
var currentBuffer = buffers.get(currentGroup)
162-
if (currentBuffer == null) {
163-
currentBuffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow]
164-
buffers.put(currentGroup, currentBuffer)
156+
if (groupingExpressions.isEmpty) {
157+
// TODO: Codegening anything other than the updateProjection is probably over kill.
158+
val buffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow]
159+
var currentRow: Row = null
160+
while (iter.hasNext) {
161+
currentRow = iter.next()
162+
updateProjection.target(buffer)(joinedRow(buffer, currentRow))
163+
}
164+
165+
val resultProjection = resultProjectionBuilder()
166+
Iterator(resultProjection(buffer))
167+
} else {
168+
val buffers = new java.util.HashMap[Row, MutableRow]()
169+
170+
var currentRow: Row = null
171+
while (iter.hasNext) {
172+
currentRow = iter.next()
173+
val currentGroup = groupProjection(currentRow)
174+
var currentBuffer = buffers.get(currentGroup)
175+
if (currentBuffer == null) {
176+
currentBuffer = newAggregationBuffer(EmptyRow).asInstanceOf[MutableRow]
177+
buffers.put(currentGroup, currentBuffer)
178+
}
179+
// Target the projection at the current aggregation buffer and then project the updated
180+
// values.
181+
updateProjection.target(currentBuffer)(joinedRow(currentBuffer, currentRow))
165182
}
166-
updateProjection.target(currentBuffer)(joinedRow(currentBuffer, currentRow))
167-
}
168183

169-
new Iterator[Row] {
170-
private[this] val resultIterator = buffers.entrySet.iterator()
171-
private[this] val resultProjection = resultProjectionBuilder()
184+
new Iterator[Row] {
185+
private[this] val resultIterator = buffers.entrySet.iterator()
186+
private[this] val resultProjection = resultProjectionBuilder()
172187

173-
def hasNext = resultIterator.hasNext
188+
def hasNext = resultIterator.hasNext
174189

175-
def next() = {
176-
val currentGroup = resultIterator.next()
177-
resultProjection(joinedRow(currentGroup.getKey, currentGroup.getValue))
190+
def next() = {
191+
val currentGroup = resultIterator.next()
192+
resultProjection(joinedRow(currentGroup.getKey, currentGroup.getValue))
193+
}
178194
}
179195
}
180196
}

0 commit comments

Comments
 (0)