Skip to content

Commit 20b97a6

Browse files
concretevitaminDavies Liu
authored andcommitted
Merge pull request #234 from hqzizania/assist
[SPARKR-163] Support sampleByKey() Conflicts: pkg/R/pairRDD.R
1 parent ba54e34 commit 20b97a6

File tree

4 files changed

+177
-61
lines changed

4 files changed

+177
-61
lines changed

R/pkg/NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ exportMethods(
5454
"repartition",
5555
"rightOuterJoin",
5656
"sampleRDD",
57+
"sampleByKey",
5758
"saveAsTextFile",
5859
"saveAsObjectFile",
5960
"sortBy",

R/pkg/R/generics.R

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,12 @@ setGeneric("mapValues", function(X, FUN) { standardGeneric("mapValues") })
262262
#' @export
263263
setGeneric("values", function(x) { standardGeneric("values") })
264264

265+
#' @rdname sampleByKey
266+
#' @export
267+
setGeneric("sampleByKey",
268+
function(x, withReplacement, fractions, seed) {
269+
standardGeneric("sampleByKey")
270+
})
265271

266272

267273
############ Shuffle Functions ############

R/pkg/R/pairRDD.R

Lines changed: 119 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -450,19 +450,19 @@ setMethod("combineByKey",
450450
})
451451

452452
#' Aggregate a pair RDD by each key.
453-
#'
453+
#'
454454
#' Aggregate the values of each key in an RDD, using given combine functions
455455
#' and a neutral "zero value". This function can return a different result type,
456456
#' U, than the type of the values in this RDD, V. Thus, we need one operation
457-
#' for merging a V into a U and one operation for merging two U's, The former
458-
#' operation is used for merging values within a partition, and the latter is
459-
#' used for merging values between partitions. To avoid memory allocation, both
460-
#' of these functions are allowed to modify and return their first argument
457+
#' for merging a V into a U and one operation for merging two U's, The former
458+
#' operation is used for merging values within a partition, and the latter is
459+
#' used for merging values between partitions. To avoid memory allocation, both
460+
#' of these functions are allowed to modify and return their first argument
461461
#' instead of creating a new U.
462-
#'
462+
#'
463463
#' @param x An RDD.
464464
#' @param zeroValue A neutral "zero value".
465-
#' @param seqOp A function to aggregate the values of each key. It may return
465+
#' @param seqOp A function to aggregate the values of each key. It may return
466466
#' a different result type from the type of the values.
467467
#' @param combOp A function to aggregate results of seqOp.
468468
#' @return An RDD containing the aggregation result.
@@ -474,7 +474,7 @@ setMethod("combineByKey",
474474
#' zeroValue <- list(0, 0)
475475
#' seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) }
476476
#' combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) }
477-
#' aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L)
477+
#' aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L)
478478
#' # list(list(1, list(3, 2)), list(2, list(7, 2)))
479479
#'}
480480
#' @rdname aggregateByKey
@@ -491,12 +491,12 @@ setMethod("aggregateByKey",
491491
})
492492

493493
#' Fold a pair RDD by each key.
494-
#'
494+
#'
495495
#' Aggregate the values of each key in an RDD, using an associative function "func"
496-
#' and a neutral "zero value" which may be added to the result an arbitrary
497-
#' number of times, and must not change the result (e.g., 0 for addition, or
496+
#' and a neutral "zero value" which may be added to the result an arbitrary
497+
#' number of times, and must not change the result (e.g., 0 for addition, or
498498
#' 1 for multiplication.).
499-
#'
499+
#'
500500
#' @param x An RDD.
501501
#' @param zeroValue A neutral "zero value".
502502
#' @param func An associative function for folding values of each key.
@@ -546,11 +546,11 @@ setMethod("join",
546546
function(x, y, numPartitions) {
547547
xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) })
548548
yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) })
549-
549+
550550
doJoin <- function(v) {
551551
joinTaggedList(v, list(FALSE, FALSE))
552552
}
553-
553+
554554
joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numToInt(numPartitions)),
555555
doJoin)
556556
})
@@ -566,8 +566,8 @@ setMethod("join",
566566
#' @param y An RDD to be joined. Should be an RDD where each element is
567567
#' list(K, V).
568568
#' @param numPartitions Number of partitions to create.
569-
#' @return For each element (k, v) in x, the resulting RDD will either contain
570-
#' all pairs (k, (v, w)) for (k, w) in rdd2, or the pair (k, (v, NULL))
569+
#' @return For each element (k, v) in x, the resulting RDD will either contain
570+
#' all pairs (k, (v, w)) for (k, w) in rdd2, or the pair (k, (v, NULL))
571571
#' if no elements in rdd2 have key k.
572572
#' @examples
573573
#'\dontrun{
@@ -584,11 +584,11 @@ setMethod("leftOuterJoin",
584584
function(x, y, numPartitions) {
585585
xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) })
586586
yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) })
587-
587+
588588
doJoin <- function(v) {
589589
joinTaggedList(v, list(FALSE, TRUE))
590590
}
591-
591+
592592
joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numPartitions), doJoin)
593593
})
594594

