Skip to content

Commit cd42f3e

Browse files
committed
Address review comments
Signed-off-by: Xiaodong Ye <[email protected]>
1 parent 3e5b500 commit cd42f3e

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,14 +1198,12 @@ static void ggml_cuda_op_mul_mat_cublas(
11981198

11991199
const int cc = ggml_cuda_info().devices[id].cc;
12001200

1201-
const bool support_bf16 = GGML_CUDA_CC_IS_NVIDIA(cc) || GGML_CUDA_CC_IS_AMD(cc) ||
1201+
const bool supports_bf16 = GGML_CUDA_CC_IS_NVIDIA(cc) || GGML_CUDA_CC_IS_AMD(cc) ||
12021202
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
12031203

1204-
const bool support_fp16 = (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) ||
1205-
GGML_CUDA_CC_IS_AMD(cc) || (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
12061204
const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT;
12071205

1208-
if (support_bf16 && src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
1206+
if (supports_bf16 && src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
12091207
ggml_cuda_pool_alloc<nv_bfloat16> src1_as_bf16(ctx.pool(id));
12101208
if (src1->type != GGML_TYPE_BF16) {
12111209
const to_bf16_cuda_t to_bf16_cuda = ggml_get_to_bf16_cuda(src1->type);
@@ -1233,7 +1231,7 @@ static void ggml_cuda_op_mul_mat_cublas(
12331231

12341232
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_BF16);
12351233
to_fp32_cuda(dst_bf16.get(), dst_dd_i, row_diff*src1_ncols, stream);
1236-
} else if (support_fp16 && use_fp16) {
1234+
} else if (fast_fp16_hardware_available(cc) && use_fp16) {
12371235
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
12381236
ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
12391237
if (src0->type != GGML_TYPE_F16) {

0 commit comments

Comments
 (0)