@@ -253,11 +253,7 @@ llama_context::llama_context(
253253
254254 // graph outputs buffer
255255 {
256- // resized during inference when a batch uses more outputs
257- // Create a dummy batch for initialization.
258- llama_batch dummy_batch = {};
259- dummy_batch.n_tokens = 0 ;
260- if (output_reserve (params.n_seq_max , dummy_batch) < params.n_seq_max ) {
256+ if (output_reserve (params.n_seq_max ) < params.n_seq_max ) {
261257 throw std::runtime_error (" failed to reserve initial output buffer" );
262258 }
263259
@@ -1225,7 +1221,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
12251221 n_queued_tokens += n_tokens;
12261222
12271223 // reserve output buffer
1228- if (output_reserve (n_tokens, batch_inp ) < n_tokens) {
1224+ if (output_reserve (n_tokens) < n_tokens) {
12291225 LLAMA_LOG_ERROR (" %s: could not reserve space for batch with %u outputs\n " , __func__, n_tokens);
12301226 return -2 ;
12311227 };
@@ -1456,6 +1452,23 @@ static void copy_tensor_async_candidates(
14561452 }
14571453}
14581454
1455+ static bool needs_raw_logits (const llama_ubatch & ubatch, const std::map<llama_seq_id, llama_sampler *> & samplers) {
1456+ for (uint32_t i = 0 ; i < ubatch.n_tokens ; i++) {
1457+ if (!ubatch.output [i]) {
1458+ continue ;
1459+ }
1460+
1461+ // Check if the output token has at least one sequence without a backend sampler.
1462+ for (int32_t j = 0 ; j < ubatch.n_seq_id [i]; ++j) {
1463+ llama_seq_id seq_id = ubatch.seq_id [i][j];
1464+ if (samplers.find (seq_id) == samplers.end ()) {
1465+ return true ;
1466+ }
1467+ }
1468+ }
1469+ return false ; // all sequences use backend sampling
1470+ }
1471+
14591472int llama_context::decode (const llama_batch & batch_inp) {
14601473 GGML_ASSERT ((!batch_inp.token && batch_inp.embd ) || (batch_inp.token && !batch_inp.embd )); // NOLINT
14611474
@@ -1588,7 +1601,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
15881601 }
15891602
15901603 // reserve output buffer
1591- if (output_reserve (n_outputs_all, balloc-> get_batch () ) < n_outputs_all) {
1604+ if (output_reserve (n_outputs_all) < n_outputs_all) {
15921605 LLAMA_LOG_ERROR (" %s: could not reserve space for batch with %d outputs\n " , __func__, n_outputs_all);
15931606 return -2 ;
15941607 };
@@ -1661,10 +1674,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
16611674 }
16621675
16631676 // extract logits
1664- // For multi-sequence batches that mix backend samplers and CPU sampler
1665- // this is currently inefficient as we copy all logits even for the
1666- // backend sampled tokens.
1667- if (logits && t_logits && n_outputs > 0 ) {
1677+ if (logits && t_logits && n_outputs > 0 && needs_raw_logits (ubatch, sampling.samplers )) {
16681678 ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend (sched.get (), t_logits);
16691679 GGML_ASSERT (backend_res != nullptr );
16701680 GGML_ASSERT (logits != nullptr );
@@ -1734,11 +1744,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
17341744 }
17351745 }
17361746
1737- // This flag indicates whether a backend sampler has actually sampled a specific
1738- // token, or if it has produced probabilites. If true, we can skip the normal copying of logits and embeddings.
1739- const bool has_sampled = !res->t_sampled .empty () || !res->t_sampled_probs .empty () || !res->t_sampled_logits .empty ();
1740-
1741- if (has_samplers && has_sampled) {
1747+ // Copy backend sampling output if this ubatch produced any sampling tensors.
1748+ if (has_samplers && (!res->t_sampled .empty () || !res->t_sampled_probs .empty () || !res->t_sampled_logits .empty ())) {
17421749 const auto seq_to_output_row = build_seq_to_output_row (ubatch, n_outputs_prev);
17431750 const auto stride = n_vocab;
17441751
@@ -1813,7 +1820,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
18131820// output
18141821//
18151822
1816- uint32_t llama_context::output_reserve (int32_t n_outputs, const llama_batch & batch) {
1823+ uint32_t llama_context::output_reserve (int32_t n_outputs) {
1824+
18171825 const auto & hparams = model.hparams ;
18181826 const auto & vocab = model.vocab ;
18191827
@@ -1832,45 +1840,16 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba
18321840 has_embd = true ;
18331841 }
18341842
1835- // Check which sampling modes are needed for the current batch.
1836- // TODO: avoid this branching by working with the worst-case
1837- bool has_sampling = false ;
1838- bool cpu_logits = false ;
1839-
1840- if (batch.logits ) {
1841- for (int32_t i = 0 ; i < batch.n_tokens ; i++) {
1842- if (!batch.logits [i]) {
1843- continue ;
1844- }
1845- for (int32_t j = 0 ; j < batch.n_seq_id [i]; j++) {
1846- llama_seq_id seq_id = batch.seq_id [i][j];
1847- if (sampling.samplers .find (seq_id) != sampling.samplers .end ()) {
1848- has_sampling = true ;
1849- } else {
1850- cpu_logits = true ;
1851- }
1852- }
1853- }
1854- } else {
1855- // When batch.logits is nullptr (when loading state with a dummy batch),
1856- // allocate CPU logits.
1857- cpu_logits = true ;
1858- }
18591843
18601844 size_t backend_float_count = 0 ;
18611845 size_t backend_token_count = 0 ;
18621846
1863- // Allocate CPU logits buffer only if needed by sequences in this batch
1864- logits_size = (has_logits && cpu_logits) ? n_vocab*n_outputs_max : 0 ;
1847+ logits_size = has_logits ? n_vocab*n_outputs_max : 0 ;
18651848 embd_size = has_embd ? n_embd_out*n_outputs_max : 0 ;
18661849
1867- // TODO: avoid this branching by working with the worst-case
1868- if (!has_sampling) {
1869- sampling.logits_size = 0 ;
1870- sampling.probs_size = 0 ;
1871- sampling.sampled_size = 0 ;
1872- sampling.candidates_size = 0 ;
1873- } else {
1850+ // Allocate backend sampling output buffers if there are backend samplers configured.
1851+ const bool has_sampling = !sampling.samplers .empty ();
1852+ if (has_sampling) {
18741853 sampling.logits_size = n_vocab*n_outputs_max;
18751854 sampling.probs_size = n_vocab*n_outputs_max;
18761855 sampling.sampled_size = n_outputs_max;
@@ -1928,7 +1907,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba
19281907 size_t offset = 0 ;
19291908 uint8_t * base = (uint8_t *) output_base;
19301909
1931- logits = ( has_logits && cpu_logits) ? output_base : nullptr ;
1910+ logits = has_logits ? output_base : nullptr ;
19321911 offset += logits_size * sizeof (float );
19331912
19341913 embd = has_embd ? (float *) (base + offset) : nullptr ;
@@ -2614,10 +2593,7 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
26142593 auto n_outputs = this ->n_outputs ;
26152594 io.read_to (&n_outputs, sizeof (n_outputs));
26162595
2617- // Create a dummy batch for state loading.
2618- llama_batch dummy_batch = {};
2619- dummy_batch.n_tokens = 0 ;
2620- if (n_outputs > output_reserve (n_outputs, dummy_batch)) {
2596+ if (n_outputs > output_reserve (n_outputs)) {
26212597 throw std::runtime_error (" could not reserve outputs" );
26222598 }
26232599
@@ -2862,7 +2838,7 @@ void llama_context::opt_epoch_iter(
28622838 }
28632839
28642840 // reserve output buffer
2865- if (output_reserve (n_outputs_all, balloc-> get_batch () ) < n_outputs_all) {
2841+ if (output_reserve (n_outputs_all) < n_outputs_all) {
28662842 LLAMA_LOG_ERROR (" %s: could not reserve space for batch with %d outputs\n " , __func__, n_outputs_all);
28672843 GGML_ABORT (" TODO: handle this error" );
28682844 };
0 commit comments