diff --git a/src/equal.c b/src/equal.c index 434102ae7..1ad476e3c 100644 --- a/src/equal.c +++ b/src/equal.c @@ -8,7 +8,7 @@ static int raw_equal_scalar(const Rbyte* x, const Rbyte* y, bool na_equal); static int cpl_equal_scalar(const Rcomplex* x, const Rcomplex* y, bool na_equal); static int chr_equal_scalar(const SEXP* x, const SEXP* y, bool na_equal); static int list_equal_scalar(SEXP x, R_len_t i, SEXP y, R_len_t j, bool na_equal); -static int df_equal_scalar(SEXP x, R_len_t i, SEXP y, R_len_t j, bool na_equal); +static int df_equal_scalar(SEXP x, R_len_t i, SEXP y, R_len_t j, bool na_equal, int n_col); // If `x` is a data frame, it must have been recursively proxied @@ -28,7 +28,15 @@ int equal_scalar(SEXP x, R_len_t i, SEXP y, R_len_t j, bool na_equal) { switch (vec_proxy_typeof(x)) { case vctrs_type_list: return list_equal_scalar(x, i, y, j, na_equal); - case vctrs_type_dataframe: return df_equal_scalar(x, i, y, j, na_equal); + case vctrs_type_dataframe: { + int n_col = Rf_length(x); + + if (n_col != Rf_length(y)) { + Rf_errorcall(R_NilValue, "`x` and `y` must have the same number of columns"); + } + + return df_equal_scalar(x, i, y, j, na_equal, n_col); + } default: break; } @@ -55,6 +63,20 @@ int equal_scalar(SEXP x, R_len_t i, SEXP y, R_len_t j, bool na_equal) { } \ while (0) +#define EQUAL_DF(SCALAR_EQUAL) \ + do { \ + int n_col = Rf_length(x); \ + \ + if (n_col != Rf_length(y)) { \ + Rf_errorcall(R_NilValue, "`x` and `y` must have the same number of columns"); \ + } \ + \ + for (R_len_t i = 0; i < n; ++i) { \ + p[i] = SCALAR_EQUAL(x, i, y, i, na_equal, n_col); \ + } \ + } \ + while (0) + // [[ register() ]] SEXP vctrs_equal(SEXP x, SEXP y, SEXP na_equal_) { x = PROTECT(vec_proxy_recursive(x, vctrs_proxy_equal)); @@ -79,7 +101,7 @@ SEXP vctrs_equal(SEXP x, SEXP y, SEXP na_equal_) { case vctrs_type_complex: EQUAL(Rcomplex, COMPLEX_RO, cpl_equal_scalar); break; case vctrs_type_character: EQUAL(SEXP, STRING_PTR_RO, chr_equal_scalar); break; case vctrs_type_list: EQUAL_BARRIER(list_equal_scalar); break; - case vctrs_type_dataframe: EQUAL_BARRIER(df_equal_scalar); break; + case vctrs_type_dataframe: EQUAL_DF(df_equal_scalar); break; case vctrs_type_scalar: Rf_errorcall(R_NilValue, "Can't compare scalars with `vctrs_equal()`"); default: Rf_error("Unimplemented type in `vctrs_equal()`"); } @@ -90,6 +112,7 @@ SEXP vctrs_equal(SEXP x, SEXP y, SEXP na_equal_) { #undef EQUAL #undef EQUAL_BARRIER +#undef EQUAL_DF // Storing pointed values on the stack helps performance for the // `!na_equal` cases @@ -173,23 +196,8 @@ static int list_equal_scalar(SEXP x, R_len_t i, SEXP y, R_len_t j, bool na_equal return equal_object(VECTOR_ELT(x, i), VECTOR_ELT(y, j), na_equal); } -static int df_equal_scalar(SEXP x, R_len_t i, SEXP y, R_len_t j, bool na_equal) { - if (!is_data_frame(y)) { - return false; - } - - int p = Rf_length(x); - if (p != Rf_length(y)) { - return false; - } - - // Don't worry about names missingness because properly formed - // data frames shouldn't have any missing names - if (!equal_names(x, y)) { - return false; - } - - for (int k = 0; k < p; ++k) { +static int df_equal_scalar(SEXP x, R_len_t i, SEXP y, R_len_t j, bool na_equal, int n_col) { + for (int k = 0; k < n_col; ++k) { int eq = equal_scalar(VECTOR_ELT(x, k), i, VECTOR_ELT(y, k), j, na_equal); if (eq <= 0) { @@ -399,6 +407,20 @@ do { \ } \ while (0) +#define DUPLICATE_ALL_DF(SCALAR_EQUAL) \ +do { \ + int n_col = Rf_length(x); \ + \ + for (R_len_t i = 1; i < n; ++i) { \ + if (SCALAR_EQUAL(x, 0, x, i, true, n_col)) { \ + continue; \ + } \ + *p = false; \ + break; \ + } \ +} \ +while (0) + // [[ register() ]] SEXP vctrs_duplicate_all(SEXP x) { x = PROTECT(vec_proxy_recursive(x, vctrs_proxy_equal)); @@ -422,7 +444,7 @@ SEXP vctrs_duplicate_all(SEXP x) { case vctrs_type_complex: DUPLICATE_ALL(Rcomplex, COMPLEX_RO, cpl_equal_scalar); break; case vctrs_type_character: DUPLICATE_ALL(SEXP, STRING_PTR_RO, chr_equal_scalar); break; case vctrs_type_list: DUPLICATE_ALL_BARRIER(list_equal_scalar); break; - case vctrs_type_dataframe: DUPLICATE_ALL_BARRIER(df_equal_scalar); break; + case vctrs_type_dataframe: DUPLICATE_ALL_DF(df_equal_scalar); break; case vctrs_type_scalar: Rf_errorcall(R_NilValue, "Can't detect duplicates in scalars with `vctrs_duplicate_all()`"); default: Rf_error("Unimplemented type in `vctrs_duplicate_all()`"); } @@ -433,6 +455,7 @@ SEXP vctrs_duplicate_all(SEXP x) { #undef DUPLICATE_ALL #undef DUPLICATE_ALL_BARRIER +#undef DUPLICATE_ALL_DF // ----------------------------------------------------------------------------- diff --git a/tests/testthat/test-equal.R b/tests/testthat/test-equal.R index bd2b3bb74..f041478bd 100644 --- a/tests/testthat/test-equal.R +++ b/tests/testthat/test-equal.R @@ -43,13 +43,18 @@ test_that("can compare data frames", { }) test_that("data frames must have same size and columns", { - expect_false(.Call(vctrs_equal, + expect_error(.Call(vctrs_equal, data.frame(x = 1), data.frame(x = 1, y = 2), TRUE - )) - - expect_false(.Call(vctrs_equal, + ), + "must have the same number of columns" + ) + + # Names are not checked, as `vec_cast_common()` should take care of the type. + # So if `vec_cast_common()` is not called, or is improperly specified, then + # this could result in false equality. + expect_true(.Call(vctrs_equal, data.frame(x = 1), data.frame(y = 1), TRUE