Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions examples/sparse_tensor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -211,5 +211,17 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv) {
(R = matvec(AdiaJ, V)).run(exec);
print(R);

//
// Perform a direct solve. This is only supported for a tri-diagonal
// matrix in DIA-I format where the rhs is overwritten with the answer.
//
dvals.SetVals({0, -1, -1, -1, -1, -1, 4, 4, 4, 4, 4, 4, 1, 1, 1, 1, 1, 0});
auto AdiaI = experimental::make_tensor_dia<experimental::DIA_INDEX_I>(
dvals, doffsets, {6, 6});
auto Rhs = make_tensor<float, 2>({2, 6});
Rhs.SetVals({{6, 10, 14, 18, 22, 19}, {36, 34, 38, 42, 46, 37}});
(Rhs = solve(AdiaI, Rhs)).run(exec);
print(Rhs);

MATX_EXIT_HANDLER();
}
3 changes: 2 additions & 1 deletion include/matx/operators/solve.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

#include "matx/core/type_utils.h"
#include "matx/operators/base_operator.h"
#include "matx/transforms/solve/solve_cusparse.h"
#ifdef MATX_EN_CUDSS
#include "matx/transforms/solve/solve_cudss.h"
#endif
Expand Down Expand Up @@ -92,7 +93,7 @@ class SolveOp : public BaseOp<SolveOp<OpA, OpB>> {
static_assert(!is_sparse_tensor_v<OpB>, "sparse rhs not implemented");
if constexpr (is_sparse_tensor_v<OpA>) {
if constexpr (OpA::Format::isDIAI() || OpA::Format::isDIAJ()) {
MATX_THROW(matxNotSupported, "DIA support coming soon");
sparse_dia_solve_impl(cuda::std::get<0>(out), a_, b_, ex);
} else {
#ifdef MATX_EN_CUDSS
sparse_solve_impl(cuda::std::get<0>(out), a_, b_, ex);
Expand Down
186 changes: 186 additions & 0 deletions include/matx/transforms/solve/solve_cusparse.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
////////////////////////////////////////////////////////////////////////////////
// BSD 3-Clause License
//
// Copyright (c) 2025, NVIDIA Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
// POSSIBILITY OF SUCH DAMAGE.
/////////////////////////////////////////////////////////////////////////////////

#pragma once

#include <cusparse.h>

#include <numeric>

#include "matx/core/cache.h"
#include "matx/core/sparse_tensor.h"
#include "matx/core/tensor.h"
#include "matx/kernels/matvec.cuh"

namespace matx {

namespace detail {

// A tridiagonal solver that uses the cuSPARSE legacy API. The setup is
// relatively simple, which is why we forego the usual path of caching
// shared context. Rather, we just do a single-shot solve.
template <class VAL>
inline void SolveTridiagonalSystem(int m, int n, VAL *dl, VAL *dm, VAL *du,
VAL *b) {
cusparseHandle_t handle = nullptr; // TODO: share handle globally?
[[maybe_unused]] cusparseStatus_t ret = cusparseCreate(&handle);
MATX_ASSERT(ret == CUSPARSE_STATUS_SUCCESS, matxSolverError);

size_t workspaceSize = 0;
void *workspace = nullptr;

if constexpr (std::is_same_v<VAL, float>) {
ret = cusparseSgtsv2_bufferSizeExt(handle, m, n, dl, dm, du, b, /*ldb*/ m,
&workspaceSize);
} else if constexpr (std::is_same_v<VAL, double>) {
ret = cusparseDgtsv2_bufferSizeExt(handle, m, n, dl, dm, du, b, /*ldb*/ m,
&workspaceSize);
} else if constexpr (std::is_same_v<VAL, cuFloatComplex>) {
ret = cusparseCgtsv2_bufferSizeExt(handle, m, n, dl, dm, du, b, /*ldb*/ m,
&workspaceSize);
} else if constexpr (std::is_same_v<VAL, cuDoubleComplex>) {
ret = cusparseZgtsv2_bufferSizeExt(handle, m, n, dl, dm, du, b, /*ldb*/ m,
&workspaceSize);
} else {
MATX_THROW(matxNotSupported, "Unsupported type for tri-diagonal solve");
}
MATX_ASSERT(ret == CUSPARSE_STATUS_SUCCESS, matxSolverError);

matxAlloc((void **)&workspace, workspaceSize, MATX_DEVICE_MEMORY);

if constexpr (std::is_same_v<VAL, float>) {
ret = cusparseSgtsv2(handle, m, n, dl, dm, du, b, /*ldb*/ m, workspace);
} else if constexpr (std::is_same_v<VAL, double>) {
ret = cusparseDgtsv2(handle, m, n, dl, dm, du, b, /*ldb*/ m, workspace);
} else if constexpr (std::is_same_v<VAL, cuFloatComplex>) {
ret = cusparseCgtsv2(handle, m, n, dl, dm, du, b, /*ldb*/ m, workspace);
} else if constexpr (std::is_same_v<VAL, cuDoubleComplex>) {
ret = cusparseZgtsv2(handle, m, n, dl, dm, du, b, /*ldb*/ m, workspace);
}
MATX_ASSERT(ret == CUSPARSE_STATUS_SUCCESS, matxSolverError);

matxFree(workspace);

ret = cusparseDestroy(handle);
MATX_ASSERT(ret == CUSPARSE_STATUS_SUCCESS, matxSolverError);
}

template <typename Op>
__MATX_INLINE__ auto getCuSparseSolveSupportedTensor(const Op &in,
cudaStream_t stream) {
const auto func = [&]() {
if constexpr (is_tensor_view_v<Op>) {
return in.Stride(Op::Rank() - 1) == 1;
} else {
return true;
}
};
return GetSupportedTensor(in, func, MATX_ASYNC_DEVICE_MEMORY, stream);
}

} // end namespace detail

template <typename TensorTypeC, typename TensorTypeA, typename TensorTypeB>
void sparse_dia_solve_impl(TensorTypeC &C, const TensorTypeA &a,
const TensorTypeB &B, const cudaExecutor &exec) {
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)
const auto stream = exec.getStream();

// Transform into supported form.
auto b = getCuSparseSolveSupportedTensor(B, stream);
auto c = getCuSparseSolveSupportedTensor(C, stream);
if (!is_matx_transform_op<TensorTypeB>() && !b.isSameView(B)) {
(b = B).run(stream);
}

using atype = TensorTypeA;
using btype = decltype(b);
using ctype = decltype(c);

using TA = typename atype::value_type;
using TB = typename btype::value_type;
using TC = typename ctype::value_type;

static constexpr int RANKA = atype::Rank();
static constexpr int RANKB = btype::Rank();
static constexpr int RANKC = ctype::Rank();

// Restrictions.
static_assert(RANKA == 2 && RANKB == 2 && RANKC == 2,
"tensors must have rank-2");
static_assert(std::is_same_v<TC, TA> && std::is_same_v<TC, TB>,
"tensors must have the same data type");
static_assert(std::is_same_v<TC, float> || std::is_same_v<TC, double> ||
std::is_same_v<TC, cuda::std::complex<float>> ||
std::is_same_v<TC, cuda::std::complex<double>>,
"unsupported data type");
MATX_ASSERT( // Note: B,C transposed!
a.Size(RANKA - 1) == a.Size(RANKA - 2) && // square
a.Size(RANKA - 1) == b.Size(RANKB - 1) &&
a.Size(RANKA - 2) == c.Size(RANKC - 1) &&
b.Size(RANKB - 2) == c.Size(RANKC - 2),
matxInvalidSize);
MATX_ASSERT(b.Stride(RANKB - 1) == 1 && c.Stride(RANKC - 1) == 1,
matxInvalidParameter);

if constexpr (atype::Format::isDIAI()) {
// These are *run-time* checks.
if (!c.isSameView(b)) {
MATX_THROW(matxNotSupported, "Tridiagonal solve overwrites rhs");
}
using CRD = typename atype::crd_type;
CRD *diags = a.CRDData(1);
const index_t numD = a.crdSize(1);
if (numD != 3 || diags[0] != -1 || diags[1] != 0 || diags[2] != 1) {
MATX_THROW(matxNotSupported, "Only tridiagonal solve supported");
}
using T = std::conditional_t<
std::is_same_v<TA, cuda::std::complex<double>>, cuDoubleComplex,
std::conditional_t<std::is_same_v<TA, cuda::std::complex<float>>,
cuFloatComplex, TA>>;
T *AD = reinterpret_cast<T *>(a.Data());
T *BD = reinterpret_cast<T *>(b.Data());
const int m = static_cast<int>(a.Size(RANKA - 2));
const int n = static_cast<int>(b.Size(RANKB - 2));
detail::SolveTridiagonalSystem<T>(m, n, AD, AD + m, AD + m + m, BD);
} else {
MATX_THROW(matxNotSupported, "Tridiagonal solve requires I-index DIAG");
}

// Copy transformed output back.
if (!c.isSameView(C)) {
(C = c).run(stream);
}
}

} // end namespace matx
64 changes: 58 additions & 6 deletions test/00_sparse/Dia.cu
Original file line number Diff line number Diff line change
Expand Up @@ -102,20 +102,22 @@ template <typename T> static auto makeC() {
return C;
}

template <typename T> class MatvecSparseTest : public ::testing::Test {
template <typename T> class DiaSparseTest : public ::testing::Test {
protected:
using GTestType = cuda::std::tuple_element_t<0, T>;
using GExecType = cuda::std::tuple_element_t<1, T>;
void SetUp() override { CheckTestTypeSupport<GTestType>(); }
float thresh = 0.001f;
};

template <typename T>
class MatvecSparseTestsAll : public MatvecSparseTest<T> {};
template <typename T> class DiaSparseTestsAll : public DiaSparseTest<T> {};

TYPED_TEST_SUITE(MatvecSparseTestsAll, MatXFloatNonComplexHalfTypesCUDAExec);
template <typename T> class DiaSolveSparseTestsAll : public DiaSparseTest<T> {};

TYPED_TEST(MatvecSparseTestsAll, MatvecDIAI) {
TYPED_TEST_SUITE(DiaSparseTestsAll, MatXFloatNonComplexHalfTypesCUDAExec);
TYPED_TEST_SUITE(DiaSolveSparseTestsAll, MatXFloatNonHalfTypesCUDAExec);

TYPED_TEST(DiaSparseTestsAll, MatvecDIAI) {
MATX_ENTER_HANDLER();
using TestType = cuda::std::tuple_element_t<0, TypeParam>;
using ExecType = cuda::std::tuple_element_t<1, TypeParam>;
Expand Down Expand Up @@ -147,7 +149,7 @@ TYPED_TEST(MatvecSparseTestsAll, MatvecDIAI) {
MATX_EXIT_HANDLER();
}

TYPED_TEST(MatvecSparseTestsAll, MatvecDIAJ) {
TYPED_TEST(DiaSparseTestsAll, MatvecDIAJ) {
MATX_ENTER_HANDLER();
using TestType = cuda::std::tuple_element_t<0, TypeParam>;
using ExecType = cuda::std::tuple_element_t<1, TypeParam>;
Expand Down Expand Up @@ -178,3 +180,53 @@ TYPED_TEST(MatvecSparseTestsAll, MatvecDIAJ) {

MATX_EXIT_HANDLER();
}

TYPED_TEST(DiaSolveSparseTestsAll, SolveDIAI) {
MATX_ENTER_HANDLER();
using TestType = cuda::std::tuple_element_t<0, TypeParam>;
using ExecType = cuda::std::tuple_element_t<1, TypeParam>;

ExecType exec{};

auto A = makeDIA<TestType, experimental::DIA_INDEX_I>();

const auto m = A.Size(0);

auto X = make_tensor<TestType, 2>({2, m});

X(0, 0) = static_cast<TestType>(3);
X(0, 1) = static_cast<TestType>(6);
X(0, 2) = static_cast<TestType>(11);
X(0, 3) = static_cast<TestType>(13);
X(1, 0) = static_cast<TestType>(30);
X(1, 1) = static_cast<TestType>(60);
X(1, 2) = static_cast<TestType>(110);
X(1, 3) = static_cast<TestType>(130);

// Solve.
(X = solve(A, X)).run(exec);

// Verify result.
exec.sync();
auto E = make_tensor<TestType>({2, 4});
E(0, 0) = static_cast<TestType>(1);
E(0, 1) = static_cast<TestType>(2);
E(0, 2) = static_cast<TestType>(3);
E(0, 3) = static_cast<TestType>(4);
E(1, 0) = static_cast<TestType>(10);
E(1, 1) = static_cast<TestType>(20);
E(1, 2) = static_cast<TestType>(30);
E(1, 3) = static_cast<TestType>(40);
for (index_t i = 0; i < 2; i++) {
for (index_t j = 0; j < 4; j++) {
if constexpr (is_complex_v<TestType>) {
ASSERT_NEAR(X(i, j).real(), E(i, j).real(), this->thresh);
ASSERT_NEAR(X(i, j).imag(), E(i, j).imag(), this->thresh);
} else {
ASSERT_NEAR(X(i, j), E(i, j), this->thresh);
}
}
}

MATX_EXIT_HANDLER();
}