diff --git a/dpctl/tensor/libtensor/include/kernels/sorting.hpp b/dpctl/tensor/libtensor/include/kernels/sorting.hpp index e7a024259e..e577a1a52a 100644 --- a/dpctl/tensor/libtensor/include/kernels/sorting.hpp +++ b/dpctl/tensor/libtensor/include/kernels/sorting.hpp @@ -539,34 +539,32 @@ sort_over_work_group_contig_impl(sycl::queue &q, sycl::group_barrier(it.get_group()); bool data_in_temp = false; - size_t sorted_size = 1; - while (true) { - const size_t nelems_sorted_so_far = sorted_size * chunk; - if (nelems_sorted_so_far < wg_chunk_size) { - const size_t q = (lid / sorted_size); - const size_t start_1 = - sycl::min(2 * nelems_sorted_so_far * q, wg_chunk_size); - const size_t end_1 = sycl::min( - start_1 + nelems_sorted_so_far, wg_chunk_size); - const size_t end_2 = - sycl::min(end_1 + nelems_sorted_so_far, wg_chunk_size); - const size_t offset = chunk * (lid - q * sorted_size); - - if (data_in_temp) { - merge_impl(offset, scratch_space, work_space, start_1, - end_1, end_2, start_1, comp, chunk); - } - else { - merge_impl(offset, work_space, scratch_space, start_1, - end_1, end_2, start_1, comp, chunk); - } - sycl::group_barrier(it.get_group()); - - data_in_temp = !data_in_temp; - sorted_size *= 2; + size_t n_chunks_merged = 1; + + // merge chunk while n_chunks_merged * chunk < wg_chunk_size + const size_t max_chunks_merged = 1 + ((wg_chunk_size - 1) / chunk); + for (; n_chunks_merged < max_chunks_merged; + data_in_temp = !data_in_temp, n_chunks_merged *= 2) + { + const size_t nelems_sorted_so_far = n_chunks_merged * chunk; + const size_t q = (lid / n_chunks_merged); + const size_t start_1 = + sycl::min(2 * nelems_sorted_so_far * q, wg_chunk_size); + const size_t end_1 = + sycl::min(start_1 + nelems_sorted_so_far, wg_chunk_size); + const size_t end_2 = + sycl::min(end_1 + nelems_sorted_so_far, wg_chunk_size); + const size_t offset = chunk * (lid - q * n_chunks_merged); + + if (data_in_temp) { + merge_impl(offset, scratch_space, work_space, start_1, + end_1, end_2, start_1, comp, chunk); + } + else { + merge_impl(offset, work_space, scratch_space, start_1, + end_1, end_2, start_1, comp, chunk); } - else - break; + sycl::group_barrier(it.get_group()); } const auto &out_src = (data_in_temp) ? scratch_space : work_space; diff --git a/dpctl/tensor/libtensor/source/sorting/sorting_common.hpp b/dpctl/tensor/libtensor/source/sorting/sorting_common.hpp index 62b30ccb8f..bda9227b71 100644 --- a/dpctl/tensor/libtensor/source/sorting/sorting_common.hpp +++ b/dpctl/tensor/libtensor/source/sorting/sorting_common.hpp @@ -41,7 +41,7 @@ template struct ExtendedRealFPLess /* [R, nan] */ bool operator()(const fpT v1, const fpT v2) const { - return (!sycl::isnan(v1) && (sycl::isnan(v2) || (v1 < v2))); + return (!std::isnan(v1) && (std::isnan(v2) || (v1 < v2))); } }; @@ -49,7 +49,7 @@ template struct ExtendedRealFPGreater { bool operator()(const fpT v1, const fpT v2) const { - return (!sycl::isnan(v2) && (sycl::isnan(v1) || (v2 < v1))); + return (!std::isnan(v2) && (std::isnan(v1) || (v2 < v1))); } }; @@ -64,14 +64,14 @@ template struct ExtendedComplexFPLess const realT real1 = std::real(v1); const realT real2 = std::real(v2); - const bool r1_nan = sycl::isnan(real1); - const bool r2_nan = sycl::isnan(real2); + const bool r1_nan = std::isnan(real1); + const bool r2_nan = std::isnan(real2); const realT imag1 = std::imag(v1); const realT imag2 = std::imag(v2); - const bool i1_nan = sycl::isnan(imag1); - const bool i2_nan = sycl::isnan(imag2); + const bool i1_nan = std::isnan(imag1); + const bool i2_nan = std::isnan(imag2); const int idx1 = ((r1_nan) ? 2 : 0) + ((i1_nan) ? 1 : 0); const int idx2 = ((r2_nan) ? 2 : 0) + ((i2_nan) ? 1 : 0);