Skip to content

Commit d104fec

Browse files
authored
Generalize DIA to DIA-I and DIA-J (#972)
* Generalize DIA to DIA-I and DIA-J DIA-J is most common but DIA-I is used by cuSPARSE legacy * improved method doc * add index tag
1 parent 0b2dfde commit d104fec

File tree

5 files changed

+64
-22
lines changed

5 files changed

+64
-22
lines changed

examples/sparse_tensor.cu

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,10 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv) {
5555
experimental::CSR::print();
5656
experimental::CSC::print();
5757
experimental::DCSR::print();
58-
experimental::DIA::print();
59-
experimental::SkewDIA::print();
58+
experimental::DIAI::print();
59+
experimental::DIAJ::print();
60+
experimental::SkewDIAI::print();
61+
experimental::SkewDIAJ::print();
6062
experimental::BSR<2, 2>::print(); // 2x2 blocks
6163
experimental::COO4::print(); // 4-dim tensor in COO
6264
experimental::CSF5::print(); // 5-dim tensor in CSF
@@ -195,8 +197,8 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv) {
195197
auto doffsets = make_tensor<int>({3});
196198
dvals.SetVals({-1, -1, -1, -1, -1, 0, 4, 4, 4, 4, 4, 4, 0, 1, 1, 1, 1, 1});
197199
doffsets.SetVals({-1, 0, 1});
198-
auto Adia = experimental::make_tensor_dia(dvals, doffsets, {6, 6});
199-
print(Adia);
200+
auto AdiaJ = experimental::make_tensor_dia<experimental::DIA_INDEX_J>(dvals, doffsets, {6, 6});
201+
print(AdiaJ);
200202

201203
//
202204
// Perform a direct SpMV. This is also the correct way of performing
@@ -205,7 +207,7 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv) {
205207
auto V = make_tensor<float>({6});
206208
auto R = make_tensor<float>({6});
207209
V.SetVals({1, 2, 3, 4, 5, 6});
208-
(R = matvec(Adia, V)).run(exec);
210+
(R = matvec(AdiaJ, V)).run(exec);
209211
print(R);
210212

211213
MATX_EXIT_HANDLER();

include/matx/core/make_sparse_tensor.h

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -204,11 +204,14 @@ auto make_zero_tensor_csc(const index_t (&shape)[2],
204204

205205
// Constructs a sparse matrix in DIA format directly from the values and the
206206
// offset vectors. For an m x n matrix, this format uses a linearized storage
207-
// where each diagonal has n entries, padded with zeros on the right for the
208-
// lower triangular part and padded with zeros on the left for the upper
209-
// triagonal part. This format is most efficient for matrices with only a
210-
// few nonzero diagonals that are close to the main diagonal.
211-
template <typename ValTensor, typename CrdTensor>
207+
// where each diagonal has n entries and is accessed by index I or index J.
208+
// For index I, diagonals padded with zeros on the left for the lower triangular
209+
// part and padded with zeros on the right for the upper triagonal part. This
210+
// is vv. when using index J. This format is most efficient for matrices with
211+
// only a few nonzero diagonals that are close to the main diagonal.
212+
struct DIA_INDEX_I {};
213+
struct DIA_INDEX_J {};
214+
template <typename IDX, typename ValTensor, typename CrdTensor>
212215
auto make_tensor_dia(ValTensor &val, CrdTensor &off,
213216
const index_t (&shape)[2]) {
214217
using VAL = typename ValTensor::value_type;
@@ -224,7 +227,8 @@ auto make_tensor_dia(ValTensor &val, CrdTensor &off,
224227
matxMemorySpace_t space = GetPointerKind(val.GetStorage().data());
225228
auto tp = makeDefaultNonOwningZeroStorage<POS>(2, space);
226229
setVal(tp.data() + 1, static_cast<POS>(val.Size(0)), space);
227-
// Construct DIA.
230+
// Construct DIA-I/J.
231+
using DIA = std::conditional_t<std::is_same_v<IDX, DIA_INDEX_I>, DIAI, DIAJ>;
228232
return sparse_tensor_t<VAL, CRD, POS, DIA>(
229233
shape, val.GetStorage(),
230234
{makeDefaultNonOwningEmptyStorage<CRD>(), off.GetStorage()},

include/matx/core/sparse_tensor_format.h

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,18 @@ template <int D, class... S> class SparseTensorFormat {
187187
return false;
188188
}
189189

190-
static constexpr bool isDIA() {
190+
static constexpr bool isDIAI() {
191+
if constexpr (LVL == 2) {
192+
using type0 = cuda::std::tuple_element_t<0, LvlSpecs>;
193+
using type1 = cuda::std::tuple_element_t<1, LvlSpecs>;
194+
return type0::Expr::op == LvlOp::Sub && type0::Expr::di == 1 &&
195+
type0::Expr::cj == 0 && type0::Type::isCompressed() &&
196+
type1::Expr::isId(0) && type1::Type::isRange();
197+
}
198+
return false;
199+
}
200+
201+
static constexpr bool isDIAJ() {
191202
if constexpr (LVL == 2) {
192203
using type0 = cuda::std::tuple_element_t<0, LvlSpecs>;
193204
using type1 = cuda::std::tuple_element_t<1, LvlSpecs>;
@@ -337,9 +348,13 @@ using DCSC =
337348
SparseTensorFormat<2, LvlSpec<D1, Compressed>, LvlSpec<D0, Compressed>>;
338349
using CROW = SparseTensorFormat<2, LvlSpec<D0, Compressed>, LvlSpec<D1, Dense>>;
339350
using CCOL = SparseTensorFormat<2, LvlSpec<D1, Compressed>, LvlSpec<D0, Dense>>;
340-
using DIA =
351+
using DIAI =
352+
SparseTensorFormat<2, LvlSpec<Sub<1, 0>, Compressed>, LvlSpec<D0, Range>>;
353+
using DIAJ =
341354
SparseTensorFormat<2, LvlSpec<Sub<1, 0>, Compressed>, LvlSpec<D1, Range>>;
342-
using SkewDIA =
355+
using SkewDIAI =
356+
SparseTensorFormat<2, LvlSpec<Add<1, 0>, Compressed>, LvlSpec<D0, Range>>;
357+
using SkewDIAJ =
343358
SparseTensorFormat<2, LvlSpec<Add<1, 0>, Compressed>, LvlSpec<D1, Range>>;
344359

345360
// Sparse Block Matrices.

include/matx/kernels/matvec.cuh

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,32 @@
3939

4040
namespace matx {
4141

42-
// Kernel that performs SpMV for an m x n DIA matrix.
42+
// Kernel that performs SpMV for an m x n DIA-I matrix.
4343
template <typename VAL, typename CRD>
44-
__global__ void dia_spmv_kernel(VAL *A, CRD *diags, uint64_t numD, VAL *B,
45-
VAL *C, uint64_t m, uint64_t n) {
44+
__global__ void diai_spmv_kernel(VAL *A, CRD *diags, uint64_t numDiags, VAL *B,
45+
VAL *C, uint64_t m, uint64_t n) {
4646
uint64_t i = blockIdx.x * blockDim.x + threadIdx.x;
4747
if (i < m) {
4848
VAL acc = 0.0;
49-
for (uint64_t d = 0; d < numD; d++) { // numD-DIA SpMV
50-
int64_t j = i + diags[d]; // signed
49+
for (uint64_t d = 0; d < numDiags; d++) {
50+
int64_t j = i + diags[d]; // signed
51+
if (0 <= j && j < n) {
52+
acc += A[d * m + i] * B[j];
53+
}
54+
}
55+
C[i] = acc;
56+
}
57+
}
58+
59+
// Kernel that performs SpMV for an m x n DIA-J matrix.
60+
template <typename VAL, typename CRD>
61+
__global__ void diaj_spmv_kernel(VAL *A, CRD *diags, uint64_t numDiags, VAL *B,
62+
VAL *C, uint64_t m, uint64_t n) {
63+
uint64_t i = blockIdx.x * blockDim.x + threadIdx.x;
64+
if (i < m) {
65+
VAL acc = 0.0;
66+
for (uint64_t d = 0; d < numDiags; d++) {
67+
int64_t j = i + diags[d]; // signed
5168
if (0 <= j && j < n) {
5269
acc += A[d * n + j] * B[j];
5370
}

include/matx/transforms/matmul/matvec_cusparse.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ void sparse_matvec_impl(TensorTypeC &C, const TensorTypeA &a,
303303
MATX_ASSERT(b.Stride(RANKB - 1) == 1 && c.Stride(RANKC - 1) == 1,
304304
matxInvalidParameter);
305305

306-
if constexpr (atype::Format::isDIA()) {
306+
if constexpr (atype::Format::isDIAI() || atype::Format::isDIAJ()) {
307307

308308
// Fall back to a hand-written kernel for DIA format, since
309309
// this format is not supported in cuSPARSE. The hand-written
@@ -325,8 +325,12 @@ void sparse_matvec_impl(TensorTypeC &C, const TensorTypeA &a,
325325
uint32_t THREADS = static_cast<uint32_t>(std::min(m, 1024LU));
326326
uint32_t BATCHES = static_cast<uint32_t>(
327327
cuda::std::ceil(static_cast<double>(m) / THREADS));
328-
dia_spmv_kernel<<<BATCHES, THREADS, 0, stream>>>(AD, diags, numD, BD, CD, m,
329-
n);
328+
if constexpr (atype::Format::isDIAI())
329+
diai_spmv_kernel<<<BATCHES, THREADS, 0, stream>>>(AD, diags, numD, BD, CD,
330+
m, n);
331+
else
332+
diaj_spmv_kernel<<<BATCHES, THREADS, 0, stream>>>(AD, diags, numD, BD, CD,
333+
m, n);
330334
#endif
331335

332336
} else {

0 commit comments

Comments
 (0)