Skip to content

Commit 078d9a3

Browse files
authored
Use group_load_store compiler extension (#2123)
* Use group_load_store compiler extension * Tune description of dpnp.bincount
1 parent 7bfe0c8 commit 078d9a3

File tree

2 files changed

+24
-13
lines changed

2 files changed

+24
-13
lines changed

dpnp/backend/kernels/dpnp_krnl_elemwise.cpp

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@
4141
using dpctl::tensor::kernels::alignment_utils::is_aligned;
4242
using dpctl::tensor::kernels::alignment_utils::required_alignment;
4343

44+
using sycl::ext::oneapi::experimental::group_load;
45+
using sycl::ext::oneapi::experimental::group_store;
46+
4447
template <typename T>
4548
constexpr T dispatch_erf_op(T elem)
4649
{
@@ -523,41 +526,49 @@ static void func_map_init_elemwise_1arg_1type(func_map_t &fmap)
523526
_DataType_input2, \
524527
_DataType_output>) \
525528
{ \
526-
sycl::vec<_DataType_input1, vec_sz> x1 = \
527-
sg.load<vec_sz>(input1_multi_ptr); \
528-
sycl::vec<_DataType_input2, vec_sz> x2 = \
529-
sg.load<vec_sz>(input2_multi_ptr); \
529+
sycl::vec<_DataType_input1, vec_sz> x1{}; \
530+
sycl::vec<_DataType_input2, vec_sz> x2{}; \
531+
\
532+
group_load(sg, input1_multi_ptr, x1); \
533+
group_load(sg, input2_multi_ptr, x2); \
530534
\
531535
res_vec = __vec_operation__; \
532536
} \
533537
else /* input types don't match result type, so \
534538
explicit casting is required */ \
535539
{ \
540+
sycl::vec<_DataType_input1, vec_sz> tmp_x1{}; \
541+
sycl::vec<_DataType_input2, vec_sz> tmp_x2{}; \
542+
\
543+
group_load(sg, input1_multi_ptr, tmp_x1); \
544+
group_load(sg, input2_multi_ptr, tmp_x2); \
545+
\
536546
sycl::vec<_DataType_output, vec_sz> x1 = \
537547
dpnp_vec_cast<_DataType_output, \
538548
_DataType_input1, vec_sz>( \
539-
sg.load<vec_sz>(input1_multi_ptr)); \
549+
tmp_x1); \
540550
sycl::vec<_DataType_output, vec_sz> x2 = \
541551
dpnp_vec_cast<_DataType_output, \
542552
_DataType_input2, vec_sz>( \
543-
sg.load<vec_sz>(input2_multi_ptr)); \
553+
tmp_x2); \
544554
\
545555
res_vec = __vec_operation__; \
546556
} \
547557
} \
548558
else { \
549-
sycl::vec<_DataType_input1, vec_sz> x1 = \
550-
sg.load<vec_sz>(input1_multi_ptr); \
551-
sycl::vec<_DataType_input2, vec_sz> x2 = \
552-
sg.load<vec_sz>(input2_multi_ptr); \
559+
sycl::vec<_DataType_input1, vec_sz> x1{}; \
560+
sycl::vec<_DataType_input2, vec_sz> x2{}; \
561+
\
562+
group_load(sg, input1_multi_ptr, x1); \
563+
group_load(sg, input2_multi_ptr, x2); \
553564
\
554565
for (size_t k = 0; k < vec_sz; ++k) { \
555566
const _DataType_output input1_elem = x1[k]; \
556567
const _DataType_output input2_elem = x2[k]; \
557568
res_vec[k] = __operation__; \
558569
} \
559570
} \
560-
sg.store<vec_sz>(result_multi_ptr, res_vec); \
571+
group_store(sg, res_vec, result_multi_ptr); \
561572
} \
562573
else { \
563574
for (size_t k = start + sg.get_local_id()[0]; \

dpnp/dpnp_iface_histograms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ def bincount(x, weights=None, minlength=None):
325325
-------
326326
out : dpnp.ndarray of ints
327327
The result of binning the input array.
328-
The length of `out` is equal to ``np.amax(x) + 1``.
328+
The length of `out` is equal to ``dpnp.max(x) + 1``.
329329
330330
See Also
331331
--------
@@ -353,7 +353,7 @@ def bincount(x, weights=None, minlength=None):
353353
...
354354
TypeError: x must be an integer array
355355
356-
A possible use of ``bincount`` is to perform sums over
356+
A possible use of :obj:`dpnp.bincount` is to perform sums over
357357
variable-size chunks of an array, using the `weights` keyword.
358358
359359
>>> w = np.array([0.3, 0.5, 0.2, 0.7, 1., -0.6], dtype=np.float32) # weights

0 commit comments

Comments
 (0)