Skip to content

Commit df7e0c8

Browse files

File tree

3 files changed

+41
-25
lines changed

3 files changed

+41
-25
lines changed

src/TiledArray/device/btas.h

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,12 @@ ::btas::Tensor<T, Range, Storage> gemm(
7777
gemm_helper.compute_matrix_sizes(m, n, k, left.range(), right.range());
7878

7979
// Get the leading dimension for left and right matrices.
80-
const integer lda =
81-
(gemm_helper.left_op() == TiledArray::math::blas::Op::NoTrans ? k : m);
82-
const integer ldb =
83-
(gemm_helper.right_op() == TiledArray::math::blas::Op::NoTrans ? n : k);
80+
const integer lda = std::max(
81+
integer{1},
82+
(gemm_helper.left_op() == TiledArray::math::blas::Op::NoTrans ? k : m));
83+
const integer ldb = std::max(
84+
integer{1},
85+
(gemm_helper.right_op() == TiledArray::math::blas::Op::NoTrans ? n : k));
8486

8587
T factor_t = T(factor);
8688
T zero(0);
@@ -112,10 +114,11 @@ ::btas::Tensor<T, Range, Storage> gemm(
112114

113115
static_assert(::btas::boxrange_iteration_order<Range>::value ==
114116
::btas::boxrange_iteration_order<Range>::row_major);
117+
const integer ldc = std::max(integer{1}, n);
115118
blas::gemm(blas::Layout::ColMajor, gemm_helper.right_op(),
116119
gemm_helper.left_op(), n, m, k, factor_t,
117120
device_data(right.storage()), ldb, device_data(left.storage()),
118-
lda, zero, device_data(result.storage()), n, queue);
121+
lda, zero, device_data(result.storage()), ldc, queue);
119122

120123
device::sync_madness_task_with(stream);
121124
}
@@ -185,10 +188,12 @@ void gemm(::btas::Tensor<T, Range, Storage> &result,
185188
gemm_helper.compute_matrix_sizes(m, n, k, left.range(), right.range());
186189

187190
// Get the leading dimension for left and right matrices.
188-
const integer lda =
189-
(gemm_helper.left_op() == TiledArray::math::blas::Op::NoTrans ? k : m);
190-
const integer ldb =
191-
(gemm_helper.right_op() == TiledArray::math::blas::Op::NoTrans ? n : k);
191+
const integer lda = std::max(
192+
integer{1},
193+
(gemm_helper.left_op() == TiledArray::math::blas::Op::NoTrans ? k : m));
194+
const integer ldb = std::max(
195+
integer{1},
196+
(gemm_helper.right_op() == TiledArray::math::blas::Op::NoTrans ? n : k));
192197

193198
auto &queue = blasqueue_for(result.range());
194199
const auto stream = device::Stream(queue.device(), queue.stream());
@@ -207,10 +212,11 @@ void gemm(::btas::Tensor<T, Range, Storage> &result,
207212

208213
static_assert(::btas::boxrange_iteration_order<Range>::value ==
209214
::btas::boxrange_iteration_order<Range>::row_major);
215+
const integer ldc = std::max(integer{1}, n);
210216
blas::gemm(blas::Layout::ColMajor, gemm_helper.right_op(),
211217
gemm_helper.left_op(), n, m, k, factor_t,
212218
device_data(right.storage()), ldb, device_data(left.storage()),
213-
lda, one, device_data(result.storage()), n, queue);
219+
lda, one, device_data(result.storage()), ldc, queue);
214220
device::sync_madness_task_with(stream);
215221
}
216222
}

