Skip to content

Commit 389ee69

Browse files
committed
matmul bug fixes. 1) if beta!=0 copy C in. 2) detect additional
case that cublas doesn't support.
1 parent 141df30 commit 389ee69

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

include/matx/transforms/matmul.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1107,7 +1107,8 @@ __MATX_INLINE__ auto getCublasSupportedTensor( const Op &in, cudaStream_t stream
11071107
(in.Stride(RANK-1) != (index_t)1 && in.Stride(RANK-2) != (index_t)1) ||
11081108
// cublas allows 0 strides, but verify that the corresponding size is 1
11091109
(in.Stride(RANK-1) == (index_t)0 && in.Size(RANK-1) != (index_t)1) ||
1110-
(in.Stride(RANK-2) == (index_t)0 && in.Size(RANK-2) != (index_t)1)
1110+
(in.Stride(RANK-2) == (index_t)0 && in.Size(RANK-2) != (index_t)1) ||
1111+
in.Stride(RANK-2) == 0 // WAR for CUBLAS bug
11111112
) {
11121113
supported = false;
11131114
}
@@ -1192,6 +1193,10 @@ void matmul_impl(TensorTypeC C, const TensorTypeA A,
11921193
if(!b.isSameView(B_)) {
11931194
(b = B_).run(stream);
11941195
}
1196+
1197+
if(beta != 0 && !c.isSameView(C)) {
1198+
(c = C).run(stream);
1199+
}
11951200

11961201
#if MATX_ENABLE_CUTLASS != 1
11971202
// cublasLt does not allow transpose modes on C. Thus we need to make sure that the right most dimension has a stride of 1.

0 commit comments

Comments
 (0)