Skip to content

Commit 6084e9c

Browse files
committed
Resolv conflict
2 parents d2aa2a0 + 085a721 commit 6084e9c

File tree

62 files changed

+2583
-952
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+2583
-952
lines changed

R/pkg/R/client.R

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,24 +34,36 @@ connectBackend <- function(hostname, port, timeout = 6000) {
3434
con
3535
}
3636

37-
launchBackend <- function(args, sparkHome, jars, sparkSubmitOpts) {
37+
determineSparkSubmitBin <- function() {
3838
if (.Platform$OS.type == "unix") {
3939
sparkSubmitBinName = "spark-submit"
4040
} else {
4141
sparkSubmitBinName = "spark-submit.cmd"
4242
}
43+
sparkSubmitBinName
44+
}
45+
46+
generateSparkSubmitArgs <- function(args, sparkHome, jars, sparkSubmitOpts, packages) {
47+
if (jars != "") {
48+
jars <- paste("--jars", jars)
49+
}
50+
51+
if (packages != "") {
52+
packages <- paste("--packages", packages)
53+
}
4354

55+
combinedArgs <- paste(jars, packages, sparkSubmitOpts, args, sep = " ")
56+
combinedArgs
57+
}
58+
59+
launchBackend <- function(args, sparkHome, jars, sparkSubmitOpts, packages) {
60+
sparkSubmitBin <- determineSparkSubmitBin()
4461
if (sparkHome != "") {
4562
sparkSubmitBin <- file.path(sparkHome, "bin", sparkSubmitBinName)
4663
} else {
4764
sparkSubmitBin <- sparkSubmitBinName
4865
}
49-
50-
if (jars != "") {
51-
jars <- paste("--jars", jars)
52-
}
53-
54-
combinedArgs <- paste(jars, sparkSubmitOpts, args, sep = " ")
66+
combinedArgs <- generateSparkSubmitArgs(args, sparkHome, jars, sparkSubmitOpts, packages)
5567
cat("Launching java with spark-submit command", sparkSubmitBin, combinedArgs, "\n")
5668
invisible(system2(sparkSubmitBin, combinedArgs, wait = F))
5769
}

R/pkg/R/sparkR.R

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ sparkR.stop <- function() {
8181
#' @param sparkExecutorEnv Named list of environment variables to be used when launching executors.
8282
#' @param sparkJars Character string vector of jar files to pass to the worker nodes.
8383
#' @param sparkRLibDir The path where R is installed on the worker nodes.
84+
#' @param sparkPackages Character string vector of packages from spark-packages.org
8485
#' @export
8586
#' @examples
8687
#'\dontrun{
@@ -100,7 +101,8 @@ sparkR.init <- function(
100101
sparkEnvir = list(),
101102
sparkExecutorEnv = list(),
102103
sparkJars = "",
103-
sparkRLibDir = "") {
104+
sparkRLibDir = "",
105+
sparkPackages = "") {
104106

105107
if (exists(".sparkRjsc", envir = .sparkREnv)) {
106108
cat("Re-using existing Spark Context. Please stop SparkR with sparkR.stop() or restart R to create a new Spark Context\n")
@@ -129,7 +131,8 @@ sparkR.init <- function(
129131
args = path,
130132
sparkHome = sparkHome,
131133
jars = jars,
132-
sparkSubmitOpts = Sys.getenv("SPARKR_SUBMIT_ARGS", "sparkr-shell"))
134+
sparkSubmitOpts = Sys.getenv("SPARKR_SUBMIT_ARGS", "sparkr-shell"),
135+
sparkPackages = sparkPackages)
133136
# wait atmost 100 seconds for JVM to launch
134137
wait <- 0.1
135138
for (i in 1:25) {

R/pkg/inst/tests/test_client.R

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
context("functions in client.R")
19+
20+
test_that("adding spark-testing-base as a package works", {
21+
args <- generateSparkSubmitArgs("", "", "", "",
22+
"holdenk:spark-testing-base:1.3.0_0.0.5")
23+
expect_equal(gsub("[[:space:]]", "", args),
24+
gsub("[[:space:]]", "",
25+
"--packages holdenk:spark-testing-base:1.3.0_0.0.5"))
26+
})
27+
28+
test_that("no package specified doesn't add packages flag", {
29+
args <- generateSparkSubmitArgs("", "", "", "", "")
30+
expect_equal(gsub("[[:space:]]", "", args),
31+
"")
32+
})

core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala

Lines changed: 23 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -17,29 +17,29 @@
1717

1818
package org.apache.spark.shuffle.hash
1919

20-
import scala.collection.mutable.ArrayBuffer
21-
import scala.collection.mutable.HashMap
22-
import scala.util.{Failure, Success, Try}
20+
import java.io.InputStream
21+
22+
import scala.collection.mutable.{ArrayBuffer, HashMap}
23+
import scala.util.{Failure, Success}
2324

2425
import org.apache.spark._
25-
import org.apache.spark.serializer.Serializer
2626
import org.apache.spark.shuffle.FetchFailedException
27-
import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId}
28-
import org.apache.spark.util.CompletionIterator
27+
import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockFetcherIterator,
28+
ShuffleBlockId}
2929

3030
private[hash] object BlockStoreShuffleFetcher extends Logging {
31-
def fetch[T](
31+
def fetchBlockStreams(
3232
shuffleId: Int,
3333
reduceId: Int,
3434
context: TaskContext,
35-
serializer: Serializer)
36-
: Iterator[T] =
35+
blockManager: BlockManager,
36+
mapOutputTracker: MapOutputTracker)
37+
: Iterator[(BlockId, InputStream)] =
3738
{
3839
logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
39-
val blockManager = SparkEnv.get.blockManager
4040

4141
val startTime = System.currentTimeMillis
42-
val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId)
42+
val statuses = mapOutputTracker.getServerStatuses(shuffleId, reduceId)
4343
logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format(
4444
shuffleId, reduceId, System.currentTimeMillis - startTime))
4545

@@ -53,12 +53,21 @@ private[hash] object BlockStoreShuffleFetcher extends Logging {
5353
(address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2)))
5454
}
5555

