@@ -1125,21 +1125,52 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize):
1125
1125
1126
1126
# With larger block sizes, we can expect this to blow up.
1127
1127
# At blocksize>=1024, don't even bother looking at relerr.
1128
- if blocksize <= 64 :
1129
- assert err .item () < 0.1
1130
- assert relerr .item () < 0.28
1131
- elif blocksize <= 256 :
1132
- assert err .item () < 0.11
1133
- assert relerr .item () < 0.30
1134
- elif blocksize <= 512 :
1135
- assert err .item () < 0.12
1136
- assert relerr .item () < 0.31
1137
- elif quant_type == "fp4" :
1138
- # 1024 => 0.48, 2048 => 0.52, 4096 => 0.56
1139
- assert err .item () < 0.08 + math .log2 (blocksize ) * 4e-2
1140
- else :
1141
- # 1024 => 0.8, 2048 => 0.88, 4096 => 0.96
1142
- assert err .item () < math .log2 (blocksize ) * 8e-2
1128
+ #
1129
+ # Actually, the above is not true anymore after fixing the integer packing bug.
1130
+ # The following values were taken from averaging 1k samples per test configuration after fixing the bug.
1131
+ error_dict = dict ()
1132
+ error_dict ["fp4" ] = dict ()
1133
+ error_dict ["nf4" ] = dict ()
1134
+ error_dict ["fp4" ]["err" ] = {
1135
+ 64 : 0.096545 ,
1136
+ 128 : 0.102947 ,
1137
+ 256 : 0.108685 ,
1138
+ 512 : 0.114087 ,
1139
+ 1024 : 0.119312 ,
1140
+ 2048 : 0.124460 ,
1141
+ 4096 : 0.129573 ,
1142
+ }
1143
+ error_dict ["fp4" ]["rel_err" ] = {
1144
+ 64 : 0.260130 ,
1145
+ 128 : 0.275734 ,
1146
+ 256 : 0.289842 ,
1147
+ 512 : 0.302852 ,
1148
+ 1024 : 0.314982 ,
1149
+ 2048 : 0.326402 ,
1150
+ 4096 : 0.337228 ,
1151
+ }
1152
+
1153
+ error_dict ["nf4" ]["err" ] = {
1154
+ 64 : 0.072792 ,
1155
+ 128 : 0.076835 ,
1156
+ 256 : 0.080326 ,
1157
+ 512 : 0.083535 ,
1158
+ 1024 : 0.086603 ,
1159
+ 2048 : 0.089592 ,
1160
+ 4096 : 0.092537 ,
1161
+ }
1162
+ error_dict ["nf4" ]["rel_err" ] = {
1163
+ 64 : 0.203299 ,
1164
+ 128 : 0.215252 ,
1165
+ 256 : 0.226044 ,
1166
+ 512 : 0.236021 ,
1167
+ 1024 : 0.245365 ,
1168
+ 2048 : 0.254146 ,
1169
+ 4096 : 0.262457 ,
1170
+ }
1171
+
1172
+ assert err < error_dict [quant_type ]["err" ][blocksize ] + 1e-3
1173
+ assert relerr < error_dict [quant_type ]["rel_err" ][blocksize ] + 1e-3
1143
1174
1144
1175
@pytest .mark .parametrize ("device" , get_available_devices ())
1145
1176
@pytest .mark .parametrize ("quant_type" , ["fp4" , "nf4" ])
0 commit comments