Skip to content

Commit 497b0f4

Browse files
committed
2 parents 4a2e36d + 9eb49d4 commit 497b0f4

File tree

176 files changed

+1391
-162
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

176 files changed

+1391
-162
lines changed

core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,16 +78,18 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
7878
// greater than totalParts because we actually cap it at totalParts in runJob.
7979
var numPartsToTry = 1
8080
if (partsScanned > 0) {
81-
// If we didn't find any rows after the first iteration, just try all partitions next.
81+
// If we didn't find any rows after the previous iteration, quadruple and retry.
8282
// Otherwise, interpolate the number of partitions we need to try, but overestimate it
83-
// by 50%.
83+
// by 50%. We also cap the estimation in the end.
8484
if (results.size == 0) {
85-
numPartsToTry = totalParts - 1
85+
numPartsToTry = partsScanned * 4
8686
} else {
87-
numPartsToTry = (1.5 * num * partsScanned / results.size).toInt
87+
// the left side of max is >=1 whenever partsScanned >= 2
88+
numPartsToTry = Math.max(1,
89+
(1.5 * num * partsScanned / results.size).toInt - partsScanned)
90+
numPartsToTry = Math.min(numPartsToTry, partsScanned * 4)
8891
}
8992
}
90-
numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions
9193

9294
val left = num - results.size
9395
val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1079,15 +1079,17 @@ abstract class RDD[T: ClassTag](
10791079
// greater than totalParts because we actually cap it at totalParts in runJob.
10801080
var numPartsToTry = 1
10811081
if (partsScanned > 0) {
1082-
// If we didn't find any rows after the previous iteration, quadruple and retry. Otherwise,
1082+
// If we didn't find any rows after the previous iteration, quadruple and retry. Otherwise,
10831083
// interpolate the number of partitions we need to try, but overestimate it by 50%.
1084+
// We also cap the estimation in the end.
10841085
if (buf.size == 0) {
10851086
numPartsToTry = partsScanned * 4
10861087
} else {
1087-
numPartsToTry = (1.5 * num * partsScanned / buf.size).toInt
1088+
// the left side of max is >=1 whenever partsScanned >= 2
1089+
numPartsToTry = Math.max((1.5 * num * partsScanned / buf.size).toInt - partsScanned, 1)
1090+
numPartsToTry = Math.min(numPartsToTry, partsScanned * 4)
10881091
}
10891092
}
1090-
numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions
10911093

10921094
val left = num - buf.size
10931095
val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)

python/pyspark/rdd.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1070,10 +1070,13 @@ def take(self, num):
10701070
# If we didn't find any rows after the previous iteration,
10711071
# quadruple and retry. Otherwise, interpolate the number of
10721072
# partitions we need to try, but overestimate it by 50%.
1073+
# We also cap the estimation in the end.
10731074
if len(items) == 0:
10741075
numPartsToTry = partsScanned * 4
10751076
else:
1076-
numPartsToTry = int(1.5 * num * partsScanned / len(items))
1077+
# the first paramter of max is >=1 whenever partsScanned >= 2
1078+
numPartsToTry = int(1.5 * num * partsScanned / len(items)) - partsScanned
1079+
numPartsToTry = min(max(numPartsToTry, 1), partsScanned * 4)
10771080

10781081
left = num - len(items)
10791082

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst
1919

20-
import java.sql.Timestamp
20+
import java.sql.{Date, Timestamp}
2121

2222
import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference}
2323
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
@@ -77,8 +77,9 @@ object ScalaReflection {
7777
val Schema(valueDataType, valueNullable) = schemaFor(valueType)
7878
Schema(MapType(schemaFor(keyType).dataType,
7979
valueDataType, valueContainsNull = valueNullable), nullable = true)
80-
case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
80+
case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
8181
case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true)
82+
case t if t <:< typeOf[Date] => Schema(DateType, nullable = true)
8283
case t if t <:< typeOf[BigDecimal] => Schema(DecimalType, nullable = true)
8384
case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true)
8485
case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
6363
typeCoercionRules ++
6464
extendedRules : _*),
6565
Batch("Check Analysis", Once,
66-
CheckResolution),
66+
CheckResolution,
67+
CheckAggregation),
6768
Batch("AnalysisOperators", fixedPoint,
6869
EliminateAnalysisOperators)
6970
)
@@ -88,6 +89,32 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
8889
}
8990
}
9091

