Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions include/matx/transforms/convert/dense2sparse_cusparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,18 +84,12 @@ class Dense2SparseHandle_t {
using POS = typename TensorTypeO::pos_type;
using CRD = typename TensorTypeO::crd_type;

static constexpr int RANKA = TensorTypeA::Rank();
static constexpr int RANKO = TensorTypeO::Rank();

/**
* Construct a dense2sparse handle.
*/
Dense2SparseHandle_t(TensorTypeO &o, const TensorTypeA &a,
cudaStream_t stream) {
MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL)

static_assert(RANKA == RANKO);

params_ = GetConvParams(o, a, stream);

[[maybe_unused]] cusparseStatus_t ret = cusparseCreate(&handle_);
Expand Down Expand Up @@ -261,7 +255,22 @@ void dense2sparse_impl(OutputTensorType &o, const InputTensorType &a,
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)
const auto stream = exec.getStream();

// TODO: some more checking, supported type? on device? etc.
using TA = typename InputTensorType::value_type;
using TO = typename OutputTensorType::value_type;

// Restrictions.
static_assert(OutputTensorType::Rank() == InputTensorType::Rank(),
"tensors must have same rank");
static_assert(std::is_same_v<TA, TO>,
"tensors must have the same data type");
static_assert(std::is_same_v<TA, int8_t> ||
std::is_same_v<TA, matx::matxFp16> ||
std::is_same_v<TA, matx::matxBf16> ||
std::is_same_v<TA, float> ||
std::is_same_v<TA, double> ||
std::is_same_v<TA, cuda::std::complex<float>> ||
std::is_same_v<TA, cuda::std::complex<double>>,
"unsupported data type");

// Get parameters required by these tensors (for caching).
auto params =
Expand Down
23 changes: 16 additions & 7 deletions include/matx/transforms/convert/sparse2dense_cusparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,18 +72,12 @@ class Sparse2DenseHandle_t {
using TA = typename TensorTypeA::value_type;
using TO = typename TensorTypeO::value_type;

static constexpr int RANKA = TensorTypeA::Rank();
static constexpr int RANKO = TensorTypeO::Rank();

/**
* Construct a sparse2dense handle.
*/
Sparse2DenseHandle_t(TensorTypeO &o, const TensorTypeA &a,
cudaStream_t stream) {
MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL)

static_assert(RANKA == RANKO);

params_ = GetConvParams(o, a, stream);

[[maybe_unused]] cusparseStatus_t ret = cusparseCreate(&handle_);
Expand Down Expand Up @@ -221,7 +215,22 @@ void sparse2dense_impl(OutputTensorType &o, const InputTensorType &a,
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)
const auto stream = exec.getStream();

// TODO: some more checking, supported type? on device? etc.
using TA = typename InputTensorType::value_type;
using TO = typename OutputTensorType::value_type;

// Restrictions.
static_assert(OutputTensorType::Rank() == InputTensorType::Rank(),
"tensors must have same rank");
static_assert(std::is_same_v<TA, TO>,
"tensors must have the same data type");
static_assert(std::is_same_v<TA, int8_t> ||
std::is_same_v<TA, matx::matxFp16> ||
std::is_same_v<TA, matx::matxBf16> ||
std::is_same_v<TA, float> ||
std::is_same_v<TA, double> ||
std::is_same_v<TA, cuda::std::complex<float>> ||
std::is_same_v<TA, cuda::std::complex<double>>,
"unsupported data type");

// Get parameters required by these tensors (for caching).
auto params =
Expand Down
38 changes: 24 additions & 14 deletions include/matx/transforms/matmul/matmul_cusparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,6 @@ class MatMulCUSPARSEHandle_t {
using TB = typename TensorTypeB::value_type;
using TC = typename TensorTypeC::value_type;

static constexpr int RANKA = TensorTypeC::Rank();
static constexpr int RANKB = TensorTypeC::Rank();
static constexpr int RANKC = TensorTypeC::Rank();

/**
* Construct a sparse GEMM handle
* SpMV
Expand All @@ -94,15 +90,6 @@ class MatMulCUSPARSEHandle_t {
const TensorTypeB &b, cudaStream_t stream, float alpha,
float beta) {
MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL)

static_assert(RANKA == 2);
static_assert(RANKB == 2);
static_assert(RANKC == 2);

MATX_ASSERT(a.Size(RANKA - 1) == b.Size(RANKB - 2), matxInvalidSize);
MATX_ASSERT(c.Size(RANKC - 1) == b.Size(RANKB - 1), matxInvalidSize);
MATX_ASSERT(c.Size(RANKC - 2) == a.Size(RANKA - 2), matxInvalidSize);

params_ = GetGemmParams(c, a, b, stream, alpha, beta);

[[maybe_unused]] cusparseStatus_t ret = cusparseCreate(&handle_);
Expand Down Expand Up @@ -261,7 +248,30 @@ void sparse_matmul_impl(TensorTypeC &c, const TensorTypeA &a, const TensorTypeB
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)
const auto stream = exec.getStream();

// TODO: some more checking, supported type? on device? etc.
using TA = typename TensorTypeA::value_type;
using TB = typename TensorTypeB::value_type;
using TC = typename TensorTypeC::value_type;

static constexpr int RANKA = TensorTypeA::Rank();
static constexpr int RANKB = TensorTypeB::Rank();
static constexpr int RANKC = TensorTypeC::Rank();

// Restrictions.
static_assert(RANKA == 2 && RANKB == 2 && RANKC == 2,
"tensors must have rank-2");
static_assert(std::is_same_v<TC, TA> &&
std::is_same_v<TC, TB>,
"tensors must have the same data type");
// TODO: allow MIXED-PRECISION computation!
static_assert(std::is_same_v<TC, float> ||
std::is_same_v<TC, double> ||
std::is_same_v<TC, cuda::std::complex<float>> ||
std::is_same_v<TC, cuda::std::complex<double>>,
"unsupported data type");
MATX_ASSERT(
a.Size(RANKA - 1) == b.Size(RANKB - 2) &&
c.Size(RANKC - 1) == b.Size(RANKB - 1) &&
c.Size(RANKC - 2) == a.Size(RANKA - 2), matxInvalidSize);

// Get parameters required by these tensors (for caching).
auto params =
Expand Down
46 changes: 29 additions & 17 deletions include/matx/transforms/solve/solve_cudss.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,23 +75,9 @@ class SolveCUDSSHandle_t {
using TB = typename TensorTypeB::value_type;
using TC = typename TensorTypeC::value_type;

static constexpr int RANKA = TensorTypeC::Rank();
static constexpr int RANKB = TensorTypeC::Rank();
static constexpr int RANKC = TensorTypeC::Rank();

SolveCUDSSHandle_t(TensorTypeC &c, const TensorTypeA &a, const TensorTypeB &b,
cudaStream_t stream) {
MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL)

static_assert(RANKA == 2);
static_assert(RANKB == 2);
static_assert(RANKC == 2);

// Note: B,C transposed!
MATX_ASSERT(a.Size(RANKA - 1) == b.Size(RANKB - 1), matxInvalidSize);
MATX_ASSERT(a.Size(RANKA - 2) == b.Size(RANKB - 1), matxInvalidSize);
MATX_ASSERT(b.Size(RANKB - 2) == c.Size(RANKC - 2), matxInvalidSize);

params_ = GetSolveParams(c, a, b, stream);

[[maybe_unused]] cudssStatus_t ret = cudssCreate(&handle_);
Expand All @@ -100,7 +86,7 @@ class SolveCUDSSHandle_t {
// Create cuDSS handle for sparse matrix A.
static_assert(is_sparse_tensor_v<TensorTypeA>);
MATX_ASSERT(TypeToInt<typename TensorTypeA::pos_type> ==
TypeToInt<typename TensorTypeA::crd_type>,
TypeToInt<typename TensorTypeA::crd_type>,
matxNotSupported);
cudaDataType itp = MatXTypeToCudaType<typename TensorTypeA::crd_type>();
cudaDataType dta = MatXTypeToCudaType<TA>();
Expand Down Expand Up @@ -244,7 +230,29 @@ void sparse_solve_impl_trans(TensorTypeC &c, const TensorTypeA &a,
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)
const auto stream = exec.getStream();

// TODO: some more checking, supported type? on device? etc.
using TA = typename TensorTypeA::value_type;
using TB = typename TensorTypeB::value_type;
using TC = typename TensorTypeC::value_type;

static constexpr int RANKA = TensorTypeA::Rank();
static constexpr int RANKB = TensorTypeB::Rank();
static constexpr int RANKC = TensorTypeC::Rank();

// Restrictions.
static_assert(RANKA == 2 && RANKB == 2 && RANKC == 2,
"tensors must have rank-2");
static_assert(std::is_same_v<TC, TA> &&
std::is_same_v<TC, TB>,
"tensors must have the same data type");
static_assert(std::is_same_v<TC, float> ||
std::is_same_v<TC, double> ||
std::is_same_v<TC, cuda::std::complex<float>> ||
std::is_same_v<TC, cuda::std::complex<double>>,
"unsupported data type");
MATX_ASSERT( // Note: B,C transposed!
a.Size(RANKA - 1) == b.Size(RANKB - 1) &&
a.Size(RANKA - 2) == b.Size(RANKB - 1) &&
b.Size(RANKB - 2) == c.Size(RANKC - 2), matxInvalidSize);

// Get parameters required by these tensors (for caching).
auto params = detail::SolveCUDSSHandle_t<TensorTypeC, TensorTypeA, TensorTypeB>::GetSolveParams(
Expand All @@ -266,12 +274,16 @@ void sparse_solve_impl_trans(TensorTypeC &c, const TensorTypeA &a,
// convoluted way of performing the solve step must be removed once cuDSS
// supports MATX native row-major storage, which will clean up the copies from
// and to memory.
//
// TODO: remove this when cuDSS supports row-major storage
//
template <typename TensorTypeC, typename TensorTypeA, typename TensorTypeB>
void sparse_solve_impl(TensorTypeC &c, const TensorTypeA &a,
const TensorTypeB &b, const cudaExecutor &exec) {
const auto stream = exec.getStream();

// Some copying-in hacks, assumes rank 2.
// Some copying-in hacks.
static_assert(TensorTypeB::Rank() == 2 && TensorTypeC::Rank() == 2);
using TB = typename TensorTypeB::value_type;
using TC = typename TensorTypeB::value_type;
TB *bptr;
Expand Down