56-
def unpackBlock(blockPair: (BlockId, Try[Iterator[Any]])) : Iterator[T] = {
56+
val blockFetcherItr = new ShuffleBlockFetcherIterator(
57+
context,
58+
blockManager.shuffleClient,
59+
blockManager,
60+
blocksByAddress,
61+
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
62+
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)
63+
64+
// Make sure that fetch failures are wrapped inside a FetchFailedException for the scheduler
65+
blockFetcherItr.map { blockPair =>
5766
val blockId = blockPair._1
5867
val blockOption = blockPair._2
5968
blockOption match {
60-
case Success(block) => {
61-
block.asInstanceOf[Iterator[T]]
69+
case Success(inputStream) => {
70+
(blockId, inputStream)
6271
}
6372
case Failure(e) => {
6473
blockId match {
@@ -72,27 +81,5 @@ private[hash] object BlockStoreShuffleFetcher extends Logging {
7281
}
7382
}
7483
}
75-
76-
val blockFetcherItr = new ShuffleBlockFetcherIterator(
77-
context,
78-
SparkEnv.get.blockManager.shuffleClient,
79-
blockManager,
80-
blocksByAddress,
81-
serializer,
82-
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
83-
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)
84-
val itr = blockFetcherItr.flatMap(unpackBlock)
85-
86-
val completionIter = CompletionIterator[T, Iterator[T]](itr, {
87-
context.taskMetrics.updateShuffleReadMetrics()
88-
})
89-
90-
new InterruptibleIterator[T](context, completionIter) {
91-
val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
92-
override def next(): T = {
93-
readMetrics.incRecordsRead(1)
94-
delegate.next()
95-
}
96-
}
9784
}
9885
}

core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,20 @@
1717

1818
package org.apache.spark.shuffle.hash
1919

20-
import org.apache.spark.{InterruptibleIterator, TaskContext}
20+
import org.apache.spark.{InterruptibleIterator, MapOutputTracker, SparkEnv, TaskContext}
2121
import org.apache.spark.serializer.Serializer
2222
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
23+
import org.apache.spark.storage.BlockManager
24+
import org.apache.spark.util.CompletionIterator
2325
import org.apache.spark.util.collection.ExternalSorter
2426

2527
private[spark] class HashShuffleReader[K, C](
2628
handle: BaseShuffleHandle[K, _, C],
2729
startPartition: Int,
2830
endPartition: Int,
29-
context: TaskContext)
31+
context: TaskContext,
32+
blockManager: BlockManager = SparkEnv.get.blockManager,
33+
mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker)
3034
extends ShuffleReader[K, C]
3135
{
3236
require(endPartition == startPartition + 1,
@@ -36,20 +40,52 @@ private[spark] class HashShuffleReader[K, C](
3640

3741
/** Read the combined key-values for this reduce task */
3842
override def read(): Iterator[Product2[K, C]] = {
43+
val blockStreams = BlockStoreShuffleFetcher.fetchBlockStreams(
44+
handle.shuffleId, startPartition, context, blockManager, mapOutputTracker)
45+
46+
// Wrap the streams for compression based on configuration
47+
val wrappedStreams = blockStreams.map { case (blockId, inputStream) =>
48+
blockManager.wrapForCompression(blockId, inputStream)
49+
}
50+
3951
val ser = Serializer.getSerializer(dep.serializer)
40-
val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser)
52+
val serializerInstance = ser.newInstance()
53+
54+
// Create a key/value iterator for each stream
55+
val recordIter = wrappedStreams.flatMap { wrappedStream =>
56+
// Note: the asKeyValueIterator below wraps a key/value iterator inside of a
57+
// NextIterator. The NextIterator makes sure that close() is called on the
58+
// underlying InputStream when all records have been read.
59+
serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
60+
}
61+
62+
// Update the context task metrics for each record read.
63+
val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
64+
val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
65+
recordIter.map(record => {
66+
readMetrics.incRecordsRead(1)
67+
record
68+
}),
69+
context.taskMetrics().updateShuffleReadMetrics())
70+
71+
// An interruptible iterator must be used here in order to support task cancellation
72+
val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
4173

4274
val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
4375
if (dep.mapSideCombine) {
44-
new InterruptibleIterator(context, dep.aggregator.get.combineCombinersByKey(iter, context))
76+
// We are reading values that are already combined
77+
val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
78+
dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
4579
} else {
46-
new InterruptibleIterator(context, dep.aggregator.get.combineValuesByKey(iter, context))
80+
// We don't know the value type, but also don't care -- the dependency *should*
81+
// have made sure its compatible w/ this aggregator, which will convert the value
82+
// type to the combined type C
83+
val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
84+
dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
4785
}
4886
} else {
4987
require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
50-
51-
// Convert the Product2s to pairs since this is what downstream RDDs currently expect
52-
iter.asInstanceOf[Iterator[Product2[K, C]]].map(pair => (pair._1, pair._2))
88+
interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
5389
}
5490

5591
// Sort the output if there is a sort ordering defined.

0 commit comments

Comments
 (0)