Skip to content

Commit 5d11de4

Browse files
Add Qwen2Moe GGUF loading support (#33264)
* update gguf doc, config and tensor mapping * add qwen2moe architecture support, GGUFQwen2MoeConverter and q4 unit tests * apply code style fixes * reformat files * assign GGUFQwen2Converter to qwen2_moe
1 parent 132e875 commit 5d11de4

File tree

4 files changed

+76
-5
lines changed

4 files changed

+76
-5
lines changed

docs/source/en/gguf.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ For now the supported model architectures are the architectures that have been v
7878
- LLaMa
7979
- Mistral
8080
- Qwen2
81+
- Qwen2Moe
8182

8283
## Example usage
8384

src/transformers/integrations/ggml.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,21 @@
7979
"output.weight": "lm_head.weight",
8080
"output_norm": "model.norm",
8181
},
82+
"qwen2moe": {
83+
"token_embd": "model.embed_tokens",
84+
"blk": "model.layers",
85+
"ffn_up": "mlp.up_proj",
86+
"ffn_down": "mlp.down_proj",
87+
"ffn_gate": "mlp.gate_proj",
88+
"ffn_norm": "post_attention_layernorm",
89+
"attn_norm": "input_layernorm",
90+
"attn_q": "self_attn.q_proj",
91+
"attn_v": "self_attn.v_proj",
92+
"attn_k": "self_attn.k_proj",
93+
"attn_output": "self_attn.o_proj",
94+
"output.weight": "lm_head.weight",
95+
"output_norm": "model.norm",
96+
},
8297
}
8398

8499

