Skip to content

Commit 66625a5

Browse files
authored
graph : reduce splits for recurrent and hybrid models (#14825)
* graph : avoid creating redundant s_copy views * graph : comment the s_copy views
1 parent 6e67254 commit 66625a5

File tree

2 files changed

+34
-21
lines changed

2 files changed

+34
-21
lines changed

src/llama-graph.cpp

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1644,56 +1644,62 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
16441644

16451645
ggml_tensor * llm_graph_context::build_rs(
16461646
ggml_tensor * s,
1647-
ggml_tensor * state_copy,
1647+
ggml_tensor * state_copy_main,
1648+
ggml_tensor * state_copy_extra,
16481649
int32_t state_size,
16491650
int32_t n_seqs,
1650-
uint32_t n_kv,
1651-
uint32_t kv_head,
1652-
uint32_t kv_size,
1651+
uint32_t n_rs,
1652+
uint32_t rs_head,
1653+
uint32_t rs_size,
16531654
int32_t rs_zero,
16541655
const llm_graph_get_rows_fn & get_state_rows) const {
16551656

1656-
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size);
1657+
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, rs_size);
16571658

16581659
// Clear a single state which will then be copied to the other cleared states.
16591660
// Note that this is a no-op when the view is zero-sized.
16601661
ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
16611662
ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
16621663

16631664
// copy states
1664-
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1665-
// {state_size, kv_size} -> {state_size, n_seqs}
1666-
ggml_tensor * output_states = get_state_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
1665+
// NOTE: assuming the copy destinations are ALL contained between rs_head and rs_head + n_rs
1666+
// {state_size, rs_size} -> {state_size, n_seqs}
1667+
ggml_tensor * output_states = get_state_rows(ctx0, states, state_copy_main);
16671668
ggml_build_forward_expand(gf, output_states);
16681669

1669-
// copy extra states which won't be changed further (between n_seqs and n_kv)
1670-
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
1670+
// copy extra states which won't be changed further (between n_seqs and n_rs)
1671+
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, state_copy_extra);
16711672
ggml_build_forward_expand(gf,
16721673
ggml_cpy(ctx0,
16731674
states_extra,
1674-
ggml_view_1d(ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size(s))));
1675+
ggml_view_1d(ctx0, s, state_size*(n_rs - n_seqs), (rs_head + n_seqs)*state_size*ggml_element_size(s))));
16751676

16761677
return output_states;
16771678
}
16781679

16791680
static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
16801681
ggml_context * ctx0,
1682+
const llama_ubatch & ubatch,
16811683
const llama_memory_recurrent_context * mctx_cur) {
16821684

16831685
auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
16841686

1685-
const auto n_rs = mctx_cur->get_n_rs();
1687+
const int64_t n_rs = mctx_cur->get_n_rs();
1688+
const int64_t n_seqs = ubatch.n_seqs;
16861689

16871690
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
16881691
ggml_set_input(inp->s_copy);
16891692

1693+
inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0);
1694+
inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]);
1695+
16901696
return inp;
16911697
}
16921698

16931699
llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
16941700
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
16951701

1696-
auto inp = build_rs_inp_impl(ctx0, mctx_cur);
1702+
auto inp = build_rs_inp_impl(ctx0, ubatch, mctx_cur);
16971703

16981704
return (llm_graph_input_rs *) res->add_input(std::move(inp));
16991705
}
@@ -1706,7 +1712,9 @@ ggml_tensor * llm_graph_context::build_rs(
17061712
const llm_graph_get_rows_fn & get_state_rows) const {
17071713
const auto * kv_state = inp->mctx;
17081714

1709-
return build_rs(s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
1715+
return build_rs(s, inp->s_copy_main, inp->s_copy_extra, state_size, n_seqs,
1716+
kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(),
1717+
get_state_rows);
17101718
}
17111719

17121720
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
@@ -1753,7 +1761,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
17531761
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
17541762
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
17551763

1756-
auto inp_rs = build_rs_inp_impl(ctx0, mctx_cur->get_recr());
1764+
auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
17571765
auto inp_attn = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
17581766

17591767
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move(inp_attn), std::move(inp_rs), mctx_cur);

src/llama-graph.h

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,12 @@ class llm_graph_input_rs : public llm_graph_input_i {
214214

215215
void set_input(const llama_ubatch * ubatch) override;
216216

217-
ggml_tensor * s_copy; // I32 [kv_size]
217+
ggml_tensor * s_copy; // I32 [n_rs]
218+
219+
// views of s_copy, computed once per graph
220+
// and shared across layers which use build_rs
221+
ggml_tensor * s_copy_main; // I32 [n_seqs]
222+
ggml_tensor * s_copy_extra; // I32 [n_rs - n_seqs]
218223

219224
const llama_memory_recurrent_context * mctx;
220225
};
@@ -730,20 +735,20 @@ struct llm_graph_context {
730735
// recurrent
731736
//
732737

733-
// TODO: avoid notion of "kv"
734738
// TODO: move this implementation to llama_memory_recurrent.
735739
// this is analogous to llama_kv_cache_unified::cpy_k / cpy_v
736740
// when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
737741
// implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
738742
// `llama_memory_recurrent`
739743
ggml_tensor * build_rs(
740744
ggml_tensor * s,
741-
ggml_tensor * state_copy,
745+
ggml_tensor * state_copy_main,
746+
ggml_tensor * state_copy_extra,
742747
int32_t state_size,
743748
int32_t n_seqs,
744-
uint32_t n_kv,
745-
uint32_t kv_head,
746-
uint32_t kv_size,
749+
uint32_t n_rs,
750+
uint32_t rs_head,
751+
uint32_t rs_size,
747752
int32_t rs_zero,
748753
const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
749754

0 commit comments

Comments
 (0)