Skip to content

CUDA: mul_mat_v support for batch sizes > 1 #14262

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

JohannesGaessler
Copy link
Collaborator

This PR extends the mul_mat_vec kernels for batch sizes > 1, they seem to be viable up to a batch size of 8. The primary purpose is to help with speculative decoding and batched inference.

Performance changes
GPU Model Microbatch size Test t/s master t/s PR Speedup
RX 6800 llama 1B all F32 1 pp512 93.97 94.58 1.01
RX 6800 llama 1B all F32 2 pp512 44.37 169.10 3.81
RX 6800 llama 1B all F32 3 pp512 65.45 211.74 3.24
RX 6800 llama 1B all F32 4 pp512 86.84 260.85 3.00
RX 6800 llama 1B all F32 5 pp512 101.39 248.44 2.45
RX 6800 llama 1B all F32 6 pp512 120.59 275.06 2.28
RX 6800 llama 1B all F32 7 pp512 141.30 295.12 2.09
RX 6800 llama 1B all F32 8 pp512 162.56 305.50 1.88
RX 6800 llama 8B BF16 1 pp512 21.20 21.36 1.01
RX 6800 llama 8B BF16 2 pp512 1.89 38.77 20.55
RX 6800 llama 8B BF16 3 pp512 2.81 48.75 17.34
RX 6800 llama 8B BF16 4 pp512 3.76 56.86 15.12
RX 6800 llama 8B BF16 5 pp512 4.61 57.07 12.39
RX 6800 llama 8B BF16 6 pp512 5.52 59.26 10.74
RX 6800 llama 8B BF16 7 pp512 6.55 60.94 9.31
RX 6800 llama 8B BF16 8 pp512 7.47 61.88 8.28
RX 6800 llama 8B F16 1 pp512 21.38 21.31 1.00
RX 6800 llama 8B F16 2 pp512 9.60 39.14 4.08
RX 6800 llama 8B F16 3 pp512 14.16 48.64 3.43
RX 6800 llama 8B F16 4 pp512 18.90 54.80 2.90
RX 6800 llama 8B F16 5 pp512 22.84 56.92 2.49
RX 6800 llama 8B F16 6 pp512 27.33 60.27 2.21
RX 6800 llama 8B F16 7 pp512 32.07 61.76 1.93
RX 6800 llama 8B F16 8 pp512 36.72 63.06 1.72
P40 llama 1B all F32 1 pp512 75.35 75.65 1.00
P40 llama 1B all F32 2 pp512 140.43 143.76 1.02
P40 llama 1B all F32 3 pp512 186.86 212.35 1.14
P40 llama 1B all F32 4 pp512 259.12 260.10 1.00
P40 llama 1B all F32 5 pp512 304.59 304.61 1.00
P40 llama 1B all F32 6 pp512 357.97 358.68 1.00
P40 llama 1B all F32 7 pp512 414.78 415.16 1.00
P40 llama 1B all F32 8 pp512 475.44 476.04 1.00
P40 llama 8B BF16 1 pp512 21.15 21.21 1.00
P40 llama 8B BF16 2 pp512 8.60 35.31 4.10
P40 llama 8B BF16 3 pp512 12.83 39.42 3.07
P40 llama 8B BF16 4 pp512 17.09 45.63 2.67
P40 llama 8B BF16 5 pp512 21.14 43.44 2.06
P40 llama 8B BF16 6 pp512 25.26 53.78 2.13
P40 llama 8B BF16 7 pp512 29.71 47.35 1.59
P40 llama 8B BF16 8 pp512 33.90 46.15 1.36
P40 llama 8B F16 1 pp512 20.95 21.15 1.01
P40 llama 8B F16 2 pp512 6.96 35.44 5.09
P40 llama 8B F16 3 pp512 10.20 39.67 3.89
P40 llama 8B F16 4 pp512 13.70 46.57 3.40
P40 llama 8B F16 5 pp512 16.54 48.39 2.93
P40 llama 8B F16 6 pp512 19.77 53.76 2.72
P40 llama 8B F16 7 pp512 22.95 47.02 2.05
P40 llama 8B F16 8 pp512 26.10 46.37 1.78
RTX 3090 llama 1B all F32 1 pp512 201.17 200.97 1.00
RTX 3090 llama 1B all F32 2 pp512 325.44 379.24 1.17
RTX 3090 llama 1B all F32 3 pp512 464.06 538.10 1.16
RTX 3090 llama 1B all F32 4 pp512 601.38 683.79 1.14
RTX 3090 llama 1B all F32 5 pp512 743.95 740.42 1.00
RTX 3090 llama 1B all F32 6 pp512 885.69 887.77 1.00
RTX 3090 llama 1B all F32 7 pp512 1025.44 1024.07 1.00
RTX 3090 llama 1B all F32 8 pp512 1178.03 1178.45 1.00
RTX 3090 llama 8B BF16 1 pp512 58.06 58.27 1.00
RTX 3090 llama 8B BF16 2 pp512 98.48 109.70 1.11
RTX 3090 llama 8B BF16 3 pp512 146.26 148.08 1.01
RTX 3090 llama 8B BF16 4 pp512 195.15 194.32 1.00
RTX 3090 llama 8B BF16 5 pp512 239.12 238.88 1.00
RTX 3090 llama 8B BF16 6 pp512 285.15 284.49 1.00
RTX 3090 llama 8B BF16 7 pp512 330.18 329.39 1.00
RTX 3090 llama 8B BF16 8 pp512 380.56 378.83 1.00
RTX 3090 llama 8B F16 1 pp512 58.27 58.39 1.00
RTX 3090 llama 8B F16 2 pp512 101.39 108.35 1.07
RTX 3090 llama 8B F16 3 pp512 149.68 150.05 1.00
RTX 3090 llama 8B F16 4 pp512 198.52 198.50 1.00
RTX 3090 llama 8B F16 5 pp512 243.57 244.09 1.00
RTX 3090 llama 8B F16 6 pp512 290.06 290.72 1.00
RTX 3090 llama 8B F16 7 pp512 340.58 340.60 1.00
RTX 3090 llama 8B F16 8 pp512 391.75 392.27 1.00
RTX 4090 llama 1B all F32 1 pp512 231.53 232.40 1.00
RTX 4090 llama 1B all F32 2 pp512 371.68 435.37 1.17
RTX 4090 llama 1B all F32 3 pp512 550.96 642.04 1.17
RTX 4090 llama 1B all F32 4 pp512 733.60 851.59 1.16
RTX 4090 llama 1B all F32 5 pp512 908.50 1031.69 1.14
RTX 4090 llama 1B all F32 6 pp512 1102.94 1205.03 1.09
RTX 4090 llama 1B all F32 7 pp512 1278.15 1375.06 1.08
RTX 4090 llama 1B all F32 8 pp512 1478.59 1560.42 1.06
RTX 4090 llama 8B BF16 1 pp512 66.49 66.67 1.00
RTX 4090 llama 8B BF16 2 pp512 119.44 127.02 1.06
RTX 4090 llama 8B BF16 3 pp512 177.66 187.72 1.06
RTX 4090 llama 8B BF16 4 pp512 236.78 247.97 1.05
RTX 4090 llama 8B BF16 5 pp512 291.99 291.87 1.00
RTX 4090 llama 8B BF16 6 pp512 348.79 348.97 1.00
RTX 4090 llama 8B BF16 7 pp512 404.26 403.96 1.00
RTX 4090 llama 8B BF16 8 pp512 466.13 465.14 1.00
RTX 4090 llama 8B F16 1 pp512 66.56 66.66 1.00
RTX 4090 llama 8B F16 2 pp512 117.49 126.75 1.08
RTX 4090 llama 8B F16 3 pp512 177.31 188.96 1.07
RTX 4090 llama 8B F16 4 pp512 235.92 247.04 1.05
RTX 4090 llama 8B F16 5 pp512 290.63 289.39 1.00
RTX 4090 llama 8B F16 6 pp512 346.52 345.44 1.00
RTX 4090 llama 8B F16 7 pp512 401.26 399.99 1.00
RTX 4090 llama 8B F16 8 pp512 462.38 461.38 1.00

