@@ -77,10 +77,12 @@ ::btas::Tensor<T, Range, Storage> gemm(
77
77
gemm_helper.compute_matrix_sizes (m, n, k, left.range (), right.range ());
78
78
79
79
// Get the leading dimension for left and right matrices.
80
- const integer lda =
81
- (gemm_helper.left_op () == TiledArray::math::blas::Op::NoTrans ? k : m);
82
- const integer ldb =
83
- (gemm_helper.right_op () == TiledArray::math::blas::Op::NoTrans ? n : k);
80
+ const integer lda = std::max (
81
+ integer{1 },
82
+ (gemm_helper.left_op () == TiledArray::math::blas::Op::NoTrans ? k : m));
83
+ const integer ldb = std::max (
84
+ integer{1 },
85
+ (gemm_helper.right_op () == TiledArray::math::blas::Op::NoTrans ? n : k));
84
86
85
87
T factor_t = T (factor);
86
88
T zero (0 );
@@ -112,10 +114,11 @@ ::btas::Tensor<T, Range, Storage> gemm(
112
114
113
115
static_assert (::btas::boxrange_iteration_order<Range>::value ==
114
116
::btas::boxrange_iteration_order<Range>::row_major);
117
+ const integer ldc = std::max (integer{1 }, n);
115
118
blas::gemm (blas::Layout::ColMajor, gemm_helper.right_op (),
116
119
gemm_helper.left_op (), n, m, k, factor_t ,
117
120
device_data (right.storage ()), ldb, device_data (left.storage ()),
118
- lda, zero, device_data (result.storage ()), n , queue);
121
+ lda, zero, device_data (result.storage ()), ldc , queue);
119
122
120
123
device::sync_madness_task_with (stream);
121
124
}
@@ -185,10 +188,12 @@ void gemm(::btas::Tensor<T, Range, Storage> &result,
185
188
gemm_helper.compute_matrix_sizes (m, n, k, left.range (), right.range ());
186
189
187
190
// Get the leading dimension for left and right matrices.
188
- const integer lda =
189
- (gemm_helper.left_op () == TiledArray::math::blas::Op::NoTrans ? k : m);
190
- const integer ldb =
191
- (gemm_helper.right_op () == TiledArray::math::blas::Op::NoTrans ? n : k);
191
+ const integer lda = std::max (
192
+ integer{1 },
193
+ (gemm_helper.left_op () == TiledArray::math::blas::Op::NoTrans ? k : m));
194
+ const integer ldb = std::max (
195
+ integer{1 },
196
+ (gemm_helper.right_op () == TiledArray::math::blas::Op::NoTrans ? n : k));
192
197
193
198
auto &queue = blasqueue_for (result.range ());
194
199
const auto stream = device::Stream (queue.device (), queue.stream ());
@@ -207,10 +212,11 @@ void gemm(::btas::Tensor<T, Range, Storage> &result,
207
212
208
213
static_assert (::btas::boxrange_iteration_order<Range>::value ==
209
214
::btas::boxrange_iteration_order<Range>::row_major);
215
+ const integer ldc = std::max (integer{1 }, n);
210
216
blas::gemm (blas::Layout::ColMajor, gemm_helper.right_op (),
211
217
gemm_helper.left_op (), n, m, k, factor_t ,
212
218
device_data (right.storage ()), ldb, device_data (left.storage ()),
213
- lda, one, device_data (result.storage ()), n , queue);
219
+ lda, one, device_data (result.storage ()), ldc , queue);
214
220
device::sync_madness_task_with (stream);
215
221
}
216
222
}
0 commit comments