Skip to content

Commit 9dd452d

Browse files
author
Iwan Kawrakow
committed
Fix iq5_ks on NEON
1 parent 2881f5f commit 9dd452d

File tree

1 file changed

+23
-16
lines changed

1 file changed

+23
-16
lines changed

ggml/src/iqk/iqk_mul_mat.cpp

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11207,7 +11207,8 @@ struct DequantizerIQ4KS final : public BaseDequantizer<block_iq4_ks, true> {
1120711207
};
1120811208

1120911209
struct DequantizerIQ5KS final : public BaseDequantizer<block_iq5_ks, true> {
11210-
DequantizerIQ5KS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(vld1q_s8_x2(iq5nl_values)) {}
11210+
DequantizerIQ5KS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc),
11211+
values(vld1q_s8_x4(iq5nl_values)) {}
1121111212

1121211213
constexpr static int num_blocks() { return 8; }
1121311214
constexpr static bool should_scale_quants() { return false; }
@@ -11216,7 +11217,11 @@ struct DequantizerIQ5KS final : public BaseDequantizer<block_iq5_ks, true> {
1121611217
inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) {
1121711218
(void)q8;
1121811219
(void)acc;
11219-
auto scales16 = vaddq_s16(vreinterpretq_s16_u16(vandq_u16(vmovl_u8(vld1_u8(x[i].scales)), mask)), m127);
11220+
auto sas8 = vld1_u8(x[i].scales);
11221+
auto scales16 = vaddq_s16(vreinterpretq_s16_u16(vandq_u16(vmovl_u8(sas8), mask)), m127);
11222+
hbits = vld1q_u8_x2(x[i].qh);
11223+
sas = vcombine_u8(sas8, sas8);
11224+
sas = vshlq_n_u8(vandq_u8(sas, vdupq_n_u8(1)), 5);
1122011225
int32x4x2_t scales = {vmovl_s16(vget_low_s16(scales16)), vmovl_s16(vget_high_s16(scales16))};
1122111226
return scales;
1122211227
}
@@ -11226,27 +11231,29 @@ struct DequantizerIQ5KS final : public BaseDequantizer<block_iq5_ks, true> {
1122611231
if (j == 1) {
1122711232
for (int k = 0; k < 2; ++k) hbits.val[k] = vshrq_n_u8(hbits.val[k], 4);
1122811233
}
11229-
bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 4), hm));
11230-
bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 4), hm));
11231-
bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 3), hm));
11232-
bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 3), hm));
11233-
bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 2), hm));
11234-
bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 2), hm));
11235-
bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 1), hm));
11236-
bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 1), hm));
11237-
for (int k = 0; k < 4; ++k) {
11238-
bits.b1.val[k] = vqtbl2q_s8(values, bits.b1.val[k]);
11239-
bits.b2.val[k] = vqtbl2q_s8(values, bits.b2.val[k]);
11240-
}
11234+
auto shift = vdupq_n_u8((x[i].scales[4*j+0] & 1) << 5);
11235+
bits.b1.val[0] = vaddq_u8(shift, vorrq_u8(bits.b1.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 4), hm)));
11236+
bits.b1.val[1] = vaddq_u8(shift, vorrq_u8(bits.b1.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 4), hm)));
11237+
shift = vdupq_n_u8((x[i].scales[4*j+1] & 1) << 5);
11238+
bits.b1.val[2] = vaddq_u8(shift, vorrq_u8(bits.b1.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 3), hm)));
11239+
bits.b1.val[3] = vaddq_u8(shift, vorrq_u8(bits.b1.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 3), hm)));
11240+
for (int k = 0; k < 4; ++k) bits.b1.val[k] = vqtbl4q_s8(values, bits.b1.val[k]);
11241+
shift = vdupq_n_u8((x[i].scales[4*j+2] & 1) << 5);
11242+
bits.b2.val[0] = vaddq_u8(shift, vorrq_u8(bits.b2.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 2), hm)));
11243+
bits.b2.val[1] = vaddq_u8(shift, vorrq_u8(bits.b2.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 2), hm)));
11244+
shift = vdupq_n_u8((x[i].scales[4*j+3] & 1) << 5);
11245+
bits.b2.val[2] = vaddq_u8(shift, vorrq_u8(bits.b2.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 1), hm)));
11246+
bits.b2.val[3] = vaddq_u8(shift, vorrq_u8(bits.b2.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 1), hm)));
11247+
for (int k = 0; k < 4; ++k) bits.b2.val[k] = vqtbl4q_s8(values, bits.b2.val[k]);
1124111248
}
1124211249

1124311250
Q4bits bits;
11244-
const int8x16x2_t values;
11245-
const uint8x16_t hshuff = vreinterpretq_u8_u32(uint32x4_t{0x09010800, 0x0b030a02, 0x0d050c04, 0x0f070e06});
11251+
const int8x16x4_t values;
1124611252
const uint8x16_t hm = vdupq_n_u8(0x10);
1124711253
const uint16x8_t mask = vdupq_n_u16(254);
1124811254
const int16x8_t m127 = vdupq_n_s16(-127);
1124911255
uint8x16x2_t hbits;
11256+
uint8x16_t sas;
1125011257

1125111258
};
1125211259

0 commit comments

Comments
 (0)