On modern NVIDIA GPUs the speedup vs. cuBLAS for FP16 and BF16 is relatively small though the speedup for FP32 is larger than I expected. Conversely, the FP32 speedup for Pascal is much smaller, if there is any. What I think happened is that the NVIDIA engineers simply put less work into optimizing FP32 GEMM on more modern GPUs. The cuBLAS performance for old NVIDIA GPUs and the hipBLAS performance seem to be very bad for FP16/BF16 so this PR achieves a ridiculous 20x speedup for some use cases; maybe we are running the BLAS libraries in a suboptimal way.

@IMbackK @yeahdongcn it may be worth checking whether the logic I implemented in ggml_cuda_should_use_mmv can be improved for non-NVIDIA hardware.

@IMbackK
Copy link
Collaborator

IMbackK commented Jun 18, 2025

(hip)rocblas performing very poorly on rdna is a known issue and not down to the exact calls we are useing, its pretty bad for rdna2 but it gets worse for rdna3 and for rnda4 it might as well be broken performance wise.

On mi hardware the performance is much better so possibly we will not want to do this there, but needs bench marking.

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Jun 18, 2025
@JohannesGaessler
Copy link
Collaborator Author

I forgot: I changed the integer size in the kernel from 64 bit to 32 bit due to issues with register pressure.

@slaren
Copy link
Member

slaren commented Jun 18, 2025

I forgot: I changed the integer size in the kernel from 64 bit to 32 bit due to issues with register pressure.

I think this is ok as long as the pointers or indexes to the weight matrix are still computed with 64-bit math, otherwise it will result in overflows with large matrices. E.g. Command-R output matrix is 256000*8192 elements, which is very close to the limit of a 32-bit int.

@JohannesGaessler
Copy link
Collaborator Author

JohannesGaessler commented Jun 18, 2025

I changed specifically the calculation of the initial offsets to 64 bit math. That is the only part of the kernel where the pointer offsets scale with the product of 2 tensor dimensions. The pointer offsets scaling with 1 tensor dimension are at least 1024x lower.

@slaren
Copy link
Member

slaren commented Jun 18, 2025

test-backend-ops crashes:

MUL_MAT_ID(type_a=f32,type_b=f32,n_mats=4,n_used=2,b=0,m=512,n=1,k=256): ggml/src/ggml-cuda/mmv.cu:352: GGML_ASSERT(!ids || ne1 == 1) failed

@JohannesGaessler
Copy link
Collaborator Author

Thank you, I forgot to check MUL_MAT_ID for the final version.

@yeahdongcn
Copy link
Collaborator

Merged your changes along with #13842 and tested on MTT S80 and S4000. All test-backend-ops tests passed.

However, I noticed a slight performance drop on the S4000 when running llama-bench. I’ll investigate further to understand the cause.

@IMbackK
Copy link
Collaborator

IMbackK commented Jun 19, 2025

On cdna i am seeing a large (2x +) slow down starting at batch 4 in all datatypes.

I will try to take a look soon, maybe sunday

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants