Skip to content

Commit 1f611d9

Browse files
authored
Merge branch 'master' into add-more-dtypes-in-umath-tests
2 parents 7007d02 + eb8ebf8 commit 1f611d9

File tree

3 files changed

+8
-15
lines changed

3 files changed

+8
-15
lines changed

dpnp/backend/extensions/blas/blas_py.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,7 @@ PYBIND11_MODULE(_blas_impl, m)
6262
using event_vecT = std::vector<sycl::event>;
6363

6464
{
65-
dot_ns::init_dot_dispatch_vector<dot_impl_fn_ptr_t,
66-
blas_ns::DotContigFactory>(
65+
dot_ns::init_dot_dispatch_vector<blas_ns::DotContigFactory>(
6766
dot_dispatch_vector);
6867

6968
auto dot_pyapi = [&](sycl::queue &exec_q, const arrayT &src1,
@@ -81,8 +80,7 @@ PYBIND11_MODULE(_blas_impl, m)
8180
}
8281

8382
{
84-
dot_ns::init_dot_dispatch_vector<dot_impl_fn_ptr_t,
85-
blas_ns::DotcContigFactory>(
83+
dot_ns::init_dot_dispatch_vector<blas_ns::DotcContigFactory>(
8684
dotc_dispatch_vector);
8785

8886
auto dotc_pyapi = [&](sycl::queue &exec_q, const arrayT &src1,
@@ -101,8 +99,7 @@ PYBIND11_MODULE(_blas_impl, m)
10199
}
102100

103101
{
104-
dot_ns::init_dot_dispatch_vector<dot_impl_fn_ptr_t,
105-
blas_ns::DotuContigFactory>(
102+
dot_ns::init_dot_dispatch_vector<blas_ns::DotuContigFactory>(
106103
dotu_dispatch_vector);
107104

108105
auto dotu_pyapi = [&](sycl::queue &exec_q, const arrayT &src1,

dpnp/backend/extensions/blas/dot_common.hpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,13 @@ typedef sycl::event (*dot_impl_fn_ptr_t)(sycl::queue &,
5050
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
5151
namespace py = pybind11;
5252

53-
template <typename dispatchT>
5453
std::pair<sycl::event, sycl::event>
5554
dot_func(sycl::queue &exec_q,
5655
const dpctl::tensor::usm_ndarray &vectorX,
5756
const dpctl::tensor::usm_ndarray &vectorY,
5857
const dpctl::tensor::usm_ndarray &result,
5958
const std::vector<sycl::event> &depends,
60-
const dispatchT &dot_dispatch_vector)
59+
const dot_impl_fn_ptr_t *dot_dispatch_vector)
6160
{
6261
const int vectorX_nd = vectorX.get_ndim();
6362
const int vectorY_nd = vectorY.get_ndim();
@@ -166,12 +165,10 @@ std::pair<sycl::event, sycl::event>
166165
return std::make_pair(args_ev, dot_ev);
167166
}
168167

169-
template <typename dispatchT,
170-
template <typename fnT, typename T>
171-
typename factoryT>
172-
void init_dot_dispatch_vector(dispatchT dot_dispatch_vector[])
168+
template <template <typename fnT, typename T> typename factoryT>
169+
void init_dot_dispatch_vector(dot_impl_fn_ptr_t dot_dispatch_vector[])
173170
{
174-
dpctl_td_ns::DispatchVectorBuilder<dispatchT, factoryT,
171+
dpctl_td_ns::DispatchVectorBuilder<dot_impl_fn_ptr_t, factoryT,
175172
dpctl_td_ns::num_types>
176173
contig;
177174
contig.populate_dispatch_vector(dot_dispatch_vector);

dpnp/backend/extensions/window/common.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,11 @@ sycl::event window_impl(sycl::queue &q,
6767
return window_ev;
6868
}
6969

70-
template <typename dispatchT>
7170
std::pair<sycl::event, sycl::event>
7271
py_window(sycl::queue &exec_q,
7372
const dpctl::tensor::usm_ndarray &result,
7473
const std::vector<sycl::event> &depends,
75-
const dispatchT &window_dispatch_vector)
74+
const window_fn_ptr_t *window_dispatch_vector)
7675
{
7776
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(result);
7877

0 commit comments

Comments
 (0)