diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 1a57f1cd75a31..a3e34621704d3 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -470,6 +470,7 @@ extern "C" { GGML_OP_TRANSPOSE, GGML_OP_GET_ROWS, GGML_OP_GET_ROWS_BACK, + GGML_OP_SET_ROWS, GGML_OP_DIAG, GGML_OP_DIAG_MASK_INF, GGML_OP_DIAG_MASK_ZERO, @@ -1374,6 +1375,12 @@ extern "C" { struct ggml_tensor * b, // row indices struct ggml_tensor * c); // data for ggml_get_rows, only used for its shape + GGML_API struct ggml_tensor * ggml_set_rows( + struct ggml_context * ctx, + struct ggml_tensor * a, // destination + struct ggml_tensor * b, // source + struct ggml_tensor * c); // row indices + GGML_API struct ggml_tensor * ggml_diag( struct ggml_context * ctx, struct ggml_tensor * a); diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 2c12e493bc9b0..9d73392e40507 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -1891,6 +1891,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_get_rows_back(params, tensor); } break; + case GGML_OP_SET_ROWS: + { + ggml_compute_forward_set_rows(params, tensor); + } break; case GGML_OP_DIAG: { ggml_compute_forward_diag(params, tensor); @@ -2240,6 +2244,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { n_tasks = n_threads; } break; case GGML_OP_GET_ROWS: + case GGML_OP_SET_ROWS: { // FIXME: get_rows can use additional threads, but the cost of launching additional threads // decreases performance with GPU offloading diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 08facb6d03d5e..43572d87ee439 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -4470,6 +4470,65 @@ void ggml_compute_forward_get_rows( //} } +static void ggml_compute_forward_set_rows_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_TENSOR_BINARY_OP_LOCALS + + const int64_t nc = ne00; + const int64_t nr = ggml_nelements(src1); + + assert(ne0 == nc); + assert(ne02 == ne11); + assert(nb00 == sizeof(float)); + assert(ggml_nrows(src0) == nr); + + const int ith = params->ith; + const int nth = params->nth; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int64_t i = ir0; i < ir1; ++i) { + const int64_t i12 = i/(ne11*ne10); + const int64_t i11 = (i - i12*ne11*ne10)/ne10; + const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); + const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); + + GGML_ASSERT(i01 >= 0 && i01 < ne1); + + ggml_cpu_fp32_to_fp16( + (const float *) ((char *) src0->data + i10*nb01 + i11*nb02 + i12*nb03), + (ggml_fp16_t *) ((char *) dst->data + i01*nb1 + i11*nb2 + i12*nb3), nc); + } +} + +void ggml_compute_forward_set_rows( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_set_rows_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_get_rows_back static void ggml_compute_forward_get_rows_back_f32_f16( diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h index dc081b9e66397..91151e311390e 100644 --- a/ggml/src/ggml-cpu/ops.h +++ b/ggml/src/ggml-cpu/ops.h @@ -53,6 +53,7 @@ void ggml_compute_forward_permute(const struct ggml_compute_params * params, str void ggml_compute_forward_transpose(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_get_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_get_rows_back(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_set_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_diag(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_diag_mask_inf(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_diag_mask_zero(const struct ggml_compute_params * params, struct ggml_tensor * dst); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index a8edad3778aa9..8bf6efa2ef494 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -985,7 +985,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "OPT_STEP_ADAMW", }; -static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82"); +static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1080,7 +1080,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "adamw(x)", }; -static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82"); +static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -3393,6 +3393,28 @@ struct ggml_tensor * ggml_get_rows_back( return result; } +// ggml_set_rows + +struct ggml_tensor * ggml_set_rows( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c) { + GGML_ASSERT(b->ne[2] == c->ne[1]); + GGML_ASSERT(c->ne[3] == 1); + GGML_ASSERT(a->type == GGML_TYPE_F16); + GGML_ASSERT(b->type == GGML_TYPE_F32); + GGML_ASSERT(c->type == GGML_TYPE_I32); + + struct ggml_tensor * result = ggml_view_tensor(ctx, a); + + result->op = GGML_OP_SET_ROWS; + result->src[0] = b; + result->src[1] = c; + + return result; +} + // ggml_diag struct ggml_tensor * ggml_diag(