Skip to content

Commit 3cc0b6e

Browse files
committed
Split int8 and bfloat16 runs for benchmarks
1 parent 2f0774f commit 3cc0b6e

File tree

2 files changed

+25
-5
lines changed

2 files changed

+25
-5
lines changed

.github/workflows/triton-benchmarks.yml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,14 +206,22 @@ jobs:
206206
source ../../scripts/capture-hw-details.sh
207207
python ../../scripts/build_report.py $REPORTS/matmul-performance-postop-gelu.csv $REPORTS/gemm-postop-gelu-triton-report.csv --benchmark gemm-postop-gelu --compiler triton --param_cols "B,M,K,N" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
208208
209-
- name: Run Triton GEMM + PostOp (add matrix) kernel benchmark
209+
- name: Run Triton GEMM + PostOp (add matrix) kernel benchmark bfloat16
210210
if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'gemm_postop_addmatrix_benchmark.py') }}
211211
run: |
212212
cd benchmarks/triton_kernels_benchmark
213213
python gemm_postop_addmatrix_benchmark.py --reports $REPORTS
214214
source ../../scripts/capture-hw-details.sh
215215
python ../../scripts/build_report.py $REPORTS/matmul-performance-postop-addmatrix.csv $REPORTS/gemm-postop-addmatrix-triton-report.csv --benchmark gemm-postop-addmatrix --compiler triton --param_cols "B,M,K,N" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
216216
217+
- name: Run Triton GEMM + PostOp (add matrix) kernel benchmark int8
218+
if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'gemm_postop_addmatrix_benchmark.py') }}
219+
run: |
220+
cd benchmarks/triton_kernels_benchmark
221+
INT8_ONLY=1 python gemm_postop_addmatrix_benchmark.py --reports $REPORTS
222+
source ../../scripts/capture-hw-details.sh
223+
python ../../scripts/build_report.py $REPORTS/matmul-performance-postop-addmatrix-int8.csv $REPORTS/gemm-postop-addmatrix-int8-triton-report.csv --benchmark gemm-postop-addmatrix-int8 --compiler triton --param_cols "B,M,K,N" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
224+
217225
- name: Run Triton FA kernel benchmark
218226
if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'flash_attention_fwd_benchmark.py') }}
219227
run: |

benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,20 @@
1212

1313
import triton_kernels_benchmark as benchmark_suit
1414

15+
import os
16+
17+
INT8_ONLY_OPTION = os.getenv("INT8_ONLY", "0") == "1"
18+
ALL_DTYPES_OPTION = os.getenv("ALL_DTYPES", "0") == "1"
19+
20+
21+
def dtypes():
22+
if ALL_DTYPES_OPTION:
23+
return [torch.bfloat16, torch.int8]
24+
elif INT8_ONLY_OPTION:
25+
return [torch.int8]
26+
else:
27+
return [torch.bfloat16]
28+
1529

1630
@triton.autotune(
1731
configs=[
@@ -214,9 +228,7 @@ def matmul(a, b, d, c):
214228
# argument names to use as an x-axis for the plot
215229
x_names=['B', 'M', 'K', 'N', 'dtype'],
216230
# different possible values for `x_name`
217-
x_vals=[[1, 1024 * i, 1024 * i, 1024 * i, dtype]
218-
for i in [1, 2, 4, 8]
219-
for dtype in [torch.bfloat16, torch.int8]] + #
231+
x_vals=[[1, 1024 * i, 1024 * i, 1024 * i, dtype] for i in [1, 2, 4, 8] for dtype in dtypes()] + #
220232
[[*shape, dtype]
221233
for shape in [[1, 1, 5120, 13824], #
222234
[1, 4, 4096, 12288], #
@@ -238,7 +250,7 @@ def matmul(a, b, d, c):
238250
[32, 4096, 4096, 128], #
239251
[4096, 8, 128, 16384], #
240252
[4096, 8, 16384, 128]]
241-
for dtype in [torch.bfloat16, torch.int8]],
253+
for dtype in dtypes()],
242254
line_arg='provider',
243255
# argument name whose value corresponds to a different line in the plot
244256
# possible values for `line_arg``

0 commit comments

Comments
 (0)