92+
/**
93+
* Checks for non-aggregated attributes with aggregation
94+
*/
95+
object CheckAggregation extends Rule[LogicalPlan] {
96+
def apply(plan: LogicalPlan): LogicalPlan = {
97+
plan.transform {
98+
case aggregatePlan @ Aggregate(groupingExprs, aggregateExprs, child) =>
99+
def isValidAggregateExpression(expr: Expression): Boolean = expr match {
100+
case _: AggregateExpression => true
101+
case e: Attribute => groupingExprs.contains(e)
102+
case e if groupingExprs.contains(e) => true
103+
case e if e.references.isEmpty => true
104+
case e => e.children.forall(isValidAggregateExpression)
105+
}
106+
107+
aggregateExprs.foreach { e =>
108+
if (!isValidAggregateExpression(e)) {
109+
throw new TreeNodeException(plan, s"Expression not in GROUP BY: $e")
110+
}
111+
}
112+
113+
aggregatePlan
114+
}
115+
}
116+
}
117+
91118
/**
92119
* Replaces [[UnresolvedRelation]]s with concrete relations from the catalog.
93120
*/
@@ -204,18 +231,17 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
204231
*/
205232
object UnresolvedHavingClauseAttributes extends Rule[LogicalPlan] {
206233
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
207-
case filter @ Filter(havingCondition, aggregate @ Aggregate(_, originalAggExprs, _))
234+
case filter @ Filter(havingCondition, aggregate @ Aggregate(_, originalAggExprs, _))
208235
if aggregate.resolved && containsAggregate(havingCondition) => {
209236
val evaluatedCondition = Alias(havingCondition, "havingCondition")()
210237
val aggExprsWithHaving = evaluatedCondition +: originalAggExprs
211-
238+
212239
Project(aggregate.output,
213240
Filter(evaluatedCondition.toAttribute,
214241
aggregate.copy(aggregateExpressions = aggExprsWithHaving)))
215242
}
216-
217243
}
218-
244+
219245
protected def containsAggregate(condition: Expression): Boolean =
220246
condition
221247
.collect { case ae: AggregateExpression => ae }

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -220,20 +220,39 @@ trait HiveTypeCoercion {
220220
case a: BinaryArithmetic if a.right.dataType == StringType =>
221221
a.makeCopy(Array(a.left, Cast(a.right, DoubleType)))
222222

223+
// we should cast all timestamp/date/string compare into string compare
224+
case p: BinaryPredicate if p.left.dataType == StringType
225+
&& p.right.dataType == DateType =>
226+
p.makeCopy(Array(p.left, Cast(p.right, StringType)))
227+
case p: BinaryPredicate if p.left.dataType == DateType
228+
&& p.right.dataType == StringType =>
229+
p.makeCopy(Array(Cast(p.left, StringType), p.right))
223230
case p: BinaryPredicate if p.left.dataType == StringType
224231
&& p.right.dataType == TimestampType =>
225-
p.makeCopy(Array(Cast(p.left, TimestampType), p.right))
232+
p.makeCopy(Array(p.left, Cast(p.right, StringType)))
226233
case p: BinaryPredicate if p.left.dataType == TimestampType
227234
&& p.right.dataType == StringType =>
228-
p.makeCopy(Array(p.left, Cast(p.right, TimestampType)))
235+
p.makeCopy(Array(Cast(p.left, StringType), p.right))
236+
case p: BinaryPredicate if p.left.dataType == TimestampType
237+
&& p.right.dataType == DateType =>
238+
p.makeCopy(Array(Cast(p.left, StringType), Cast(p.right, StringType)))
239+
case p: BinaryPredicate if p.left.dataType == DateType
240+
&& p.right.dataType == TimestampType =>
241+
p.makeCopy(Array(Cast(p.left, StringType), Cast(p.right, StringType)))
229242

230243
case p: BinaryPredicate if p.left.dataType == StringType && p.right.dataType != StringType =>
231244
p.makeCopy(Array(Cast(p.left, DoubleType), p.right))
232245
case p: BinaryPredicate if p.left.dataType != StringType && p.right.dataType == StringType =>
233246
p.makeCopy(Array(p.left, Cast(p.right, DoubleType)))
234247

235-
case i @ In(a,b) if a.dataType == TimestampType && b.forall(_.dataType == StringType) =>
236-
i.makeCopy(Array(a,b.map(Cast(_,TimestampType))))
248+
case i @ In(a, b) if a.dataType == DateType && b.forall(_.dataType == StringType) =>
249+
i.makeCopy(Array(Cast(a, StringType), b))
250+
case i @ In(a, b) if a.dataType == TimestampType && b.forall(_.dataType == StringType) =>
251+
i.makeCopy(Array(Cast(a, StringType), b))
252+
case i @ In(a, b) if a.dataType == DateType && b.forall(_.dataType == TimestampType) =>
253+
i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType))))
254+
case i @ In(a, b) if a.dataType == TimestampType && b.forall(_.dataType == DateType) =>
255+
i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType))))
237256

