diff --git a/dpctl/tensor/libtensor/include/kernels/accumulators.hpp b/dpctl/tensor/libtensor/include/kernels/accumulators.hpp index 110010706c..a8ef1c423e 100644 --- a/dpctl/tensor/libtensor/include/kernels/accumulators.hpp +++ b/dpctl/tensor/libtensor/include/kernels/accumulators.hpp @@ -125,54 +125,56 @@ sycl::event inclusive_scan_rec(sycl::queue &exec_q, auto lws = sycl::range<1>(wg_size); auto gws = sycl::range<1>(n_groups * wg_size); + auto ndRange = sycl::nd_range<1>(gws, lws); + slmT slm_iscan_tmp(lws, cgh); - cgh.parallel_for>( - sycl::nd_range<1>(gws, lws), [=, slm_iscan_tmp = std::move(slm_iscan_tmp)](sycl::nd_item<1> it) - { - auto chunk_gid = it.get_global_id(0); - auto lid = it.get_local_id(0); + using KernelName = inclusive_scan_rec_local_scan_krn< + inputT, outputT, n_wi, IndexerT, decltype(transformer)>; + + cgh.parallel_for(ndRange, [=, slm_iscan_tmp = std::move( + slm_iscan_tmp)]( + sycl::nd_item<1> it) { + auto chunk_gid = it.get_global_id(0); + auto lid = it.get_local_id(0); - std::array local_isum; + std::array local_isum; - size_t i = chunk_gid * n_wi; - for (size_t m_wi = 0; m_wi < n_wi; ++m_wi) { - constexpr outputT out_zero(0); + size_t i = chunk_gid * n_wi; + for (size_t m_wi = 0; m_wi < n_wi; ++m_wi) { + constexpr outputT out_zero(0); - local_isum[m_wi] = - (i + m_wi < n_elems) - ? transformer(input[indexer(s0 + s1 * (i + m_wi))]) - : out_zero; - } + local_isum[m_wi] = + (i + m_wi < n_elems) + ? transformer(input[indexer(s0 + s1 * (i + m_wi))]) + : out_zero; + } -// local_isum is now result of -// inclusive scan of locally stored mask indicators #pragma unroll - for (size_t m_wi = 1; m_wi < n_wi; ++m_wi) { - local_isum[m_wi] += local_isum[m_wi - 1]; - } + for (size_t m_wi = 1; m_wi < n_wi; ++m_wi) { + local_isum[m_wi] += local_isum[m_wi - 1]; + } + // local_isum is now result of + // inclusive scan of locally stored inputs - size_t wg_iscan_val = - sycl::inclusive_scan_over_group(it.get_group(), - local_isum.back(), - sycl::plus(), - size_t(0)); + size_t wg_iscan_val = sycl::inclusive_scan_over_group( + it.get_group(), local_isum.back(), sycl::plus(), + size_t(0)); - slm_iscan_tmp[(lid + 1) % wg_size] = wg_iscan_val; - it.barrier(sycl::access::fence_space::local_space); - size_t addand = (lid == 0) ? 0 : slm_iscan_tmp[lid]; - it.barrier(sycl::access::fence_space::local_space); + slm_iscan_tmp[(lid + 1) % wg_size] = wg_iscan_val; + it.barrier(sycl::access::fence_space::local_space); + size_t addand = (lid == 0) ? 0 : slm_iscan_tmp[lid]; #pragma unroll - for (size_t m_wi = 0; m_wi < n_wi; ++m_wi) { - local_isum[m_wi] += addand; - } - - for (size_t m_wi = 0; m_wi < n_wi && i + m_wi < n_elems; ++m_wi) { - output[i + m_wi] = local_isum[m_wi]; - } - }); + for (size_t m_wi = 0; m_wi < n_wi; ++m_wi) { + local_isum[m_wi] += addand; + } + + for (size_t m_wi = 0; m_wi < n_wi && i + m_wi < n_elems; ++m_wi) + { + output[i + m_wi] = local_isum[m_wi]; + } + }); }); sycl::event out_event = inc_scan_phase1_ev;