Skip to content

Add optional MLA #188

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Feb 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.flash_attn = true;
return true;
}
if (arg == "-mla" || arg == "--mla-use") {
params.mla_attn = true;
return true;
}
if (arg == "-co" || arg == "--color") {
params.use_color = true;
return true;
Expand Down Expand Up @@ -1452,6 +1456,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "*", " --keep N", "number of tokens to keep from the initial prompt (default: %d, -1 = all)", params.n_keep });
options.push_back({ "*", " --chunks N", "max number of chunks to process (default: %d, -1 = all)", params.n_chunks });
options.push_back({ "*", "-fa, --flash-attn", "enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled" });
options.push_back({ "*", "-mla, --mla-use", "enable MLA (default: %s)", params.mla_attn ? "enabled" : "disabled" });
options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n"
"in conversation mode, this will be used as system prompt\n"
"(default: '%s')", params.prompt.c_str() });
Expand Down Expand Up @@ -2283,6 +2288,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.cb_eval_user_data = params.cb_eval_user_data;
cparams.offload_kqv = !params.no_kv_offload;
cparams.flash_attn = params.flash_attn;
cparams.mla_attn = params.mla_attn;

cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
Expand Down Expand Up @@ -3280,6 +3286,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false");
fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false");
fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false");
fprintf(stream, "mla_attn: %s # default: false\n", params.mla_attn ? "true" : "false");
fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp);

const std::vector<float> tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices());
Expand Down
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ struct gpt_params {
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
bool cont_batching = true; // insert new sequences for decoding on-the-fly
bool flash_attn = false; // flash attention
bool mla_attn = false; // MLA

bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
bool ignore_eos = false; // ignore generated EOS tokens
Expand Down
42 changes: 42 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3123,6 +3123,7 @@ def prepare_tensors(self):


@Model.register("DeepseekV2ForCausalLM")
@Model.register("DeepseekV3ForCausalLM")
class DeepseekV2Model(Model):
model_arch = gguf.MODEL_ARCH.DEEPSEEK2

Expand All @@ -3144,6 +3145,15 @@ def set_gguf_parameters(self):
self.gguf_writer.add_expert_count(hparams["n_routed_experts"])
self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"])
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])

if hparams["scoring_func"] == "sigmoid":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
elif hparams["scoring_func"] == "softmax":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
else:
raise ValueError(f"Unsupported scoring_func value: {hparams['scoring_func']}")

self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])

if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
Expand All @@ -3156,6 +3166,17 @@ def set_gguf_parameters(self):
_experts: list[dict[str, Tensor]] | None = None

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# rename e_score_correction_bias tensors
if name.endswith("e_score_correction_bias"):
name = name.replace("e_score_correction_bias", "e_score_correction.bias")

# skip Multi-Token Prediction (MTP) layers
block_count = self.hparams["num_hidden_layers"]
match = re.match(r"model.layers.(\d+)", name)
if match and int(match.group(1)) >= block_count:
return []


# process the experts separately
if name.find("mlp.experts") != -1:
n_experts = self.hparams["n_routed_experts"]
Expand Down Expand Up @@ -3188,6 +3209,27 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
return tensors
else:
return []
if name.endswith("kv_b_proj.weight"):
name_kb = name.replace("kv_b_proj", "k_b_proj")
name_vb = name.replace("kv_b_proj", "v_b_proj")

n_head_kv = self.hparams["num_key_value_heads"]
v_head_dim = self.hparams["v_head_dim"]
qk_nope_head_dim = self.hparams["qk_nope_head_dim"]

assert data_torch.shape[0] == n_head_kv * (v_head_dim + qk_nope_head_dim)

kv_b = data_torch.view(n_head_kv, v_head_dim + qk_nope_head_dim, data_torch.shape[-1])
k_b, v_b = torch.split(kv_b, [qk_nope_head_dim, v_head_dim], dim=1)
k_b = k_b.transpose(1, 2)
k_b = k_b.reshape(n_head_kv * data_torch.shape[-1], qk_nope_head_dim)
v_b = v_b.reshape(n_head_kv * v_head_dim, data_torch.shape[-1])

return [
(self.map_tensor_name(name), data_torch),
(self.map_tensor_name(name_kb), k_b),
(self.map_tensor_name(name_vb), v_b)
]

return [(self.map_tensor_name(name), data_torch)]

Expand Down
35 changes: 32 additions & 3 deletions examples/llama-bench/llama-bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ struct cmd_params {
std::vector<int> main_gpu;
std::vector<bool> no_kv_offload;
std::vector<bool> flash_attn;
std::vector<bool> mla_attn;
std::vector<std::vector<float>> tensor_split;
std::vector<bool> use_mmap;
std::vector<bool> embeddings;
Expand Down Expand Up @@ -261,6 +262,7 @@ static const cmd_params cmd_params_defaults = {
/* main_gpu */ {0},
/* no_kv_offload */ {false},
/* flash_attn */ {false},
/* mla_attn */ {false},
/* tensor_split */ {std::vector<float>(llama_max_devices(), 0.0f)},
/* use_mmap */ {true},
/* embeddings */ {false},
Expand Down Expand Up @@ -294,6 +296,7 @@ static void print_usage(int /* argc */, char ** argv) {
printf(" -mg, --main-gpu <i> (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str());
printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str());
printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str());
printf(" -mla, --mla-attn <0|1> (default: %s)\n", join(cmd_params_defaults.mla_attn, ",").c_str());
printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str());
printf(" --numa <distribute|isolate|numactl> (default: disabled)\n");
printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str());
Expand Down Expand Up @@ -526,6 +529,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
}
auto p = string_split<bool>(argv[i], split_delim);
params.flash_attn.insert(params.flash_attn.end(), p.begin(), p.end());
} else if (arg == "-mla" || arg == "--mla-attn") {
if (++i >= argc) {
invalid_param = true;
break;
}
auto p = string_split<bool>(argv[i], split_delim);
params.mla_attn.insert(params.mla_attn.end(), p.begin(), p.end());
} else if (arg == "-mmp" || arg == "--mmap") {
if (++i >= argc) {
invalid_param = true;
Expand Down Expand Up @@ -621,6 +631,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
if (params.main_gpu.empty()) { params.main_gpu = cmd_params_defaults.main_gpu; }
if (params.no_kv_offload.empty()){ params.no_kv_offload = cmd_params_defaults.no_kv_offload; }
if (params.flash_attn.empty()) { params.flash_attn = cmd_params_defaults.flash_attn; }
if (params.mla_attn.empty()) { params.mla_attn = cmd_params_defaults.mla_attn; }
if (params.tensor_split.empty()) { params.tensor_split = cmd_params_defaults.tensor_split; }
if (params.use_mmap.empty()) { params.use_mmap = cmd_params_defaults.use_mmap; }
if (params.embeddings.empty()) { params.embeddings = cmd_params_defaults.embeddings; }
Expand Down Expand Up @@ -656,6 +667,7 @@ struct cmd_params_instance {
int main_gpu;
bool no_kv_offload;
bool flash_attn;
bool mla_attn;
std::vector<float> tensor_split;
bool use_mmap;
bool embeddings;
Expand Down Expand Up @@ -698,6 +710,7 @@ struct cmd_params_instance {
cparams.type_v = type_v;
cparams.offload_kqv = !no_kv_offload;
cparams.flash_attn = flash_attn;
cparams.mla_attn = mla_attn;
cparams.embeddings = embeddings;

return cparams;
Expand All @@ -722,6 +735,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
for (const auto & tv : params.type_v)
for (const auto & nkvo : params.no_kv_offload)
for (const auto & fa : params.flash_attn)
for (const auto & mla : params.mla_attn)
for (const auto & nt : params.n_threads) {
for (const auto & n_prompt : params.n_prompt) {
if (n_prompt == 0) {
Expand All @@ -743,6 +757,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .main_gpu = */ mg,
/* .no_kv_offload= */ nkvo,
/* .flash_attn = */ fa,
/* .mla_attn = */ mla,
/* .tensor_split = */ ts,
/* .use_mmap = */ mmp,
/* .embeddings = */ embd,
Expand Down Expand Up @@ -771,6 +786,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .main_gpu = */ mg,
/* .no_kv_offload= */ nkvo,
/* .flash_attn = */ fa,
/* .mla_attn = */ mla,
/* .tensor_split = */ ts,
/* .use_mmap = */ mmp,
/* .embeddings = */ embd,
Expand Down Expand Up @@ -799,6 +815,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .main_gpu = */ mg,
/* .no_kv_offload= */ nkvo,
/* .flash_attn = */ fa,
/* .mla_attn = */ mla,
/* .tensor_split = */ ts,
/* .use_mmap = */ mmp,
/* .embeddings = */ embd,
Expand Down Expand Up @@ -827,6 +844,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .main_gpu = */ mg,
/* .no_kv_offload= */ nkvo,
/* .flash_attn = */ fa,
/* .mla_attn = */ mla,
/* .tensor_split = */ ts,
/* .use_mmap = */ mmp,
/* .embeddings = */ embd,
Expand Down Expand Up @@ -866,6 +884,7 @@ struct test {
int main_gpu;
bool no_kv_offload;
bool flash_attn;
bool mla_attn;
std::vector<float> tensor_split;
bool use_mmap;
bool embeddings;
Expand Down Expand Up @@ -895,6 +914,7 @@ struct test {
main_gpu = inst.main_gpu;
no_kv_offload = inst.no_kv_offload;
flash_attn = inst.flash_attn;
mla_attn = inst.mla_attn;
tensor_split = inst.tensor_split;
use_mmap = inst.use_mmap;
embeddings = inst.embeddings;
Expand Down Expand Up @@ -988,7 +1008,7 @@ struct test {
"n_batch", "n_ubatch",
"n_threads", "type_k", "type_v",
"n_gpu_layers", "split_mode",
"main_gpu", "no_kv_offload", "flash_attn",
"main_gpu", "no_kv_offload", "flash_attn", "mla_attn",
"tensor_split", "use_mmap", "embeddings", "repack",
"n_prompt", "n_gen", "test_time",
"avg_ns", "stddev_ns",
Expand All @@ -1010,7 +1030,7 @@ struct test {
}
if (field == "cuda" || field == "vulkan" || field == "kompute" || field == "metal" ||
field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" ||
field == "flash_attn" || field == "use_mmap" || field == "embeddings" || field == "repack") {
field == "flash_attn" || field == "mla_attn" || field == "use_mmap" || field == "embeddings" || field == "repack") {
return BOOL;
}
if (field == "avg_ts" || field == "stddev_ts") {
Expand Down Expand Up @@ -1044,7 +1064,7 @@ struct test {
std::to_string(n_batch), std::to_string(n_ubatch),
std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v),
std::to_string(n_gpu_layers), split_mode_str(split_mode),
std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn),
std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), std::to_string(mla_attn),
tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), std::to_string(repack),
std::to_string(n_prompt), std::to_string(n_gen), test_time,
std::to_string(avg_ns()), std::to_string(stdev_ns()),
Expand Down Expand Up @@ -1208,6 +1228,9 @@ struct markdown_printer : public printer {
if (field == "flash_attn") {
return 2;
}
if (field == "mla_attn") {
return 3;
}
if (field == "use_mmap") {
return 4;
}
Expand Down Expand Up @@ -1242,6 +1265,9 @@ struct markdown_printer : public printer {
if (field == "flash_attn") {
return "fa";
}
if (field == "mla_attn") {
return "mla";
}
if (field == "use_mmap") {
return "mmap";
}
Expand Down Expand Up @@ -1294,6 +1320,9 @@ struct markdown_printer : public printer {
if (params.flash_attn.size() > 1 || params.flash_attn != cmd_params_defaults.flash_attn) {
fields.emplace_back("flash_attn");
}
if (params.mla_attn.size() > 1 || params.mla_attn != cmd_params_defaults.mla_attn) {
fields.emplace_back("mla_attn");
}
if (params.tensor_split.size() > 1 || params.tensor_split != cmd_params_defaults.tensor_split) {
fields.emplace_back("tensor_split");
}
Expand Down
17 changes: 4 additions & 13 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -14056,31 +14056,22 @@ static void ggml_compute_forward_mul_mat(
#endif

#if GGML_USE_IQK_MULMAT
if (dst->type == GGML_TYPE_F32 && (ne12*ne13)%nth == 0) {
if (dst->type == GGML_TYPE_F32) {
int gcd = simple_gcd(ne12*ne13, nth);
int counter = 0;
for (int64_t i13 = 0; i13 < ne13; i13++) {
for (int64_t i12 = 0; i12 < ne12; i12++) {
if (counter++ % nth == ith) {
if ((counter++ % gcd) == (ith%gcd)) {
if (!iqk_mul_mat(ne01, ne11, ne00,
src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01, ///ggml_type_size(src0->type),
src1->type, (const char *)src1->data + i12*nb12 + i13*nb13, nb11, ///ggml_type_size(src1->type),
(float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type),
0, 1)) goto IQK_MulMat_Not_Available1;
ith/gcd, nth/gcd)) goto IQK_MulMat_Not_Available1;
}
}
}
return;
}
if (dst->type == GGML_TYPE_F32) {
for (int64_t i13 = 0; i13 < ne13; i13++)
for (int64_t i12 = 0; i12 < ne12; i12++)
if (!iqk_mul_mat(ne01, ne11, ne00,
src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01, ///ggml_type_size(src0->type),
src1->type, (const char *)src1->data + i12*nb12 + i13*nb13, nb11, ///ggml_type_size(src1->type),
(float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type),
ith, nth)) goto IQK_MulMat_Not_Available1;
return;
}
IQK_MulMat_Not_Available1:;
#endif

Expand Down
6 changes: 6 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,8 @@ class MODEL_TENSOR(IntEnum):
ATTN_Q_B = auto()
ATTN_KV_A_MQA = auto()
ATTN_KV_B = auto()
ATTN_K_B = auto()
ATTN_V_B = auto()
ATTN_Q_A_NORM = auto()
ATTN_KV_A_NORM = auto()
FFN_SUB_NORM = auto()
Expand Down Expand Up @@ -403,6 +405,8 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b",
MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa",
MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b",
MODEL_TENSOR.ATTN_K_B: "blk.{bid}.attn_k_b",
MODEL_TENSOR.ATTN_V_B: "blk.{bid}.attn_v_b",
MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm",
MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm",
MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm",
Expand Down Expand Up @@ -967,6 +971,8 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ATTN_Q_B,
MODEL_TENSOR.ATTN_KV_A_MQA,
MODEL_TENSOR.ATTN_KV_B,
MODEL_TENSOR.ATTN_K_B,
MODEL_TENSOR.ATTN_V_B,
MODEL_TENSOR.ATTN_Q_A_NORM,
MODEL_TENSOR.ATTN_KV_A_NORM,
MODEL_TENSOR.ATTN_OUT,
Expand Down
8 changes: 8 additions & 0 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,14 @@ class TensorNameMap:
"model.layers.{bid}.self_attn.kv_b_proj", # deepseek2
),

MODEL_TENSOR.ATTN_K_B: (
"model.layers.{bid}.self_attn.k_b_proj", # deepseek2
),

MODEL_TENSOR.ATTN_V_B: (
"model.layers.{bid}.self_attn.v_b_proj", # deepseek2
),

MODEL_TENSOR.ATTN_Q_A_NORM: (
"model.layers.{bid}.self_attn.q_a_layernorm", # deepseek2
),
Expand Down
1 change: 1 addition & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ extern "C" {
bool embeddings; // if true, extract embeddings (together with logits)
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
bool mla_attn; // whether to use MLA attention [EXPERIMENTAL]

// Abort callback
// if it returns true, execution of llama_decode() will be aborted
Expand Down
Loading