diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 58e455ae645ed..4af85ba68a4d3 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -604,7 +604,10 @@ def get_vocab_base(self) -> tuple[list[str], list[int], str]: from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(self.dir_model) - vocab_size = self.hparams.get("vocab_size", len(tokenizer.vocab)) + vocab_size = max( + self.hparams.get("vocab_size", len(tokenizer.vocab)), + len(tokenizer.vocab) + ) assert max(tokenizer.vocab.values()) < vocab_size tokpre = self.get_vocab_base_pre(tokenizer) @@ -683,6 +686,14 @@ def get_vocab_base_pre(self, tokenizer) -> str: if chkhsh == "9d032fcbd5501f4a38150912590928bfb36091efb5df11b8e2124b0390e3fb1e": # ref: https://huggingface.co/tiiuae/Falcon3-7B-Base res = "falcon3" + if ( + chkhsh == "60476e1243776c4fb1b993dbd7a5f15ac22f83c80afdf425fa5ae01c8d44ef86" or + chkhsh == "3eda48b4c4dc7de733d1a8b3e3b4a85243dbbf704da2ee9d42c6beced8897896" or + chkhsh == "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b" or + chkhsh == "a6b57017d60e6edb4d88ecc2845188e0eb333a70357e45dcc9b53964a73bbae6" + ): + # ref: + res = "falcon-h1" if chkhsh == "8e62295832751ca1e8f92f2226f403dea30dc5165e448b5bfa05af5340c64ec7": # ref: https://huggingface.co/BAAI/bge-large-zh-v1.5 res = "bert-bge-large" @@ -4636,6 +4647,14 @@ def set_gguf_parameters(self): class MambaModel(TextModel): model_arch = gguf.MODEL_ARCH.MAMBA + def __init__(self, dir_model: Path, *args, **kwargs): + # Avoid using AutoConfig for hparams + hparams = kwargs.pop("hparams", None) + if hparams is None: + with open(dir_model / "config.json", "r", encoding="utf-8") as f: + hparams = json.load(f) + super().__init__(dir_model, *args, hparams=hparams, **kwargs) + def set_vocab(self): vocab_size = self.hparams["vocab_size"] # Round vocab size to next multiple of 8 @@ -4710,6 +4729,238 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [(new_name, data_torch)] +@ModelBase.register("Mamba2ForCausalLM") +class Mamba2Model(TextModel): + model_arch = gguf.MODEL_ARCH.MAMBA2 + + def __init__(self, dir_model: Path, *args, **kwargs): + # Avoid using AutoConfig for hparams + # It wrongly assumes all Mamba2 models are Mamba-Codestral-7B-v0.1 + hparams = kwargs.pop("hparams", None) + if hparams is None: + with open(dir_model / "config.json", "r", encoding="utf-8") as f: + hparams = json.load(f) + super().__init__(dir_model, *args, hparams=hparams, **kwargs) + + def set_vocab(self): + vocab_size = self.hparams["vocab_size"] + # Round vocab size to next multiple of 16 + pad_vocab = self.hparams.get("pad_vocab_size_multiple", 16) + # pad using ceiling division + # ref: https://stackoverflow.com/a/17511341/22827863 + vocab_size = -(vocab_size // -pad_vocab) * pad_vocab + self.hparams["vocab_size"] = vocab_size + + if (self.dir_model / "tokenizer.model").is_file(): + self._set_vocab_sentencepiece() + elif (self.dir_model / "tokenizer.model.v3").is_file(): + # mamba-codestral + raise NotImplementedError(f"Please rename {self.dir_model / 'tokenizer.model.v3'} to {self.dir_model / 'tokenizer.model'}") + elif (self.dir_model / "tokenizer.json").is_file(): + self._set_vocab_gpt2() + else: + # Use the GPT-NeoX tokenizer when no tokenizer files are present + self._set_vocab_builtin("gpt-neox", vocab_size) + + def set_gguf_parameters(self): + d_model = self.find_hparam(["hidden_size", "d_model", "dim"]) + d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4 + d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model + d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 128 + head_dim = self.find_hparam(["head_dim"], optional=True) or 64 + n_group = self.find_hparam(["n_groups"], optional=True) or 1 + + rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5 + + # Fail early for models which don't have a block expansion factor of 2 + # TODO: does this really matter? + assert d_inner == 2 * d_model + assert d_inner % head_dim == 0 + + self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default + self.gguf_writer.add_embedding_length(d_model) + self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading + self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading + self.gguf_writer.add_block_count(self.block_count) + self.gguf_writer.add_ssm_conv_kernel(d_conv) + self.gguf_writer.add_ssm_inner_size(d_inner) + self.gguf_writer.add_ssm_state_size(d_state) + self.gguf_writer.add_ssm_time_step_rank(d_inner // head_dim) + self.gguf_writer.add_ssm_group_count(n_group) + self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps) + self.gguf_writer.add_file_type(self.ftype) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + + if name.startswith("model.backbone") or name.startswith("model.lm_head"): + # map Mamba-Codestral-7B-v0.1 tensor names to the names used by Mamba-2 + name = name.removeprefix("model.") + + if name.endswith(".dt_bias"): + name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias" + + new_name = self.map_tensor_name(name) + + if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_CONV1D, bid): + data_torch = data_torch.squeeze() + elif any(self.match_model_tensor_name(new_name, t, bid, suffix="") for t in [ + gguf.MODEL_TENSOR.SSM_A, + gguf.MODEL_TENSOR.SSM_D, + ]): + # unsqueeze A to use similar shape semantics as Mamba-1 + # (D is also unsqueezed, but for more straightforward broadcast internally) + data_torch = data_torch.reshape((*data_torch.shape, 1)) + elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_NORM, bid): + d_model = self.find_hparam(["hidden_size", "d_model", "dim"]) + d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model + n_group = self.hparams.get("n_groups", 1) + data_torch = data_torch.reshape((n_group, d_inner // n_group)) + + if name.endswith(".A_log"): + logger.debug("A_log --> A ==> " + new_name) + data_torch = -torch.exp(data_torch) + + yield (new_name, data_torch) + + +@Model.register("FalconH1ForCausalLM") +class FalconH1Model(Mamba2Model): + model_arch = gguf.MODEL_ARCH.FALCON_H1 + + def __init__(self, *args, **kwargs): + # Set the hparam prefixes for Falcon Mamba2 + self.hparam_prefixes = ["mamba"] + + # Initialize the base Mamba2Model + super().__init__(*args, **kwargs) + + # Use Llama conversion for attention + self._transformer_model_class: type[Model] = LlamaModel + + # n_group and d_inner are used during reshape_tensors for mamaba2 + self.d_model = self.find_hparam(["hidden_size", "d_model"]) + self.n_group = self.find_hparam(["n_groups"]) + self.d_inner = self.find_hparam(["expand"]) * self.d_model + + # Initialize any Falcon Mamba2 specific attributes + self.has_attention = True # Falcon Mamba2 has attention components + + # Load Falcon-H1 multipliers from hyperparameters + self.attention_in_multiplier = self.find_hparam(["attention_in_multiplier"], optional=True) + self.attention_out_multiplier = self.find_hparam(["attention_out_multiplier"], optional=True) + self.ssm_in_multiplier = self.find_hparam(["ssm_in_multiplier"], optional=True) + self.ssm_out_multiplier = self.find_hparam(["ssm_out_multiplier"], optional=True) + self.mlp_multipliers = self.find_hparam(["mlp_multipliers"], optional=True) + self.ssm_multipliers = self.find_hparam(["ssm_multipliers"], optional=True) + self.intermediate_size = self.find_hparam(["intermediate_size"]) + + def find_hparam(self, keys: Iterable[str], *args, **kwargs) -> Any: + prefixed = [] + for pfx in self.hparam_prefixes: + prefixed.extend( + "_".join([pfx, k]) + for k in keys + ) + keys = list(keys) + prefixed + return super().find_hparam(keys, *args, **kwargs) + + def _generate_mup_vector(self, block_id: int) -> torch.Tensor: + zxbcdt_multipliers = self.hparams["ssm_multipliers"] + intermediate_size = self.hparams["mamba_d_ssm"] + groups_time_state_size = self.hparams["mamba_n_groups"] * self.hparams["mamba_d_state"] + vector_shape = (2 * intermediate_size + 2 * groups_time_state_size + self.hparams["mamba_n_heads"]) + + mup_vector = torch.ones(1, 1, vector_shape) + mup_vector[:, :, :intermediate_size] *= zxbcdt_multipliers[0] + mup_vector[:, :, intermediate_size:2 * intermediate_size] *= zxbcdt_multipliers[1] + mup_vector[:, :, 2 * intermediate_size:2 * intermediate_size + groups_time_state_size] *= zxbcdt_multipliers[2] + mup_vector[:, :, 2 * intermediate_size + groups_time_state_size:2 * intermediate_size + 2 * groups_time_state_size] *= zxbcdt_multipliers[3] + mup_vector[:, :, 2 * intermediate_size + 2 * groups_time_state_size:] *= zxbcdt_multipliers[4] + + return mup_vector + + def get_tensors(self) -> Iterator[tuple[str, Tensor]]: + for name, tensor in super().get_tensors(): + if name.startswith("model.backbone") or name.startswith("model.lm_head"): + name = name.removeprefix("model.") + yield name, tensor + + if self.ssm_multipliers is not None: + # Insert MUP vector after mamba.dt_bias + if "mamba.dt_bias" in name: + block_match = re.search(r"(?:model\.layers\.)?(\d+)\.mamba\.dt_bias", name) + if block_match: + block_id = int(block_match.group(1)) + # Generate MUP vector with correct name format + mup_tensor = self._generate_mup_vector(block_id) + mup_name = f"blk.{block_id}.ssm_mup_vec" + logger.debug(f"Inserting MUP vector for block {block_id}: {mup_name}") + yield mup_name, mup_tensor + + def set_gguf_parameters(self): + super().set_gguf_parameters() + + ## General Params ## + self.gguf_writer.add_block_count(self.block_count) + self.gguf_writer.add_context_length(self.hparams.get("max_position_embeddings", 0)) + self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) + self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) + + ## Mamba mixer params ## + self.gguf_writer.add_ssm_conv_kernel(self.find_hparam(["conv_kernel", "d_conv"])) + self.gguf_writer.add_ssm_group_count(self.n_group) + self.gguf_writer.add_ssm_inner_size(self.d_inner) + self.gguf_writer.add_ssm_head_dim(d_head := self.find_hparam(["d_head"])) + self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["n_heads"])) + + ## Attention params ## + self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) + self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"] if "num_key_value_heads" in self.hparams else self.hparams["num_attention_heads"]) + self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) + self.gguf_writer.add_key_length(self.hparams["head_dim"]) + self.gguf_writer.add_value_length(self.hparams["head_dim"]) + self.gguf_writer.add_float32("falcon-h1.key_multiplier", self.hparams["key_multiplier"]) + + ## Other params + self.gguf_writer.add_float32("falcon-h1.lm_head_multiplier", self.hparams["lm_head_multiplier"]) + self.gguf_writer.add_float32("falcon-h1.embedding_multiplier", self.hparams["embedding_multiplier"]) + + ## Validation ## + assert self.hparams.get("hidden_act") in [None, "silu"], "Only SILU activation supported" + assert self.d_inner % d_head == 0, f"SSM inner size {self.d_inner} not a multiple of head dim {d_head}" + + + # Add Falcon Mamba2 specific configuration + self.gguf_writer.add_uint32("falcon-h1.ssm.mamba_chunk_size", self.hparams["mamba_chunk_size"]) + self.gguf_writer.add_uint32("falcon-h1.attention.head_dim", self.hparams["head_dim"]) + self.gguf_writer.add_uint32("falcon-h1.ssm.mamba_d_ssm", self.hparams["mamba_d_ssm"]) + self.gguf_writer.add_uint32("falcon-h1.num_attention_heads", self.find_hparam(["num_attention_heads"])) + self.gguf_writer.add_uint32("falcon-h1.num_key_value_heads", + self.find_hparam(["num_key_value_heads"], optional=True) or + self.find_hparam(["num_attention_heads"])) + + # Add multipliers as metadata instead of tensors + self.gguf_writer.add_float32("falcon-h1.attention_in_multiplier", self.attention_in_multiplier) + self.gguf_writer.add_float32("falcon-h1.attention_out_multiplier", self.attention_out_multiplier) + self.gguf_writer.add_float32("falcon-h1.ssm_in_multiplier", self.ssm_in_multiplier) + self.gguf_writer.add_float32("falcon-h1.ssm_out_multiplier", self.ssm_out_multiplier) + + # Add MLP multipliers + if isinstance(self.mlp_multipliers, (list, tuple)) and len(self.mlp_multipliers) == 2: + self.gguf_writer.add_float32("falcon-h1.mlp_gate_multiplier", self.mlp_multipliers[0]) + self.gguf_writer.add_float32("falcon-h1.mlp_down_multiplier", self.mlp_multipliers[1]) + + # Add has MuP flag if SSM multipliers are present + if self.ssm_multipliers is not None: + self.gguf_writer.add_bool("falcon-h1.ssm.has_mup", True) + + # Add any other Falcon Mamba2 specific configuration + self.gguf_writer.add_bool("falcon-h1.mamba_use_mlp", self.find_hparam(["mamba_use_mlp"], optional=True)) + self.gguf_writer.add_bool("falcon-h1.mamba_norm_before_gate", self.find_hparam(["mamba_norm_before_gate"], optional=True)) + self.gguf_writer.add_bool("falcon-h1.mamba_rms_norm", self.find_hparam(["mamba_rms_norm"], optional=True)) + self.gguf_writer.add_float32("falcon-h1.rope_theta", self.find_hparam(["rope_theta"], optional=True)) + + @ModelBase.register("CohereForCausalLM") class CommandR2Model(TextModel): model_arch = gguf.MODEL_ARCH.COMMAND_R @@ -6477,12 +6728,20 @@ def get_model_architecture(hparams: dict[str, Any], model_type: ModelType) -> st # maybe we should fallback to text model's arch in that case, since not many models have both text_config = hparams.get("text_config", {}) vision_config = hparams.get("vision_config", {}) - arch = hparams["architectures"][0] + arch = None + if (arches := hparams.get("architectures")) is not None and len(arches) > 0: + arch = arches[0] + elif "ssm_cfg" in hparams: + # For non-hf Mamba and Mamba2 models + arch = hparams["ssm_cfg"].get("layer", "Mamba") + "ForCausalLM" + # if "architectures" is found in the sub-config, use that instead if model_type == ModelType.TEXT and text_config.get("architectures") is not None: arch = text_config["architectures"][0] elif model_type == ModelType.MMPROJ and vision_config.get("architectures") is not None: arch = vision_config["architectures"][0] + if arch is None: + raise ValueError("Failed to detect model architecture") return arch diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 1a57f1cd75a31..43456bd802103 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1878,7 +1878,8 @@ extern "C" { struct ggml_tensor * dt, struct ggml_tensor * A, struct ggml_tensor * B, - struct ggml_tensor * C); + struct ggml_tensor * C, + struct ggml_tensor * ids); // partition into non-overlapping windows with padding if needed // example: diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 08facb6d03d5e..711d8abcc5fdd 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -7596,120 +7596,210 @@ void ggml_compute_forward_ssm_conv( static void ggml_compute_forward_ssm_scan_f32( const ggml_compute_params * params, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; // s - const ggml_tensor * src1 = dst->src[1]; // x - const ggml_tensor * src2 = dst->src[2]; // dt - const ggml_tensor * src3 = dst->src[3]; // A - const ggml_tensor * src4 = dst->src[4]; // B - const ggml_tensor * src5 = dst->src[5]; // C + const ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs+} + const ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs} + const ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs} + const ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {1, n_head} + const ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs} + const ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs} + const ggml_tensor * src6 = dst->src[6]; // ids {n_seqs} const int ith = params->ith; const int nth = params->nth; - const int64_t nc = src0->ne[0]; // d_state - const int64_t nr = src0->ne[1]; // d_inner - const int64_t n_t = src1->ne[1]; // number of tokens per sequence - const int64_t n_s = src0->ne[2]; // number of sequences in the batch + const int64_t nc = src0->ne[0]; // d_state + const int64_t nr = src0->ne[1]; // dim + const int64_t nh = src1->ne[1]; // n_head + const int64_t ng = src4->ne[1]; + const int64_t nt = src1->ne[2]; // number of tokens per sequence + const int64_t ns = src1->ne[3]; // number of sequences in the batch - GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst)); + // can't use ggml_nbytes because src1 is not necessarily contiguous + const int64_t s_off = ggml_nelements(src1) * ggml_element_size(src1); + + GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*ns == ggml_nelements(dst)); GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src1->nb[0] == sizeof(float)); GGML_ASSERT(src2->nb[0] == sizeof(float)); GGML_ASSERT(src3->nb[0] == sizeof(float)); GGML_ASSERT(src4->nb[0] == sizeof(float)); GGML_ASSERT(src5->nb[0] == sizeof(float)); - // required for the dot product between s and C - GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); - // required for per-sequence offsets for states - GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float)); - // required to get correct offset for state destination (i.e. src1->nb[3]) - GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float)); + GGML_ASSERT(src6->nb[0] == sizeof(int32_t)); + // allows optimizing the modulo since n_group should be a power of 2 + GGML_ASSERT((ng & -ng) == ng); - // rows per thread - const int dr = (nr + nth - 1)/nth; + // heads per thread + const int dh = (nh + nth - 1)/nth; - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - const int ir = ir1 - ir0; + // head range for this thread + const int ih0 = dh*ith; + const int ih1 = MIN(ih0 + dh, nh); + + const int32_t * ids = (const int32_t *) src6->data; + + for (int i3 = 0; i3 < ns; ++i3) { + const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns} + float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns} + + for (int i2 = 0; i2 < nt; ++i2) { + const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns} + const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns} + const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh} + const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns} + const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns} + float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns} + + if (src3->ne[0] == 1) { + // Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop + + // n_head + for (int h = ih0; h < ih1; ++h) { + // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16 + const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h]; + const float dA = expf(dt_soft_plus * A[h]); + + // dim + for (int i1 = 0; i1 < nr; ++i1) { + const int ii = i1 + h*nr; + const float x_dt = x[ii] * dt_soft_plus; + float sumf = 0.0f; +#if defined(GGML_SIMD) + #if defined(__ARM_FEATURE_SVE) + const int ggml_f32_epr = svcntw(); + const int ggml_f32_step = 1 * ggml_f32_epr; + + const int np = (nc & ~(ggml_f32_step - 1)); + + GGML_F32_VEC sum = GGML_F32_VEC_ZERO; + + GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA); + GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt); + + for (int i = 0; i < np; i += ggml_f32_step) { + // TODO: maybe unroll more? + for (int j = 0; j < 1; j++) { + GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc); + GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + (h & (ng - 1))*nc); + GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + (h & (ng - 1))*nc); + + t0 = GGML_F32_VEC_MUL(t0, adA); + t1 = GGML_F32_VEC_MUL(t1, axdt); + + t0 = GGML_F32_VEC_ADD(t0, t1); + + sum = GGML_F32_VEC_FMA(sum, t0, t2); + + GGML_F32_VEC_STORE(s + i + j*ggml_f32_epr + ii*nc, t0); + } + } + + sumf = GGML_F32xt_REDUCE_ONE(sum); + #else + const int np = (nc & ~(GGML_F32_STEP - 1)); + + GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; + + GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA); + GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt); + + GGML_F32_VEC ax[GGML_F32_ARR]; + GGML_F32_VEC ay[GGML_F32_ARR]; + GGML_F32_VEC az[GGML_F32_ARR]; + + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc); + ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + (h & (ng - 1))*nc); + az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + (h & (ng - 1))*nc); + + ax[j] = GGML_F32_VEC_MUL(ax[j], adA); + ay[j] = GGML_F32_VEC_MUL(ay[j], axdt); + + ax[j] = GGML_F32_VEC_ADD(ax[j], ay[j]); - #ifdef __ARM_FEATURE_SVE - for (int i3 = 0; i3 < n_s; ++i3) { - for (int i2 = 0; i2 < n_t; ++i2) { - const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s} - const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} - const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s} - const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} - const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s} - const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s} - float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} - float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s} - - // use the output as the source for the next token-wise iterations - if (i2 > 0) { s0 = s; } - - // d_inner - for (int i1 = 0; i1 < ir; ++i1) { - float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1]; - float x_dt = x[i1] * dt_soft_plus; - svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt); - svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus); - svfloat32_t r1_vector = GGML_F32_VEC_ZERO; - - for (int64_t k = 0; k < nc; k += svcntw()) { - svfloat32_t vA = GGML_F32_VEC_LOAD(&A[i1*nc + k]); - svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k]); - svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k]); - svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[i1*nc + k]); - - svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA); - t1 = exp_ps_sve(svptrue_b32(), t1); - svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB); - - vs0 = GGML_F32_VEC_FMA(vs0, t1, t2); - r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector); - - GGML_F32_VEC_STORE(&s[i1*nc + k], vs0); + sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], az[j]); + + GGML_F32_VEC_STORE(s + i + j*GGML_F32_EPR + ii*nc, ax[j]); + } + } + + // reduce sum0..sum3 to sum0 + GGML_F32_VEC_REDUCE(sumf, sum); + #endif +#else + const int np = 0; +#endif + // d_state + for (int i0 = np; i0 < nc; ++i0) { + const int i = i0 + ii*nc; + const int ig = i0 + (h & (ng - 1))*nc; + // state = prev_state * dA + dB * x + const float state = (s0[i] * dA) + (B[ig] * x_dt); + // y = rowwise_dotprod(state, C) + sumf += state * C[ig]; + s[i] = state; + } + y[ii] = sumf; } - y[i1] = GGML_F32xt_REDUCE_ONE(r1_vector); } - } - } - #else - for (int i3 = 0; i3 < n_s; ++i3) { - for (int i2 = 0; i2 < n_t; ++i2) { - const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s} - const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} - const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s} - const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} - const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s} - const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s} - float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} - float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s} - - // use the output as the source for the next token-wise iterations - if (i2 > 0) { s0 = s; } - - // d_inner - for (int i1 = 0; i1 < ir; ++i1) { - // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78 - float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1]; - float x_dt = x[i1] * dt_soft_plus; - float sumf = 0.0f; - // d_state - for (int i0 = 0; i0 < nc; ++i0) { - int i = i0 + i1*nc; - // state = prev_state * dA + dB * x - float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); - // y = rowwise_dotprod(state, C) - sumf += state * C[i0]; - s[i] = state; + } else { + // Mamba-1 has an element-wise decay factor for the states + + // n_head + for (int h = ih0; h < ih1; ++h) { + // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16 + const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h]; + + // dim + for (int i1 = 0; i1 < nr; ++i1) { + const int ii = i1 + h*nr; + const float x_dt = x[ii] * dt_soft_plus; +#if defined(__ARM_FEATURE_SVE) + svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt); + svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus); + svfloat32_t r1_vector = GGML_F32_VEC_ZERO; + + // d_state + // TODO: what happens when (d_state % svcntw()) != 0? + for (int64_t k = 0; k < nc; k += svcntw()) { + svfloat32_t vA = GGML_F32_VEC_LOAD(&A[h*nc + k]); + svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + (h & (ng - 1))*nc]); + svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + (h & (ng - 1))*nc]); + svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[ii*nc + k]); + + svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA); + t1 = exp_ps_sve(svptrue_b32(), t1); + svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB); + + vs0 = GGML_F32_VEC_FMA(t2, vs0, t1); + r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector); + + GGML_F32_VEC_STORE(&s[ii*nc + k], vs0); + } + y[ii] = GGML_F32xt_REDUCE_ONE(r1_vector); +#else + float sumf = 0.0f; + // NOTE: can't really use GGML_SIMD here because d_state is usually 16 + // and also because expf is used within the loop. + // d_state + for (int i0 = 0; i0 < nc; ++i0) { + const int i = i0 + ii*nc; + const int ig = i0 + (h & (ng - 1))*nc; + // state = prev_state * dA + dB * x + const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt); + // y = rowwise_dotprod(state, C) + sumf += state * C[ig]; + s[i] = state; + } + y[ii] = sumf; +#endif } - y[i1] = sumf; } } + // use the output as the source when it's not the first token-wise iteration + s0 = s; } - #endif + } } void ggml_compute_forward_ssm_scan( diff --git a/ggml/src/ggml-cpu/simd-mappings.h b/ggml/src/ggml-cpu/simd-mappings.h index 2e3669c0186c9..91bb867bf57b8 100644 --- a/ggml/src/ggml-cpu/simd-mappings.h +++ b/ggml/src/ggml-cpu/simd-mappings.h @@ -32,7 +32,7 @@ #define GGML_F32xt_LOAD(...) GGML_F32xt_LOAD_IMPL(DEFAULT_PG, __VA_ARGS__) #define GGML_F32xt_STORE_IMPL(pg,a,b) svst1_f32(pg, a, b) #define GGML_F32xt_STORE(...) GGML_F32xt_STORE_IMPL(DEFAULT_PG, __VA_ARGS__) -#define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, a, b, c) +#define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, b, c, a) #define GGML_F32xt_FMA(...) GGML_F32xt_FMA_IMPL(DEFAULT_PG, __VA_ARGS__) #define GGML_F32xt_ADD_IMPL(pg, a, b) svadd_f32_m(pg, a, b) #define GGML_F32xt_ADD(...) GGML_F32xt_ADD_IMPL(DEFAULT_PG, __VA_ARGS__) diff --git a/ggml/src/ggml-cpu/vec.cpp b/ggml/src/ggml-cpu/vec.cpp index f7614568ea388..7e61b5bf965a3 100644 --- a/ggml/src/ggml-cpu/vec.cpp +++ b/ggml/src/ggml-cpu/vec.cpp @@ -37,35 +37,35 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G for (int i = 0; i < np; i += ggml_f32_step) { ax1 = GGML_F32_VEC_LOAD(x + i); ay1 = GGML_F32_VEC_LOAD(y + i); - sum1 = GGML_F32_VEC_FMA(ax1, ay1, sum1); + sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1); ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr); ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr); - sum2 = GGML_F32_VEC_FMA(ax2, ay2, sum2); + sum2 = GGML_F32_VEC_FMA(sum2, ax2, ay2); ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr); ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr); - sum3 = GGML_F32_VEC_FMA(ax3, ay3, sum3); + sum3 = GGML_F32_VEC_FMA(sum3, ax3, ay3); ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr); ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr); - sum4 = GGML_F32_VEC_FMA(ax4, ay4, sum4); + sum4 = GGML_F32_VEC_FMA(sum4, ax4, ay4); ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr); ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr); - sum5 = GGML_F32_VEC_FMA(ax5, ay5, sum5); + sum5 = GGML_F32_VEC_FMA(sum5, ax5, ay5); ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr); ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr); - sum6 = GGML_F32_VEC_FMA(ax6, ay6, sum6); + sum6 = GGML_F32_VEC_FMA(sum6, ax6, ay6); ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr); ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr); - sum7 = GGML_F32_VEC_FMA(ax7, ay7, sum7); + sum7 = GGML_F32_VEC_FMA(sum7, ax7, ay7); ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr); ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr); - sum8 = GGML_F32_VEC_FMA(ax8, ay8, sum8); + sum8 = GGML_F32_VEC_FMA(sum8, ax8, ay8); } // leftovers // Since 8 unrolls are done in above loop, leftovers lie in range [0, ggml_f32_step] which is handled in below loop @@ -73,7 +73,7 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G for (int i = np; i < np2; i += ggml_f32_epr) { ax1 = GGML_F32_VEC_LOAD(x + i); ay1 = GGML_F32_VEC_LOAD(y + i); - sum1 = GGML_F32_VEC_FMA(ax1, ay1, sum1); + sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1); } // maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only if (np2 < n) { diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index 09dbade2179fb..a144259800477 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -163,49 +163,49 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const ax1 = GGML_F32_VEC_LOAD(x + i); ay1 = GGML_F32_VEC_LOAD(y + i); - ay1 = GGML_F32_VEC_FMA(ax1, vx, ay1); + ay1 = GGML_F32_VEC_FMA(ay1, ax1, vx); GGML_F32_VEC_STORE(y + i, ay1); ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr); ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr); - ay2 = GGML_F32_VEC_FMA(ax2, vx, ay2); + ay2 = GGML_F32_VEC_FMA(ay2, ax2, vx); GGML_F32_VEC_STORE(y + i + 1*ggml_f32_epr, ay2); ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr); ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr); - ay3 = GGML_F32_VEC_FMA(ax3, vx, ay3); + ay3 = GGML_F32_VEC_FMA(ay3, ax3, vx); GGML_F32_VEC_STORE(y + i + 2*ggml_f32_epr, ay3); ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr); ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr); - ay4 = GGML_F32_VEC_FMA(ax4, vx, ay4); + ay4 = GGML_F32_VEC_FMA(ay4, ax4, vx); GGML_F32_VEC_STORE(y + i + 3*ggml_f32_epr, ay4); ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr); ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr); - ay5 = GGML_F32_VEC_FMA(ax5, vx, ay5); + ay5 = GGML_F32_VEC_FMA(ay5, ax5, vx); GGML_F32_VEC_STORE(y + i + 4*ggml_f32_epr, ay5); ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr); ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr); - ay6 = GGML_F32_VEC_FMA(ax6, vx, ay6); + ay6 = GGML_F32_VEC_FMA(ay6, ax6, vx); GGML_F32_VEC_STORE(y + i + 5*ggml_f32_epr, ay6); ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr); ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr); - ay7 = GGML_F32_VEC_FMA(ax7, vx, ay7); + ay7 = GGML_F32_VEC_FMA(ay7, ax7, vx); GGML_F32_VEC_STORE(y + i + 6*ggml_f32_epr, ay7); ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr); ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr); - ay8 = GGML_F32_VEC_FMA(ax8, vx, ay8); + ay8 = GGML_F32_VEC_FMA(ay8, ax8, vx); GGML_F32_VEC_STORE(y + i + 7*ggml_f32_epr, ay8); } @@ -215,7 +215,7 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const for (int i = np; i < np2; i += ggml_f32_epr) { ax1 = GGML_F32_VEC_LOAD(x + i); ay1 = GGML_F32_VEC_LOAD(y + i); - ay1 = GGML_F32_VEC_FMA(ax1, vx, ay1); + ay1 = GGML_F32_VEC_FMA(ay1, ax1, vx); GGML_F32_VEC_STORE(y + i, ay1); } diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 17eab976f3ad1..9942ff5fbf4d7 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -488,26 +488,25 @@ typedef struct { typedef struct { int64_t d_state; int64_t d_inner; + int64_t n_head; + int64_t n_group; int64_t n_seq_tokens; int64_t n_seqs; - uint64_t nb00; uint64_t nb01; uint64_t nb02; - uint64_t nb10; + uint64_t nb03; uint64_t nb11; uint64_t nb12; uint64_t nb13; - uint64_t nb20; uint64_t nb21; uint64_t nb22; - uint64_t nb30; uint64_t nb31; - uint64_t nb40; uint64_t nb41; uint64_t nb42; - uint64_t nb50; + uint64_t nb43; uint64_t nb51; uint64_t nb52; + uint64_t nb53; } ggml_metal_kargs_ssm_scan; typedef struct { diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index bc93bc633a49b..9fee938242936 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -191,6 +191,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte GGML_METAL_KERNEL_TYPE_NORM, GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, + GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, @@ -1147,6 +1148,7 @@ @implementation GGMLMetalClass GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, ssm_scan_f32_group, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction); @@ -2632,71 +2634,91 @@ static bool ggml_metal_encode_node( struct ggml_tensor * src3 = node->src[3]; struct ggml_tensor * src4 = node->src[4]; struct ggml_tensor * src5 = node->src[5]; + struct ggml_tensor * src6 = node->src[6]; GGML_ASSERT(src3); GGML_ASSERT(src4); GGML_ASSERT(src5); + GGML_ASSERT(src6); size_t offs_src3 = 0; size_t offs_src4 = 0; size_t offs_src5 = 0; + size_t offs_src6 = 0; id id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; id id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil; id id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil; + id id_src6 = src6 ? ggml_metal_get_buffer(src6, &offs_src6) : nil; - const int64_t ne30 = src3->ne[0]; GGML_UNUSED(ne30); + const int64_t ne30 = src3->ne[0]; const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31); - const uint64_t nb30 = src3->nb[0]; + const uint64_t nb30 = src3->nb[0]; GGML_UNUSED(nb30); const uint64_t nb31 = src3->nb[1]; const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40); - const int64_t ne41 = src4->ne[1]; GGML_UNUSED(ne41); + const int64_t ne41 = src4->ne[1]; const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42); + const int64_t ne43 = src4->ne[3]; GGML_UNUSED(ne43); - const uint64_t nb40 = src4->nb[0]; + const uint64_t nb40 = src4->nb[0]; GGML_UNUSED(nb40); const uint64_t nb41 = src4->nb[1]; const uint64_t nb42 = src4->nb[2]; + const uint64_t nb43 = src4->nb[3]; const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50); const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51); const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52); + const int64_t ne53 = src5->ne[3]; GGML_UNUSED(ne53); - const uint64_t nb50 = src5->nb[0]; + const uint64_t nb50 = src5->nb[0]; GGML_UNUSED(nb50); const uint64_t nb51 = src5->nb[1]; const uint64_t nb52 = src5->nb[2]; + const uint64_t nb53 = src5->nb[3]; + + const int64_t ne60 = src6->ne[0]; GGML_UNUSED(ne60); + + const uint64_t nb60 = src6->nb[0]; GGML_UNUSED(nb60); const int64_t d_state = ne00; const int64_t d_inner = ne01; - const int64_t n_seq_tokens = ne11; - const int64_t n_seqs = ne02; + const int64_t n_head = ne02; + const int64_t n_group = ne41; + const int64_t n_seq_tokens = ne12; + const int64_t n_seqs = ne13; + + id pipeline = nil; - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; + if (ne30 == 1) { + // Mamba-2 + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; + } ggml_metal_kargs_ssm_scan args = { - /*.d_state =*/ d_state, - /*.d_inner =*/ d_inner, + /*.d_state =*/ d_state, + /*.d_inner =*/ d_inner, + /*.n_head =*/ n_head, + /*.n_group =*/ n_group, /*.n_seq_tokens =*/ n_seq_tokens, - /*.n_seqs =*/ n_seqs, - /*.nb00 =*/ nb00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb10 =*/ nb10, - /*.nb11 =*/ nb11, - /*.nb12 =*/ nb12, - /*.nb13 =*/ nb13, - /*.nb20 =*/ nb20, - /*.nb21 =*/ nb21, - /*.nb22 =*/ nb22, - /*.nb30 =*/ nb30, - /*.nb31 =*/ nb31, - /*.nb40 =*/ nb40, - /*.nb41 =*/ nb41, - /*.nb42 =*/ nb42, - /*.nb50 =*/ nb50, - /*.nb51 =*/ nb51, - /*.nb52 =*/ nb52, + /*.n_seqs =*/ n_seqs, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.nb21 =*/ nb21, + /*.nb22 =*/ nb22, + /*.nb31 =*/ nb31, + /*.nb41 =*/ nb41, + /*.nb42 =*/ nb42, + /*.nb43 =*/ nb43, + /*.nb51 =*/ nb51, + /*.nb52 =*/ nb52, + /*.nb53 =*/ nb53, }; [encoder setComputePipelineState:pipeline]; @@ -2706,10 +2728,17 @@ static bool ggml_metal_encode_node( [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4]; [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:6]; - [encoder setBytes:&args length:sizeof(args) atIndex:7]; + [encoder setBuffer:id_src6 offset:offs_src6 atIndex:6]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:7]; + [encoder setBytes:&args length:sizeof(args) atIndex:8]; - [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + if (ne30 == 1) { + // Mamba-2 + [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } else { + GGML_ASSERT(d_inner == 1); + [encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } } break; case GGML_OP_RWKV_WKV6: { diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 5d7760217f826..232276a5dac2f 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1293,7 +1293,7 @@ kernel void kernel_ssm_conv_f32( x[0] = sumf; } -// ref: ggml.c:ggml_compute_forward_ssm_scan_f32 +// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part kernel void kernel_ssm_scan_f32( device const void * src0, device const void * src1, @@ -1301,46 +1301,119 @@ kernel void kernel_ssm_scan_f32( device const void * src3, device const void * src4, device const void * src5, + device const void * src6, device float * dst, constant ggml_metal_kargs_ssm_scan & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { - const int64_t ir = tgpig.x; - const int64_t i3 = tgpig.y; + const int64_t i1 = 0; + const int64_t ir = tgpig.x; // current head + const int64_t i3 = tgpig.y; // current seq + + const uint64_t nb00 = sizeof(float); + const uint64_t nb10 = sizeof(float); + const uint64_t nb20 = sizeof(float); const int64_t nc = args.d_state; - // const int64_t nr = args.d_inner; + const int64_t nr = args.d_inner; + const int64_t nh = args.n_head; + const int64_t ng = args.n_group; const int64_t n_t = args.n_seq_tokens; - // const int64_t n_s = args.n_seqs; + + const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float); + + device const int32_t * ids = (device const int32_t *) src6; + + device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); + device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); for (int64_t i2 = 0; i2 < n_t; ++i2) { - device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb01 + i3*args.nb02); - device const float * x = (device const float *) ((device const char *) src1 + ir*args.nb10 + i2*args.nb11 + i3*args.nb12); - device const float * dt = (device const float *) ((device const char *) src2 + ir*args.nb20 + i2*args.nb21 + i3*args.nb22); - device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); - device const float * B = (device const float *) ((device const char *) src4 + i2*args.nb41 + i3*args.nb42); - device const float * C = (device const float *) ((device const char *) src5 + i2*args.nb51 + i3*args.nb52); - device float * y = (device float *) ((device char *) dst + ir*args.nb10 + i2*args.nb11 + i3*args.nb12); // TODO: do not use src1 strides - device float * s = (device float *) ((device char *) dst + ir*args.nb01 + i3*args.nb02 + args.nb13); - - if (i2 > 0) { - s0 = s; - } - - // i1 == 0 - float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; - float x_dt = x[0] * dt_soft_plus; + device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns} + device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns} + device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {d_state, nh} + device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns} + device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns} + device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns} + + const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; + const float x_dt = x[0] * dt_soft_plus; float sumf = 0.0f; for (int64_t i0 = 0; i0 < nc; ++i0) { - int64_t i = i0; - float state = (s0[i] * exp(dt_soft_plus * A[i])) + (B[i0] * x_dt); + const int64_t i = i0 + i1*nc; + const float state = (s0[i] * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt); sumf += state * C[i0]; s[i] = state; } y[0] = sumf; + + // recurse + s0 = s; + } +} + +// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part +// TODO: optimize (e.g. by parallelizing over d_state) +kernel void kernel_ssm_scan_f32_group( + device const void * src0, + device const void * src1, + device const void * src2, + device const void * src3, + device const void * src4, + device const void * src5, + device const void * src6, + device float * dst, + constant ggml_metal_kargs_ssm_scan & args, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i1 = tgpig.x; + const int64_t ir = tgpig.y; // current head + const int64_t i3 = tgpig.z; // current seq + + const uint64_t nb00 = sizeof(float); + const uint64_t nb10 = sizeof(float); + const uint64_t nb20 = sizeof(float); + + const int64_t nc = args.d_state; + const int64_t nr = args.d_inner; + const int64_t nh = args.n_head; + const int64_t ng = args.n_group; + const int64_t n_t = args.n_seq_tokens; + + const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float); + + device const int32_t * ids = (device const int32_t *) src6; + + device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); + device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); + + for (int64_t i2 = 0; i2 < n_t; ++i2) { + device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns} + device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns} + device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh} + device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns} + device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns} + device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns} + + const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; + const float x_dt = x[0] * dt_soft_plus; + const float dA = exp(dt_soft_plus * A[0]); + float sumf = 0.0f; + + for (int64_t i0 = 0; i0 < nc; ++i0) { + const int64_t i = i0 + i1*nc; + const float state = (s0[i] * dA) + (B[i0] * x_dt); + sumf += state * C[i0]; + s[i] = state; + } + + y[0] = sumf; + + // recurse + s0 = s; } } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 196b7b8f3e2ae..d798f919fd27f 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -4575,7 +4575,6 @@ struct ggml_tensor * ggml_ssm_conv( const int64_t n_s = sx->ne[2]; // TODO: maybe support other strides than 1? - // FIXME: this is always true? GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t); GGML_ASSERT(sx->ne[1] == d_inner); GGML_ASSERT(n_t >= 0); @@ -4598,36 +4597,49 @@ struct ggml_tensor * ggml_ssm_scan( struct ggml_tensor * dt, struct ggml_tensor * A, struct ggml_tensor * B, - struct ggml_tensor * C) { + struct ggml_tensor * C, + struct ggml_tensor * ids) { GGML_ASSERT(ggml_is_contiguous(s)); - GGML_ASSERT(ggml_is_contiguous(x)); GGML_ASSERT(ggml_is_contiguous(dt)); GGML_ASSERT(ggml_is_contiguous(A)); - GGML_ASSERT(ggml_is_matrix(A)); - GGML_ASSERT(ggml_is_3d(B)); - GGML_ASSERT(ggml_is_3d(s)); + GGML_ASSERT(x->nb[0] == ggml_type_size(x->type)); GGML_ASSERT(B->nb[0] == ggml_type_size(B->type)); GGML_ASSERT(C->nb[0] == ggml_type_size(C->type)); - GGML_ASSERT(ggml_are_same_shape(x, dt)); + GGML_ASSERT(x->nb[1] == x->ne[0]*x->nb[0]); + GGML_ASSERT(B->nb[1] == B->ne[0]*B->nb[0]); + GGML_ASSERT(C->nb[1] == C->ne[0]*C->nb[0]); GGML_ASSERT(ggml_are_same_shape(B, C)); + GGML_ASSERT(ids->type == GGML_TYPE_I32); { const int64_t d_state = s->ne[0]; - const int64_t d_inner = s->ne[1]; - const int64_t n_seq_tokens = x->ne[1]; - const int64_t n_seqs = x->ne[2]; - - GGML_ASSERT(s->ne[2] == n_seqs); - GGML_ASSERT(x->ne[0] == d_inner); - GGML_ASSERT(A->ne[0] == d_state); - GGML_ASSERT(A->ne[1] == d_inner); + const int64_t head_dim = x->ne[0]; + const int64_t n_head = x->ne[1]; + const int64_t n_seq_tokens = x->ne[2]; + const int64_t n_seqs = x->ne[3]; + + GGML_ASSERT(dt->ne[0] == n_head); + GGML_ASSERT(dt->ne[1] == n_seq_tokens); + GGML_ASSERT(dt->ne[2] == n_seqs); + GGML_ASSERT(ggml_is_3d(dt)); + GGML_ASSERT(s->ne[1] == head_dim); + GGML_ASSERT(s->ne[2] == n_head); GGML_ASSERT(B->ne[0] == d_state); - GGML_ASSERT(B->ne[1] == n_seq_tokens); - GGML_ASSERT(B->ne[2] == n_seqs); + GGML_ASSERT(B->ne[2] == n_seq_tokens); + GGML_ASSERT(B->ne[3] == n_seqs); + GGML_ASSERT(ids->ne[0] == n_seqs); + GGML_ASSERT(ggml_is_vector(ids)); + GGML_ASSERT(A->ne[1] == n_head); + GGML_ASSERT(ggml_is_matrix(A)); + + if (A->ne[0] != 1) { + // Mamba-1 has more granular decay factors + GGML_ASSERT(A->ne[0] == d_state); + } } // concatenated y + ssm_states - struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s)); + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + s->ne[0]*s->ne[1]*s->ne[2]*ids->ne[0]); result->op = GGML_OP_SSM_SCAN; result->src[0] = s; @@ -4636,6 +4648,7 @@ struct ggml_tensor * ggml_ssm_scan( result->src[3] = A; result->src[4] = B; result->src[5] = C; + result->src[6] = ids; return result; } diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 834a1d5e1a97e..9476f1a1b11c2 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -164,7 +164,9 @@ class SSM: INNER_SIZE = "{arch}.ssm.inner_size" STATE_SIZE = "{arch}.ssm.state_size" TIME_STEP_RANK = "{arch}.ssm.time_step_rank" + GROUP_COUNT = "{arch}.ssm.group_count" DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms" + HEAD_DIM = "{arch}.ssm.head_dim" class WKV: HEAD_SIZE = "{arch}.wkv.head_size" @@ -319,6 +321,7 @@ class MODEL_ARCH(IntEnum): RWKV7 = auto() ARWKV7 = auto() MAMBA = auto() + MAMBA2 = auto() XVERSE = auto() COMMAND_R = auto() COHERE2 = auto() @@ -346,6 +349,7 @@ class MODEL_ARCH(IntEnum): BAILINGMOE = auto() DOTS1 = auto() ARCEE = auto() + FALCON_H1 = auto() class VISION_PROJECTOR_TYPE(IntEnum): @@ -404,7 +408,9 @@ class MODEL_TENSOR(IntEnum): SSM_DT = auto() SSM_A = auto() SSM_D = auto() + SSM_NORM = auto() SSM_OUT = auto() + SSM_MUP_VEC = auto() TIME_MIX_W0 = auto() TIME_MIX_W1 = auto() TIME_MIX_W2 = auto() @@ -602,6 +608,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.RWKV7: "rwkv7", MODEL_ARCH.ARWKV7: "arwkv7", MODEL_ARCH.MAMBA: "mamba", + MODEL_ARCH.MAMBA2: "mamba2", MODEL_ARCH.XVERSE: "xverse", MODEL_ARCH.COMMAND_R: "command-r", MODEL_ARCH.COHERE2: "cohere2", @@ -629,6 +636,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.BAILINGMOE: "bailingmoe", MODEL_ARCH.DOTS1: "dots1", MODEL_ARCH.ARCEE: "arcee", + MODEL_ARCH.FALCON_H1: "falcon-h1", } VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = { @@ -666,7 +674,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_GATE_INP: "blk.{bid}.ffn_gate_inp", MODEL_TENSOR.FFN_GATE_INP_SHEXP: "blk.{bid}.ffn_gate_inp_shexp", MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm", - MODEL_TENSOR.FFN_PRE_NORM: "blk.{bid}.ffn_norm", + MODEL_TENSOR.FFN_PRE_NORM: "blk.{bid}.ffn_pre_norm", MODEL_TENSOR.FFN_POST_NORM: "blk.{bid}.post_ffw_norm", MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate", MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down", @@ -687,7 +695,9 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt", MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a", MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d", + MODEL_TENSOR.SSM_NORM: "blk.{bid}.ssm_norm", MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out", + MODEL_TENSOR.SSM_MUP_VEC: "blk.{bid}.ssm_mup_vec", MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0", MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1", MODEL_TENSOR.TIME_MIX_W2: "blk.{bid}.time_mix_w2", @@ -1636,6 +1646,19 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.SSM_D, MODEL_TENSOR.SSM_OUT, ], + MODEL_ARCH.MAMBA2: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.SSM_IN, + MODEL_TENSOR.SSM_CONV1D, + MODEL_TENSOR.SSM_DT, + MODEL_TENSOR.SSM_A, + MODEL_TENSOR.SSM_D, + MODEL_TENSOR.SSM_NORM, + MODEL_TENSOR.SSM_OUT, + ], MODEL_ARCH.XVERSE: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, @@ -2156,6 +2179,41 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.BAILINGMOE: [ MODEL_TENSOR.ROPE_FREQS, ], + MODEL_ARCH.FALCON_H1: [ + # Token embedding + MODEL_TENSOR.TOKEN_EMBD, + + # Input layernorm + MODEL_TENSOR.ATTN_NORM, + + # Attention components + MODEL_TENSOR.ATTN_Q, # Query projection + MODEL_TENSOR.ATTN_K, # Key projection + MODEL_TENSOR.ATTN_V, # Value projection + MODEL_TENSOR.ATTN_OUT, # Output projection + + # SSM components (Mamba2 specific) + MODEL_TENSOR.SSM_MUP_VEC, # Mup vector + MODEL_TENSOR.SSM_IN, # Input projection for SSM + MODEL_TENSOR.SSM_CONV1D, # Convolution layer + MODEL_TENSOR.SSM_DT, # Delta time projection + MODEL_TENSOR.SSM_A, # A parameter (log form) + MODEL_TENSOR.SSM_D, # D parameter + MODEL_TENSOR.SSM_NORM, # Normalization in SSM + MODEL_TENSOR.SSM_OUT, # Output projection + + # Pre-feedforward layernorm + MODEL_TENSOR.FFN_PRE_NORM, + + # Feed-forward network components + MODEL_TENSOR.FFN_GATE, # Gate projection (SwiGLU) + MODEL_TENSOR.FFN_DOWN, # Down projection + MODEL_TENSOR.FFN_UP, # Up projection + + # Post-feedforward layernorm + MODEL_TENSOR.OUTPUT_NORM, # Final layer norm + MODEL_TENSOR.OUTPUT, # Output projection (lm_head) + ], } # @@ -2405,6 +2463,7 @@ class VisionProjectorType: KEY_SSM_INNER_SIZE = Keys.SSM.INNER_SIZE KEY_SSM_STATE_SIZE = Keys.SSM.STATE_SIZE KEY_SSM_TIME_STEP_RANK = Keys.SSM.TIME_STEP_RANK +KEY_SSM_GROUP_COUNT = Keys.SSM.GROUP_COUNT KEY_SSM_DT_B_C_RMS = Keys.SSM.DT_B_C_RMS # tokenization diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 54ca0c33fd336..fde78894d4fac 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -843,9 +843,22 @@ def add_ssm_state_size(self, value: int) -> None: def add_ssm_time_step_rank(self, value: int) -> None: self.add_uint32(Keys.SSM.TIME_STEP_RANK.format(arch=self.arch), value) + def add_ssm_group_count(self, value: int) -> None: + self.add_uint32(Keys.SSM.GROUP_COUNT.format(arch=self.arch), value) + def add_ssm_dt_b_c_rms(self, value: bool) -> None: self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value) + def add_ssm_head_dim(self, value: int) -> None: + self.add_uint32(Keys.SSM.HEAD_DIM.format(arch=self.arch), value) + + def add_attn_head_count(self, count: int) -> None: + self.add_uint32(Keys.Attention.HEAD_COUNT.format(arch=self.arch), count) + + def add_key_value_head_count(self, count: int) -> None: + self.add_uint32(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count) + + def add_tokenizer_model(self, model: str) -> None: self.add_string(Keys.Tokenizer.MODEL, model) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 79f044d2a5945..79d336539c483 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -286,6 +286,7 @@ class TensorNameMap: # Post feed-forward norm MODEL_TENSOR.FFN_PRE_NORM: ( "model.layers.{bid}.pre_feedforward_layernorm", # gemma2 + "model.layers.{bid}.pre_ff_layernorm.weight", # falcon-h1 ), # Post feed-forward norm @@ -356,6 +357,7 @@ class TensorNameMap: "model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged) "model.layers.{bid}.feed_forward.experts.up_proj", # llama4 "encoder.layers.{bid}.mlp.experts.mlp.w1", # nomic-bert-moe + "model.layers.{bid}.feed_forward.up_proj", # falcon-h1 ), MODEL_TENSOR.FFN_UP_SHEXP: ( @@ -392,6 +394,7 @@ class TensorNameMap: "model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged) "model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged) "model.layers.{bid}.feed_forward.experts.gate_proj", # llama4 + "model.layers.{bid}.feed_forward.down_proj", # falcon-h1 ), MODEL_TENSOR.FFN_GATE_SHEXP: ( @@ -431,6 +434,14 @@ class TensorNameMap: "transformer_encoder.{bid}.ffn.w3", # neobert ), + MODEL_TENSOR.SSM_MUP_VEC: ( + "model.layers.{bid}.mamba.mup_vector", # falcon-h1 + ), + + MODEL_TENSOR.SSM_NORM: ( + "model.layers.{bid}.mamba.norm", + ), + MODEL_TENSOR.FFN_DOWN_EXP: ( "layers.{bid}.feed_forward.experts.w2", # mixtral (merged) "transformer.decoder_layer.{bid}.moe.linear_1", # Grok (merged) @@ -477,17 +488,19 @@ class TensorNameMap: "encoder.layers.{bid}.norm2", # nomic-bert "transformer.decoder_layer.{bid}.rms_norm_3", # Grok "encoder.layer.{bid}.mlp.layernorm", # jina-bert-v2 - "encoder.layer.{bid}.layer_norm_2" # jina-v2-code + "encoder.layer.{bid}.layer_norm_2", # jina-v2-code ), MODEL_TENSOR.SSM_IN: ( "model.layers.{bid}.in_proj", "backbone.layers.{bid}.mixer.in_proj", + "model.layers.{bid}.mamba.in_proj", # falcon-h1 ), MODEL_TENSOR.SSM_CONV1D: ( "model.layers.{bid}.conv1d", "backbone.layers.{bid}.mixer.conv1d", + "model.layers.{bid}.mamba.conv1d", # falcon-h1 ), MODEL_TENSOR.SSM_X: ( @@ -498,21 +511,29 @@ class TensorNameMap: MODEL_TENSOR.SSM_DT: ( "model.layers.{bid}.dt_proj", "backbone.layers.{bid}.mixer.dt_proj", + "model.layers.{bid}.mamba.dt_proj", # falcon-h1 ), MODEL_TENSOR.SSM_A: ( "model.layers.{bid}.A_log", "backbone.layers.{bid}.mixer.A_log", + "model.layers.{bid}.mamba.A_log", # falcon-h1 ), MODEL_TENSOR.SSM_D: ( "model.layers.{bid}.D", "backbone.layers.{bid}.mixer.D", + "model.layers.{bid}.mamba.D", # falcon-h1 + ), + + MODEL_TENSOR.SSM_NORM: ( + "backbone.layers.{bid}.mixer.norm", # mamba2 ), MODEL_TENSOR.SSM_OUT: ( "model.layers.{bid}.out_proj", "backbone.layers.{bid}.mixer.out_proj", + "model.layers.{bid}.mamba.out_proj", # falcon-h1 ), MODEL_TENSOR.TIME_MIX_W0: ( diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 70be604e4b0d3..e70c8d9a183df 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -23,6 +23,7 @@ add_library(llama llama-kv-cache-unified.cpp llama-kv-cache-unified-iswa.cpp llama-kv-cache-recurrent.cpp + llama-kv-cache-hybrid-recurrent.cpp llama-memory.cpp llama-mmap.cpp llama-model-loader.cpp diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index de8d289cf967e..3bda84d389831 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -44,6 +44,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_GEMMA3, "gemma3" }, { LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_MAMBA, "mamba" }, + { LLM_ARCH_MAMBA2, "mamba2" }, { LLM_ARCH_XVERSE, "xverse" }, { LLM_ARCH_COMMAND_R, "command-r" }, { LLM_ARCH_COHERE2, "cohere2" }, @@ -73,6 +74,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" }, { LLM_ARCH_PLM, "plm" }, { LLM_ARCH_BAILINGMOE, "bailingmoe" }, + { LLM_ARCH_FALCON_H1, "falcon-h1" }, { LLM_ARCH_DOTS1, "dots1" }, { LLM_ARCH_ARCEE, "arcee" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, @@ -124,6 +126,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" }, { LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" }, { LLM_KV_INTERLEAVE_MOE_LAYER_STEP, "%s.interleave_moe_layer_step" }, + { LLM_KV_ATTN_HEAD_DIM, "%s.attention.head_dim" }, { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, @@ -147,6 +150,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, { LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" }, { LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" }, + { LLM_KV_ATTENTION_LAYER_INDICES, "%s.attention.layer_indices" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, @@ -163,11 +167,30 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_SPLIT_COUNT, "split.count" }, { LLM_KV_SPLIT_TENSORS_COUNT, "split.tensors.count" }, - { LLM_KV_SSM_CONV_KERNEL, "%s.ssm.conv_kernel" }, - { LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" }, - { LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" }, - { LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" }, - { LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" }, + { LLM_KV_SSM_CONV_KERNEL, "%s.ssm.conv_kernel" }, + { LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" }, + { LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" }, + { LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" }, + { LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" }, + { LLM_KV_SSM_GROUP_COUNT, "%s.ssm.group_count" }, + { LLM_KV_SSM_HEAD_DIM, "%s.ssm.head_dim" }, + { LLM_KV_MAMBA_D_SSM, "%s.ssm.mamba_d_ssm" }, + + { LLM_KV_FALCON_H1_USE_MLP, "%s.mamba_use_mlp" }, + { LLM_KV_FALCON_H1_ATTENTION_IN_MULTIPLIER, "%s.attention_in_multiplier" }, + { LLM_KV_FALCON_H1_ATTENTION_OUT_MULTIPLIER, "%s.attention_out_multiplier" }, + { LLM_KV_FALCON_H1_SSM_IN_MULTIPLIER, "%s.ssm_in_multiplier" }, + { LLM_KV_FALCON_H1_SSM_OUT_MULTIPLIER, "%s.ssm_out_multiplier" }, + { LLM_KV_FALCON_H1_MLP_GATE_MULTIPLIER, "%s.mlp_gate_multiplier" }, + { LLM_KV_FALCON_H1_MLP_DOWN_MULTIPLIER, "%s.mlp_down_multiplier" }, + { LLM_KV_FALCON_H1_SSM_HAS_MUP, "%s.ssm.has_mup" }, + { LLM_KV_FALCON_H1_MAMBA_NORM_BEFORE_GATE, "%s.mamba_norm_before_gate" }, + { LLM_KV_FALCON_H1_MAMBA_RMS_NORM, "%s.mamba_rms_norm" }, + { LLM_KV_FALCON_H1_ROPE_THETA, "%s.rope_theta" }, + { LLM_KV_FALCON_H1_KEY_MULTIPLIER, "%s.key_multiplier" }, + { LLM_KV_FALCON_H1_LM_HEAD_MULTIPLIER, "%s.lm_head_multiplier" }, + { LLM_KV_FALCON_H1_EMBEDDING_MULTIPLIER, "%s.embedding_multiplier" }, + { LLM_KV_FALCON_H1_MAMBA_CHUNK_SIZE, "%s.ssm.mamba_chunk_size" }, { LLM_KV_WKV_HEAD_SIZE, "%s.wkv.head_size" }, @@ -352,6 +375,31 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_FALCON_H1, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, + { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, + { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, + { LLM_TENSOR_SSM_MUP_VEC, "blk.%d.ssm_mup_vec" }, + { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, + { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" }, + { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, + { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, + { LLM_TENSOR_FFN_PRE_NORM, "blk.%d.ffn_pre_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_GROK, { @@ -964,6 +1012,22 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, }, }, + { + LLM_ARCH_MAMBA2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, + { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, + { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, + { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, + { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" }, + { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, + { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, + }, + }, { LLM_ARCH_XVERSE, { @@ -1636,6 +1700,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + {LLM_TENSOR_FINAL_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, {LLM_TENSOR_DEC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, {LLM_TENSOR_ENC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, {LLM_TENSOR_ROPE_FREQS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ROPE}}, @@ -1704,6 +1769,8 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_SSM_CONV1D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}}, {LLM_TENSOR_SSM_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_SCAN}}, {LLM_TENSOR_SSM_D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_SSM_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_SSM_MUP_VEC, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_LERP_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, @@ -1727,6 +1794,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_ATTN_OUT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_ATTN_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_FFN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_FFN_PRE_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_FFN_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_FFN_NORM_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_ATTN_Q_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, @@ -1816,3 +1884,28 @@ llm_arch llm_arch_from_string(const std::string & name) { const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) { return LLM_TENSOR_INFOS.at(tensor); } + +bool llm_arch_is_recurrent(const llm_arch & arch) { + switch (arch) { + case LLM_ARCH_MAMBA: + case LLM_ARCH_MAMBA2: + case LLM_ARCH_RWKV6: + case LLM_ARCH_RWKV6QWEN2: + case LLM_ARCH_RWKV7: + case LLM_ARCH_ARWKV7: + return true; + default: + return false; + } +} + +bool llm_arch_is_hybrid_recurrent(const llm_arch & arch) { + // TODO: There are currently no hybrid models! Once there are, this will be + // the place to identify them + switch (arch) { + case LLM_ARCH_FALCON_H1: + return true; + default: + return false; + } +} diff --git a/src/llama-arch.h b/src/llama-arch.h index 3e8a61da3c13e..14acd4bd97b7f 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -13,6 +13,7 @@ enum llm_arch { LLM_ARCH_LLAMA4, LLM_ARCH_DECI, LLM_ARCH_FALCON, + LLM_ARCH_FALCON_H1, LLM_ARCH_BAICHUAN, LLM_ARCH_GROK, LLM_ARCH_GPT2, @@ -48,6 +49,7 @@ enum llm_arch { LLM_ARCH_GEMMA3, LLM_ARCH_STARCODER2, LLM_ARCH_MAMBA, + LLM_ARCH_MAMBA2, LLM_ARCH_XVERSE, LLM_ARCH_COMMAND_R, LLM_ARCH_COHERE2, @@ -151,6 +153,28 @@ enum llm_kv { LLM_KV_ATTENTION_SCALE, LLM_KV_ATTENTION_KEY_LENGTH_MLA, LLM_KV_ATTENTION_VALUE_LENGTH_MLA, + LLM_KV_ATTENTION_LAYER_INDICES, + + // Falcon-H1 specific + LLM_KV_ATTN_HEAD_DIM, + LLM_KV_SSM_HEAD_DIM, + LLM_KV_MAMBA_D_SSM, + LLM_KV_N_LAYER, + LLM_KV_FALCON_H1_USE_MLP, + LLM_KV_FALCON_H1_ATTENTION_IN_MULTIPLIER, + LLM_KV_FALCON_H1_ATTENTION_OUT_MULTIPLIER, + LLM_KV_FALCON_H1_SSM_IN_MULTIPLIER, + LLM_KV_FALCON_H1_SSM_OUT_MULTIPLIER, + LLM_KV_FALCON_H1_MLP_GATE_MULTIPLIER, + LLM_KV_FALCON_H1_MLP_DOWN_MULTIPLIER, + LLM_KV_FALCON_H1_SSM_HAS_MUP, + LLM_KV_FALCON_H1_MAMBA_NORM_BEFORE_GATE, + LLM_KV_FALCON_H1_MAMBA_RMS_NORM, + LLM_KV_FALCON_H1_ROPE_THETA, + LLM_KV_FALCON_H1_KEY_MULTIPLIER, + LLM_KV_FALCON_H1_LM_HEAD_MULTIPLIER, + LLM_KV_FALCON_H1_EMBEDDING_MULTIPLIER, + LLM_KV_FALCON_H1_MAMBA_CHUNK_SIZE, LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_DIMENSION_SECTIONS, @@ -171,6 +195,7 @@ enum llm_kv { LLM_KV_SSM_CONV_KERNEL, LLM_KV_SSM_STATE_SIZE, LLM_KV_SSM_TIME_STEP_RANK, + LLM_KV_SSM_GROUP_COUNT, LLM_KV_SSM_DT_B_C_RMS, LLM_KV_WKV_HEAD_SIZE, @@ -273,6 +298,7 @@ enum llm_tensor { LLM_TENSOR_SSM_DT, LLM_TENSOR_SSM_A, LLM_TENSOR_SSM_D, + LLM_TENSOR_SSM_NORM, LLM_TENSOR_SSM_OUT, LLM_TENSOR_TIME_MIX_W0, LLM_TENSOR_TIME_MIX_W1, @@ -366,6 +392,9 @@ enum llm_tensor { LLM_TENSOR_POS_NET_ATTN_K, LLM_TENSOR_POS_NET_ATTN_V, LLM_TENSOR_POS_NET_ATTN_OUT, + LLM_TENSOR_SSM_MUP_VEC, + LLM_TENSOR_FFN_PRE_NORM, + LLM_TENSOR_FINAL_NORM, }; enum llm_tensor_layer { @@ -439,3 +468,6 @@ const char * llm_arch_name(llm_arch arch); llm_arch llm_arch_from_string(const std::string & name); const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor); + +bool llm_arch_is_recurrent(const llm_arch& arch); +bool llm_arch_is_hybrid_recurrent(const llm_arch& arch); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index f56a58e9b6ec6..b4063e0f2b3f0 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1157,12 +1157,12 @@ int llama_context::decode(const llama_batch & batch_inp) { } } + // set to total number of outputs in the batch, for use in llama_get_logits_ith + n_outputs = n_outputs_all; + // make the outputs have the same order they had in the user-provided batch // note: this is mostly relevant for recurrent models atm if (!sorted_output) { - const uint32_t n_vocab = model.vocab.n_tokens(); - const uint64_t n_embd = model.hparams.n_embd; - GGML_ASSERT((size_t) n_outputs == out_ids.size()); // TODO: is there something more efficient which also minimizes swaps? @@ -1179,12 +1179,12 @@ int llama_context::decode(const llama_batch & batch_inp) { } std::swap(out_ids[i], out_ids[j_min]); if (logits_size > 0) { - for (uint32_t k = 0; k < n_vocab; k++) { + for (int32_t k = 0; k < n_vocab; k++) { std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]); } } if (embd_size > 0) { - for (uint32_t k = 0; k < n_embd; k++) { + for (int64_t k = 0; k < n_embd; k++) { std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]); } } @@ -1196,6 +1196,8 @@ int llama_context::decode(const llama_batch & batch_inp) { output_ids[out_ids[i]] = i; } } + // sorted, so no need for the indices anymore + out_ids.clear(); } // wait for the computation to finish (automatically done when obtaining the model output) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 337fb5cb0df36..33df2c06039cf 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -7,6 +7,7 @@ #include "llama-kv-cache-unified.h" #include "llama-kv-cache-unified-iswa.h" #include "llama-kv-cache-recurrent.h" +#include "llama-kv-cache-hybrid-recurrent.h" #include #include @@ -238,7 +239,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) { } } -void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) { +void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) { GGML_UNUSED(ubatch); const int64_t n_kv = kv_state->get_n_kv(); @@ -254,6 +255,11 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) { } } +llm_graph_input_rs_hybrid_recurrent::llm_graph_input_rs_hybrid_recurrent( + const llama_kv_cache_hybrid_recurrent_state * kv_state) : + llm_graph_input_rs(kv_state->get_state_recurrent()) { +} + void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) { GGML_UNUSED(ubatch); @@ -359,6 +365,13 @@ void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) { } } +llm_graph_input_attn_kv_hybrid_recurrent::llm_graph_input_attn_kv_hybrid_recurrent( + const llama_hparams & hparams, + const llama_cparams & cparams, + const llama_kv_cache_hybrid_recurrent_state * kv_state) : + llm_graph_input_attn_kv_unified(hparams, cparams, kv_state->get_state_attn()) { +} + void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) { if (self_kq_mask) { kv_state->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); @@ -591,6 +604,11 @@ ggml_tensor * llm_graph_context::build_ffn( case LLM_FFN_PAR: { cur = build_lora_mm(gate, cur); + + if (arch == LLM_ARCH_FALCON_H1) { + cur = ggml_scale(ctx0, cur, hparams.mlp_gate_multiplier); + } + cb(cur, "ffn_gate", il); } break; } @@ -678,6 +696,9 @@ ggml_tensor * llm_graph_context::build_ffn( // GLM4 seems to have numerical issues with half-precision accumulators ggml_mul_mat_set_prec(cur, GGML_PREC_F32); } + if (arch == LLM_ARCH_FALCON_H1) { + cur = ggml_scale(ctx0, cur, hparams.mlp_down_multiplier); + } } if (down_b) { @@ -961,23 +982,6 @@ ggml_tensor * llm_graph_context::build_inp_cls() const { return cur; } -ggml_tensor * llm_graph_context::build_inp_s_copy() const { - const auto * kv_state = static_cast(mstate); - - auto inp = std::make_unique(kv_state); - - const auto n_kv = kv_state->get_n_kv(); - - auto & cur = inp->s_copy; - - cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv); - ggml_set_input(cur); - - res->add_input(std::move(inp)); - - return cur; -} - ggml_tensor * llm_graph_context::build_inp_cross_embd() const { auto inp = std::make_unique(cross); @@ -1291,6 +1295,75 @@ ggml_tensor * llm_graph_context::build_attn( return cur; } +llm_graph_input_attn_kv_hybrid_recurrent * llm_graph_context::build_attn_inp_kv_hybrid_recurrent() const { + auto inp = std::make_unique( + hparams, cparams, static_cast(mstate)); + + { + GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers"); + + const auto n_kv = inp->kv_state->get_n_kv(); + + inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + //cb(inp->self_kq_mask, "KQ_mask", -1); + ggml_set_input(inp->self_kq_mask); + + inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; + } + + return (llm_graph_input_attn_kv_hybrid_recurrent *) res->add_input(std::move(inp)); +} + +ggml_tensor * llm_graph_context::build_attn( + llm_graph_input_attn_kv_hybrid_recurrent * inp, + ggml_cgraph * gf, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, + ggml_tensor * k_cur, + ggml_tensor * v_cur, + ggml_tensor * kq_b, + ggml_tensor * v_mla, + float kq_scale, + int il) const { + // these nodes are added to the graph together so that they are not reordered + // by doing so, the number of splits in the graph is reduced + ggml_build_forward_expand(gf, q_cur); + ggml_build_forward_expand(gf, k_cur); + ggml_build_forward_expand(gf, v_cur); + + const auto * kv_state = static_cast(mstate)->get_state_attn(); + + // store to KV cache + { + ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il)); + ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il)); + } + + const auto & kq_mask = inp->get_kq_mask(); + + ggml_tensor * q = q_cur; + ggml_tensor * k = kv_state->get_k(ctx0, il); + ggml_tensor * v = kv_state->get_v(ctx0, il); + + ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale); + cb(cur, "kqv_out", il); + + if (wo) { + cur = build_lora_mm(wo, cur); + if (arch == LLM_ARCH_GLM4) { + // GLM4 seems to have numerical issues with half-precision accumulators + ggml_mul_mat_set_prec(cur, GGML_PREC_F32); + } + } + + if (wo_b) { + cur = ggml_add(ctx0, cur, wo_b); + } + + return cur; +} + llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const { const auto * kv_state = static_cast(mstate); @@ -1429,15 +1502,14 @@ ggml_tensor * llm_graph_context::build_attn( return cur; } - ggml_tensor * llm_graph_context::build_recurrent_state( + const llama_kv_cache_recurrent_state * kv_state, ggml_cgraph * gf, ggml_tensor * s, ggml_tensor * state_copy, int32_t state_size, int32_t n_seqs, - bool avoid_copies) const { - const auto * kv_state = static_cast(mstate); + const std::function & get_state_rows) const { const auto n_kv = kv_state->get_n_kv(); const auto kv_head = kv_state->get_head(); @@ -1452,17 +1524,11 @@ ggml_tensor * llm_graph_context::build_recurrent_state( ggml_tensor * output_states; - if (!avoid_copies) { - // copy states - // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv - // {state_size, kv_size} -> {state_size, n_seqs} - output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0)); - ggml_build_forward_expand(gf, output_states); - } else { - // FIXME: make the gathering operation happen before the copy below - // (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?) - output_states = states; - } + // copy states + // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv + // {state_size, kv_size} -> {state_size, n_seqs} + output_states = get_state_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0)); + ggml_build_forward_expand(gf, output_states); // copy extra states which won't be changed further (between n_seqs and n_kv) ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0])); @@ -1474,10 +1540,63 @@ ggml_tensor * llm_graph_context::build_recurrent_state( return output_states; } +llm_graph_input_rs * llm_graph_context::build_rs_inp_recurrent() const { + const auto * kv_state = static_cast(mstate); + + auto inp = std::make_unique(kv_state); + + const auto n_kv = kv_state->get_n_kv(); + + auto & cur = inp->s_copy; + + cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv); + ggml_set_input(cur); + + return (llm_graph_input_rs *) res->add_input(std::move(inp)); +} + +ggml_tensor * llm_graph_context::build_rs( + llm_graph_input_rs * inp, + ggml_cgraph * gf, + ggml_tensor * s, + int32_t state_size, + int32_t n_seqs, + const std::function & get_state_rows) const { + + const auto * kv_state = static_cast(mstate); + return build_recurrent_state(kv_state, gf, s, inp->s_copy, state_size, n_seqs, get_state_rows); +} + +llm_graph_input_rs_hybrid_recurrent * llm_graph_context::build_rs_inp_hybrid_recurrent() const { + auto inp = std::make_unique( + static_cast(mstate)); + + const auto n_kv = inp->kv_state->get_n_kv(); + + auto & cur = inp->s_copy; + + cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv); + ggml_set_input(cur); + + return (llm_graph_input_rs_hybrid_recurrent *) res->add_input(std::move(inp)); +} + +ggml_tensor * llm_graph_context::build_rs( + llm_graph_input_rs_hybrid_recurrent * inp, + ggml_cgraph * gf, + ggml_tensor * s, + int32_t state_size, + int32_t n_seqs, + const std::function & get_state_rows) const { + + const auto * kv_state = static_cast(mstate)->get_state_recurrent(); + return build_recurrent_state(kv_state, gf, s, inp->s_copy, state_size, n_seqs, get_state_rows); +} + ggml_tensor * llm_graph_context::build_rwkv_token_shift_load( - ggml_cgraph * gf, - ggml_tensor * state_copy, - const llama_ubatch & ubatch, + llm_graph_input_rs * inp, + ggml_cgraph * gf, + const llama_ubatch & ubatch, int il) const { const auto * kv_state = static_cast(mstate); @@ -1487,8 +1606,8 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load( ggml_tensor * token_shift_all = kv_state->get_k_l(il); - ggml_tensor * token_shift = build_recurrent_state( - gf, token_shift_all, state_copy, + ggml_tensor * token_shift = build_rs( + inp, gf, token_shift_all, hparams.n_embd_k_s(), n_seqs); token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs); diff --git a/src/llama-graph.h b/src/llama-graph.h index 87813119b1a3c..bb1df567c7904 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -22,6 +22,7 @@ struct llama_memory_state_i; class llama_kv_cache_unified_state; class llama_kv_cache_unified_iswa_state; class llama_kv_cache_recurrent_state; +class llama_kv_cache_hybrid_recurrent_state; // certain models (typically multi-modal) can produce different types of graphs enum llm_graph_type { @@ -188,10 +189,10 @@ class llm_graph_input_cls : public llm_graph_input_i { const llama_cparams & cparams; }; -class llm_graph_input_s_copy : public llm_graph_input_i { +class llm_graph_input_rs : public llm_graph_input_i { public: - llm_graph_input_s_copy(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {} - virtual ~llm_graph_input_s_copy() = default; + llm_graph_input_rs(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {} + virtual ~llm_graph_input_rs() = default; void set_input(const llama_ubatch * ubatch) override; @@ -200,6 +201,12 @@ class llm_graph_input_s_copy : public llm_graph_input_i { const llama_kv_cache_recurrent_state * kv_state; }; +class llm_graph_input_rs_hybrid_recurrent : public llm_graph_input_rs { +public: + llm_graph_input_rs_hybrid_recurrent(const llama_kv_cache_hybrid_recurrent_state * kv_state); + virtual ~llm_graph_input_rs_hybrid_recurrent() = default; +}; + class llm_graph_input_cross_embd : public llm_graph_input_i { public: llm_graph_input_cross_embd( @@ -257,6 +264,15 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i { const llama_kv_cache_unified_state * kv_state; }; +class llm_graph_input_attn_kv_hybrid_recurrent : public llm_graph_input_attn_kv_unified { +public: + llm_graph_input_attn_kv_hybrid_recurrent( + const llama_hparams & hparams, + const llama_cparams & cparams, + const llama_kv_cache_hybrid_recurrent_state * kv_state); + virtual ~llm_graph_input_attn_kv_hybrid_recurrent() = default; +}; + class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i { public: llm_graph_input_attn_kv_unified_iswa( @@ -508,7 +524,6 @@ struct llm_graph_context { ggml_tensor * build_inp_out_ids() const; ggml_tensor * build_inp_mean() const; ggml_tensor * build_inp_cls() const; - ggml_tensor * build_inp_s_copy() const; ggml_tensor * build_inp_cross_embd() const; ggml_tensor * build_inp_pos_bucket_enc() const; @@ -589,22 +604,60 @@ struct llm_graph_context { float kq_scale, int il) const; + llm_graph_input_attn_kv_hybrid_recurrent * build_attn_inp_kv_hybrid_recurrent() const; + + ggml_tensor * build_attn( + llm_graph_input_attn_kv_hybrid_recurrent * inp, + ggml_cgraph * gf, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] + ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] + ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] + ggml_tensor * kq_b, + ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] + float kq_scale, + int il) const; // // recurrent // ggml_tensor * build_recurrent_state( - ggml_cgraph * gf, - ggml_tensor * s, - ggml_tensor * state_copy, - int32_t state_size, - int32_t n_seqs, - bool avoid_copies = false) const; + const llama_kv_cache_recurrent_state * kv_state, + ggml_cgraph * gf, + ggml_tensor * s, + ggml_tensor * state_copy, + int32_t state_size, + int32_t n_seqs, + const std::function + & get_state_rows = ggml_get_rows) const; + + llm_graph_input_rs * build_rs_inp_recurrent() const; + + ggml_tensor * build_rs( + llm_graph_input_rs * inp, + ggml_cgraph * gf, + ggml_tensor * s, + int32_t state_size, + int32_t n_seqs, + const std::function + & get_state_rows = ggml_get_rows) const; + + llm_graph_input_rs_hybrid_recurrent * build_rs_inp_hybrid_recurrent() const; + + ggml_tensor * build_rs( + llm_graph_input_rs_hybrid_recurrent * inp, + ggml_cgraph * gf, + ggml_tensor * s, + int32_t state_size, + int32_t n_seqs, + const std::function + & get_state_rows = ggml_get_rows) const; ggml_tensor * build_rwkv_token_shift_load( - ggml_cgraph * gf, - ggml_tensor * state_copy, - const llama_ubatch & ubatch, + llm_graph_input_rs * inp, + ggml_cgraph * gf, + const llama_ubatch & ubatch, int il) const; ggml_tensor * build_rwkv_token_shift_store( diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 1499eb08a5dd9..27f24d7fb679b 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -73,7 +73,8 @@ uint32_t llama_hparams::n_embd_k_s() const { // TODO: maybe support other convolution strides than 1 // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed - return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner; + // Corresponds to Mamba's conv_states size + return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * (ssm_mamba_d_ssm + 2*ssm_n_group*ssm_d_state); } uint32_t llama_hparams::n_embd_v_s() const { @@ -83,7 +84,11 @@ uint32_t llama_hparams::n_embd_v_s() const { } // corresponds to Mamba's ssm_states size - return ssm_d_state * ssm_d_inner; + return ssm_d_state * ssm_mamba_d_ssm; +} + +bool llama_hparams::recurrent_layer(uint32_t il) const { + return recurrent_layer_arr[il]; } bool llama_hparams::is_swa(uint32_t il) const { diff --git a/src/llama-hparams.h b/src/llama-hparams.h index b2bcb8b01a18b..5241631d5fb59 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -110,12 +110,41 @@ struct llama_hparams { std::array swa_layers; // for State Space Models - uint32_t ssm_d_conv = 0; - uint32_t ssm_d_inner = 0; - uint32_t ssm_d_state = 0; - uint32_t ssm_dt_rank = 0; - - bool ssm_dt_b_c_rms = false; + uint32_t ssm_d_conv = 0; + uint32_t ssm_d_inner = 0; + uint32_t ssm_d_state = 0; + uint32_t ssm_dt_rank = 0; + uint32_t ssm_n_group = 0; + bool ssm_dt_b_c_rms = false; + uint32_t ssm_head_dim = 0; + uint32_t ssm_mamba_d_ssm = 0; + + // Falcon-H1 specific parameters + uint32_t attn_head_dim = 0; + bool mamba_use_mlp = false; + bool mamba_norm_before_gate = false; + bool mamba_rms_norm = false; + float attention_in_multiplier = 1.0f; + float attention_out_multiplier = 1.0f; + float ssm_in_multiplier = 1.0f; + float ssm_out_multiplier = 1.0f; + float mlp_gate_multiplier = 1.0f; + float mlp_down_multiplier = 1.0f; + float key_multiplier = 1.0f; + float lm_head_multiplier = 1.0f; + float rope_theta = 10000.0f; + bool ssm_has_mup = false; + float embedding_multiplier = 1.0f; + uint32_t vocab_size = 0; + uint32_t intermediate_size = 0; + float mamba_expand = 0.0f; + bool ssm_rms_norm = false; + bool ssm_conv_bias = false; + bool ssm_proj_bias = false; + uint32_t chunk_size = 0; + + // for hybrid state space models + std::array recurrent_layer_arr; float f_clamp_kqv = 0.0f; float f_max_alibi_bias = 0.0f; @@ -186,6 +215,9 @@ struct llama_hparams { // dimension of the recurrent state embeddings uint32_t n_embd_v_s() const; + // whether or not the given layer is recurrent (for hybrid models) + bool recurrent_layer(uint32_t il) const; + bool is_swa(uint32_t il) const; }; diff --git a/src/llama-kv-cache-hybrid-recurrent.cpp b/src/llama-kv-cache-hybrid-recurrent.cpp new file mode 100644 index 0000000000000..46449b5178b33 --- /dev/null +++ b/src/llama-kv-cache-hybrid-recurrent.cpp @@ -0,0 +1,249 @@ +#include "llama-kv-cache-hybrid-recurrent.h" + +#include "llama-impl.h" +#include "llama-model.h" +#include "llama-context.h" + +// +// llama_kv_cache_hybrid_recurrent +// + +llama_kv_cache_hybrid_recurrent::llama_kv_cache_hybrid_recurrent( + const llama_model & model, + /* attn */ + ggml_type attn_type_k, + ggml_type attn_type_v, + bool attn_v_trans, + uint32_t attn_kv_size, + uint32_t attn_n_pad, + uint32_t attn_n_swa, + llama_swa_type attn_swa_type, + /* recurrent */ + ggml_type recurrent_type_k, + ggml_type recurrent_type_v, + uint32_t recurrent_kv_size, + /* common */ + uint32_t n_seq_max, + bool offload, + /* layer filters */ + layer_filter_cb && attn_filter, + layer_filter_cb && recurrent_filter) : + hparams(model.hparams), + kv_attn(new llama_kv_cache_unified( + model, + attn_filter == nullptr ? + [&](int32_t il) { return model.hparams.recurrent_layer(il); } + : attn_filter, + attn_type_k, + attn_type_v, + attn_v_trans, + offload, + attn_kv_size, + n_seq_max, + attn_n_pad, + attn_n_swa, + attn_swa_type + )), + kv_recurrent(new llama_kv_cache_recurrent( + model, + recurrent_filter == nullptr ? + [&](int32_t il) { return model.hparams.recurrent_layer(il); } + : recurrent_filter, + recurrent_type_k, + recurrent_type_v, + offload, + recurrent_kv_size, + n_seq_max + )) {} + +llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) { + + // since this includes a recurrent cache, we cannot use split_simple + auto sbatch = llama_sbatch(batch, hparams.n_embd, false); + + // follow the recurrent pattern for creating the ubatch splits + std::vector ubatches; + while (sbatch.n_tokens > 0) { + llama_ubatch ubatch; + + if (embd_pooled) { + // Pooled embeddings cannot be split across ubatches (yet) + ubatch = sbatch.split_seq(n_ubatch); + } else { + ubatch = sbatch.split_equal(n_ubatch); + } + + ubatches.push_back(ubatch); + } + + // prepare the recurrent batches first + if (!kv_recurrent->prepare(ubatches)) { + // TODO: will the recurrent cache be in an undefined state at this point? + LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__); + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + // prepare the attention cache + auto heads_attn = kv_attn->prepare(ubatches); + if (heads_attn.empty()) { + LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__); + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + return std::make_unique( + this, std::move(sbatch), std::move(heads_attn), std::move(ubatches)); +} + +llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_full() { + return std::make_unique(this); +} + +llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_update(llama_context * lctx, bool optimize) { + return std::make_unique( + static_cast( kv_attn ->init_update(lctx, optimize).release()), + static_cast(kv_recurrent->init_update(lctx, optimize).release())); +} + +bool llama_kv_cache_hybrid_recurrent::get_can_shift() const { + // Shifting is trivially supported for recurrent + return kv_attn->get_can_shift(); +} + +void llama_kv_cache_hybrid_recurrent::clear(bool data) { + kv_attn ->clear(data); + kv_recurrent->clear(data); +} + +bool llama_kv_cache_hybrid_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + // Try removing from the recurrent cache first since it may fail. If it does + // fail, the cache will not have been mutated. + if (!kv_recurrent->seq_rm(seq_id, p0, p1)) { + return false; + } + return kv_attn->seq_rm(seq_id, p0, p1); +} + +void llama_kv_cache_hybrid_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + kv_attn ->seq_cp(seq_id_src, seq_id_dst, p0, p1); + kv_recurrent->seq_cp(seq_id_src, seq_id_dst, p0, p1); +} + +void llama_kv_cache_hybrid_recurrent::seq_keep(llama_seq_id seq_id) { + kv_attn ->seq_keep(seq_id); + kv_recurrent->seq_keep(seq_id); +} + +void llama_kv_cache_hybrid_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { + kv_attn->seq_add(seq_id, p0, p1, shift); + kv_recurrent->seq_add(seq_id, p0, p1, shift); +} + +void llama_kv_cache_hybrid_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + kv_attn ->seq_div(seq_id, p0, p1, d); + kv_recurrent->seq_div(seq_id, p0, p1, d); +} + +llama_pos llama_kv_cache_hybrid_recurrent::seq_pos_min(llama_seq_id seq_id) const { + // the min of the total cache is the max of the two caches' min values + return std::max(kv_attn->seq_pos_min(seq_id), kv_recurrent->seq_pos_min(seq_id)); +} + +llama_pos llama_kv_cache_hybrid_recurrent::seq_pos_max(llama_seq_id seq_id) const { + // the max of the total cache is the min of the two caches' max values + return std::min(kv_attn->seq_pos_max(seq_id), kv_recurrent->seq_pos_max(seq_id)); +} + +void llama_kv_cache_hybrid_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { + kv_attn ->state_write(io, seq_id); + kv_recurrent->state_write(io, seq_id); +} + +void llama_kv_cache_hybrid_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) { + kv_attn ->state_read(io, seq_id); + kv_recurrent->state_read(io, seq_id); +} + +llama_kv_cache_unified * llama_kv_cache_hybrid_recurrent::get_kv_attn() const { + return kv_attn.get(); +} + +llama_kv_cache_recurrent * llama_kv_cache_hybrid_recurrent::get_kv_recurrent() const { + return kv_recurrent.get(); +} + +llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(llama_memory_status status) + : status(status), + state_attn(new llama_kv_cache_unified_state(status)), + state_recurrent(new llama_kv_cache_recurrent_state(status)) {} + +llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(llama_kv_cache_hybrid_recurrent * kv) + : status(LLAMA_MEMORY_STATUS_SUCCESS), + state_attn(new llama_kv_cache_unified_state(kv->get_kv_attn())), + state_recurrent(new llama_kv_cache_recurrent_state(status, kv->get_kv_recurrent())) {} + +llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state( + llama_kv_cache_unified_state * state_unified, + llama_kv_cache_recurrent_state * state_recurrent) + : status(LLAMA_MEMORY_STATUS_NO_UPDATE), + state_attn(state_unified), + state_recurrent(state_recurrent) {} + +llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state( + llama_kv_cache_hybrid_recurrent * kv, + llama_sbatch sbatch, + std::vector heads_attn, + std::vector ubatches) + : status(LLAMA_MEMORY_STATUS_SUCCESS), + sbatch(std::move(sbatch)), + ubatches(std::move(ubatches)), + // note: here we copy the ubatches. not sure if this is ideal + state_attn(new llama_kv_cache_unified_state(kv->get_kv_attn(), {}, std::move(heads_attn), this->ubatches)), + state_recurrent(new llama_kv_cache_recurrent_state(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_recurrent(), {}, this->ubatches)) {} + + +bool llama_kv_cache_hybrid_recurrent_state::next() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + state_attn ->next(); + state_recurrent->next(); + + if (++i_next >= ubatches.size()) { + return false; + } + + return true; +} + +bool llama_kv_cache_hybrid_recurrent_state::apply() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + bool res = true; + + res = res & state_attn ->apply(); + res = res & state_recurrent->apply(); + + return res; +} + +std::vector & llama_kv_cache_hybrid_recurrent_state::out_ids() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return sbatch.out_ids; +} + +llama_memory_status llama_kv_cache_hybrid_recurrent_state::get_status() const { + return status; +} + +const llama_ubatch & llama_kv_cache_hybrid_recurrent_state::get_ubatch() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + return ubatches[i_next]; +} + +const llama_kv_cache_unified_state * llama_kv_cache_hybrid_recurrent_state::get_state_attn() const { + return static_cast(state_attn.get()); +} + +const llama_kv_cache_recurrent_state * llama_kv_cache_hybrid_recurrent_state::get_state_recurrent() const { + return static_cast(state_recurrent.get()); +} diff --git a/src/llama-kv-cache-hybrid-recurrent.h b/src/llama-kv-cache-hybrid-recurrent.h new file mode 100644 index 0000000000000..17a72c613d7b9 --- /dev/null +++ b/src/llama-kv-cache-hybrid-recurrent.h @@ -0,0 +1,146 @@ +#pragma once + +#include "llama-batch.h" +#include "llama-graph.h" +#include "llama-kv-cache-recurrent.h" +#include "llama-kv-cache-unified.h" +#include "llama-kv-cells.h" +#include "llama-memory.h" + +#include +#include + +// +// llama_kv_cache_hybrid_recurrent +// + +// utilizes instances of llama_kv_cache_recurrent and llama_kv_cache_unified to +// support models where each layer may be either attention-based or recurrent + +class llama_kv_cache_hybrid_recurrent : public llama_memory_i { +public: + + // this callback is used to filter out layers that should not be included in the cache + using layer_filter_cb = std::function; + + llama_kv_cache_hybrid_recurrent( + const llama_model & model, + /* attn */ + ggml_type attn_type_k, + ggml_type attn_type_v, + bool attn_v_trans, + uint32_t attn_kv_size, + uint32_t attn_n_pad, + uint32_t attn_n_swa, + llama_swa_type attn_swa_type, + /* recurrent */ + ggml_type recurrent_type_k, + ggml_type recurrent_type_v, + uint32_t recurrent_kv_size, + /* common */ + uint32_t n_seq_max, + bool offload, + /* layer filters */ + layer_filter_cb && attn_filter = nullptr, + layer_filter_cb && recurrent_filter = nullptr); + + ~llama_kv_cache_hybrid_recurrent() = default; + + // + // llama_memory_i + // + + llama_memory_state_ptr init_batch( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_pooled) override; + + llama_memory_state_ptr init_full() override; + + llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override; + + bool get_can_shift() const override; + + void clear(bool data) override; + + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; + void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; + void seq_keep(llama_seq_id seq_id) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; + void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; + + llama_pos seq_pos_min(llama_seq_id seq_id) const override; + llama_pos seq_pos_max(llama_seq_id seq_id) const override; + + // state write/load + + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; + + // + // llama_kv_cache_hybrid_recurrent specific API + // + + llama_kv_cache_unified * get_kv_attn () const; + llama_kv_cache_recurrent * get_kv_recurrent() const; + +private: + const llama_hparams & hparams; + + const std::unique_ptr kv_attn; + const std::unique_ptr kv_recurrent; +}; + +class llama_kv_cache_hybrid_recurrent_state : public llama_memory_state_i { +public: + using llama_kv_cache_unified_state_ptr = std::unique_ptr; + using llama_kv_cache_recurrent_state_ptr = std::unique_ptr; + + // init failure + explicit llama_kv_cache_hybrid_recurrent_state(llama_memory_status status); + + // init full + explicit llama_kv_cache_hybrid_recurrent_state(llama_kv_cache_hybrid_recurrent * kv); + + // init update + explicit llama_kv_cache_hybrid_recurrent_state( + llama_kv_cache_unified_state * state_unified, + llama_kv_cache_recurrent_state * state_recurrent); + + // init success + llama_kv_cache_hybrid_recurrent_state( + llama_kv_cache_hybrid_recurrent * kv, + llama_sbatch sbatch, + std::vector heads_attn, + std::vector ubatches); + + ~llama_kv_cache_hybrid_recurrent_state() = default; + + bool next() override; + bool apply() override; + + std::vector & out_ids() override; + + llama_memory_status get_status() const override; + const llama_ubatch & get_ubatch() const override; + + // + // llama_kv_cache_hybrid_recurrent_state + // + + const llama_kv_cache_unified_state * get_state_attn () const; + const llama_kv_cache_recurrent_state * get_state_recurrent() const; + +private: + const llama_memory_status status; + + llama_sbatch sbatch; + + // the index of the next ubatch to process + size_t i_next = 0; + + std::vector ubatches; + + const llama_memory_state_ptr state_attn; + const llama_memory_state_ptr state_recurrent; +}; diff --git a/src/llama-kv-cache-recurrent.cpp b/src/llama-kv-cache-recurrent.cpp index 8f6f120f682b7..d1b3ca520380f 100644 --- a/src/llama-kv-cache-recurrent.cpp +++ b/src/llama-kv-cache-recurrent.cpp @@ -16,12 +16,13 @@ // llama_kv_cache_recurrent::llama_kv_cache_recurrent( - const llama_model & model, - ggml_type type_k, - ggml_type type_v, - bool offload, - uint32_t kv_size, - uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) { + const llama_model & model, + layer_filter_cb && filter, + ggml_type type_k, + ggml_type type_v, + bool offload, + uint32_t kv_size, + uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) { const int32_t n_layer = hparams.n_layer; LLAMA_LOG_INFO("%s: kv_size = %u, n_seq_max = %u, type_k = '%s', type_v = '%s', n_layer = %d\n", @@ -44,27 +45,27 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent( /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; - ggml_context * ctx = ggml_init(params); if (!ctx) { + std::printf("Failed to create ggml context for kv cache\n"); return nullptr; } - ctx_map[buft] = ctx; ctxs.emplace_back(ctx); return ctx; } - return it->second; }; - k_l.reserve(n_layer); - v_l.reserve(n_layer); + k_l.resize(n_layer); + v_l.resize(n_layer); for (int i = 0; i < n_layer; i++) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); + if (filter && !filter(i)) { + LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, i); + continue; + } const char * dev_name = "CPU"; @@ -84,12 +85,12 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent( throw std::runtime_error("failed to create ggml context for kv cache"); } - ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); - ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); + ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, hparams.n_embd_k_s()*kv_size); + ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, hparams.n_embd_v_s()*kv_size); ggml_format_name(k, "cache_k_l%d", i); ggml_format_name(v, "cache_v_l%d", i); - k_l.push_back(k); - v_l.push_back(v); + k_l[i] = k; + v_l[i] = v; } // allocate tensors and initialize the buffers to avoid NaNs in the padding @@ -644,7 +645,9 @@ size_t llama_kv_cache_recurrent::size_k_bytes() const { size_t size_k_bytes = 0; for (const auto & k : k_l) { - size_k_bytes += ggml_nbytes(k); + if (k != nullptr) { + size_k_bytes += ggml_nbytes(k); + } } return size_k_bytes; @@ -654,7 +657,9 @@ size_t llama_kv_cache_recurrent::size_v_bytes() const { size_t size_v_bytes = 0; for (const auto & v : v_l) { - size_v_bytes += ggml_nbytes(v); + if (v != nullptr) { + size_v_bytes += ggml_nbytes(v); + } } return size_v_bytes; @@ -748,14 +753,13 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std // Iterate and write all the keys first, each row is a cell // Get whole range at a time for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); // Write key type const int32_t k_type_i = (int32_t)k_l[il]->type; io.write(&k_type_i, sizeof(k_type_i)); // Write row size of key - const uint64_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); + const uint64_t k_size_row = ggml_row_size(k_l[il]->type, hparams.n_embd_k_s()); io.write(&k_size_row, sizeof(k_size_row)); // Read each range of cells of k_size length each into tmp_buf and write out @@ -768,14 +772,13 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std if (!v_trans) { for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); // Write value type const int32_t v_type_i = (int32_t)v_l[il]->type; io.write(&v_type_i, sizeof(v_type_i)); // Write row size of value - const uint64_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa); + const uint64_t v_size_row = ggml_row_size(v_l[il]->type, hparams.n_embd_v_s()); io.write(&v_size_row, sizeof(v_size_row)); // Read each range of cells of v_size length each into tmp_buf and write out @@ -789,7 +792,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std // When v is transposed, we also need the element size and get the element ranges from each row const uint32_t kv_size = size; for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_s = hparams.n_embd_v_s(); // Write value type const int32_t v_type_i = (int32_t)v_l[il]->type; @@ -800,10 +803,10 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std io.write(&v_size_el, sizeof(v_size_el)); // Write GQA embedding size - io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa)); + io.write(&n_embd_v_s, sizeof(n_embd_v_s)); // For each row, we get the element values of each cell - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + for (uint32_t j = 0; j < n_embd_v_s; ++j) { // Read each range of cells of v_size_el length each into tmp_buf and write out for (const auto & range : cell_ranges) { const size_t range_size = range.second - range.first; @@ -936,7 +939,6 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); // Read type of key int32_t k_type_i_ref; @@ -950,7 +952,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce // Read row size of key uint64_t k_size_row_ref; io.read_to(&k_size_row_ref, sizeof(k_size_row_ref)); - const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); + const size_t k_size_row = ggml_row_size(k_l[il]->type, hparams.n_embd_k_s()); if (k_size_row != k_size_row_ref) { LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il); return false; @@ -964,7 +966,6 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce if (!v_trans) { for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); // Read type of value int32_t v_type_i_ref; @@ -978,7 +979,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce // Read row size of value uint64_t v_size_row_ref; io.read_to(&v_size_row_ref, sizeof(v_size_row_ref)); - const size_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa); + const size_t v_size_row = ggml_row_size(v_l[il]->type, hparams.n_embd_v_s()); if (v_size_row != v_size_row_ref) { LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il); return false; @@ -992,7 +993,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce } else { // For each layer, read the values for each cell (transposed) for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_s = hparams.n_embd_v_s(); // Read type of value int32_t v_type_i_ref; @@ -1012,17 +1013,17 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce return false; } - // Read GQA embedding size - uint32_t n_embd_v_gqa_ref; - io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref)); - if (n_embd_v_gqa != n_embd_v_gqa_ref) { - LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il); + // Read state embedding size + uint32_t n_embd_v_s_ref; + io.read_to(&n_embd_v_s_ref, sizeof(n_embd_v_s_ref)); + if (n_embd_v_s != n_embd_v_s_ref) { + LLAMA_LOG_ERROR("%s: mismatched state embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_s, n_embd_v_s_ref, il); return false; } if (cell_count) { // For each row in the transposed matrix, read the values for the whole cell range - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + for (uint32_t j = 0; j < n_embd_v_s; ++j) { const size_t dst_offset = (head + j * size) * v_size_el; ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); } diff --git a/src/llama-kv-cache-recurrent.h b/src/llama-kv-cache-recurrent.h index f9b01a6513393..b89590b7be6a2 100644 --- a/src/llama-kv-cache-recurrent.h +++ b/src/llama-kv-cache-recurrent.h @@ -15,13 +15,18 @@ // see the implementation of llama_kv_cache_unified_state_i for an example how to do it class llama_kv_cache_recurrent : public llama_memory_i { public: + + // this callback is used to filter out layers that should not be included in the cache + using layer_filter_cb = std::function; + llama_kv_cache_recurrent( - const llama_model & model, - ggml_type type_k, - ggml_type type_v, - bool offload, - uint32_t kv_size, - uint32_t n_seq_max); + const llama_model & model, + layer_filter_cb && filter, + ggml_type type_k, + ggml_type type_v, + bool offload, + uint32_t kv_size, + uint32_t n_seq_max); ~llama_kv_cache_recurrent() = default; diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index 3b37679859d39..d4412288925c3 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -68,8 +68,8 @@ llama_kv_cache_unified::llama_kv_cache_unified( continue; } - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); const char * dev_name = "CPU"; @@ -1430,7 +1430,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std:: for (const auto & layer : layers) { const uint32_t il = layer.il; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); // Write key type const int32_t k_type_i = (int32_t)layer.k->type; @@ -1452,7 +1452,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std:: for (const auto & layer : layers) { const uint32_t il = layer.il; - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); // Write value type const int32_t v_type_i = (int32_t)layer.v->type; @@ -1476,7 +1476,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std:: for (const auto & layer : layers) { const uint32_t il = layer.il; - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); // Write value type const int32_t v_type_i = (int32_t)layer.v->type; @@ -1621,7 +1621,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell for (const auto & layer : layers) { const uint32_t il = layer.il; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); // Read type of key int32_t k_type_i_ref; @@ -1651,7 +1651,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell for (const auto & layer : layers) { const uint32_t il = layer.il; - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); // Read type of value int32_t v_type_i_ref; @@ -1681,7 +1681,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell for (const auto & layer : layers) { const uint32_t il = layer.il; - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); // Read type of value int32_t v_type_i_ref; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index a5eb122f998d8..b64c3529f1e28 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -9,6 +9,7 @@ #include "llama-kv-cache-unified.h" #include "llama-kv-cache-unified-iswa.h" #include "llama-kv-cache-recurrent.h" +#include "llama-kv-cache-hybrid-recurrent.h" #include "ggml-cpp.h" @@ -204,23 +205,27 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w } break; case GGML_OP_SSM_CONV: { - // FIXME - ggml_tensor * conv_x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 12345, w->ne[1], 6789); + const int64_t n_seq_tokens = 512; + const int64_t n_seqs = 3; + ggml_tensor * conv_x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w->ne[0] - 1 + n_seq_tokens, w->ne[1], n_seqs); op_tensor = ggml_ssm_conv(ctx, conv_x, w); } break; case GGML_OP_SSM_SCAN: { - // FIXME - const int64_t d_state = w->ne[0]; - const int64_t d_inner = w->ne[1]; + // w is ssm_a, which is used to distinguish Mamba-1 and Mamba-2 + const int64_t d_state = w->ne[0] == 1 ? hparams.ssm_d_state : w->ne[0]; + const int64_t n_head = w->ne[1]; + const int64_t head_dim = hparams.ssm_d_inner / n_head; + const int64_t n_group = hparams.ssm_n_group; const int64_t n_seq_tokens = 512; - const int64_t n_seqs = 1; - ggml_tensor * s = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, d_inner, n_seqs); - ggml_tensor * x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs); - ggml_tensor * dt = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs); - ggml_tensor * B = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs); - ggml_tensor * C = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs); - op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C); + const int64_t n_seqs = 3; + ggml_tensor * s = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, head_dim, n_head, n_seqs); + ggml_tensor * x = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, head_dim, n_head, n_seq_tokens, n_seqs); + ggml_tensor * dt = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_head, n_seq_tokens, n_seqs); + ggml_tensor * B = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, n_group, n_seq_tokens, n_seqs); + ggml_tensor * C = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, n_group, n_seq_tokens, n_seqs); + ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs); + op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C, ids); } break; case GGML_OP_RWKV_WKV6: { @@ -470,6 +475,14 @@ void llama_model::load_hparams(llama_model_loader & ml) { std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0); std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0); std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0); + // std::fill( + // hparams.recurrent_layer_arr.begin(), + // hparams.recurrent_layer_arr.end(), + // llm_arch_is_recurrent(ml.get_arch())); + std::fill( + hparams.recurrent_layer_arr.begin(), + hparams.recurrent_layer_arr.end(), + true); std::fill(hparams.rope_sections.begin(), hparams.rope_sections.end(), 0); @@ -1055,6 +1068,38 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_MAMBA2: + { + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 24: + switch (hparams.n_embd) { + case 768: type = LLM_TYPE_SMALL; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 48: + switch (hparams.n_embd) { + case 1024: type = LLM_TYPE_MEDIUM; break; + case 1536: type = LLM_TYPE_LARGE; break; + case 2048: type = LLM_TYPE_XL; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 64: + switch (hparams.n_embd) { + case 2560: type = LLM_TYPE_3B; break; + case 4096: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_XVERSE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -1465,6 +1510,53 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_FALCON_H1: + { + // Common parameters + ml.get_key(LLM_KV_VOCAB_SIZE, hparams.vocab_size); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + // SSM parameters + ml.get_key(LLM_KV_MAMBA_D_SSM, hparams.ssm_mamba_d_ssm); + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + ml.get_key(LLM_KV_SSM_HEAD_DIM, hparams.ssm_head_dim); + ml.get_key(LLM_KV_FALCON_H1_MAMBA_CHUNK_SIZE, hparams.chunk_size); + + // Falcon-H1 parameters + ml.get_key(LLM_KV_ATTN_HEAD_DIM, hparams.attn_head_dim); + ml.get_key(LLM_KV_FALCON_H1_USE_MLP, hparams.mamba_use_mlp); + ml.get_key(LLM_KV_FALCON_H1_ATTENTION_IN_MULTIPLIER, hparams.attention_in_multiplier); + ml.get_key(LLM_KV_FALCON_H1_ATTENTION_OUT_MULTIPLIER, hparams.attention_out_multiplier); + ml.get_key(LLM_KV_FALCON_H1_SSM_IN_MULTIPLIER, hparams.ssm_in_multiplier); + ml.get_key(LLM_KV_FALCON_H1_SSM_OUT_MULTIPLIER, hparams.ssm_out_multiplier); + ml.get_key(LLM_KV_FALCON_H1_MLP_GATE_MULTIPLIER, hparams.mlp_gate_multiplier); + ml.get_key(LLM_KV_FALCON_H1_MLP_DOWN_MULTIPLIER, hparams.mlp_down_multiplier); + ml.get_key(LLM_KV_FALCON_H1_SSM_HAS_MUP, hparams.ssm_has_mup); + ml.get_key(LLM_KV_FALCON_H1_MAMBA_NORM_BEFORE_GATE, hparams.mamba_norm_before_gate); + ml.get_key(LLM_KV_FALCON_H1_MAMBA_RMS_NORM, hparams.mamba_rms_norm); + ml.get_key(LLM_KV_FALCON_H1_ROPE_THETA, hparams.rope_theta); + ml.get_key(LLM_KV_FALCON_H1_KEY_MULTIPLIER, hparams.key_multiplier); + ml.get_key(LLM_KV_FALCON_H1_LM_HEAD_MULTIPLIER, hparams.lm_head_multiplier); + ml.get_key(LLM_KV_FALCON_H1_EMBEDDING_MULTIPLIER, hparams.embedding_multiplier); + + switch (hparams.n_layer) { + case 36: + type = LLM_TYPE_0_5B; break; + case 24: + type = LLM_TYPE_1_5B; break; + case 66: + type = LLM_TYPE_1B; break; + case 32: + type = LLM_TYPE_3B; break; + case 44: + type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_DOTS1: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -3030,6 +3122,54 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}, 0); layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {d_inner}, 0); + // out_proj + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); + } + } break; + case LLM_ARCH_MAMBA2: + { + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t n_head = hparams.ssm_dt_rank; + const int64_t n_group = hparams.ssm_n_group; + const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_head; + + // only an expansion factor of 2 is supported for now + GGML_ASSERT(2 * n_embd == d_inner); + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + { + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed, duplicated to allow offloading + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + } + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + // norm + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0); + + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0); + layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, 0); + + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_head}, 0); + + // no "weight" suffix for these + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_head}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_head}, 0); + + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0); + // out_proj layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); } @@ -4184,6 +4324,88 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); } } break; + case LLM_ARCH_FALCON_H1: + { + // Common + const int64_t hidden_size = hparams.n_embd; // hidden_size + const int64_t vocab_size = hparams.vocab_size; // vocab_size + + // mamba2 Mixer SSM params + const int64_t ssm_conv_kernel_size = hparams.ssm_d_conv; // ssm_conv_kernel_size + const int64_t ssm_n_groups = hparams.ssm_n_group; // ssm_n_groups + const int64_t ssm_state_size = hparams.ssm_d_state; // ssm_state_size + const int64_t ssm_intermediate_size = hparams.ssm_mamba_d_ssm > 0 ? hparams.ssm_mamba_d_ssm : int(hparams.mamba_expand * hidden_size); // TODO expand + const int64_t ssm_num_heads = hparams.ssm_dt_rank; // ssm_num_heads + const int64_t ssm_conv_dim = ssm_intermediate_size + 2 * ssm_n_groups * ssm_state_size; + const int64_t ssm_projection_size = ssm_intermediate_size + ssm_conv_dim + ssm_num_heads; + + // attn params + const int64_t attn_num_attention_head = hparams.n_head(0); // rename to: attn_num_attention_head + const int64_t attn_num_key_value_head = hparams.n_head_kv(0); + const int64_t attn_head_dim = hparams.attn_head_dim > 0 ? hparams.attn_head_dim : hidden_size / attn_num_attention_head; + + // ffn params + const int64_t ffn_intermediate_size = hparams.n_ff(0); + + // embeddings + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hidden_size, vocab_size}, 0); + + // output + { + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hidden_size, vocab_size}, TENSOR_NOT_REQUIRED); + final_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {hidden_size}, 0); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + /*SSM LAYERS*/ + // ssm in + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {hidden_size, ssm_projection_size}, 0); + layer.ssm_in_b = create_tensor(tn(LLM_TENSOR_SSM_IN, "bias", i), {n_embd, ssm_projection_size}, llama_model_loader::TENSOR_NOT_REQUIRED); + // ssm 1d conv + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {ssm_conv_kernel_size, ssm_conv_dim}, 0); + layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {ssm_conv_dim}, llama_model_loader::TENSOR_NOT_REQUIRED); + // ssm_dt + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {ssm_num_heads}, 0); + // no "weight" suffix for these + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, ssm_num_heads}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, ssm_num_heads}, 0); + if (hparams.ssm_has_mup == true) { + layer.ssm_mup_vec = create_tensor(tn(LLM_TENSOR_SSM_MUP_VEC, i), {2*ssm_intermediate_size + 2*ssm_n_groups*ssm_state_size + ssm_num_heads}, 0); + } + // ssm_norm + if (hparams.mamba_rms_norm == true) { + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {ssm_intermediate_size / ssm_n_groups, ssm_n_groups}, 0); + } + // out_proj + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {ssm_intermediate_size, hidden_size}, 0); + + /*ATTENTION LAYERS*/ + // attention layers (with optional bias) + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {hidden_size, attn_head_dim * attn_num_attention_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {hidden_size, attn_num_key_value_head * attn_head_dim}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {hidden_size, attn_num_key_value_head * attn_head_dim}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {attn_head_dim * attn_num_attention_head, hidden_size}, 0); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {hidden_size}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {attn_num_key_value_head * attn_head_dim}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {attn_num_key_value_head * attn_head_dim}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {hidden_size}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {hidden_size}, 0); + + + // feed forward (w/ optional biases) + layer.ffn_pre_norm = create_tensor(tn(LLM_TENSOR_FFN_PRE_NORM, i), {hidden_size}, 0); + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {hidden_size, ffn_intermediate_size}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { ffn_intermediate_size, hidden_size}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {hidden_size, ffn_intermediate_size}, 0); + + layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {ffn_intermediate_size}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {hidden_size}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {ffn_intermediate_size}, llama_model_loader::TENSOR_NOT_REQUIRED); + } + } break; case LLM_ARCH_DOTS1: { const int64_t n_ff_exp = hparams.n_ff_exp; @@ -4506,10 +4728,14 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn); LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown"); + } + + if (arch == LLM_ARCH_MAMBA || arch == LLM_ARCH_MAMBA2) { LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); + LLAMA_LOG_INFO("%s: ssm_n_group = %u\n", __func__, hparams.ssm_n_group); LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); if (!classifier_labels.empty()) { @@ -5299,6 +5525,286 @@ struct llm_build_baichuan : public llm_graph_context { } }; +struct llm_build_falcon_h1 : public llm_graph_context { + const llama_model & model; + + llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params), model(model) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + inpL = ggml_scale(ctx0, inpL, hparams.embedding_multiplier); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + // Build the inputs in the recurrent cache + auto * inp_rs = build_rs_inp_hybrid_recurrent(); + + // Build the inputs in the attention cache + auto * inp_attn = build_attn_inp_kv_hybrid_recurrent(); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + cur = ggml_scale(ctx0, cur, hparams.attention_in_multiplier); + + // self-attention + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = ggml_scale(ctx0, Kcur, hparams.key_multiplier); + + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, 0, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, 0, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + ggml_tensor * attn_out = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + attn_out = ggml_scale(ctx0, attn_out, hparams.attention_out_multiplier); + cb(attn_out, "attn_out", il); + + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + // Mamba2 layer + cur = ggml_scale(ctx0, cur, hparams.ssm_in_multiplier); + cb(cur, "ssm_in", il); + + ggml_tensor * ssm_out = build_mamba2_layer(inp_rs, gf, cur, ubatch, il); + ssm_out = ggml_scale(ctx0, ssm_out, hparams.ssm_out_multiplier); + cb(ssm_out, "ssm_out", il); + + // // Aggregation + cur = ggml_add(ctx0, attn_out, ssm_out); + inpSA = ggml_add(ctx0, cur, inpSA); + cb(cur, "layer_out", il); + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = inpSA; + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = build_norm(ffn_inp, + model.layers[il].ffn_pre_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, inpSA); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.final_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + cur = ggml_scale(ctx0, cur, hparams.lm_head_multiplier); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } + + ggml_tensor * build_mamba2_layer( + llm_graph_input_rs_hybrid_recurrent * inp, + ggml_cgraph * gf, + ggml_tensor * cur, + const llama_ubatch & ubatch, + int il) const { + const auto * kv_state = static_cast(mstate)->get_state_recurrent(); + + const auto kv_head = kv_state->get_head(); + + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_ssm = hparams.ssm_mamba_d_ssm; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t n_head = hparams.ssm_dt_rank; + const int64_t head_dim = hparams.ssm_head_dim == 0 ? d_inner / n_head : hparams.ssm_head_dim; + const int64_t n_group = hparams.ssm_n_group; + const int64_t n_seqs = ubatch.n_seqs; + + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + + GGML_ASSERT(n_seqs != 0); + GGML_ASSERT(ubatch.equal_seqs); + GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); + + ggml_tensor * conv_states_all = kv_state->get_k_l(il); + ggml_tensor * ssm_states_all = kv_state->get_v_l(il); + + ggml_tensor * conv = build_rs(inp, gf, conv_states_all, hparams.n_embd_k_s(), n_seqs); + conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_ssm + 2*n_group*d_state, n_seqs); + + // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} + cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); + + // d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads + + // {n_embd, d_in_proj} @ {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs} + ggml_tensor * zxBCdt = build_lora_mm(model.layers[il].ssm_in, cur); + + + // check if the models has ssm_multipliers (MuP) + if (hparams.ssm_has_mup) { + struct ggml_tensor * mup_vec = model.layers[il].ssm_mup_vec; + cur = ggml_mul(ctx0, zxBCdt, mup_vec); + zxBCdt = cur; + } + + // split the above in three + ggml_tensor * z = ggml_view_4d(ctx0, zxBCdt, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*zxBCdt->nb[0], zxBCdt->nb[1], zxBCdt->nb[2], 0); + ggml_tensor * xBC = ggml_view_3d(ctx0, zxBCdt, d_ssm + 2*n_group*d_state, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], d_ssm*ggml_element_size(zxBCdt)); + ggml_tensor * dt = ggml_view_3d(ctx0, zxBCdt, n_head, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], (2*d_ssm + 2*n_group*d_state)*ggml_element_size(zxBCdt)); + + // conv + { + // => {d_conv - 1 + n_seq_tokens, d_inner + 2*n_group*d_state, n_seqs} + ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, xBC), 0); + + + // copy last (d_conv - 1) columns back into the state cache + ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_ssm + 2*n_group*d_state, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0])); + + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, last_conv, + ggml_view_1d(ctx0, conv_states_all, + (d_conv - 1)*(d_ssm + 2*n_group*d_state)*(n_seqs), + kv_head*(d_conv - 1)*(d_ssm + 2*n_group*d_state)*ggml_element_size(conv_states_all)))); + + // 1D convolution + // The equivalent is to make a self-overlapping view of conv_x + // over d_conv columns at each stride in the 3rd dimension, + // then element-wise multiply that with the conv1d weight, + // then sum the elements of each row, + // (the last two steps are a dot product over rows (also doable with mul_mat)) + // then permute away the ne[0] dimension, + // and then you're left with the resulting x tensor. + // For simultaneous sequences, all sequences need to have the same length. + xBC = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d); + + // bias + xBC = ggml_add(ctx0, xBC, model.layers[il].ssm_conv1d_b); + + xBC = ggml_silu(ctx0, xBC); + } + + // ssm + { + // These correspond to V K Q in SSM/attention duality + ggml_tensor * x = ggml_view_4d(ctx0, xBC, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*xBC->nb[0], xBC->nb[1], xBC->nb[2], 0); + + ggml_tensor * B = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], d_ssm*ggml_element_size(xBC)); + + ggml_tensor * C = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], (d_ssm + n_group*d_state)*ggml_element_size(xBC)); + + // {n_head, n_seq_tokens, n_seqs} + dt = ggml_add(ctx0, ggml_cont(ctx0, dt), model.layers[il].ssm_dt_b); + + + ggml_tensor * A = model.layers[il].ssm_a; + + // use the states and the indices provided by build_rs + // (this is necessary in order to properly use the states before they are overwritten, + // while avoiding to make unnecessary copies of the states) + auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) { + ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, kv_state->get_size()); + + // TODO: use semistructured matrices to implement state-space duality + // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} + return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); + }; + + ggml_tensor * y_ssm = build_rs(inp, gf, ssm_states_all, hparams.n_embd_v_s(), ubatch.n_seqs, get_ssm_rows); + + // store last states + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, + ggml_view_1d(ctx0, y_ssm, d_state*d_ssm*n_seqs, ggml_nelements(x)*x->nb[0]), + ggml_view_1d(ctx0, ssm_states_all, d_state*d_ssm*n_seqs, kv_head*d_state*d_ssm*ggml_element_size(ssm_states_all)))); + + ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_head, n_seq_tokens, n_seqs, x->nb[1], n_head*x->nb[1], n_seq_tokens*n_head*x->nb[1], 0); + + // TODO: skip computing output earlier for unused tokens + + y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d)); + y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z))); + + // grouped RMS norm + if (hparams.mamba_rms_norm){ + y = ggml_reshape_4d(ctx0, y, d_ssm / n_group, n_group, n_seq_tokens, n_seqs); + y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il); + } + y = ggml_reshape_3d(ctx0, y, d_ssm, n_seq_tokens, n_seqs); + + // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} + cur = build_lora_mm(model.layers[il].ssm_out, y); + } + + // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} + cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs); + cb(cur, "mamba_out", il); + return cur; + } +}; + struct llm_build_xverse : public llm_graph_context { llm_build_xverse(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -9111,7 +9617,7 @@ struct llm_build_mamba : public llm_graph_context { // {n_embd, n_tokens} inpL = build_inp_embd(model.tok_embd); - ggml_tensor * state_copy = build_inp_s_copy(); + auto * rs_inp = build_rs_inp_recurrent(); for (int il = 0; il < n_layer; ++il) { // norm @@ -9120,7 +9626,11 @@ struct llm_build_mamba : public llm_graph_context { LLM_NORM_RMS, il); cb(cur, "attn_norm", il); - cur = build_mamba_layer(gf, cur, state_copy, ubatch, il); + if (model.arch == LLM_ARCH_MAMBA2) { + cur = build_mamba2_layer(rs_inp, gf, cur, ubatch, il); + } else { + cur = build_mamba_layer(rs_inp, gf, cur, ubatch, il); + } if (il == n_layer - 1) { // skip computing output for unused tokens @@ -9156,13 +9666,12 @@ struct llm_build_mamba : public llm_graph_context { ggml_build_forward_expand(gf, cur); } - // TODO: split ggml_tensor * build_mamba_layer( - ggml_cgraph * gf, - ggml_tensor * cur, - ggml_tensor * state_copy, - const llama_ubatch & ubatch, - int il) const { + llm_graph_input_rs * inp, + ggml_cgraph * gf, + ggml_tensor * cur, + const llama_ubatch & ubatch, + int il) const { const auto * kv_state = static_cast(mstate); const auto kv_head = kv_state->get_head(); @@ -9171,6 +9680,8 @@ struct llm_build_mamba : public llm_graph_context { const int64_t d_inner = hparams.ssm_d_inner; const int64_t d_state = hparams.ssm_d_state; const int64_t dt_rank = hparams.ssm_dt_rank; + const int64_t n_head = d_inner; + const int64_t head_dim = 1; const int64_t n_seqs = ubatch.n_seqs; // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers) const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms; @@ -9186,15 +9697,8 @@ struct llm_build_mamba : public llm_graph_context { ggml_tensor * conv_states_all = kv_state->get_k_l(il); ggml_tensor * ssm_states_all = kv_state->get_v_l(il); - // (ab)using the KV cache to store the states - ggml_tensor * conv = build_recurrent_state( - gf, conv_states_all, state_copy, - hparams.n_embd_k_s(), n_seqs); + ggml_tensor * conv = build_rs(inp, gf, conv_states_all, hparams.n_embd_k_s(), n_seqs); conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs); - ggml_tensor * ssm = build_recurrent_state( - gf, ssm_states_all, state_copy, - hparams.n_embd_v_s(), n_seqs); - ssm = ggml_reshape_3d(ctx0, ssm, d_state, d_inner, n_seqs); // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); @@ -9243,8 +9747,8 @@ struct llm_build_mamba : public llm_graph_context { ggml_tensor * x_db = build_lora_mm(model.layers[il].ssm_x, x); // split ggml_tensor * dt = ggml_view_3d(ctx0, x_db, dt_rank, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], 0); - ggml_tensor * B = ggml_view_3d(ctx0, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*dt_rank); - ggml_tensor * C = ggml_view_3d(ctx0, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*(dt_rank+d_state)); + ggml_tensor * B = ggml_view_4d(ctx0, x_db, d_state, /* n_group */ 1, n_seq_tokens, n_seqs, d_state*x_db->nb[0], x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*dt_rank); + ggml_tensor * C = ggml_view_4d(ctx0, x_db, d_state, /* n_group */ 1, n_seq_tokens, n_seqs, d_state*x_db->nb[0], x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*(dt_rank+d_state)); // Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers if (ssm_dt_b_c_rms) { @@ -9257,32 +9761,173 @@ struct llm_build_mamba : public llm_graph_context { dt = build_lora_mm(model.layers[il].ssm_dt, dt); dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b); - // Custom operator to optimize the parallel associative scan - // as described in the Annex D of the Mamba paper. - // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} - ggml_tensor * y_ssm = ggml_ssm_scan(ctx0, ssm, x, dt, model.layers[il].ssm_a, B, C); + cur = x; + x = ggml_reshape_4d(ctx0, x, head_dim, n_head, n_seq_tokens, n_seqs); + + ggml_tensor * A = model.layers[il].ssm_a; + + // use the states and the indices provided by build_rs + // (this is necessary in order to properly use the states before they are overwritten, + // while avoiding to make unnecessary copies of the states) + auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) { + ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, kv_state->get_size()); + + // Custom operator to optimize the parallel associative scan + // as described in the Annex D of the Mamba paper. + // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} + return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); + }; + + ggml_tensor * y_ssm = build_rs(inp, gf, ssm_states_all, hparams.n_embd_v_s(), ubatch.n_seqs, get_ssm_rows); + + // store last states + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, + ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, x->nb[3]*x->ne[3]), + ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); + + ggml_tensor * y = ggml_view_3d(ctx0, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[2], x->nb[3], 0); + + // TODO: skip computing output earlier for unused tokens + + y = ggml_add(ctx0, y, ggml_mul(ctx0, cur, model.layers[il].ssm_d)); + y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z))); + + // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} + cur = build_lora_mm(model.layers[il].ssm_out, y); + } + + // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} + cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs); + // cb(cur, "mamba_out", il); + + return cur; + } + + ggml_tensor * build_mamba2_layer( + llm_graph_input_rs * inp, + ggml_cgraph * gf, + ggml_tensor * cur, + const llama_ubatch & ubatch, + int il) const { + const auto * kv_state = static_cast(mstate); + + const auto kv_head = kv_state->get_head(); + + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t n_head = hparams.ssm_dt_rank; + const int64_t head_dim = d_inner / n_head; + const int64_t n_group = hparams.ssm_n_group; + const int64_t n_seqs = ubatch.n_seqs; + + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + + GGML_ASSERT(n_seqs != 0); + GGML_ASSERT(ubatch.equal_seqs); + GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); + + ggml_tensor * conv_states_all = kv_state->get_k_l(il); + ggml_tensor * ssm_states_all = kv_state->get_v_l(il); + + ggml_tensor * conv = build_rs(inp, gf, conv_states_all, hparams.n_embd_k_s(), n_seqs); + conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs); + + // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} + cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); + + // d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads + + // {n_embd, d_in_proj} @ {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs} + ggml_tensor * zxBCdt = build_lora_mm(model.layers[il].ssm_in, cur); + + // split the above in three + ggml_tensor * z = ggml_view_4d(ctx0, zxBCdt, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*zxBCdt->nb[0], zxBCdt->nb[1], zxBCdt->nb[2], 0); + ggml_tensor * xBC = ggml_view_3d(ctx0, zxBCdt, d_inner + 2*n_group*d_state, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], d_inner*ggml_element_size(zxBCdt)); + ggml_tensor * dt = ggml_view_3d(ctx0, zxBCdt, n_head, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], (2*d_inner + 2*n_group*d_state)*ggml_element_size(zxBCdt)); + + // conv + { + // => {d_conv - 1 + n_seq_tokens, d_inner + 2*n_group*d_state, n_seqs} + ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, xBC), 0); + + // copy last (d_conv - 1) columns back into the state cache + ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0])); + + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, last_conv, + ggml_view_1d(ctx0, conv_states_all, + (d_conv - 1)*(d_inner + 2*n_group*d_state)*(n_seqs), + kv_head*(d_conv - 1)*(d_inner + 2*n_group*d_state)*ggml_element_size(conv_states_all)))); + + // 1D convolution + // The equivalent is to make a self-overlapping view of conv_x + // over d_conv columns at each stride in the 3rd dimension, + // then element-wise multiply that with the conv1d weight, + // then sum the elements of each row, + // (the last two steps are a dot product over rows (also doable with mul_mat)) + // then permute away the ne[0] dimension, + // and then you're left with the resulting x tensor. + // For simultaneous sequences, all sequences need to have the same length. + xBC = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d); + + // bias + xBC = ggml_add(ctx0, xBC, model.layers[il].ssm_conv1d_b); + + xBC = ggml_silu(ctx0, xBC); + } + + // ssm + { + // These correspond to V K Q in SSM/attention duality + ggml_tensor * x = ggml_view_4d(ctx0, xBC, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*xBC->nb[0], xBC->nb[1], xBC->nb[2], 0); + ggml_tensor * B = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], d_inner*ggml_element_size(xBC)); + ggml_tensor * C = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], (d_inner + n_group*d_state)*ggml_element_size(xBC)); + + // {n_head, n_seq_tokens, n_seqs} + dt = ggml_add(ctx0, ggml_cont(ctx0, dt), model.layers[il].ssm_dt_b); + + ggml_tensor * A = model.layers[il].ssm_a; + + // use the states and the indices provided by build_rs + // (this is necessary in order to properly use the states before they are overwritten, + // while avoiding to make unnecessary copies of the states) + auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) { + ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, kv_state->get_size()); + + // TODO: use semistructured matrices to implement state-space duality + // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} + return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); + }; + + ggml_tensor * y_ssm = build_rs(inp, gf, ssm_states_all, hparams.n_embd_v_s(), ubatch.n_seqs, get_ssm_rows); // store last states ggml_build_forward_expand(gf, ggml_cpy(ctx0, - ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, x->nb[3]), + ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, ggml_nelements(x)*x->nb[0]), ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); - ggml_tensor * y = ggml_view_3d(ctx0, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[1], x->nb[2], 0); + ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_head, n_seq_tokens, n_seqs, x->nb[1], n_head*x->nb[1], n_seq_tokens*n_head*x->nb[1], 0); // TODO: skip computing output earlier for unused tokens - // {d_inner, n_seq_tokens, n_seqs} * {d_inner} => {d_inner, n_seq_tokens, n_seqs} y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d)); y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z))); + // grouped RMS norm + y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs); + y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il); + y = ggml_reshape_3d(ctx0, y, d_inner, n_seq_tokens, n_seqs); + // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} cur = build_lora_mm(model.layers[il].ssm_out, y); } // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs); - //cb(cur, "mamba_out", il); + // cb(cur, "mamba_out", il); return cur; } @@ -11904,10 +12549,10 @@ struct llm_build_rwkv6_base : public llm_graph_context { } ggml_tensor * build_rwkv6_time_mix( + llm_graph_input_rs * inp, ggml_cgraph * gf, ggml_tensor * cur, ggml_tensor * x_prev, - ggml_tensor * state_copy, const llama_ubatch & ubatch, int il) const { const auto * kv_state = static_cast(mstate); @@ -12031,8 +12676,8 @@ struct llm_build_rwkv6_base : public llm_graph_context { k = ggml_sub(ctx0, k, ggml_mul(ctx0, k, w)); } - ggml_tensor * wkv_state = build_recurrent_state( - gf, kv_state->get_v_l(il), state_copy, + ggml_tensor * wkv_state = build_rs( + inp, gf, kv_state->get_v_l(il), hparams.n_embd_v_s(), n_seqs); ggml_tensor * wkv_output; @@ -12087,7 +12732,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base { inpL = build_inp_embd(model.tok_embd); inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1); - ggml_tensor * state_copy = build_inp_s_copy(); + auto * rs_inp = build_rs_inp_recurrent(); const auto n_embd = hparams.n_embd; const auto n_seq_tokens = ubatch.n_seq_tokens; @@ -12097,9 +12742,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base { const llama_layer * layer = &model.layers[il]; inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); - ggml_tensor * token_shift = build_rwkv_token_shift_load( - gf, state_copy, ubatch, il - ); + ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il); ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0); ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift)); @@ -12114,7 +12757,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base { 1 ); - cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, ubatch, il); + cur = build_rwkv6_time_mix(rs_inp, gf, att_norm, x_prev, ubatch, il); ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); cb(ffn_inp, "ffn_inp", il); @@ -12184,7 +12827,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base { inpL = build_inp_embd(model.tok_embd); - ggml_tensor * state_copy = build_inp_s_copy(); + auto * rs_inp = build_rs_inp_recurrent(); const auto n_embd = hparams.n_embd; const auto n_seq_tokens = ubatch.n_seq_tokens; @@ -12194,9 +12837,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base { const llama_layer * layer = &model.layers[il]; inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); - ggml_tensor * token_shift = build_rwkv_token_shift_load( - gf, state_copy, ubatch, il - ); + ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il); ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il); cb(att_norm, "attn_norm", il); @@ -12208,7 +12849,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base { 1 ); - cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, ubatch, il); + cur = build_rwkv6_time_mix(rs_inp, gf, att_norm, x_prev, ubatch, il); token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm)); ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il)); @@ -12296,10 +12937,10 @@ struct llm_build_rwkv7_base : public llm_graph_context { } ggml_tensor * build_rwkv7_time_mix( + llm_graph_input_rs * inp, ggml_cgraph * gf, ggml_tensor * cur, ggml_tensor * x_prev, - ggml_tensor * state_copy, ggml_tensor *& first_layer_value, const llama_ubatch & ubatch, int il) const { @@ -12382,8 +13023,8 @@ struct llm_build_rwkv7_base : public llm_graph_context { v = ggml_reshape_3d(ctx0, v, head_size, head_count, n_tokens); a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens); - ggml_tensor * wkv_state = build_recurrent_state( - gf, kv_state->get_v_l(il), state_copy, + ggml_tensor * wkv_state = build_rs( + inp, gf, kv_state->get_v_l(il), hparams.n_embd_v_s(), n_seqs); ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state); @@ -12440,7 +13081,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base { inpL = build_inp_embd(model.tok_embd); inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1); - ggml_tensor * state_copy = build_inp_s_copy(); + auto * rs_inp = build_rs_inp_recurrent(); const auto n_embd = hparams.n_embd; const auto n_seq_tokens = ubatch.n_seq_tokens; @@ -12450,9 +13091,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base { const llama_layer * layer = &model.layers[il]; inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); - ggml_tensor * token_shift = build_rwkv_token_shift_load( - gf, state_copy, ubatch, il - ); + ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il); ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0); ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift)); @@ -12467,7 +13106,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base { 1 ); - cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, v_first, ubatch, il); + cur = build_rwkv7_time_mix(rs_inp, gf, att_norm, x_prev, v_first, ubatch, il); ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); cb(ffn_inp, "ffn_inp", il); @@ -12533,7 +13172,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base { inpL = build_inp_embd(model.tok_embd); - ggml_tensor * state_copy = build_inp_s_copy(); + auto * rs_inp = build_rs_inp_recurrent(); const auto n_embd = hparams.n_embd; const auto n_seq_tokens = ubatch.n_seq_tokens; @@ -12543,9 +13182,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base { const llama_layer * layer = &model.layers[il]; inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); - ggml_tensor * token_shift = build_rwkv_token_shift_load( - gf, state_copy, ubatch, il - ); + ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il); ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il); cb(att_norm, "attn_norm", il); @@ -12557,7 +13194,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base { 1 ); - cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, v_first, ubatch, il); + cur = build_rwkv7_time_mix(rs_inp, gf, att_norm, x_prev, v_first, ubatch, il); token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm)); ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il)); @@ -13738,6 +14375,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_memory_i * res; switch (arch) { + // Models that need specific instantiation should be handled in the + // switch statement case LLM_ARCH_BERT: case LLM_ARCH_JINA_BERT_V2: case LLM_ARCH_NOMIC_BERT: @@ -13747,57 +14386,75 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, { res = nullptr; } break; - case LLM_ARCH_MAMBA: - case LLM_ARCH_RWKV6: - case LLM_ARCH_RWKV6QWEN2: - case LLM_ARCH_RWKV7: - case LLM_ARCH_ARWKV7: - { - res = new llama_kv_cache_recurrent( - *this, - GGML_TYPE_F32, - GGML_TYPE_F32, - cparams.offload_kqv, - std::max((uint32_t) 1, cparams.n_seq_max), - cparams.n_seq_max); - } break; + // Models that need standard caching should rely on recurrent/hybrid + // checks default: { - const auto padding = llama_kv_cache_unified::get_padding(cparams); - - cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding); - - LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx); - - if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { - GGML_ASSERT(hparams.is_swa_any()); - - res = new llama_kv_cache_unified_iswa( - *this, - params.type_k, - params.type_v, - !cparams.flash_attn, - cparams.offload_kqv, - params.swa_full, - cparams.n_ctx, - cparams.n_seq_max, - cparams.n_ubatch, - padding); - } else { - GGML_ASSERT(!hparams.is_swa_any()); - - res = new llama_kv_cache_unified( + if (llm_arch_is_recurrent(arch)) { + res = new llama_kv_cache_recurrent( *this, nullptr, - params.type_k, - params.type_v, - !cparams.flash_attn, + GGML_TYPE_F32, + GGML_TYPE_F32, cparams.offload_kqv, - cparams.n_ctx, - cparams.n_seq_max, - padding, - hparams.n_swa, - hparams.swa_type); + std::max((uint32_t) 1, cparams.n_seq_max), + cparams.n_seq_max); + } else if (llm_arch_is_hybrid_recurrent(arch)) { + const auto padding = llama_kv_cache_unified::get_padding(cparams); + + cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding); + + res = new llama_kv_cache_hybrid_recurrent( + /* model */ *this, + /* attn_type_k */ params.type_k, + /* attn_type_v */ params.type_v, + /* attn_v_trans */ !cparams.flash_attn, + /* attn_kv_size */ cparams.n_ctx, + /* attn_n_pad */ padding, + /* attn_n_swa */ hparams.n_swa, + /* attn_swa_type */ hparams.swa_type, + /* recurrent_type_k */ GGML_TYPE_F32, + /* recurrent_type_v */ GGML_TYPE_F32, + /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max), + /* n_seq_max */ cparams.n_seq_max, + /* offload */ cparams.offload_kqv); + } else { + const auto padding = llama_kv_cache_unified::get_padding(cparams); + + cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding); + + LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx); + + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { + GGML_ASSERT(hparams.is_swa_any()); + + res = new llama_kv_cache_unified_iswa( + *this, + params.type_k, + params.type_v, + !cparams.flash_attn, + cparams.offload_kqv, + params.swa_full, + cparams.n_ctx, + cparams.n_seq_max, + cparams.n_ubatch, + padding); + } else { + GGML_ASSERT(!hparams.is_swa_any()); + + res = new llama_kv_cache_unified( + *this, + nullptr, + params.type_k, + params.type_v, + !cparams.flash_attn, + cparams.offload_kqv, + cparams.n_ctx, + cparams.n_seq_max, + padding, + hparams.n_swa, + hparams.swa_type); + } } } } @@ -13945,6 +14602,7 @@ llm_graph_result_ptr llama_model::build_graph( llm = std::make_unique(*this, params, gf); } break; case LLM_ARCH_MAMBA: + case LLM_ARCH_MAMBA2: { llm = std::make_unique(*this, params, gf); } break; @@ -14077,6 +14735,10 @@ llm_graph_result_ptr llama_model::build_graph( { llm = std::make_unique(*this, params, gf); } break; + case LLM_ARCH_FALCON_H1: + { + llm = std::make_unique(*this, params, gf); + } break; case LLM_ARCH_DOTS1: { llm = std::make_unique(*this, params, gf); @@ -14201,6 +14863,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_REFACT: case LLM_ARCH_BLOOM: case LLM_ARCH_MAMBA: + case LLM_ARCH_MAMBA2: case LLM_ARCH_JINA_BERT_V2: case LLM_ARCH_T5: case LLM_ARCH_T5ENCODER: @@ -14235,6 +14898,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_CHAMELEON: case LLM_ARCH_BAILINGMOE: case LLM_ARCH_NEO_BERT: + case LLM_ARCH_FALCON_H1: case LLM_ARCH_ARCEE: return LLAMA_ROPE_TYPE_NORM; @@ -14377,14 +15041,7 @@ llama_token llama_model_decoder_start_token(const llama_model * model) { } bool llama_model_is_recurrent(const llama_model * model) { - switch (model->arch) { - case LLM_ARCH_MAMBA: return true; - case LLM_ARCH_RWKV6: return true; - case LLM_ARCH_RWKV6QWEN2: return true; - case LLM_ARCH_RWKV7: return true; - case LLM_ARCH_ARWKV7: return true; - default: return false; - } + return llm_arch_is_recurrent(model->arch); } const std::vector> & llama_internal_get_tensor_map(const llama_model * model) { diff --git a/src/llama-model.h b/src/llama-model.h index 06e6c687943cc..cb24c4acb7b24 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -169,6 +169,8 @@ struct llama_layer { struct ggml_tensor * ffn_sub_norm = nullptr; struct ggml_tensor * attn_norm_cross = nullptr; struct ggml_tensor * attn_norm_enc = nullptr; + struct ggml_tensor * ssm_norm = nullptr; + struct ggml_tensor * final_norm = nullptr; // attention struct ggml_tensor * wq = nullptr; @@ -211,6 +213,7 @@ struct llama_layer { struct ggml_tensor * layer_out_norm_b = nullptr; struct ggml_tensor * ffn_norm_exps = nullptr; struct ggml_tensor * ffn_norm_enc = nullptr; + struct ggml_tensor * ffn_pre_norm = nullptr; // ff struct ggml_tensor * ffn_gate = nullptr; // w1 @@ -254,6 +257,10 @@ struct llama_layer { struct ggml_tensor * ssm_conv1d_b = nullptr; struct ggml_tensor * ssm_dt_b = nullptr; + // falcon-h1 + struct ggml_tensor * ssm_in_b = nullptr; + struct ggml_tensor * ssm_mup_vec = nullptr; + // rwkv struct ggml_tensor * time_mix_w1 = nullptr; struct ggml_tensor * time_mix_w2 = nullptr; @@ -344,6 +351,8 @@ struct llama_model { struct ggml_tensor * output = nullptr; struct ggml_tensor * output_b = nullptr; struct ggml_tensor * output_norm_enc = nullptr; + struct ggml_tensor * final_norm = nullptr; + // classifier struct ggml_tensor * cls = nullptr; diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index dd2251ef3cbef..da12ec0d8223a 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1519,6 +1519,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "llama-v3" || tokenizer_pre == "llama-bpe"|| tokenizer_pre == "falcon3" || + tokenizer_pre == "falcon-h1" || tokenizer_pre == "pixtral") { pre_type = LLAMA_VOCAB_PRE_TYPE_LLAMA3; ignore_merges = true; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 509a4b35f57cb..653ad0d9bd91e 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1854,28 +1854,58 @@ struct test_ssm_scan : public test_case { const ggml_type type; const int64_t d_state; - const int64_t d_inner; + const int64_t head_dim; + const int64_t n_head; + const int64_t n_group; const int64_t n_seq_tokens; const int64_t n_seqs; std::string vars() override { - return VARS_TO_STR5(type, d_state, d_inner, n_seq_tokens, n_seqs); + return VARS_TO_STR7(type, d_state, head_dim, n_head, n_group, n_seq_tokens, n_seqs); } test_ssm_scan(ggml_type type = GGML_TYPE_F32, - int64_t d_state = 32, int64_t d_inner = 32, int64_t n_seq_tokens = 32, int64_t n_seqs = 32) - : type(type), d_state(d_state), d_inner(d_inner), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {} + int64_t d_state = 32, + int64_t head_dim = 1, // non-zero for Mamba-2 + int64_t n_head = 32, + int64_t n_group = 1, + int64_t n_seq_tokens = 32, + int64_t n_seqs = 32) + : type(type), d_state(d_state), head_dim(head_dim), n_head(n_head), n_group(n_group), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {} ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * s = ggml_new_tensor(ctx, type, 4, std::vector{ d_state, d_inner, n_seqs, 1 }.data()); - ggml_tensor * x = ggml_new_tensor(ctx, type, 4, std::vector{ d_inner, n_seq_tokens, n_seqs, 1 }.data()); - ggml_tensor * dt = ggml_new_tensor(ctx, type, 4, std::vector{ d_inner, n_seq_tokens, n_seqs, 1 }.data()); - ggml_tensor * A = ggml_new_tensor(ctx, type, 4, std::vector{ d_state, d_inner, 1 , 1 }.data()); - ggml_tensor * B = ggml_new_tensor(ctx, type, 4, std::vector{ d_state, n_seq_tokens, n_seqs, 1 }.data()); - ggml_tensor * C = ggml_new_tensor(ctx, type, 4, std::vector{ d_state, n_seq_tokens, n_seqs, 1 }.data()); - ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C); + ggml_tensor * s = ggml_new_tensor_4d(ctx, type, d_state, head_dim, n_head, n_seqs); + ggml_tensor * x = ggml_new_tensor_4d(ctx, type, head_dim, n_head, n_seq_tokens, n_seqs); + ggml_tensor * dt = ggml_new_tensor_3d(ctx, type, n_head, n_seq_tokens, n_seqs); + ggml_tensor * A = ggml_new_tensor_2d(ctx, type, (head_dim > 1) ? 1 : d_state, n_head); + ggml_tensor * B = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs); + ggml_tensor * C = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs); + ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs); + ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C, ids); return out; } + + // similar to test_mul_mat_id + void initialize_tensors(ggml_context * ctx) override { + std::random_device rd; + std::default_random_engine rng(rd()); + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + if (t->type == GGML_TYPE_I32) { + if (ggml_is_view_op(t->op)) { continue; } + // ids + for (int64_t r = 0; r < ggml_nrows(t); r++) { + std::vector data(t->ne[0]); + for (int i = 0; i < t->ne[0]; i++) { + data[i] = i; + } + std::shuffle(data.begin(), data.end(), rng); + ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t)); + } + } else { + init_tensor_uniform(t); + } + } + } }; // GGML_OP_RWKV_WKV6 @@ -4194,7 +4224,8 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {8, 1536, 1, 1}, {4, 1536, 1, 1})); test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 4, 1}, {4, 1536, 1, 1})); - test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1024, 32, 4)); + test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1, 1024, 1, 32, 4)); // Mamba-1 + test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 32, 32, 2, 32, 4)); // Mamba-2 test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 1, 1)); test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 1));