Skip to content

Commit 0a901dd

Browse files
committed
[SPARK-7231] [SPARKR] Changes to make SparkR DataFrame dplyr friendly.
Changes include 1. Rename sortDF to arrange 2. Add new aliases `group_by` and `sample_frac`, `summarize` 3. Add more user friendly column addition (mutate), rename 4. Support mean as an alias for avg in Scala and also support n_distinct, n as in dplyr Using these changes we can pretty much run the examples as described in http://cran.rstudio.com/web/packages/dplyr/vignettes/introduction.html with the same syntax The only thing missing in SparkR is auto resolving column names when used in an expression i.e. making something like `select(flights, delay)` works in dply but we right now need `select(flights, flights$delay)` or `select(flights, "delay")`. But this is a complicated change and I'll file a new issue for it cc sun-rui rxin Author: Shivaram Venkataraman <[email protected]> Closes #6005 from shivaram/sparkr-df-api and squashes the following commits: 5e0716a [Shivaram Venkataraman] Fix some roxygen bugs 1254953 [Shivaram Venkataraman] Merge branch 'master' of https://github.com/apache/spark into sparkr-df-api 0521149 [Shivaram Venkataraman] Changes to make SparkR DataFrame dplyr friendly. Changes include 1. Rename sortDF to arrange 2. Add new aliases `group_by` and `sample_frac`, `summarize` 3. Add more user friendly column addition (mutate), rename 4. Support mean as an alias for avg in Scala and also support n_distinct, n as in dplyr
1 parent b6c797b commit 0a901dd

File tree

8 files changed

+249
-29
lines changed

8 files changed

+249
-29
lines changed

R/pkg/NAMESPACE

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ export("print.jobj")
99

1010
exportClasses("DataFrame")
1111