@@ -123,6 +138,18 @@
123138
"attention.layer_norm_rms_epsilon": "rms_norm_eps",
124139
"vocab_size": "vocab_size",
125140
},
141+
"qwen2moe": {
142+
"context_length": "max_position_embeddings",
143+
"block_count": "num_hidden_layers",
144+
"feed_forward_length": "intermediate_size",
145+
"embedding_length": "hidden_size",
146+
"rope.dimension_count": None,
147+
"rope.freq_base": "rope_theta",
148+
"attention.head_count": "num_attention_heads",
149+
"attention.head_count_kv": "num_key_value_heads",
150+
"attention.layer_norm_rms_epsilon": "rms_norm_eps",
151+
"vocab_size": "vocab_size",
152+
},
126153
"tokenizer": {
127154
"ggml.bos_token_id": "bos_token_id",
128155
"ggml.eos_token_id": "eos_token_id",
@@ -244,7 +271,15 @@ def tokenizer(self, proto):
244271
bos_token = proto.tokens[proto.bos_token_id] if getattr(proto, "bos_token_id", None) is not None else None
245272
eos_token = proto.tokens[proto.bos_token_id] if getattr(proto, "eos_token_id", None) is not None else None
246273

247-
tokenizer = Tokenizer(BPE(bpe_vocab, merges, unk_token=unk_token, fuse_unk=True, byte_fallback=True))
274+
tokenizer = Tokenizer(
275+
BPE(
276+
bpe_vocab,
277+
merges,
278+
unk_token=unk_token,
279+
fuse_unk=True,
280+
byte_fallback=True,
281+
)
282+
)
248283

249284
special_tokens = []
250285

@@ -358,6 +393,7 @@ def converted(self) -> Tokenizer:
358393
GGUF_TO_FAST_CONVERTERS = {
359394
"llama": GGUFLlamaConverter,
360395
"qwen2": GGUFQwen2Converter,
396+
"qwen2_moe": GGUFQwen2Converter,
361397
}
362398

363399

src/transformers/modeling_gguf_pytorch_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
9696
else:
9797
updated_architecture = architecture
9898

99+
if "qwen2moe" in architecture:
100+
updated_architecture = "qwen2_moe"
101+
99102
if architecture not in GGUF_SUPPORTED_ARCHITECTURES:
100103
raise ValueError(f"Architecture {architecture} not supported")
101104

tests/quantization/ggml/test_ggml.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@
1616
import unittest
1717

1818
from transformers import AddedToken, AutoModelForCausalLM, AutoTokenizer
19-
from transformers.testing_utils import require_gguf, require_torch_gpu, slow, torch_device
19+
from transformers.testing_utils import (
20+
require_gguf,
21+
require_torch_gpu,
22+
slow,
23+
torch_device,
24+
)
2025
from transformers.utils import is_torch_available
2126

2227

@@ -33,6 +38,7 @@ class GgufIntegrationTests(unittest.TestCase):
3338
imatrix_model_id = "duyntnet/TinyLlama-1.1B-Chat-v1.0-imatrix-GGUF"
3439
mistral_model_id = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF"
3540
qwen2_model_id = "Qwen/Qwen1.5-0.5B-Chat-GGUF"
41+
qwen2_moe_model_id = "RichardErkhov/Qwen_-_Qwen1.5-MoE-A2.7B-Chat-gguf"
3642
llama3_model_id = "NousResearch/Meta-Llama-3-8B-GGUF"
3743
tinyllama_model_id = "PenutChen/TinyLlama-1.1B-Chat-v1.0-GGUF"
3844

@@ -59,6 +65,7 @@ class GgufIntegrationTests(unittest.TestCase):
5965

6066
q4_0_mistral_model_id = "mistral-7b-instruct-v0.2.Q4_0.gguf"
6167
q4_0_qwen2_model_id = "qwen1_5-0_5b-chat-q4_0.gguf"
68+
q4_0_qwen2_moe_model_id = "Qwen1.5-MoE-A2.7B-Chat.Q4_0.gguf"
6269
q4_llama3_model_id = "Meta-Llama-3-8B-Q4_K_M.gguf"
6370
f16_tinyllama_model_id = "TinyLlama-1.1B-Chat-v1.0.FP16.gguf"
6471

@@ -298,7 +305,10 @@ def test_f16(self):
298305
def test_mistral_q4_0(self):
299306
tokenizer = AutoTokenizer.from_pretrained(self.mistral_model_id, gguf_file=self.q4_0_mistral_model_id)
300307
model = AutoModelForCausalLM.from_pretrained(
301-
self.mistral_model_id, gguf_file=self.q4_0_mistral_model_id, device_map="auto", torch_dtype=torch.float16
308+
self.mistral_model_id,
309+
gguf_file=self.q4_0_mistral_model_id,
310+
device_map="auto",
311+
torch_dtype=torch.float16,
302312
)
303313

304314
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
@@ -310,7 +320,10 @@ def test_mistral_q4_0(self):
310320
def test_qwen2_q4_0(self):
311321
tokenizer = AutoTokenizer.from_pretrained(self.qwen2_model_id, gguf_file=self.q4_0_qwen2_model_id)
312322
model = AutoModelForCausalLM.from_pretrained(
313-
self.qwen2_model_id, gguf_file=self.q4_0_qwen2_model_id, device_map="auto", torch_dtype=torch.float16
323+
self.qwen2_model_id,
324+
gguf_file=self.q4_0_qwen2_model_id,
325+
device_map="auto",
326+
torch_dtype=torch.float16,
314327
)
315328

316329
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
@@ -319,6 +332,21 @@ def test_qwen2_q4_0(self):
319332
EXPECTED_TEXT = "Hello.jsoup\n\nI am a beginner"
320333
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
321334

335+
def test_qwen2_moe_q4_0(self):
336+
tokenizer = AutoTokenizer.from_pretrained(self.qwen2_moe_model_id, gguf_file=self.q4_0_qwen2_moe_model_id)
337+
model = AutoModelForCausalLM.from_pretrained(
338+
self.qwen2_moe_model_id,
339+
gguf_file=self.q4_0_qwen2_moe_model_id,
340+
device_map="auto",
341+
torch_dtype=torch.float16,
342+
)
343+
344+
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
345+
out = model.generate(**text, max_new_tokens=10)
346+
347+
EXPECTED_TEXT = "Hello everyone, I'm a newbie here and would like"
348+
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
349+
322350
def test_llama3_q4_0_tokenizer(self):
323351
tokenizer = AutoTokenizer.from_pretrained(self.llama3_model_id, gguf_file=self.q4_llama3_model_id)
324352
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -331,7 +359,10 @@ def test_llama3_q4_0_tokenizer(self):
331359
def test_llama3_q4_0(self):
332360
tokenizer = AutoTokenizer.from_pretrained(self.llama3_model_id, gguf_file=self.q4_llama3_model_id)
333361
model = AutoModelForCausalLM.from_pretrained(
334-
self.llama3_model_id, gguf_file=self.q4_llama3_model_id, device_map="auto", torch_dtype=torch.float16
362+
self.llama3_model_id,
363+
gguf_file=self.q4_llama3_model_id,
364+
device_map="auto",
365+
torch_dtype=torch.float16,
335366
)
336367

337368
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)

0 commit comments

Comments
 (0)