238257
case Sum(e) if e.dataType == StringType =>
239258
Sum(Cast(e, DoubleType))
@@ -283,6 +302,8 @@ trait HiveTypeCoercion {
283302
// Skip if the type is boolean type already. Note that this extra cast should be removed
284303
// by optimizer.SimplifyCasts.
285304
case Cast(e, BooleanType) if e.dataType == BooleanType => e
305+
// DateType should be null if be cast to boolean.
306+
case Cast(e, BooleanType) if e.dataType == DateType => Cast(e, BooleanType)
286307
// If the data type is not boolean and is being cast boolean, turn it into a comparison
287308
// with the numeric value, i.e. x != 0. This will coerce the type into numeric type.
288309
case Cast(e, BooleanType) if e.dataType != BooleanType => Not(EqualTo(e, Literal(0)))

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ case class Star(
8888
mapFunction: Attribute => Expression = identity[Attribute])
8989
extends Attribute with trees.LeafNode[Expression] {
9090

91-
override def name = throw new UnresolvedException(this, "exprId")
91+
override def name = throw new UnresolvedException(this, "name")
9292
override def exprId = throw new UnresolvedException(this, "exprId")
9393
override def dataType = throw new UnresolvedException(this, "dataType")
9494
override def nullable = throw new UnresolvedException(this, "nullable")

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst
1919

20-
import java.sql.Timestamp
20+
import java.sql.{Date, Timestamp}
2121

2222
import scala.language.implicitConversions
2323

@@ -119,6 +119,7 @@ package object dsl {
119119
implicit def floatToLiteral(f: Float) = Literal(f)
120120
implicit def doubleToLiteral(d: Double) = Literal(d)
121121
implicit def stringToLiteral(s: String) = Literal(s)
122+
implicit def dateToLiteral(d: Date) = Literal(d)
122123
implicit def decimalToLiteral(d: BigDecimal) = Literal(d)
123124
implicit def timestampToLiteral(t: Timestamp) = Literal(t)
124125
implicit def binaryToLiteral(a: Array[Byte]) = Literal(a)
@@ -174,6 +175,9 @@ package object dsl {
174175
/** Creates a new AttributeReference of type string */
175176
def string = AttributeReference(s, StringType, nullable = true)()
176177

178+
/** Creates a new AttributeReference of type date */
179+
def date = AttributeReference(s, DateType, nullable = true)()
180+
177181
/** Creates a new AttributeReference of type decimal */
178182
def decimal = AttributeReference(s, DecimalType, nullable = true)()
179183

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,26 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20+
import org.apache.spark.sql.catalyst.analysis.Star
21+
2022
protected class AttributeEquals(val a: Attribute) {
2123
override def hashCode() = a.exprId.hashCode()
22-
override def equals(other: Any) = other match {
23-
case otherReference: AttributeEquals => a.exprId == otherReference.a.exprId
24-
case otherAttribute => false
24+
override def equals(other: Any) = (a, other.asInstanceOf[AttributeEquals].a) match {
25+
case (a1: AttributeReference, a2: AttributeReference) => a1.exprId == a2.exprId
26+
case (a1, a2) => a1 == a2
2527
}
2628
}
2729

2830
object AttributeSet {
29-
/** Constructs a new [[AttributeSet]] given a sequence of [[Attribute Attributes]]. */
30-
def apply(baseSet: Seq[Attribute]) = {
31-
new AttributeSet(baseSet.map(new AttributeEquals(_)).toSet)
32-
}
31+
def apply(a: Attribute) =
32+
new AttributeSet(Set(new AttributeEquals(a)))
33+
34+
/** Constructs a new [[AttributeSet]] given a sequence of [[Expression Expressions]]. */
35+
def apply(baseSet: Seq[Expression]) =
36+
new AttributeSet(
37+
baseSet
38+
.flatMap(_.references)
39+
.map(new AttributeEquals(_)).toSet)
3340
}
3441

3542
/**
@@ -103,4 +110,6 @@ class AttributeSet private (val baseSet: Set[AttributeEquals])
103110
// We must force toSeq to not be strict otherwise we end up with a [[Stream]] that captures all
104111
// sorts of things in its closure.
105112
override def toSeq: Seq[Attribute] = baseSet.map(_.a).toArray.toSeq
113+
114+
override def toString = "{" + baseSet.map(_.a).mkString(", ") + "}"
106115
}

0 commit comments

Comments
 (0)