diff --git a/include/matx/transforms/convert/dense2sparse_cusparse.h b/include/matx/transforms/convert/dense2sparse_cusparse.h index d6ae08c25..f8e0465ed 100644 --- a/include/matx/transforms/convert/dense2sparse_cusparse.h +++ b/include/matx/transforms/convert/dense2sparse_cusparse.h @@ -84,18 +84,12 @@ class Dense2SparseHandle_t { using POS = typename TensorTypeO::pos_type; using CRD = typename TensorTypeO::crd_type; - static constexpr int RANKA = TensorTypeA::Rank(); - static constexpr int RANKO = TensorTypeO::Rank(); - /** * Construct a dense2sparse handle. */ Dense2SparseHandle_t(TensorTypeO &o, const TensorTypeA &a, cudaStream_t stream) { MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) - - static_assert(RANKA == RANKO); - params_ = GetConvParams(o, a, stream); [[maybe_unused]] cusparseStatus_t ret = cusparseCreate(&handle_); @@ -261,7 +255,22 @@ void dense2sparse_impl(OutputTensorType &o, const InputTensorType &a, MATX_NVTX_START("", matx::MATX_NVTX_LOG_API) const auto stream = exec.getStream(); - // TODO: some more checking, supported type? on device? etc. + using TA = typename InputTensorType::value_type; + using TO = typename OutputTensorType::value_type; + + // Restrictions. + static_assert(OutputTensorType::Rank() == InputTensorType::Rank(), + "tensors must have same rank"); + static_assert(std::is_same_v, + "tensors must have the same data type"); + static_assert(std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v> || + std::is_same_v>, + "unsupported data type"); // Get parameters required by these tensors (for caching). auto params = diff --git a/include/matx/transforms/convert/sparse2dense_cusparse.h b/include/matx/transforms/convert/sparse2dense_cusparse.h index 24d95dbbe..6a3f65cdd 100644 --- a/include/matx/transforms/convert/sparse2dense_cusparse.h +++ b/include/matx/transforms/convert/sparse2dense_cusparse.h @@ -72,18 +72,12 @@ class Sparse2DenseHandle_t { using TA = typename TensorTypeA::value_type; using TO = typename TensorTypeO::value_type; - static constexpr int RANKA = TensorTypeA::Rank(); - static constexpr int RANKO = TensorTypeO::Rank(); - /** * Construct a sparse2dense handle. */ Sparse2DenseHandle_t(TensorTypeO &o, const TensorTypeA &a, cudaStream_t stream) { MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) - - static_assert(RANKA == RANKO); - params_ = GetConvParams(o, a, stream); [[maybe_unused]] cusparseStatus_t ret = cusparseCreate(&handle_); @@ -221,7 +215,22 @@ void sparse2dense_impl(OutputTensorType &o, const InputTensorType &a, MATX_NVTX_START("", matx::MATX_NVTX_LOG_API) const auto stream = exec.getStream(); - // TODO: some more checking, supported type? on device? etc. + using TA = typename InputTensorType::value_type; + using TO = typename OutputTensorType::value_type; + + // Restrictions. + static_assert(OutputTensorType::Rank() == InputTensorType::Rank(), + "tensors must have same rank"); + static_assert(std::is_same_v, + "tensors must have the same data type"); + static_assert(std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v> || + std::is_same_v>, + "unsupported data type"); // Get parameters required by these tensors (for caching). auto params = diff --git a/include/matx/transforms/matmul/matmul_cusparse.h b/include/matx/transforms/matmul/matmul_cusparse.h index 5fa06fbf2..2c94eb4f9 100644 --- a/include/matx/transforms/matmul/matmul_cusparse.h +++ b/include/matx/transforms/matmul/matmul_cusparse.h @@ -79,10 +79,6 @@ class MatMulCUSPARSEHandle_t { using TB = typename TensorTypeB::value_type; using TC = typename TensorTypeC::value_type; - static constexpr int RANKA = TensorTypeC::Rank(); - static constexpr int RANKB = TensorTypeC::Rank(); - static constexpr int RANKC = TensorTypeC::Rank(); - /** * Construct a sparse GEMM handle * SpMV @@ -94,15 +90,6 @@ class MatMulCUSPARSEHandle_t { const TensorTypeB &b, cudaStream_t stream, float alpha, float beta) { MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) - - static_assert(RANKA == 2); - static_assert(RANKB == 2); - static_assert(RANKC == 2); - - MATX_ASSERT(a.Size(RANKA - 1) == b.Size(RANKB - 2), matxInvalidSize); - MATX_ASSERT(c.Size(RANKC - 1) == b.Size(RANKB - 1), matxInvalidSize); - MATX_ASSERT(c.Size(RANKC - 2) == a.Size(RANKA - 2), matxInvalidSize); - params_ = GetGemmParams(c, a, b, stream, alpha, beta); [[maybe_unused]] cusparseStatus_t ret = cusparseCreate(&handle_); @@ -261,7 +248,30 @@ void sparse_matmul_impl(TensorTypeC &c, const TensorTypeA &a, const TensorTypeB MATX_NVTX_START("", matx::MATX_NVTX_LOG_API) const auto stream = exec.getStream(); - // TODO: some more checking, supported type? on device? etc. + using TA = typename TensorTypeA::value_type; + using TB = typename TensorTypeB::value_type; + using TC = typename TensorTypeC::value_type; + + static constexpr int RANKA = TensorTypeA::Rank(); + static constexpr int RANKB = TensorTypeB::Rank(); + static constexpr int RANKC = TensorTypeC::Rank(); + + // Restrictions. + static_assert(RANKA == 2 && RANKB == 2 && RANKC == 2, + "tensors must have rank-2"); + static_assert(std::is_same_v && + std::is_same_v, + "tensors must have the same data type"); + // TODO: allow MIXED-PRECISION computation! + static_assert(std::is_same_v || + std::is_same_v || + std::is_same_v> || + std::is_same_v>, + "unsupported data type"); + MATX_ASSERT( + a.Size(RANKA - 1) == b.Size(RANKB - 2) && + c.Size(RANKC - 1) == b.Size(RANKB - 1) && + c.Size(RANKC - 2) == a.Size(RANKA - 2), matxInvalidSize); // Get parameters required by these tensors (for caching). auto params = diff --git a/include/matx/transforms/solve/solve_cudss.h b/include/matx/transforms/solve/solve_cudss.h index 423376e3c..987bdc433 100644 --- a/include/matx/transforms/solve/solve_cudss.h +++ b/include/matx/transforms/solve/solve_cudss.h @@ -75,23 +75,9 @@ class SolveCUDSSHandle_t { using TB = typename TensorTypeB::value_type; using TC = typename TensorTypeC::value_type; - static constexpr int RANKA = TensorTypeC::Rank(); - static constexpr int RANKB = TensorTypeC::Rank(); - static constexpr int RANKC = TensorTypeC::Rank(); - SolveCUDSSHandle_t(TensorTypeC &c, const TensorTypeA &a, const TensorTypeB &b, cudaStream_t stream) { MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) - - static_assert(RANKA == 2); - static_assert(RANKB == 2); - static_assert(RANKC == 2); - - // Note: B,C transposed! - MATX_ASSERT(a.Size(RANKA - 1) == b.Size(RANKB - 1), matxInvalidSize); - MATX_ASSERT(a.Size(RANKA - 2) == b.Size(RANKB - 1), matxInvalidSize); - MATX_ASSERT(b.Size(RANKB - 2) == c.Size(RANKC - 2), matxInvalidSize); - params_ = GetSolveParams(c, a, b, stream); [[maybe_unused]] cudssStatus_t ret = cudssCreate(&handle_); @@ -100,7 +86,7 @@ class SolveCUDSSHandle_t { // Create cuDSS handle for sparse matrix A. static_assert(is_sparse_tensor_v); MATX_ASSERT(TypeToInt == - TypeToInt, + TypeToInt, matxNotSupported); cudaDataType itp = MatXTypeToCudaType(); cudaDataType dta = MatXTypeToCudaType(); @@ -244,7 +230,29 @@ void sparse_solve_impl_trans(TensorTypeC &c, const TensorTypeA &a, MATX_NVTX_START("", matx::MATX_NVTX_LOG_API) const auto stream = exec.getStream(); - // TODO: some more checking, supported type? on device? etc. + using TA = typename TensorTypeA::value_type; + using TB = typename TensorTypeB::value_type; + using TC = typename TensorTypeC::value_type; + + static constexpr int RANKA = TensorTypeA::Rank(); + static constexpr int RANKB = TensorTypeB::Rank(); + static constexpr int RANKC = TensorTypeC::Rank(); + + // Restrictions. + static_assert(RANKA == 2 && RANKB == 2 && RANKC == 2, + "tensors must have rank-2"); + static_assert(std::is_same_v && + std::is_same_v, + "tensors must have the same data type"); + static_assert(std::is_same_v || + std::is_same_v || + std::is_same_v> || + std::is_same_v>, + "unsupported data type"); + MATX_ASSERT( // Note: B,C transposed! + a.Size(RANKA - 1) == b.Size(RANKB - 1) && + a.Size(RANKA - 2) == b.Size(RANKB - 1) && + b.Size(RANKB - 2) == c.Size(RANKC - 2), matxInvalidSize); // Get parameters required by these tensors (for caching). auto params = detail::SolveCUDSSHandle_t::GetSolveParams( @@ -266,12 +274,16 @@ void sparse_solve_impl_trans(TensorTypeC &c, const TensorTypeA &a, // convoluted way of performing the solve step must be removed once cuDSS // supports MATX native row-major storage, which will clean up the copies from // and to memory. +// +// TODO: remove this when cuDSS supports row-major storage +// template void sparse_solve_impl(TensorTypeC &c, const TensorTypeA &a, const TensorTypeB &b, const cudaExecutor &exec) { const auto stream = exec.getStream(); - // Some copying-in hacks, assumes rank 2. + // Some copying-in hacks. + static_assert(TensorTypeB::Rank() == 2 && TensorTypeC::Rank() == 2); using TB = typename TensorTypeB::value_type; using TC = typename TensorTypeB::value_type; TB *bptr;