12-
exportMethods("cache",
12+
exportMethods("arrange",
13+
"cache",
1314
"collect",
1415
"columns",
1516
"count",
@@ -20,6 +21,7 @@ exportMethods("cache",
2021
"explain",
2122
"filter",
2223
"first",
24+
"group_by",
2325
"groupBy",
2426
"head",
2527
"insertInto",
@@ -28,12 +30,15 @@ exportMethods("cache",
2830
"join",
2931
"limit",
3032
"orderBy",
33+
"mutate",
3134
"names",
3235
"persist",
3336
"printSchema",
3437
"registerTempTable",
38+
"rename",
3539
"repartition",
3640
"sampleDF",
41+
"sample_frac",
3742
"saveAsParquetFile",
3843
"saveAsTable",
3944
"saveDF",
@@ -42,7 +47,7 @@ exportMethods("cache",
4247
"selectExpr",
4348
"show",
4449
"showDF",
45-
"sortDF",
50+
"summarize",
4651
"take",
4752
"unionAll",
4853
"unpersist",
@@ -72,6 +77,8 @@ exportMethods("abs",
7277
"max",
7378
"mean",
7479
"min",
80+
"n",
81+
"n_distinct",
7582
"rlike",
7683
"sqrt",
7784
"startsWith",

R/pkg/R/DataFrame.R

Lines changed: 115 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,7 @@ setMethod("distinct",
480480
#' @param withReplacement Sampling with replacement or not
481481
#' @param fraction The (rough) sample target fraction
482482
#' @rdname sampleDF
483+
#' @aliases sample_frac
483484
#' @export
484485
#' @examples
485486
#'\dontrun{
@@ -501,6 +502,15 @@ setMethod("sampleDF",
501502
dataFrame(sdf)
502503
})
503504

505+
#' @rdname sampleDF
506+
#' @aliases sampleDF
507+
setMethod("sample_frac",
508+
signature(x = "DataFrame", withReplacement = "logical",
509+
fraction = "numeric"),
510+
function(x, withReplacement, fraction) {
511+
sampleDF(x, withReplacement, fraction)
512+
})
513+
504514
#' Count
505515
#'
506516
#' Returns the number of rows in a DataFrame
@@ -682,7 +692,8 @@ setMethod("toRDD",
682692
#' @param x a DataFrame
683693
#' @return a GroupedData
684694
#' @seealso GroupedData
685-
#' @rdname DataFrame
695+
#' @aliases group_by
696+
#' @rdname groupBy
686697
#' @export
687698
#' @examples
688699
#' \dontrun{
@@ -705,19 +716,36 @@ setMethod("groupBy",
705716
groupedData(sgd)
706717
})
707718

708-
#' Agg
719+
#' @rdname groupBy
720+
#' @aliases group_by
721+
setMethod("group_by",
722+
signature(x = "DataFrame"),
723+
function(x, ...) {
724+
groupBy(x, ...)
725+
})
726+
727+
#' Summarize data across columns
709728
#'
710729
#' Compute aggregates by specifying a list of columns
711730
#'
712731
#' @param x a DataFrame
713732
#' @rdname DataFrame
733+
#' @aliases summarize
714734
#' @export
715735
setMethod("agg",
716736
signature(x = "DataFrame"),
717737
function(x, ...) {
718738
agg(groupBy(x), ...)
719739
})
720740

741+
#' @rdname DataFrame
742+
#' @aliases agg
743+
setMethod("summarize",
744+
signature(x = "DataFrame"),
745+
function(x, ...) {
746+
agg(x, ...)
747+
})
748+
721749

722750
############################## RDD Map Functions ##################################
723751
# All of the following functions mirror the existing RDD map functions, #
@@ -886,7 +914,7 @@ setMethod("select",
886914
signature(x = "DataFrame", col = "list"),
887915
function(x, col) {
888916
cols <- lapply(col, function(c) {
889-
if (class(c)== "Column") {
917+
if (class(c) == "Column") {
890918
c@jc
891919
} else {
892920
col(c)@jc
@@ -946,6 +974,42 @@ setMethod("withColumn",
946974
select(x, x$"*", alias(col, colName))
947975
})
948976

977+
#' Mutate
978+
#'
979+
#' Return a new DataFrame with the specified columns added.
980+
#'
981+
#' @param x A DataFrame
982+
#' @param col a named argument of the form name = col
983+
#' @return A new DataFrame with the new columns added.
984+
#' @rdname withColumn
985+
#' @aliases withColumn
986+
#' @export
987+
#' @examples
988+
#'\dontrun{
989+
#' sc <- sparkR.init()
990+
#' sqlCtx <- sparkRSQL.init(sc)
991+
#' path <- "path/to/file.json"
992+
#' df <- jsonFile(sqlCtx, path)
993+
#' newDF <- mutate(df, newCol = df$col1 * 5, newCol2 = df$col1 * 2)
994+
#' names(newDF) # Will contain newCol, newCol2
995+
#' }
996+
setMethod("mutate",
997+
signature(x = "DataFrame"),
998+
function(x, ...) {
999+
cols <- list(...)
1000+
stopifnot(length(cols) > 0)
1001+
stopifnot(class(cols[[1]]) == "Column")
1002+
ns <- names(cols)
1003+
if (!is.null(ns)) {
1004+
for (n in ns) {
1005+
if (n != "") {
1006+
cols[[n]] <- alias(cols[[n]], n)
1007+
}
1008+
}
1009+
}
1010+
do.call(select, c(x, x$"*", cols))
1011+
})
1012+
9491013
#' WithColumnRenamed
9501014
#'
9511015
#' Rename an existing column in a DataFrame.
@@ -977,29 +1041,67 @@ setMethod("withColumnRenamed",
9771041
select(x, cols)
9781042
})
9791043

1044+
#' Rename
1045+
#'
1046+
#' Rename an existing column in a DataFrame.
1047+
#'
1048+
#' @param x A DataFrame
1049+
#' @param newCol A named pair of the form new_column_name = existing_column
1050+
#' @return A DataFrame with the column name changed.
1051+
#' @rdname withColumnRenamed
1052+
#' @aliases withColumnRenamed
1053+
#' @export
1054+
#' @examples
1055+
#'\dontrun{
1056+
#' sc <- sparkR.init()
1057+
#' sqlCtx <- sparkRSQL.init(sc)
1058+
#' path <- "path/to/file.json"
1059+
#' df <- jsonFile(sqlCtx, path)
1060+
#' newDF <- rename(df, col1 = df$newCol1)
1061+
#' }
1062+
setMethod("rename",
1063+
signature(x = "DataFrame"),
1064+
function(x, ...) {
1065+
renameCols <- list(...)
1066+
stopifnot(length(renameCols) > 0)
1067+
stopifnot(class(renameCols[[1]]) == "Column")
1068+
newNames <- names(renameCols)
1069+
oldNames <- lapply(renameCols, function(col) {
1070+
callJMethod(col@jc, "toString")
1071+
})
1072+
cols <- lapply(columns(x), function(c) {
1073+
if (c %in% oldNames) {
1074+
alias(col(c), newNames[[match(c, oldNames)]])
1075+
} else {
1076+
col(c)
1077+
}
1078+
})
1079+
select(x, cols)
1080+
})
1081+
9801082
setClassUnion("characterOrColumn", c("character", "Column"))
9811083

982-
#' SortDF
1084+
#' Arrange
9831085
#'
9841086
#' Sort a DataFrame by the specified column(s).
9851087
#'
9861088
#' @param x A DataFrame to be sorted.
9871089
#' @param col Either a Column object or character vector indicating the field to sort on
9881090
#' @param ... Additional sorting fields
9891091
#' @return A DataFrame where all elements are sorted.
990-
#' @rdname sortDF
1092+
#' @rdname arrange
9911093
#' @export
9921094
#' @examples
9931095
#'\dontrun{
9941096
#' sc <- sparkR.init()
9951097
#' sqlCtx <- sparkRSQL.init(sc)
9961098
#' path <- "path/to/file.json"
9971099
#' df <- jsonFile(sqlCtx, path)
998-
#' sortDF(df, df$col1)
999-
#' sortDF(df, "col1")
1000-
#' sortDF(df, asc(df$col1), desc(abs(df$col2)))
1100+
#' arrange(df, df$col1)
1101+
#' arrange(df, "col1")
1102+
#' arrange(df, asc(df$col1), desc(abs(df$col2)))
10011103
#' }
1002-
setMethod("sortDF",
1104+
setMethod("arrange",
10031105
signature(x = "DataFrame", col = "characterOrColumn"),
10041106
function(x, col, ...) {
10051107
if (class(col) == "character") {
@@ -1013,20 +1115,20 @@ setMethod("sortDF",
10131115
dataFrame(sdf)
10141116
})
10151117

1016-
#' @rdname sortDF
1118+
#' @rdname arrange
10171119
#' @aliases orderBy,DataFrame,function-method
10181120
setMethod("orderBy",
10191121
signature(x = "DataFrame", col = "characterOrColumn"),
10201122
function(x, col) {
1021-
sortDF(x, col)
1123+
arrange(x, col)
10221124
})
10231125

10241126
#' Filter
10251127
#'
10261128
#' Filter the rows of a DataFrame according to a given condition.
10271129
#'
10281130
#' @param x A DataFrame to be sorted.
1029-
#' @param condition The condition to sort on. This may either be a Column expression
1131+
#' @param condition The condition to filter on. This may either be a Column expression
10301132
#' or a string containing a SQL statement
10311133
#' @return A DataFrame containing only the rows that meet the condition.
10321134
#' @rdname filter
@@ -1106,6 +1208,7 @@ setMethod("join",
11061208
#'
11071209
#' Return a new DataFrame containing the union of rows in this DataFrame
11081210
#' and another DataFrame. This is equivalent to `UNION ALL` in SQL.
1211+
#' Note that this does not remove duplicate rows across the two DataFrames.
11091212
#'
11101213
#' @param x A Spark DataFrame
11111214
#' @param y A Spark DataFrame

R/pkg/R/column.R

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,8 @@ createMethods()
131131
#' alias
132132
#'
133133
#' Set a new name for a column
134+
135+
#' @rdname column
134136
setMethod("alias",
135137
signature(object = "Column"),
136138
function(object, data) {
@@ -141,8 +143,12 @@ setMethod("alias",
141143
}
142144
})
143145

146+
#' substr
147+
#'
144148
#' An expression that returns a substring.
145149
#'
150+
#' @rdname column
151+
#'
146152
#' @param start starting position
147153
#' @param stop ending position
148154
setMethod("substr", signature(x = "Column"),
@@ -152,6 +158,9 @@ setMethod("substr", signature(x = "Column"),
152158
})
153159

154160
#' Casts the column to a different data type.
161+
#'
162+
#' @rdname column
163+
#'
155164
#' @examples
156165
#' \dontrun{
157166
#' cast(df$age, "string")
@@ -173,8 +182,8 @@ setMethod("cast",
173182

174183
#' Approx Count Distinct
175184
#'
176-
#' Returns the approximate number of distinct items in a group.
177-
#'
185+
#' @rdname column
186+
#' @return the approximate number of distinct items in a group.
178187
setMethod("approxCountDistinct",
179188
signature(x = "Column"),
180189
function(x, rsd = 0.95) {
@@ -184,8 +193,8 @@ setMethod("approxCountDistinct",
184193

185194
#' Count Distinct
186195
#'
187-
#' returns the number of distinct items in a group.
188-
#'
196+
#' @rdname column
197+
#' @return the number of distinct items in a group.
189198
setMethod("countDistinct",
190199
signature(x = "Column"),
191200
function(x, ...) {
@@ -197,3 +206,18 @@ setMethod("countDistinct",
197206
column(jc)
198207
})
199208

209+
#' @rdname column
210+
#' @aliases countDistinct
211+
setMethod("n_distinct",
212+
signature(x = "Column"),
213+
function(x, ...) {
214+
countDistinct(x, ...)
215+
})
216+
217+
#' @rdname column
218+
#' @aliases count
219+
setMethod("n",
220+
signature(x = "Column"),
221+
function(x) {
222+
count(x)
223+
})

0 commit comments

Comments
 (0)