@@ -1170,6 +1170,9 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
1170
1170
size = GGML_PAD (size, sizeof (int64_t )); // + padding for next bloc.
1171
1171
size += sizeof (int64_t ) * (1 +op->src [0 ]->ne [2 ]) * op->src [1 ]->ne [2 ];
1172
1172
return true ;
1173
+ case GGML_OP_GET_ROWS:
1174
+ size = 0 ; // GET_ROWS (standard and repacked) doesn't need a work buffer
1175
+ return true ;
1173
1176
default :
1174
1177
// GGML_ABORT("fatal error");
1175
1178
break ;
@@ -1185,6 +1188,9 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
1185
1188
case GGML_OP_MUL_MAT_ID:
1186
1189
forward_mul_mat_id (params, op);
1187
1190
return true ;
1191
+ case GGML_OP_GET_ROWS:
1192
+ forward_get_rows (params, op);
1193
+ return true ;
1188
1194
default :
1189
1195
// GGML_ABORT("fatal error");
1190
1196
break ;
@@ -1390,6 +1396,132 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
1390
1396
#undef MMID_MATRIX_ROW
1391
1397
}
1392
1398
1399
+ void forward_get_rows (const ggml_compute_params * params,
1400
+ ggml_tensor * dst) {
1401
+ const ggml_tensor * src0 = dst->src [0 ];
1402
+
1403
+ switch (src0->type ) {
1404
+ case GGML_TYPE_Q4_0: {
1405
+ ggml_compute_forward_get_rows_q4_0x8 (params, dst);
1406
+ } break ;
1407
+ default :
1408
+ GGML_ABORT (" fatal error" );
1409
+ break ;
1410
+ }
1411
+ }
1412
+
1413
+ static void ggml_compute_forward_get_rows_q4_0x8 (
1414
+ const ggml_compute_params * params,
1415
+ ggml_tensor * dst) {
1416
+ const ggml_tensor * src0 = dst->src [0 ];
1417
+ const ggml_tensor * src1 = dst->src [1 ];
1418
+
1419
+ GGML_TENSOR_BINARY_OP_LOCALS
1420
+
1421
+ const int64_t nc = ne00;
1422
+ const int64_t nr = ggml_nelements (src1);
1423
+
1424
+ assert (ne0 == nc);
1425
+ assert (ne02 == ne11);
1426
+ assert (nb00 == ggml_type_size (src0->type ));
1427
+ assert (ggml_nrows (dst) == nr);
1428
+
1429
+ const int ith = params->ith ;
1430
+ const int nth = params->nth ;
1431
+
1432
+ // rows per thread
1433
+ const int dr = (nr + nth - 1 ) / nth;
1434
+
1435
+ // row range for this thread
1436
+ const int ir0 = dr * ith;
1437
+ const int ir1 = MIN (ir0 + dr, nr);
1438
+
1439
+ constexpr int nrows_interleaved = 8 ;
1440
+ const size_t sizeof_one_repacked_block = sizeof (block_q4_0x8);
1441
+
1442
+ const int num_repacked_blocks_per_row_width = nc / QK4_0;
1443
+
1444
+ const size_t stride_between_actual_row_groups = num_repacked_blocks_per_row_width * sizeof_one_repacked_block;
1445
+
1446
+ for (int64_t i = ir0; i < ir1; ++i) {
1447
+ const int64_t i12 = i / (ne11 * ne10);
1448
+ const int64_t i11 = (i - i12 * ne11 * ne10) / ne10;
1449
+ const int64_t i10 = (i - i12 * ne11 * ne10 - i11 * ne10);
1450
+ const int64_t i01 = *(int32_t *)((char *)src1->data + i10 * nb10 + i11 * nb11 + i12 * nb12); // original logical row
1451
+
1452
+ GGML_ASSERT (i01 >= 0 && i01 < ne01);
1453
+
1454
+ int row_group_idx = i01 / nrows_interleaved;
1455
+ const int row_idx_in_group = i01 % nrows_interleaved;
1456
+
1457
+ const char * base_ptr_for_higher_dims_in_src0 = (const char *)src0->data + i11 * nb02 + i12 * nb03;
1458
+
1459
+ // Pointer to the first block_q4_0x8 of the identified row_group_idx
1460
+ const block_q4_0x8 * p_first_repacked_block_of_group_x8 = (const block_q4_0x8 *)(base_ptr_for_higher_dims_in_src0 + row_group_idx * stride_between_actual_row_groups);
1461
+
1462
+ dequantize_row_q4_0x8 (
1463
+ p_first_repacked_block_of_group_x8,
1464
+ (float *)((char *)dst->data + i10 * nb1 + i11 * nb2 + i12 * nb3), nc, row_idx_in_group);
1465
+ }
1466
+ }
1467
+
1468
+ /* *
1469
+ * Dequantizes a single logical row from data repacked with quant interleaving.
1470
+ *
1471
+ * @param p_repacked_group_column_blocks Pointer to the start of 'block_q4_0x8' for the row group.
1472
+ * @param y Output buffer for the dequantized float values.
1473
+ * @param k Total number of elements (columns) in the logical row.
1474
+ * @param row_idx_in_group Index (0-7) of the logical row to dequantize.
1475
+ */
1476
+ static void dequantize_row_q4_0x8 (
1477
+ const block_q4_0x8 * GGML_RESTRICT p_repacked_group_column_blocks,
1478
+ float * GGML_RESTRICT y,
1479
+ int64_t k,
1480
+ int row_idx_in_group) {
1481
+ const int GGML_Q4_0_X8_INTERLEAVE_SIZE = 8 ;
1482
+ assert (k % QK4_0 == 0 );
1483
+ assert (row_idx_in_group >= 0 && row_idx_in_group < GGML_Q4_0_X8_INTERLEAVE_SIZE);
1484
+
1485
+ const int nb = k / QK4_0;
1486
+ const int bytes_for_half_elements = (QK4_0 / 2 ) / 2 ;
1487
+
1488
+ const int offset_to_second_half_data = bytes_for_half_elements * GGML_Q4_0_X8_INTERLEAVE_SIZE;
1489
+ const uint64_t xor_mask = 0x8888888888888888ULL ;
1490
+ const int qk4_0_half_elements = QK4_0 / 2 ;
1491
+
1492
+ for (int i = 0 ; i < nb; ++i) {
1493
+ const block_q4_0x8 * current_column_repacked_block = &p_repacked_group_column_blocks[i];
1494
+ const float d_val = GGML_FP16_TO_FP32 (current_column_repacked_block->d [row_idx_in_group]);
1495
+ float * y_curr = y + i * QK4_0;
1496
+
1497
+ const int8_t * qs_first_half_repacked_ptr = &(current_column_repacked_block->qs [row_idx_in_group * bytes_for_half_elements]);
1498
+
1499
+ uint64_t first_half_chunk_u64;
1500
+ memcpy (&first_half_chunk_u64, qs_first_half_repacked_ptr, sizeof (uint64_t ));
1501
+ first_half_chunk_u64 ^= xor_mask; // Reverse the XOR
1502
+ const uint8_t * original_qs_first_half_bytes = (const uint8_t *)&first_half_chunk_u64;
1503
+
1504
+ const int8_t * qs_second_half_repacked_ptr = &(current_column_repacked_block->qs [offset_to_second_half_data + (row_idx_in_group * bytes_for_half_elements)]);
1505
+
1506
+ uint64_t second_half_chunk_u64;
1507
+ memcpy (&second_half_chunk_u64, qs_second_half_repacked_ptr, sizeof (uint64_t ));
1508
+ second_half_chunk_u64 ^= xor_mask; // Reverse the XOR
1509
+ const uint8_t * original_qs_second_half_bytes = (const uint8_t *)&second_half_chunk_u64;
1510
+
1511
+ // dequantizing all QK4_0's for this block.
1512
+ for (int j = 0 ; j < bytes_for_half_elements; ++j) {
1513
+ const uint8_t quant_byte_first = original_qs_first_half_bytes[j];
1514
+ y_curr[j] = ((quant_byte_first & 0x0F ) - 8 ) * d_val;
1515
+ y_curr[j + qk4_0_half_elements] = ((quant_byte_first >> 4 ) - 8 ) * d_val;
1516
+
1517
+ const uint8_t quant_byte_second = original_qs_second_half_bytes[j];
1518
+ const int out_idx_base_second_half = j + bytes_for_half_elements; // Offset for the second set of low nibbles
1519
+ y_curr[out_idx_base_second_half] = ((quant_byte_second & 0x0F ) - 8 ) * d_val;
1520
+ y_curr[out_idx_base_second_half + qk4_0_half_elements] = ((quant_byte_second >> 4 ) - 8 ) * d_val;
1521
+ }
1522
+ }
1523
+ }
1524
+
1393
1525
int repack (struct ggml_tensor * t, const void * data, size_t data_size) override {
1394
1526
GGML_LOG_DEBUG (" %s: repack tensor %s with %s_%dx%d\n " , __func__, t->name , ggml_type_name (t->type ),
1395
1527
(int ) NB_COLS, (int ) INTER_SIZE);
@@ -1522,12 +1654,23 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
1522
1654
// if (op->src[1]->type == GGML_TYPE_Q8_0) {
1523
1655
// return true;
1524
1656
// }
1657
+ } else if (op->op == GGML_OP_GET_ROWS
1658
+ && op->src [0 ]->buffer
1659
+ && (ggml_n_dims (op->src [0 ]) == 2 )
1660
+ && op->src [0 ]->buffer ->buft == ggml_backend_cpu_repack_buffer_type ()
1661
+ && ggml_repack_get_optimal_repack_type (op->src [0 ])) {
1662
+ if (op->src [1 ]->buffer && !ggml_backend_buft_is_host (op->src [1 ]->buffer ->buft )) {
1663
+ return false ;
1664
+ }
1665
+ if (op->src [0 ]->type == GGML_TYPE_Q4_0) {
1666
+ return true ;
1667
+ }
1525
1668
}
1526
1669
return false ;
1527
1670
}
1528
1671
1529
1672
ggml::cpu::tensor_traits * get_tensor_traits (const struct ggml_tensor * op) override {
1530
- if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_MUL_MAT_ID) {
1673
+ if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_MUL_MAT_ID || op-> op == GGML_OP_GET_ROWS ) {
1531
1674
if (op->src [0 ]->buffer && op->src [0 ]->buffer ->buft == ggml_backend_cpu_repack_buffer_type ()) {
1532
1675
return (ggml::cpu::tensor_traits *) op->src [0 ]->extra ;
1533
1676
}
0 commit comments