diff --git a/dpnp/backend/extensions/lapack/CMakeLists.txt b/dpnp/backend/extensions/lapack/CMakeLists.txt index c2ef9cc8f5c9..0b4a678d63e3 100644 --- a/dpnp/backend/extensions/lapack/CMakeLists.txt +++ b/dpnp/backend/extensions/lapack/CMakeLists.txt @@ -34,6 +34,7 @@ set(_module_src ${CMAKE_CURRENT_SOURCE_DIR}/getrf.cpp ${CMAKE_CURRENT_SOURCE_DIR}/getrf_batch.cpp ${CMAKE_CURRENT_SOURCE_DIR}/getri_batch.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/getrs.cpp ${CMAKE_CURRENT_SOURCE_DIR}/heevd.cpp ${CMAKE_CURRENT_SOURCE_DIR}/orgqr.cpp ${CMAKE_CURRENT_SOURCE_DIR}/orgqr_batch.cpp diff --git a/dpnp/backend/extensions/lapack/gesv.cpp b/dpnp/backend/extensions/lapack/gesv.cpp index 08a544f18f9b..3dc4ec6b125d 100644 --- a/dpnp/backend/extensions/lapack/gesv.cpp +++ b/dpnp/backend/extensions/lapack/gesv.cpp @@ -93,15 +93,18 @@ static sycl::event gesv_impl(sycl::queue exec_q, gesv_event = mkl_lapack::gesv( exec_q, - n, // The order of the matrix A (0 ≤ n). - nrhs, // The number of right-hand sides B (0 ≤ nrhs). + n, // The order of the square matrix A + // and the number of rows in matrix B (0 ≤ n). + nrhs, // The number of right-hand sides, + // i.e., the number of columns in matrix B (0 ≤ nrhs). a, // Pointer to the square coefficient matrix A (n x n). lda, // The leading dimension of a, must be at least max(1, n). ipiv, // The pivot indices that define the permutation matrix P; // row i of the matrix was interchanged with row ipiv(i), // must be at least max(1, n). b, // Pointer to the right hand side matrix B (n x nrhs). - ldb, // The leading dimension of b, must be at least max(1, n). + ldb, // The leading dimension of matrix B, + // must be at least max(1, n). scratchpad, // Pointer to scratchpad memory to be used by MKL // routine for storing intermediate results. scratchpad_size, depends); @@ -252,13 +255,12 @@ std::pair char *coeff_matrix_data = coeff_matrix.get_data(); char *dependent_vals_data = dependent_vals.get_data(); - const std::int64_t n = coeff_matrix_shape[0]; - const std::int64_t m = dependent_vals_shape[0]; + const std::int64_t n = dependent_vals_shape[0]; const std::int64_t nrhs = (dependent_vals_nd > 1) ? dependent_vals_shape[1] : 1; const std::int64_t lda = std::max(1UL, n); - const std::int64_t ldb = std::max(1UL, m); + const std::int64_t ldb = std::max(1UL, n); std::vector host_task_events; sycl::event gesv_ev = diff --git a/dpnp/backend/extensions/lapack/getrs.cpp b/dpnp/backend/extensions/lapack/getrs.cpp new file mode 100644 index 000000000000..1d15a608c252 --- /dev/null +++ b/dpnp/backend/extensions/lapack/getrs.cpp @@ -0,0 +1,314 @@ +//***************************************************************************** +// Copyright (c) 2024, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#include + +// dpctl tensor headers +#include "utils/memory_overlap.hpp" +#include "utils/type_utils.hpp" + +#include "getrs.hpp" +#include "linalg_exceptions.hpp" +#include "types_matrix.hpp" + +#include "dpnp_utils.hpp" + +namespace dpnp +{ +namespace backend +{ +namespace ext +{ +namespace lapack +{ +namespace mkl_lapack = oneapi::mkl::lapack; +namespace py = pybind11; +namespace type_utils = dpctl::tensor::type_utils; + +typedef sycl::event (*getrs_impl_fn_ptr_t)(sycl::queue, + oneapi::mkl::transpose, + const std::int64_t, + const std::int64_t, + char *, + std::int64_t, + std::int64_t *, + char *, + std::int64_t, + std::vector &, + const std::vector &); + +static getrs_impl_fn_ptr_t getrs_dispatch_vector[dpctl_td_ns::num_types]; + +template +static sycl::event getrs_impl(sycl::queue exec_q, + oneapi::mkl::transpose trans, + const std::int64_t n, + const std::int64_t nrhs, + char *in_a, + std::int64_t lda, + std::int64_t *ipiv, + char *in_b, + std::int64_t ldb, + std::vector &host_task_events, + const std::vector &depends) +{ + type_utils::validate_type_for_device(exec_q); + + T *a = reinterpret_cast(in_a); + T *b = reinterpret_cast(in_b); + + const std::int64_t scratchpad_size = + mkl_lapack::getrs_scratchpad_size(exec_q, trans, n, nrhs, lda, ldb); + T *scratchpad = nullptr; + + std::stringstream error_msg; + std::int64_t info = 0; + bool is_exception_caught = false; + + sycl::event getrs_event; + try { + scratchpad = sycl::malloc_device(scratchpad_size, exec_q); + + getrs_event = mkl_lapack::getrs( + exec_q, + trans, // Specifies the operation: whether or not to transpose + // matrix A. Can be 'N' for no transpose, 'T' for transpose, + // and 'C' for conjugate transpose. + n, // The order of the square matrix A + // and the number of rows in matrix B (0 ≤ n). + // It must be a non-negative integer. + nrhs, // The number of right-hand sides, + // i.e., the number of columns in matrix B (0 ≤ nrhs). + a, // Pointer to the square matrix A (n x n). + lda, // The leading dimension of matrix A, must be at least max(1, + // n). It must be at least max(1, n). + ipiv, // Pointer to the output array of pivot indices that were used + // during factorization (n, ). + b, // Pointer to the matrix B of right-hand sides (ldb, nrhs). + ldb, // The leading dimension of matrix B, must be at least max(1, + // n). + scratchpad, // Pointer to scratchpad memory to be used by MKL + // routine for storing intermediate results. + scratchpad_size, depends); + } catch (mkl_lapack::exception const &e) { + is_exception_caught = true; + info = e.info(); + + if (info < 0) { + error_msg << "Parameter number " << -info + << " had an illegal value."; + } + else if (info == scratchpad_size && e.detail() != 0) { + error_msg + << "Insufficient scratchpad size. Required size is at least " + << e.detail(); + } + else if (info > 0) { + is_exception_caught = false; + if (scratchpad != nullptr) { + sycl::free(scratchpad, exec_q); + } + throw LinAlgError("The solve could not be completed."); + } + else { + error_msg << "Unexpected MKL exception caught during getrs() " + "call:\nreason: " + << e.what() << "\ninfo: " << e.info(); + } + } catch (sycl::exception const &e) { + is_exception_caught = true; + error_msg << "Unexpected SYCL exception caught during getrs() call:\n" + << e.what(); + } + + if (is_exception_caught) // an unexpected error occurs + { + if (scratchpad != nullptr) { + sycl::free(scratchpad, exec_q); + } + + throw std::runtime_error(error_msg.str()); + } + + sycl::event clean_up_event = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(getrs_event); + auto ctx = exec_q.get_context(); + cgh.host_task([ctx, scratchpad]() { sycl::free(scratchpad, ctx); }); + }); + host_task_events.push_back(clean_up_event); + return getrs_event; +} + +std::pair + getrs(sycl::queue exec_q, + dpctl::tensor::usm_ndarray a_array, + dpctl::tensor::usm_ndarray ipiv_array, + dpctl::tensor::usm_ndarray b_array, + const std::vector &depends) +{ + const int a_array_nd = a_array.get_ndim(); + const int b_array_nd = b_array.get_ndim(); + const int ipiv_array_nd = ipiv_array.get_ndim(); + + if (a_array_nd != 2) { + throw py::value_error( + "The LU-factorized array has ndim=" + std::to_string(a_array_nd) + + ", but a 2-dimensional array is expected."); + } + if (b_array_nd > 2) { + throw py::value_error( + "The right-hand sides array has ndim=" + + std::to_string(b_array_nd) + + ", but a 1-dimensional or a 2-dimensional array is expected."); + } + if (ipiv_array_nd != 1) { + throw py::value_error("The array of pivot indices has ndim=" + + std::to_string(ipiv_array_nd) + + ", but a 1-dimensional array is expected."); + } + + const py::ssize_t *a_array_shape = a_array.get_shape_raw(); + const py::ssize_t *b_array_shape = b_array.get_shape_raw(); + + if (a_array_shape[0] != a_array_shape[1]) { + throw py::value_error("The LU-factorized array must be square," + " but got a shape of (" + + std::to_string(a_array_shape[0]) + ", " + + std::to_string(a_array_shape[1]) + ")."); + } + + // check compatibility of execution queue and allocation queue + if (!dpctl::utils::queues_are_compatible(exec_q, + {a_array, b_array, ipiv_array})) + { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(a_array, b_array)) { + throw py::value_error("The LU-factorized and right-hand sides arrays " + "are overlapping segments of memory"); + } + + bool is_a_array_c_contig = a_array.is_c_contiguous(); + bool is_a_array_f_contig = a_array.is_f_contiguous(); + bool is_b_array_f_contig = b_array.is_f_contiguous(); + bool is_ipiv_array_c_contig = ipiv_array.is_c_contiguous(); + bool is_ipiv_array_f_contig = ipiv_array.is_f_contiguous(); + if (!is_a_array_c_contig && !is_a_array_f_contig) { + throw py::value_error("The LU-factorized array " + "must be either C-contiguous " + "or F-contiguous"); + } + if (!is_b_array_f_contig) { + throw py::value_error("The right-hand sides array " + "must be F-contiguous"); + } + if (!is_ipiv_array_c_contig || !is_ipiv_array_f_contig) { + throw py::value_error("The array of pivot indices " + "must be contiguous"); + } + + auto array_types = dpctl_td_ns::usm_ndarray_types(); + int a_array_type_id = + array_types.typenum_to_lookup_id(a_array.get_typenum()); + int b_array_type_id = + array_types.typenum_to_lookup_id(b_array.get_typenum()); + + if (a_array_type_id != b_array_type_id) { + throw py::value_error("The types of the LU-factorized and " + "right-hand sides arrays are mismatched"); + } + + getrs_impl_fn_ptr_t getrs_fn = getrs_dispatch_vector[a_array_type_id]; + if (getrs_fn == nullptr) { + throw py::value_error( + "No getrs implementation defined for the provided type " + "of the input matrix."); + } + + auto ipiv_types = dpctl_td_ns::usm_ndarray_types(); + int ipiv_array_type_id = + ipiv_types.typenum_to_lookup_id(ipiv_array.get_typenum()); + + if (ipiv_array_type_id != static_cast(dpctl_td_ns::typenum_t::INT64)) { + throw py::value_error("The type of 'ipiv_array' must be int64."); + } + + const std::int64_t n = a_array_shape[0]; + const std::int64_t nrhs = (b_array_nd > 1) ? b_array_shape[1] : 1; + + const std::int64_t lda = std::max(1UL, n); + const std::int64_t ldb = std::max(1UL, n); + + // Use transpose::T if the LU-factorized array is passed as C-contiguous. + // For F-contiguous we use transpose::N. + oneapi::mkl::transpose trans = is_a_array_c_contig + ? oneapi::mkl::transpose::T + : oneapi::mkl::transpose::N; + + char *a_array_data = a_array.get_data(); + char *b_array_data = b_array.get_data(); + char *ipiv_array_data = ipiv_array.get_data(); + + std::int64_t *ipiv = reinterpret_cast(ipiv_array_data); + + std::vector host_task_events; + sycl::event getrs_ev = + getrs_fn(exec_q, trans, n, nrhs, a_array_data, lda, ipiv, b_array_data, + ldb, host_task_events, depends); + + sycl::event args_ev = dpctl::utils::keep_args_alive( + exec_q, {a_array, b_array, ipiv_array}, host_task_events); + + return std::make_pair(args_ev, getrs_ev); +} + +template +struct GetrsContigFactory +{ + fnT get() + { + if constexpr (types::GetrsTypePairSupportFactory::is_defined) { + return getrs_impl; + } + else { + return nullptr; + } + } +}; + +void init_getrs_dispatch_vector(void) +{ + dpctl_td_ns::DispatchVectorBuilder + contig; + contig.populate_dispatch_vector(getrs_dispatch_vector); +} +} // namespace lapack +} // namespace ext +} // namespace backend +} // namespace dpnp diff --git a/dpnp/backend/extensions/lapack/getrs.hpp b/dpnp/backend/extensions/lapack/getrs.hpp new file mode 100644 index 000000000000..ca78ed8b80de --- /dev/null +++ b/dpnp/backend/extensions/lapack/getrs.hpp @@ -0,0 +1,52 @@ +//***************************************************************************** +// Copyright (c) 2024, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#pragma once + +#include +#include + +#include + +namespace dpnp +{ +namespace backend +{ +namespace ext +{ +namespace lapack +{ +extern std::pair + getrs(sycl::queue exec_q, + dpctl::tensor::usm_ndarray a_array, + dpctl::tensor::usm_ndarray ipiv_array, + dpctl::tensor::usm_ndarray b_array, + const std::vector &depends = {}); + +extern void init_getrs_dispatch_vector(void); +} // namespace lapack +} // namespace ext +} // namespace backend +} // namespace dpnp diff --git a/dpnp/backend/extensions/lapack/lapack_py.cpp b/dpnp/backend/extensions/lapack/lapack_py.cpp index eb815ac9f6ba..e6b4365a906d 100644 --- a/dpnp/backend/extensions/lapack/lapack_py.cpp +++ b/dpnp/backend/extensions/lapack/lapack_py.cpp @@ -35,6 +35,7 @@ #include "gesvd.hpp" #include "getrf.hpp" #include "getri.hpp" +#include "getrs.hpp" #include "heevd.hpp" #include "linalg_exceptions.hpp" #include "orgqr.hpp" @@ -54,6 +55,7 @@ void init_dispatch_vectors(void) lapack_ext::init_getrf_batch_dispatch_vector(); lapack_ext::init_getrf_dispatch_vector(); lapack_ext::init_getri_batch_dispatch_vector(); + lapack_ext::init_getrs_dispatch_vector(); lapack_ext::init_orgqr_batch_dispatch_vector(); lapack_ext::init_orgqr_dispatch_vector(); lapack_ext::init_potrf_batch_dispatch_vector(); @@ -130,6 +132,13 @@ PYBIND11_MODULE(_lapack_impl, m) py::arg("stride_ipiv"), py::arg("batch_size"), py::arg("depends") = py::list()); + m.def("_getrs", &lapack_ext::getrs, + "Call `getrs` from OneMKL LAPACK library to return " + "the solves of linear equations with an LU-factored " + "square coefficient matrix, with multiple right-hand sides", + py::arg("sycl_queue"), py::arg("a_array"), py::arg("ipiv_array"), + py::arg("b_array"), py::arg("depends") = py::list()); + m.def("_heevd", &lapack_ext::heevd, "Call `heevd` from OneMKL LAPACK library to return " "the eigenvalues and eigenvectors of a complex Hermitian matrix", diff --git a/dpnp/backend/extensions/lapack/types_matrix.hpp b/dpnp/backend/extensions/lapack/types_matrix.hpp index 9a0ab36c8a45..a5edffa56dc0 100644 --- a/dpnp/backend/extensions/lapack/types_matrix.hpp +++ b/dpnp/backend/extensions/lapack/types_matrix.hpp @@ -225,6 +225,34 @@ struct GetriBatchTypePairSupportFactory dpctl_td_ns::NotDefinedEntry>::is_defined; }; +/** + * @brief A factory to define pairs of supported types for which + * MKL LAPACK library provides support in oneapi::mkl::lapack::getrs + * function. + * + * @tparam T Type of array containing input matrix (LU-factored form) + * and the array of multiple dependent variables, + * as well as the output array for storing the solutions to a system of linear + * equations. + */ +template +struct GetrsTypePairSupportFactory +{ + static constexpr bool is_defined = std::disjunction< + dpctl_td_ns::TypePairDefinedEntry, + dpctl_td_ns::TypePairDefinedEntry, + dpctl_td_ns::TypePairDefinedEntry, + T, + std::complex>, + dpctl_td_ns::TypePairDefinedEntry, + T, + std::complex>, + // fall-through + dpctl_td_ns::NotDefinedEntry>::is_defined; +}; + /** * @brief A factory to define pairs of supported types for which * MKL LAPACK library provides support in oneapi::mkl::lapack::heevd diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index ed9d0b88472d..6dd41493914a 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -1897,39 +1897,72 @@ def dpnp_solve(a, b): out_v = out_v.reshape(orig_shape_b) return out_v else: - # oneMKL LAPACK gesv overwrites `a` and `b` and assumes fortran-like array as input. - # Allocate 'F' order memory for dpnp arrays to comply with these requirements. - a_f = dpnp.empty_like( - a, order="F", dtype=res_type, usm_type=res_usm_type + # Due to MKLD-17226 (bug with incorrect checking ldb parameter + # in oneapi::mkl::lapack::gesv_scratchad_size that raises an error + # `invalid argument` when nrhs > n) we can not use _gesv directly. + # This w/a uses _getrf and _getrs instead + # to handle cases where nrhs > n for a.shape = (n x n) + # and b.shape = (n x nrhs). + + # oneMKL LAPACK getrf overwrites `a`. + a_h = dpnp.empty_like( + a, order="C", dtype=res_type, usm_type=res_usm_type ) - # use DPCTL tensor function to fill the coefficient matrix array - # with content from the input array `a` + # use DPCTL tensor function to fill the сopy of the input array + # from the input array a_ht_copy_ev, a_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( - src=a_usm_arr, dst=a_f.get_array(), sycl_queue=a.sycl_queue + src=a_usm_arr, dst=a_h.get_array(), sycl_queue=a.sycl_queue ) - b_f = dpnp.empty_like( + # oneMKL LAPACK getrs overwrites `b` and assumes fortran-like array as input. + # Allocate 'F' order memory for dpnp arrays to comply with these requirements. + b_h = dpnp.empty_like( b, order="F", dtype=res_type, usm_type=res_usm_type ) # use DPCTL tensor function to fill the array of multiple dependent variables # with content from the input array `b` b_ht_copy_ev, b_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( - src=b_usm_arr, dst=b_f.get_array(), sycl_queue=b.sycl_queue + src=b_usm_arr, dst=b_h.get_array(), sycl_queue=b.sycl_queue ) - # Call the LAPACK extension function _gesv to solve the system of linear - # equations with the coefficient square matrix and the dependent variables array. - ht_lapack_ev, _ = li._gesv( - exec_q, a_f.get_array(), b_f.get_array(), [a_copy_ev, b_copy_ev] + n = a.shape[0] + + ipiv_h = dpnp.empty_like( + a, + shape=(n,), + dtype=dpnp.int64, ) + dev_info_h = [0] - ht_lapack_ev.wait() - b_ht_copy_ev.wait() - a_ht_copy_ev.wait() + # Call the LAPACK extension function _getrf + # to perform LU decomposition of the input matrix + ht_getrf_ev, getrf_ev = li._getrf( + exec_q, + a_h.get_array(), + ipiv_h.get_array(), + dev_info_h, + [a_copy_ev], + ) + + _check_lapack_dev_info(dev_info_h) + + # Call the LAPACK extension function _getrs + # to solve the system of linear equations with an LU-factored + # coefficient square matrix, with multiple right-hand sides. + ht_getrs_ev, _ = li._getrs( + exec_q, + a_h.get_array(), + ipiv_h.get_array(), + b_h.get_array(), + [b_copy_ev, getrf_ev], + ) + + ht_list_ev = [a_ht_copy_ev, b_ht_copy_ev, ht_getrf_ev, ht_getrs_ev] + dpctl.SyclEvent.wait_for(ht_list_ev) - return b_f + return b_h def dpnp_slogdet(a): diff --git a/tests/test_linalg.py b/tests/test_linalg.py index 9a53b315763f..b5b0d8f3897e 100644 --- a/tests/test_linalg.py +++ b/tests/test_linalg.py @@ -1436,6 +1436,21 @@ def test_solve(self, dtype): assert_allclose(expected, result, rtol=1e-06) + @pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True)) + def test_solve_nrhs_greater_n(self, dtype): + # Test checking the case when nrhs > n for + # for a.shape = (n x n) and b.shape = (n x nrhs). + a_np = numpy.array([[1, 2], [3, 5]], dtype=dtype) + b_np = numpy.array([[1, 1, 1], [2, 2, 2]], dtype=dtype) + + a_dp = inp.array(a_np) + b_dp = inp.array(b_np) + + expected = numpy.linalg.solve(a_np, b_np) + result = inp.linalg.solve(a_dp, b_dp) + + assert_dtype_allclose(result, expected) + @pytest.mark.parametrize("a_dtype", get_all_dtypes(no_bool=True)) @pytest.mark.parametrize("b_dtype", get_all_dtypes(no_bool=True)) def test_solve_diff_type(self, a_dtype, b_dtype): diff --git a/tests/test_sycl_queue.py b/tests/test_sycl_queue.py index 8aa3680d3f33..bfb8df06e663 100644 --- a/tests/test_sycl_queue.py +++ b/tests/test_sycl_queue.py @@ -1916,23 +1916,44 @@ def test_where(device): valid_devices, ids=[device.filter_string for device in valid_devices], ) -def test_solve(device): - x = [[1.0, 2.0], [3.0, 5.0]] - y = [1.0, 2.0] +@pytest.mark.parametrize( + "matrix, vector", + [ + ([[1, 2], [3, 5]], numpy.empty((2, 0))), + ([[1, 2], [3, 5]], [1, 2]), + ( + [ + [[1, 1, 1], [0, 2, 5], [2, 5, -1]], + [[3, -1, 1], [1, 2, 3], [2, 3, 1]], + [[1, 4, 1], [1, 2, -2], [4, 1, 2]], + ], + [[6, -4, 27], [9, -6, 15], [15, 1, 11]], + ), + ], + ids=[ + "2D_Matrix_Empty_Vector", + "2D_Matrix_1D_Vector", + "3D_Matrix_and_Vectors", + ], +) +def test_solve(matrix, vector, device): + a_np = numpy.array(matrix) + b_np = numpy.array(vector) + + a_dp = dpnp.array(a_np, device=device) + b_dp = dpnp.array(b_np, device=device) - numpy_x = numpy.array(x) - numpy_y = numpy.array(y) - dpnp_x = dpnp.array(x, device=device) - dpnp_y = dpnp.array(y, device=device) + if a_dp.ndim > 2 and a_dp.device.sycl_device.is_cpu: + pytest.skip("SAT-6842: reported hanging in public CI") - result = dpnp.linalg.solve(dpnp_x, dpnp_y) - expected = numpy.linalg.solve(numpy_x, numpy_y) + result = dpnp.linalg.solve(a_dp, b_dp) + expected = numpy.linalg.solve(a_np, b_np) assert_dtype_allclose(result, expected) result_queue = result.sycl_queue - assert_sycl_queue_equal(result_queue, dpnp_x.sycl_queue) - assert_sycl_queue_equal(result_queue, dpnp_y.sycl_queue) + assert_sycl_queue_equal(result_queue, a_dp.sycl_queue) + assert_sycl_queue_equal(result_queue, b_dp.sycl_queue) @pytest.mark.parametrize(