@@ -621,18 +621,18 @@ setMethod("rightOuterJoin",
621621
function(x, y, numPartitions) {
622622
xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) })
623623
yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) })
624-
624+
625625
doJoin <- function(v) {
626626
joinTaggedList(v, list(TRUE, FALSE))
627627
}
628-
628+
629629
joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numPartitions), doJoin)
630630
})
631631

632632
#' Full outer join two RDDs
633633
#'
634634
#' @description
635-
#' \code{fullouterjoin} This function full-outer-joins two RDDs where every element is of the form list(K, V).
635+
#' \code{fullouterjoin} This function full-outer-joins two RDDs where every element is of the form list(K, V).
636636
#' The key types of the two RDDs should be the same.
637637
#'
638638
#' @param x An RDD to be joined. Should be an RDD where each element is
@@ -642,7 +642,7 @@ setMethod("rightOuterJoin",
642642
#' @param numPartitions Number of partitions to create.
643643
#' @return For each element (k, v) in x and (k, w) in y, the resulting RDD
644644
#' will contain all pairs (k, (v, w)) for both (k, v) in x and
645-
#' (k, w) in y, or the pair (k, (NULL, w))/(k, (v, NULL)) if no elements
645+
#' (k, w) in y, or the pair (k, (NULL, w))/(k, (v, NULL)) if no elements
646646
#' in x/y have key k.
647647
#' @examples
648648
#'\dontrun{
@@ -681,7 +681,7 @@ setMethod("fullOuterJoin",
681681
#' sc <- sparkR.init()
682682
#' rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4)))
683683
#' rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3)))
684-
#' cogroup(rdd1, rdd2, numPartitions = 2L)
684+
#' cogroup(rdd1, rdd2, numPartitions = 2L)
685685
#' # list(list(1, list(1, list(2, 3))), list(2, list(list(4), list()))
686686
#'}
687687
#' @rdname cogroup
@@ -692,7 +692,7 @@ setMethod("cogroup",
692692
rdds <- list(...)
693693
rddsLen <- length(rdds)
694694
for (i in 1:rddsLen) {
695-
rdds[[i]] <- lapply(rdds[[i]],
695+
rdds[[i]] <- lapply(rdds[[i]],
696696
function(x) { list(x[[1]], list(i, x[[2]])) })
697697
# TODO(hao): As issue [SparkR-142] mentions, the right value of i
698698
# will not be captured into UDF if getJRDD is not invoked.
@@ -721,7 +721,7 @@ setMethod("cogroup",
721721
}
722722
})
723723
}
724-
cogroup.rdd <- mapValues(groupByKey(union.rdd, numPartitions),
724+
cogroup.rdd <- mapValues(groupByKey(union.rdd, numPartitions),
725725
group.func)
726726
})
727727

@@ -743,18 +743,18 @@ setMethod("sortByKey",
743743
signature(x = "RDD"),
744744
function(x, ascending = TRUE, numPartitions = SparkR::numPartitions(x)) {
745745
rangeBounds <- list()
746-
746+
747747
if (numPartitions > 1) {
748748
rddSize <- count(x)
749749
# constant from Spark's RangePartitioner
750750
maxSampleSize <- numPartitions * 20
751751
fraction <- min(maxSampleSize / max(rddSize, 1), 1.0)
752-
752+
753753
samples <- collect(keys(sampleRDD(x, FALSE, fraction, 1L)))
754-
754+
755755
# Note: the built-in R sort() function only works on atomic vectors
756756
samples <- sort(unlist(samples, recursive = FALSE), decreasing = !ascending)
757-
757+
758758
if (length(samples) > 0) {
759759
rangeBounds <- lapply(seq_len(numPartitions - 1),
760760
function(i) {
@@ -766,24 +766,109 @@ setMethod("sortByKey",
766766

767767
rangePartitionFunc <- function(key) {
768768
partition <- 0
769-
769+
770770
# TODO: Use binary search instead of linear search, similar with Spark
771771
while (partition < length(rangeBounds) && key > rangeBounds[[partition + 1]]) {
772772
partition <- partition + 1
773773
}
774-
774+
775775
if (ascending) {
776776
partition
777777
} else {
778778
numPartitions - partition - 1
779779
}
780780
}
781-
781+
782782
partitionFunc <- function(part) {
783783
sortKeyValueList(part, decreasing = !ascending)
784784
}
785-
785+
786786
newRDD <- partitionBy(x, numPartitions, rangePartitionFunc)
787787
lapplyPartition(newRDD, partitionFunc)
788788
})
789-
789+
790+
#' @description
791+
#' \code{sampleByKey} return a subset RDD of the given RDD sampled by key
792+
#'
793+
#' @param x The RDD to sample elements by key, where each element is
794+
#' list(K, V) or c(K, V).
795+
#' @param withReplacement Sampling with replacement or not
796+
#' @param fraction The (rough) sample target fraction
797+
#' @param seed Randomness seed value
798+
#' @examples
799+
#'\dontrun{
800+
#' sc <- sparkR.init()
801+
#' rdd <- parallelize(sc, 1:3000)
802+
#' pairs <- lapply(rdd, function(x) { if (x %% 3 == 0) list("a", x)
803+
#' else { if (x %% 3 == 1) list("b", x) else list("c", x) }})
804+
#' fractions <- list(a = 0.2, b = 0.1, c = 0.3)
805+
#' sample <- sampleByKey(pairs, FALSE, fractions, 1618L)
806+
#' 100 < length(lookup(sample, "a")) && 300 > length(lookup(sample, "a")) # TRUE
807+
#' 50 < length(lookup(sample, "b")) && 150 > length(lookup(sample, "b")) # TRUE
808+
#' 200 < length(lookup(sample, "c")) && 400 > length(lookup(sample, "c")) # TRUE
809+
#' lookup(sample, "a")[which.min(lookup(sample, "a"))] >= 0 # TRUE
810+
#' lookup(sample, "a")[which.max(lookup(sample, "a"))] <= 2000 # TRUE
811+
#' lookup(sample, "b")[which.min(lookup(sample, "b"))] >= 0 # TRUE
812+
#' lookup(sample, "b")[which.max(lookup(sample, "b"))] <= 2000 # TRUE
813+
#' lookup(sample, "c")[which.min(lookup(sample, "c"))] >= 0 # TRUE
814+
#' lookup(sample, "c")[which.max(lookup(sample, "c"))] <= 2000 # TRUE
815+
#' fractions <- list(a = 0.2, b = 0.1, c = 0.3, d = 0.4)
816+
#' sample <- sampleByKey(pairs, FALSE, fractions, 1618L) # Key "d" will be ignored
817+
#' fractions <- list(a = 0.2, b = 0.1)
818+
#' sample <- sampleByKey(pairs, FALSE, fractions, 1618L) # KeyError: "c"
819+
#'}
820+
#' @rdname sampleByKey
821+
#' @aliases sampleByKey,RDD-method
822+
setMethod("sampleByKey",
823+
signature(x = "RDD", withReplacement = "logical",
824+
fractions = "vector", seed = "integer"),
825+
function(x, withReplacement, fractions, seed) {
826+
827+
for (elem in fractions) {
828+
if (elem < 0.0) {
829+
stop(paste("Negative fraction value ", fractions[which(fractions == elem)]))
830+
}
831+
}
832+
833+
# The sampler: takes a partition and returns its sampled version.
834+
samplingFunc <- function(split, part) {
835+
set.seed(bitwXor(seed, split))
836+
res <- vector("list", length(part))
837+
len <- 0
838+
839+
# mixing because the initial seeds are close to each other
840+
runif(10)
841+
842+
for (elem in part) {
843+
if (elem[[1]] %in% names(fractions)) {
844+
frac <- as.numeric(fractions[which(elem[[1]] == names(fractions))])
845+
if (withReplacement) {
846+
count <- rpois(1, frac)
847+
if (count > 0) {
848+
res[(len + 1):(len + count)] <- rep(list(elem), count)
849+
len <- len + count
850+
}
851+
} else {
852+
if (runif(1) < frac) {
853+
len <- len + 1
854+
res[[len]] <- elem
855+
}
856+
}
857+
} else {
858+
stop("KeyError: \"", elem[[1]], "\"")
859+
}
860+
}
861+
862+
# TODO(zongheng): look into the performance of the current
863+
# implementation. Look into some iterator package? Note that
864+
# Scala avoids many calls to creating an empty list and PySpark
865+
# similarly achieves this using `yield'. (duplicated from sampleRDD)
866+
if (len > 0) {
867+
res[1:len]
868+
} else {
869+
list()
870+
}
871+
}
872+
873+
lapplyPartitionsWithIndex(x, samplingFunc)
874+
})

0 commit comments

Comments
 (0)