src/TiledArray/external/btas.h

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -661,16 +661,19 @@ inline btas::Tensor<T, Range, Storage> gemm(
661661
gemm_helper.compute_matrix_sizes(m, n, k, left.range(), right.range());
662662

663663
// Get the leading dimension for left and right matrices.
664-
const integer lda =
665-
(gemm_helper.left_op() == TiledArray::math::blas::Op::NoTrans ? k : m);
666-
const integer ldb =
667-
(gemm_helper.right_op() == TiledArray::math::blas::Op::NoTrans ? n : k);
664+
const integer lda = std::max(
665+
integer{1},
666+
(gemm_helper.left_op() == TiledArray::math::blas::Op::NoTrans ? k : m));
667+
const integer ldb = std::max(
668+
integer{1},
669+
(gemm_helper.right_op() == TiledArray::math::blas::Op::NoTrans ? n : k));
668670

669671
T factor_t(factor);
670672

673+
const integer ldc = std::max(integer{1}, n);
671674
TiledArray::math::blas::gemm(gemm_helper.left_op(), gemm_helper.right_op(), m,
672675
n, k, factor_t, left.data(), lda, right.data(),
673-
ldb, T(0), result.data(), n);
676+
ldb, T(0), result.data(), ldc);
674677

675678
return result;
676679
}
@@ -736,16 +739,19 @@ inline void gemm(btas::Tensor<T, Range, Storage>& result,
736739
gemm_helper.compute_matrix_sizes(m, n, k, left.range(), right.range());
737740

738741
// Get the leading dimension for left and right matrices.
739-
const integer lda =
740-
(gemm_helper.left_op() == TiledArray::math::blas::Op::NoTrans ? k : m);
741-
const integer ldb =
742-
(gemm_helper.right_op() == TiledArray::math::blas::Op::NoTrans ? n : k);
742+
const integer lda = std::max(
743+
integer{1},
744+
(gemm_helper.left_op() == TiledArray::math::blas::Op::NoTrans ? k : m));
745+
const integer ldb = std::max(
746+
integer{1},
747+
(gemm_helper.right_op() == TiledArray::math::blas::Op::NoTrans ? n : k));
743748

744749
T factor_t(factor);
745750

751+
const integer ldc = std::max(integer{1}, n);
746752
TiledArray::math::blas::gemm(gemm_helper.left_op(), gemm_helper.right_op(), m,
747753
n, k, factor_t, left.data(), lda, right.data(),
748-
ldb, T(1), result.data(), n);
754+
ldb, T(1), result.data(), ldc);
749755
}
750756

751757
// sum of the hyperdiagonal elements

src/TiledArray/tensor/tensor.h

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2648,10 +2648,13 @@ void gemm(Alpha alpha, const Tensor<As...>& A, const Tensor<Bs...>& B,
26482648
gemm_helper.compute_matrix_sizes(m, n, k, A.range(), B.range());
26492649

26502650
// Get the leading dimension for left and right matrices.
2651-
const integer lda =
2652-
(gemm_helper.left_op() == TiledArray::math::blas::NoTranspose ? k : m);
2653-
const integer ldb =
2654-
(gemm_helper.right_op() == TiledArray::math::blas::NoTranspose ? n : k);
2651+
const integer lda = std::max(
2652+
integer{1},
2653+
(gemm_helper.left_op() == TiledArray::math::blas::NoTranspose ? k : m));
2654+
const integer ldb = std::max(
2655+
integer{1},
2656+
(gemm_helper.right_op() == TiledArray::math::blas::NoTranspose ? n
2657+
: k));
26552658

26562659
// may need to split gemm into multiply + accumulate for tracing purposes
26572660
#ifdef TA_ENABLE_TILE_OPS_LOGGING
@@ -2719,8 +2722,9 @@ void gemm(Alpha alpha, const Tensor<As...>& A, const Tensor<Bs...>& B,
27192722
}
27202723
}
27212724
#else // TA_ENABLE_TILE_OPS_LOGGING
2725+
const integer ldc = std::max(integer{1}, n);
27222726
math::blas::gemm(gemm_helper.left_op(), gemm_helper.right_op(), m, n, k,
2723-
alpha, A.data(), lda, B.data(), ldb, beta, C.data(), n);
2727+
alpha, A.data(), lda, B.data(), ldb, beta, C.data(), ldc);
27242728
#endif // TA_ENABLE_TILE_OPS_LOGGING
27252729
}
27262730
}

0 commit comments

Comments
 (0)