@@ -188,38 +188,23 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
188
188
189
189
void llm_graph_input_cls::set_input (const llama_ubatch * ubatch) {
190
190
const int64_t n_tokens = ubatch->n_tokens ;
191
- const int64_t n_seq_tokens = ubatch->n_seq_tokens ;
192
191
const int64_t n_seqs_unq = ubatch->n_seqs_unq ;
193
192
194
193
if (cparams.embeddings && (
195
- cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
196
- cparams.pooling_type == LLAMA_POOLING_TYPE_RANK
197
- )) {
194
+ cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
195
+ cparams.pooling_type == LLAMA_POOLING_TYPE_RANK ||
196
+ cparams.pooling_type == LLAMA_POOLING_TYPE_LAST
197
+ )) {
198
198
GGML_ASSERT (cls);
199
199
GGML_ASSERT (ggml_backend_buffer_is_host (cls->buffer ));
200
200
201
201
uint32_t * data = (uint32_t *) cls->data ;
202
202
memset (cls->data , 0 , n_seqs_unq*ggml_element_size (cls));
203
203
204
- for (int i = 0 ; i < n_tokens; i += n_seq_tokens) {
205
- for (int s = 0 ; s < ubatch->n_seq_id [i]; ++s) {
206
- const llama_seq_id seq_id = ubatch->seq_id [i][s];
207
- const int32_t seq_idx = ubatch->seq_idx [seq_id];
208
-
209
- data[seq_idx] = i;
210
- }
211
- }
212
- }
213
-
214
- if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
215
- GGML_ASSERT (cls);
216
- GGML_ASSERT (ggml_backend_buffer_is_host (cls->buffer ));
217
-
218
- uint32_t * data = (uint32_t *) cls->data ;
219
- memset (cls->data , 0 , n_seqs_unq*ggml_element_size (cls));
204
+ std::vector<int > target_pos (n_seqs_unq, -1 );
205
+ std::vector<int > target_row (n_seqs_unq, -1 );
220
206
221
- std::vector<int > last_pos (n_seqs_unq, -1 );
222
- std::vector<int > last_row (n_seqs_unq, -1 );
207
+ bool last = cparams.pooling_type == LLAMA_POOLING_TYPE_LAST;
223
208
224
209
for (int i = 0 ; i < n_tokens; ++i) {
225
210
const llama_pos pos = ubatch->pos [i];
@@ -228,16 +213,20 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
228
213
const llama_seq_id seq_id = ubatch->seq_id [i][s];
229
214
const int32_t seq_idx = ubatch->seq_idx [seq_id];
230
215
231
- if (pos >= last_pos[seq_idx]) {
232
- last_pos[seq_idx] = pos;
233
- last_row[seq_idx] = i;
216
+ if (
217
+ (target_pos[seq_idx] == -1 ) ||
218
+ ( last && pos >= target_pos[seq_idx]) ||
219
+ (!last && pos < target_pos[seq_idx])
220
+ ) {
221
+ target_pos[seq_idx] = pos;
222
+ target_row[seq_idx] = i;
234
223
}
235
224
}
236
225
}
237
226
238
227
for (int s = 0 ; s < n_seqs_unq; ++s) {
239
- if (last_row [s] >= 0 ) {
240
- data[s] = last_row [s];
228
+ if (target_row [s] >= 0 ) {
229
+ data[s] = target_row [s];
241
230
}
242
231
}
243
232
}
0 commit comments