Skip to content

Commit a118d80

Browse files
authored
embeddings: fix extraction of CLS pooling results (#14927)
* embeddings: fix extraction of CLS pooling results * merge RANK pooling into CLS case for inputs
1 parent 61550f8 commit a118d80

File tree

1 file changed

+16
-27
lines changed

1 file changed

+16
-27
lines changed

src/llama-graph.cpp

Lines changed: 16 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -188,38 +188,23 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
188188

189189
void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
190190
const int64_t n_tokens = ubatch->n_tokens;
191-
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
192191
const int64_t n_seqs_unq = ubatch->n_seqs_unq;
193192

194193
if (cparams.embeddings && (
195-
cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
196-
cparams.pooling_type == LLAMA_POOLING_TYPE_RANK
197-
)) {
194+
cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
195+
cparams.pooling_type == LLAMA_POOLING_TYPE_RANK ||
196+
cparams.pooling_type == LLAMA_POOLING_TYPE_LAST
197+
)) {
198198
GGML_ASSERT(cls);
199199
GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
200200

201201
uint32_t * data = (uint32_t *) cls->data;
202202
memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
203203

204-
for (int i = 0; i < n_tokens; i += n_seq_tokens) {
205-
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
206-
const llama_seq_id seq_id = ubatch->seq_id[i][s];
207-
const int32_t seq_idx = ubatch->seq_idx[seq_id];
208-
209-
data[seq_idx] = i;
210-
}
211-
}
212-
}
213-
214-
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
215-
GGML_ASSERT(cls);
216-
GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
217-
218-
uint32_t * data = (uint32_t *) cls->data;
219-
memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
204+
std::vector<int> target_pos(n_seqs_unq, -1);
205+
std::vector<int> target_row(n_seqs_unq, -1);
220206

221-
std::vector<int> last_pos(n_seqs_unq, -1);
222-
std::vector<int> last_row(n_seqs_unq, -1);
207+
bool last = cparams.pooling_type == LLAMA_POOLING_TYPE_LAST;
223208

224209
for (int i = 0; i < n_tokens; ++i) {
225210
const llama_pos pos = ubatch->pos[i];
@@ -228,16 +213,20 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
228213
const llama_seq_id seq_id = ubatch->seq_id[i][s];
229214
const int32_t seq_idx = ubatch->seq_idx[seq_id];
230215

231-
if (pos >= last_pos[seq_idx]) {
232-
last_pos[seq_idx] = pos;
233-
last_row[seq_idx] = i;
216+
if (
217+
(target_pos[seq_idx] == -1) ||
218+
( last && pos >= target_pos[seq_idx]) ||
219+
(!last && pos < target_pos[seq_idx])
220+
) {
221+
target_pos[seq_idx] = pos;
222+
target_row[seq_idx] = i;
234223
}
235224
}
236225
}
237226

238227
for (int s = 0; s < n_seqs_unq; ++s) {
239-
if (last_row[s] >= 0) {
240-
data[s] = last_row[s];
228+
if (target_row[s] >= 0) {
229+
data[s] = target_row[s];
241230
}
242231
}
243232
}

0 commit comments

Comments
 (0)