Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 44 additions & 21 deletions src/equal.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ static int raw_equal_scalar(const Rbyte* x, const Rbyte* y, bool na_equal);
static int cpl_equal_scalar(const Rcomplex* x, const Rcomplex* y, bool na_equal);
static int chr_equal_scalar(const SEXP* x, const SEXP* y, bool na_equal);
static int list_equal_scalar(SEXP x, R_len_t i, SEXP y, R_len_t j, bool na_equal);
static int df_equal_scalar(SEXP x, R_len_t i, SEXP y, R_len_t j, bool na_equal);
static int df_equal_scalar(SEXP x, R_len_t i, SEXP y, R_len_t j, bool na_equal, int n_col);


// If `x` is a data frame, it must have been recursively proxied
Expand All @@ -28,7 +28,15 @@ int equal_scalar(SEXP x, R_len_t i, SEXP y, R_len_t j, bool na_equal) {

switch (vec_proxy_typeof(x)) {
case vctrs_type_list: return list_equal_scalar(x, i, y, j, na_equal);
case vctrs_type_dataframe: return df_equal_scalar(x, i, y, j, na_equal);
case vctrs_type_dataframe: {
int n_col = Rf_length(x);

if (n_col != Rf_length(y)) {
Rf_errorcall(R_NilValue, "`x` and `y` must have the same number of columns");
}

return df_equal_scalar(x, i, y, j, na_equal, n_col);
}
default: break;
}

Expand All @@ -55,6 +63,20 @@ int equal_scalar(SEXP x, R_len_t i, SEXP y, R_len_t j, bool na_equal) {
} \
while (0)

#define EQUAL_DF(SCALAR_EQUAL) \
do { \
int n_col = Rf_length(x); \
\
if (n_col != Rf_length(y)) { \
Rf_errorcall(R_NilValue, "`x` and `y` must have the same number of columns"); \
} \
\
for (R_len_t i = 0; i < n; ++i) { \
p[i] = SCALAR_EQUAL(x, i, y, i, na_equal, n_col); \
} \
} \
while (0)

// [[ register() ]]
SEXP vctrs_equal(SEXP x, SEXP y, SEXP na_equal_) {
x = PROTECT(vec_proxy_recursive(x, vctrs_proxy_equal));
Expand All @@ -79,7 +101,7 @@ SEXP vctrs_equal(SEXP x, SEXP y, SEXP na_equal_) {
case vctrs_type_complex: EQUAL(Rcomplex, COMPLEX_RO, cpl_equal_scalar); break;
case vctrs_type_character: EQUAL(SEXP, STRING_PTR_RO, chr_equal_scalar); break;
case vctrs_type_list: EQUAL_BARRIER(list_equal_scalar); break;
case vctrs_type_dataframe: EQUAL_BARRIER(df_equal_scalar); break;
case vctrs_type_dataframe: EQUAL_DF(df_equal_scalar); break;
case vctrs_type_scalar: Rf_errorcall(R_NilValue, "Can't compare scalars with `vctrs_equal()`");
default: Rf_error("Unimplemented type in `vctrs_equal()`");
}
Expand All @@ -90,6 +112,7 @@ SEXP vctrs_equal(SEXP x, SEXP y, SEXP na_equal_) {

#undef EQUAL
#undef EQUAL_BARRIER
#undef EQUAL_DF

// Storing pointed values on the stack helps performance for the
// `!na_equal` cases
Expand Down Expand Up @@ -173,23 +196,8 @@ static int list_equal_scalar(SEXP x, R_len_t i, SEXP y, R_len_t j, bool na_equal
return equal_object(VECTOR_ELT(x, i), VECTOR_ELT(y, j), na_equal);
}

static int df_equal_scalar(SEXP x, R_len_t i, SEXP y, R_len_t j, bool na_equal) {
if (!is_data_frame(y)) {
return false;
}

int p = Rf_length(x);
if (p != Rf_length(y)) {
return false;
}

// Don't worry about names missingness because properly formed
// data frames shouldn't have any missing names
if (!equal_names(x, y)) {
return false;
}

for (int k = 0; k < p; ++k) {
static int df_equal_scalar(SEXP x, R_len_t i, SEXP y, R_len_t j, bool na_equal, int n_col) {
for (int k = 0; k < n_col; ++k) {
int eq = equal_scalar(VECTOR_ELT(x, k), i, VECTOR_ELT(y, k), j, na_equal);

if (eq <= 0) {
Expand Down Expand Up @@ -399,6 +407,20 @@ do { \
} \
while (0)

#define DUPLICATE_ALL_DF(SCALAR_EQUAL) \
do { \
int n_col = Rf_length(x); \
\
for (R_len_t i = 1; i < n; ++i) { \
if (SCALAR_EQUAL(x, 0, x, i, true, n_col)) { \
continue; \
} \
*p = false; \
break; \
} \
} \
while (0)

// [[ register() ]]
SEXP vctrs_duplicate_all(SEXP x) {
x = PROTECT(vec_proxy_recursive(x, vctrs_proxy_equal));
Expand All @@ -422,7 +444,7 @@ SEXP vctrs_duplicate_all(SEXP x) {
case vctrs_type_complex: DUPLICATE_ALL(Rcomplex, COMPLEX_RO, cpl_equal_scalar); break;
case vctrs_type_character: DUPLICATE_ALL(SEXP, STRING_PTR_RO, chr_equal_scalar); break;
case vctrs_type_list: DUPLICATE_ALL_BARRIER(list_equal_scalar); break;
case vctrs_type_dataframe: DUPLICATE_ALL_BARRIER(df_equal_scalar); break;
case vctrs_type_dataframe: DUPLICATE_ALL_DF(df_equal_scalar); break;
case vctrs_type_scalar: Rf_errorcall(R_NilValue, "Can't detect duplicates in scalars with `vctrs_duplicate_all()`");
default: Rf_error("Unimplemented type in `vctrs_duplicate_all()`");
}
Expand All @@ -433,6 +455,7 @@ SEXP vctrs_duplicate_all(SEXP x) {

#undef DUPLICATE_ALL
#undef DUPLICATE_ALL_BARRIER
#undef DUPLICATE_ALL_DF

// -----------------------------------------------------------------------------

Expand Down
13 changes: 9 additions & 4 deletions tests/testthat/test-equal.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,18 @@ test_that("can compare data frames", {
})

test_that("data frames must have same size and columns", {
expect_false(.Call(vctrs_equal,
expect_error(.Call(vctrs_equal,
data.frame(x = 1),
data.frame(x = 1, y = 2),
TRUE
))

expect_false(.Call(vctrs_equal,
),
"must have the same number of columns"
)

# Names are not checked, as `vec_cast_common()` should take care of the type.
# So if `vec_cast_common()` is not called, or is improperly specified, then
# this could result in false equality.
expect_true(.Call(vctrs_equal,
data.frame(x = 1),
data.frame(y = 1),
TRUE
Expand Down