Skip to content

Commit eef375c

Browse files
authored
sampling : remove sampling branching in output_reserve (#18811)
* sampling : remove sampling branching in output_reserve This commit updates output_reserve in llama-context.cpp to always allocate sampling buffers regardless of whether sampling is needed for the current batch. The motivation for this is to avoid reallocations and branching based on the sampling requirements of the batch.
1 parent 06961e2 commit eef375c

2 files changed

Lines changed: 33 additions & 57 deletions

File tree

src/llama-context.cpp

Lines changed: 32 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -253,11 +253,7 @@ llama_context::llama_context(
253253

254254
// graph outputs buffer
255255
{
256-
// resized during inference when a batch uses more outputs
257-
// Create a dummy batch for initialization.
258-
llama_batch dummy_batch = {};
259-
dummy_batch.n_tokens = 0;
260-
if (output_reserve(params.n_seq_max, dummy_batch) < params.n_seq_max) {
256+
if (output_reserve(params.n_seq_max) < params.n_seq_max) {
261257
throw std::runtime_error("failed to reserve initial output buffer");
262258
}
263259

@@ -1225,7 +1221,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
12251221
n_queued_tokens += n_tokens;
12261222

12271223
// reserve output buffer
1228-
if (output_reserve(n_tokens, batch_inp) < n_tokens) {
1224+
if (output_reserve(n_tokens) < n_tokens) {
12291225
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
12301226
return -2;
12311227
};
@@ -1456,6 +1452,23 @@ static void copy_tensor_async_candidates(
14561452
}
14571453
}
14581454

1455+
static bool needs_raw_logits(const llama_ubatch & ubatch, const std::map<llama_seq_id, llama_sampler *> & samplers) {
1456+
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
1457+
if (!ubatch.output[i]) {
1458+
continue;
1459+
}
1460+
1461+
// Check if the output token has at least one sequence without a backend sampler.
1462+
for (int32_t j = 0; j < ubatch.n_seq_id[i]; ++j) {
1463+
llama_seq_id seq_id = ubatch.seq_id[i][j];
1464+
if (samplers.find(seq_id) == samplers.end()) {
1465+
return true;
1466+
}
1467+
}
1468+
}
1469+
return false; // all sequences use backend sampling
1470+
}
1471+
14591472
int llama_context::decode(const llama_batch & batch_inp) {
14601473
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
14611474

@@ -1588,7 +1601,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
15881601
}
15891602

15901603
// reserve output buffer
1591-
if (output_reserve(n_outputs_all, balloc->get_batch()) < n_outputs_all) {
1604+
if (output_reserve(n_outputs_all) < n_outputs_all) {
15921605
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
15931606
return -2;
15941607
};
@@ -1661,10 +1674,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
16611674
}
16621675

16631676
// extract logits
1664-
// For multi-sequence batches that mix backend samplers and CPU sampler
1665-
// this is currently inefficient as we copy all logits even for the
1666-
// backend sampled tokens.
1667-
if (logits && t_logits && n_outputs > 0) {
1677+
if (logits && t_logits && n_outputs > 0 && needs_raw_logits(ubatch, sampling.samplers)) {
16681678
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
16691679
GGML_ASSERT(backend_res != nullptr);
16701680
GGML_ASSERT(logits != nullptr);
@@ -1734,11 +1744,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
17341744
}
17351745
}
17361746

1737-
// This flag indicates whether a backend sampler has actually sampled a specific
1738-
// token, or if it has produced probabilites. If true, we can skip the normal copying of logits and embeddings.
1739-
const bool has_sampled = !res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty();
1740-
1741-
if (has_samplers && has_sampled) {
1747+
// Copy backend sampling output if this ubatch produced any sampling tensors.
1748+
if (has_samplers && (!res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty())) {
17421749
const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev);
17431750
const auto stride = n_vocab;
17441751

@@ -1813,7 +1820,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
18131820
// output
18141821
//
18151822

1816-
uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & batch) {
1823+
uint32_t llama_context::output_reserve(int32_t n_outputs) {
1824+
18171825
const auto & hparams = model.hparams;
18181826
const auto & vocab = model.vocab;
18191827

@@ -1832,45 +1840,16 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba
18321840
has_embd = true;
18331841
}
18341842

1835-
// Check which sampling modes are needed for the current batch.
1836-
// TODO: avoid this branching by working with the worst-case
1837-
bool has_sampling = false;
1838-
bool cpu_logits = false;
1839-
1840-
if (batch.logits) {
1841-
for (int32_t i = 0; i < batch.n_tokens; i++) {
1842-
if (!batch.logits[i]) {
1843-
continue;
1844-
}
1845-
for (int32_t j = 0; j < batch.n_seq_id[i]; j++) {
1846-
llama_seq_id seq_id = batch.seq_id[i][j];
1847-
if (sampling.samplers.find(seq_id) != sampling.samplers.end()) {
1848-
has_sampling = true;
1849-
} else {
1850-
cpu_logits = true;
1851-
}
1852-
}
1853-
}
1854-
} else {
1855-
// When batch.logits is nullptr (when loading state with a dummy batch),
1856-
// allocate CPU logits.
1857-
cpu_logits = true;
1858-
}
18591843

18601844
size_t backend_float_count = 0;
18611845
size_t backend_token_count = 0;
18621846

1863-
// Allocate CPU logits buffer only if needed by sequences in this batch
1864-
logits_size = (has_logits && cpu_logits) ? n_vocab*n_outputs_max : 0;
1847+
logits_size = has_logits ? n_vocab*n_outputs_max : 0;
18651848
embd_size = has_embd ? n_embd_out*n_outputs_max : 0;
18661849

1867-
// TODO: avoid this branching by working with the worst-case
1868-
if (!has_sampling) {
1869-
sampling.logits_size = 0;
1870-
sampling.probs_size = 0;
1871-
sampling.sampled_size = 0;
1872-
sampling.candidates_size = 0;
1873-
} else {
1850+
// Allocate backend sampling output buffers if there are backend samplers configured.
1851+
const bool has_sampling = !sampling.samplers.empty();
1852+
if (has_sampling) {
18741853
sampling.logits_size = n_vocab*n_outputs_max;
18751854
sampling.probs_size = n_vocab*n_outputs_max;
18761855
sampling.sampled_size = n_outputs_max;
@@ -1928,7 +1907,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba
19281907
size_t offset = 0;
19291908
uint8_t * base = (uint8_t *) output_base;
19301909

1931-
logits = (has_logits && cpu_logits) ? output_base : nullptr;
1910+
logits = has_logits ? output_base : nullptr;
19321911
offset += logits_size * sizeof(float);
19331912

19341913
embd = has_embd ? (float *) (base + offset) : nullptr;
@@ -2614,10 +2593,7 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
26142593
auto n_outputs = this->n_outputs;
26152594
io.read_to(&n_outputs, sizeof(n_outputs));
26162595

2617-
// Create a dummy batch for state loading.
2618-
llama_batch dummy_batch = {};
2619-
dummy_batch.n_tokens = 0;
2620-
if (n_outputs > output_reserve(n_outputs, dummy_batch)) {
2596+
if (n_outputs > output_reserve(n_outputs)) {
26212597
throw std::runtime_error("could not reserve outputs");
26222598
}
26232599

@@ -2862,7 +2838,7 @@ void llama_context::opt_epoch_iter(
28622838
}
28632839

28642840
// reserve output buffer
2865-
if (output_reserve(n_outputs_all, balloc->get_batch()) < n_outputs_all) {
2841+
if (output_reserve(n_outputs_all) < n_outputs_all) {
28662842
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
28672843
GGML_ABORT("TODO: handle this error");
28682844
};

src/llama-context.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ struct llama_context {
212212

213213
// Make sure enough space is available for outputs.
214214
// Returns max number of outputs for which space was reserved.
215-
uint32_t output_reserve(int32_t n_outputs, const llama_batch & batch);
215+
uint32_t output_reserve(int32_t n_outputs);
216216

217217
void output_reorder();
218218

0 commit comments

Comments
 (0)