Skip to content

ggml : add ggml_set_rows #14274

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,7 @@ extern "C" {
GGML_OP_TRANSPOSE,
GGML_OP_GET_ROWS,
GGML_OP_GET_ROWS_BACK,
GGML_OP_SET_ROWS,
GGML_OP_DIAG,
GGML_OP_DIAG_MASK_INF,
GGML_OP_DIAG_MASK_ZERO,
Expand Down Expand Up @@ -1374,6 +1375,12 @@ extern "C" {
struct ggml_tensor * b, // row indices
struct ggml_tensor * c); // data for ggml_get_rows, only used for its shape

GGML_API struct ggml_tensor * ggml_set_rows(
struct ggml_context * ctx,
struct ggml_tensor * a, // destination
struct ggml_tensor * b, // source
struct ggml_tensor * c); // row indices

GGML_API struct ggml_tensor * ggml_diag(
struct ggml_context * ctx,
struct ggml_tensor * a);
Expand Down
5 changes: 5 additions & 0 deletions ggml/src/ggml-cpu/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -1891,6 +1891,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_get_rows_back(params, tensor);
} break;
case GGML_OP_SET_ROWS:
{
ggml_compute_forward_set_rows(params, tensor);
} break;
case GGML_OP_DIAG:
{
ggml_compute_forward_diag(params, tensor);
Expand Down Expand Up @@ -2240,6 +2244,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
n_tasks = n_threads;
} break;
case GGML_OP_GET_ROWS:
case GGML_OP_SET_ROWS:
{
// FIXME: get_rows can use additional threads, but the cost of launching additional threads
// decreases performance with GPU offloading
Expand Down
59 changes: 59 additions & 0 deletions ggml/src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4470,6 +4470,65 @@ void ggml_compute_forward_get_rows(
//}
}

static void ggml_compute_forward_set_rows_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {

const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];

GGML_TENSOR_BINARY_OP_LOCALS

const int64_t nc = ne00;
const int64_t nr = ggml_nelements(src1);

assert(ne0 == nc);
assert(ne02 == ne11);
assert(nb00 == sizeof(float));
assert(ggml_nrows(src0) == nr);

const int ith = params->ith;
const int nth = params->nth;

// rows per thread
const int dr = (nr + nth - 1)/nth;

// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);

for (int64_t i = ir0; i < ir1; ++i) {
const int64_t i12 = i/(ne11*ne10);
const int64_t i11 = (i - i12*ne11*ne10)/ne10;
const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);

GGML_ASSERT(i01 >= 0 && i01 < ne1);

ggml_cpu_fp32_to_fp16(
(const float *) ((char *) src0->data + i10*nb01 + i11*nb02 + i12*nb03),
(ggml_fp16_t *) ((char *) dst->data + i01*nb1 + i11*nb2 + i12*nb3), nc);
}
}

void ggml_compute_forward_set_rows(
const ggml_compute_params * params,
ggml_tensor * dst) {

const ggml_tensor * src0 = dst->src[0];

switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_set_rows_f32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}

// ggml_compute_forward_get_rows_back

static void ggml_compute_forward_get_rows_back_f32_f16(
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-cpu/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ void ggml_compute_forward_permute(const struct ggml_compute_params * params, str
void ggml_compute_forward_transpose(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_get_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_get_rows_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_set_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_diag(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_diag_mask_inf(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_diag_mask_zero(const struct ggml_compute_params * params, struct ggml_tensor * dst);
Expand Down
26 changes: 24 additions & 2 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -985,7 +985,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"OPT_STEP_ADAMW",
};

static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");

static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
Expand Down Expand Up @@ -1080,7 +1080,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"adamw(x)",
};

static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");

static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");

Expand Down Expand Up @@ -3393,6 +3393,28 @@ struct ggml_tensor * ggml_get_rows_back(
return result;
}

// ggml_set_rows

struct ggml_tensor * ggml_set_rows(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * c) {
GGML_ASSERT(b->ne[2] == c->ne[1]);
GGML_ASSERT(c->ne[3] == 1);
GGML_ASSERT(a->type == GGML_TYPE_F16);
GGML_ASSERT(b->type == GGML_TYPE_F32);
GGML_ASSERT(c->type == GGML_TYPE_I32);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should use I64 indices.


struct ggml_tensor * result = ggml_view_tensor(ctx, a);

result->op = GGML_OP_SET_ROWS;
result->src[0] = b;
result->src[1] = c;

return result;
}

// ggml_diag

struct ggml_tensor * ggml_diag(
Expand Down
Loading