Skip to content

Commit 19fe95a

Browse files
Merge pull request #1721 from Mhmd-Hisham/quantization-packing-bug-fix
[CUDA] Fixing quantization uint8 packing bug for NF4 and FP4
2 parents 4265392 + 639f8c0 commit 19fe95a

File tree

2 files changed

+50
-22
lines changed

2 files changed

+50
-22
lines changed

csrc/kernels.cu

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,6 @@ __global__ void kQuantizeBlockwise(
431431
LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0);
432432
}
433433

434-
unsigned char packed_4bit = 0;
435434
switch (DATA_TYPE) {
436435
case General8bit:
437436
#pragma unroll NUM_PER_TH
@@ -445,17 +444,15 @@ __global__ void kQuantizeBlockwise(
445444
case FP4:
446445
#pragma unroll NUM_PER_TH
447446
for (int j = 0; j < NUM_PER_TH / 2; j++) {
448-
packed_4bit |= dQuantizeFP4(((float)vals[2 * j]) * local_abs_max) << 4;
449-
packed_4bit |= dQuantizeFP4(((float)vals[2 * j + 1]) * local_abs_max);
450-
qvals[j] = packed_4bit;
447+
qvals[j] = dQuantizeFP4(((float)vals[2 * j]) * local_abs_max) << 4;
448+
qvals[j] |= dQuantizeFP4(((float)vals[2 * j + 1]) * local_abs_max);
451449
}
452450
break;
453451
case NF4:
454452
#pragma unroll NUM_PER_TH
455453
for (int j = 0; j < NUM_PER_TH / 2; j++) {
456-
packed_4bit |= dQuantizeNF4(((float)vals[2 * j]) * local_abs_max) << 4;
457-
packed_4bit |= dQuantizeNF4(((float)vals[2 * j + 1]) * local_abs_max);
458-
qvals[j] = packed_4bit;
454+
qvals[j] = dQuantizeNF4(((float)vals[2 * j]) * local_abs_max) << 4;
455+
qvals[j] |= dQuantizeNF4(((float)vals[2 * j + 1]) * local_abs_max);
459456
}
460457
break;
461458
}

tests/test_functional.py

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,21 +1125,52 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize):
11251125

11261126
# With larger block sizes, we can expect this to blow up.
11271127
# At blocksize>=1024, don't even bother looking at relerr.
1128-
if blocksize <= 64:
1129-
assert err.item() < 0.1
1130-
assert relerr.item() < 0.28
1131-
elif blocksize <= 256:
1132-
assert err.item() < 0.11
1133-
assert relerr.item() < 0.30
1134-
elif blocksize <= 512:
1135-
assert err.item() < 0.12
1136-
assert relerr.item() < 0.31
1137-
elif quant_type == "fp4":
1138-
# 1024 => 0.48, 2048 => 0.52, 4096 => 0.56
1139-
assert err.item() < 0.08 + math.log2(blocksize) * 4e-2
1140-
else:
1141-
# 1024 => 0.8, 2048 => 0.88, 4096 => 0.96
1142-
assert err.item() < math.log2(blocksize) * 8e-2
1128+
#
1129+
# Actually, the above is not true anymore after fixing the integer packing bug.
1130+
# The following values were taken from averaging 1k samples per test configuration after fixing the bug.
1131+
error_dict = dict()
1132+
error_dict["fp4"] = dict()
1133+
error_dict["nf4"] = dict()
1134+
error_dict["fp4"]["err"] = {
1135+
64: 0.096545,
1136+
128: 0.102947,
1137+
256: 0.108685,
1138+
512: 0.114087,
1139+
1024: 0.119312,
1140+
2048: 0.124460,
1141+
4096: 0.129573,
1142+
}
1143+
error_dict["fp4"]["rel_err"] = {
1144+
64: 0.260130,
1145+
128: 0.275734,
1146+
256: 0.289842,
1147+
512: 0.302852,
1148+
1024: 0.314982,
1149+
2048: 0.326402,
1150+
4096: 0.337228,
1151+
}
1152+
1153+
error_dict["nf4"]["err"] = {
1154+
64: 0.072792,
1155+
128: 0.076835,
1156+
256: 0.080326,
1157+
512: 0.083535,
1158+
1024: 0.086603,
1159+
2048: 0.089592,
1160+
4096: 0.092537,
1161+
}
1162+
error_dict["nf4"]["rel_err"] = {
1163+
64: 0.203299,
1164+
128: 0.215252,
1165+
256: 0.226044,
1166+
512: 0.236021,
1167+
1024: 0.245365,
1168+
2048: 0.254146,
1169+
4096: 0.262457,
1170+
}
1171+
1172+
assert err < error_dict[quant_type]["err"][blocksize] + 1e-3
1173+
assert relerr < error_dict[quant_type]["rel_err"][blocksize] + 1e-3
11431174

11441175
@pytest.mark.parametrize("device", get_available_devices())
11451176
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])

0 commit comments

Comments
 (0)