Skip to content

Add gguf support for bloom #33473

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/en/gguf.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ For now the supported model architectures are the architectures that have been v
- Qwen2
- Qwen2Moe
- Phi3
- Bloom

## Example usage

Expand Down
14 changes: 9 additions & 5 deletions src/transformers/convert_slow_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,9 +328,11 @@ def converted(self) -> Tokenizer:


class GPT2Converter(Converter):
def converted(self) -> Tokenizer:
vocab = self.original_tokenizer.encoder
merges = list(self.original_tokenizer.bpe_ranks.keys())
def converted(self, vocab: Dict[str, int] = None, merges: List[Tuple[str, str]] = None) -> Tokenizer:
if not vocab:
vocab = self.original_tokenizer.encoder
if not merges:
merges = list(self.original_tokenizer.bpe_ranks)

tokenizer = Tokenizer(
BPE(
Expand All @@ -343,9 +345,11 @@ def converted(self) -> Tokenizer:
)
)

tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=self.original_tokenizer.add_prefix_space)
add_prefix_space = False
add_prefix_space = getattr(self.original_tokenizer, "add_prefix_space", False)
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=add_prefix_space)
tokenizer.decoder = decoders.ByteLevel()
if self.original_tokenizer.add_bos_token:
if getattr(self.original_tokenizer, "add_bos_token", False):
bos = self.original_tokenizer.bos_token
bos_token_id = self.original_tokenizer.bos_token_id
tokenizer.post_processor = processors.TemplateProcessing(
Expand Down
35 changes: 34 additions & 1 deletion src/transformers/integrations/ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from tokenizers.models import BPE

from .. import AddedToken
from ..convert_slow_tokenizer import LlamaConverter, Qwen2Converter
from ..convert_slow_tokenizer import GPT2Converter, LlamaConverter, Qwen2Converter
from ..utils import logging
from ..utils.logging import tqdm

Expand Down Expand Up @@ -107,6 +107,19 @@
"output.weight": "lm_head.weight",
"output_norm": "model.norm",
},
"bloom": {
"token_embd.weight": "transformer.word_embeddings.weight",
"token_embd_norm": "transformer.word_embeddings_layernorm",
"blk": "transformer.h",
"ffn_up": "mlp.dense_h_to_4h",
"ffn_down": "mlp.dense_4h_to_h",
"ffn_norm": "post_attention_layernorm",
"attn_norm": "input_layernorm",
"attn_qkv": "self_attention.query_key_value",
"attn_output": "self_attention.dense",
"output.weight": "lm_head.weight",
"output_norm": "transformer.ln_f",
},
}


Expand Down Expand Up @@ -183,6 +196,13 @@
"attention.layer_norm_rms_epsilon": "rms_norm_eps",
"vocab_size": "vocab_size",
},
"bloom": {
"block_count": "n_layer",
"embedding_length": "hidden_size",
"attention.head_count": "n_head",
"vocab_size": "vocab_size",
"attention.layer_norm_epsilon": "layer_norm_epsilon",
},
}

GGUF_TOKENIZER_MAPPING = {
Expand Down Expand Up @@ -492,11 +512,24 @@ def converted(self) -> Tokenizer:
return tokenizer


class GGUFBloomConverter(GPT2Converter):
def __init__(self, tokenizer_dict):
self.original_tokenizer = GGUFTokenizerSkeleton(tokenizer_dict)
self.additional_kwargs = {}

def converted(self) -> Tokenizer:
vocab = {word: i for i, word in enumerate(self.original_tokenizer.tokens)}
merges = self.original_tokenizer.merges
tokenizer = super().converted(vocab, merges)
return tokenizer


GGUF_TO_FAST_CONVERTERS = {
"llama": GGUFLlamaConverter,
"qwen2": GGUFQwen2Converter,
"qwen2_moe": GGUFQwen2Converter,
"phi3": GGUFPhi3Converter,
"bloom": GGUFBloomConverter,
}


Expand Down
34 changes: 34 additions & 0 deletions src/transformers/modeling_gguf_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,14 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
elif ".attn_k." in name:
weights = reverse_permute_weights(weights, num_heads, num_kv_heads)

if architecture == "bloom" and "attn_qkv" in name:
num_heads = parsed_parameters["config"]["n_head"]
n_embed = parsed_parameters["config"]["hidden_size"]
if "weight" in name:
weights = reverse_reshape_weights(weights, num_heads, n_embed)
else:
weights = reverse_reshape_bias(weights, num_heads, n_embed)

for tensor_name in tensor_key_mapping:
if tensor_name in name:
name = name.replace(tensor_name, tensor_key_mapping[tensor_name])
Expand All @@ -191,3 +199,29 @@ def reverse_permute_weights(weights: np.ndarray, n_head: int, num_kv_heads: Opti
dim = weights.shape[0] // n_head // 2
w = weights.reshape(n_head, dim, 2, *weights.shape[1:])
return w.swapaxes(2, 1).reshape(weights.shape)


def reverse_reshape_weights(weights: np.ndarray, n_head: int, n_embed: int):
# Original reshape implementation
# https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L972-L985
q, k, v = np.array_split(weights, 3, axis=0)

q = q.reshape(n_head, n_embed // n_head, n_embed)
k = k.reshape(n_head, n_embed // n_head, n_embed)
v = v.reshape(n_head, n_embed // n_head, n_embed)
qkv_weights = np.stack([q, k, v], axis=1)

return qkv_weights.reshape(n_head * 3 * (n_embed // n_head), n_embed)


def reverse_reshape_bias(weights: np.ndarray, n_head: int, n_embed: int):
# Original reshape implementation
# https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L986-L998
q_bias, k_bias, v_bias = np.array_split(weights, 3)

q_bias = q_bias.reshape(n_head, n_embed // n_head)
k_bias = k_bias.reshape(n_head, n_embed // n_head)
v_bias = v_bias.reshape(n_head, n_embed // n_head)

qkv_bias = np.stack([q_bias, k_bias, v_bias], axis=1).flatten()
return qkv_bias
4 changes: 2 additions & 2 deletions src/transformers/models/bloom/tokenization_bloom_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ def __init__(
**kwargs,
):
super().__init__(
vocab_file,
merges_file,
vocab_file=vocab_file,
merges_file=merges_file,
tokenizer_file=tokenizer_file,
unk_token=unk_token,
bos_token=bos_token,
Expand Down
60 changes: 60 additions & 0 deletions tests/quantization/ggml/test_ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class GgufIntegrationTests(unittest.TestCase):
llama3_model_id = "NousResearch/Meta-Llama-3-8B-GGUF"
tinyllama_model_id = "PenutChen/TinyLlama-1.1B-Chat-v1.0-GGUF"
phi3_model_id = "microsoft/Phi-3-mini-4k-instruct-gguf"
bloom_model_id = "afrideva/bloom-560m-GGUF"
original_bloom_model_id = "bigscience/bloom-560m"

# standard quants
q4_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf"
Expand Down Expand Up @@ -69,6 +71,8 @@ class GgufIntegrationTests(unittest.TestCase):
q4_0_qwen2_model_id = "qwen1_5-0_5b-chat-q4_0.gguf"
q4_0_qwen2_moe_model_id = "Qwen1.5-MoE-A2.7B-Chat.Q4_0.gguf"
q4_llama3_model_id = "Meta-Llama-3-8B-Q4_K_M.gguf"
fp16_bloom_model_id = "bloom-560m.fp16.gguf"
q8_bloom_model_id = "bloom-560m.q8_0.gguf"
f16_tinyllama_model_id = "TinyLlama-1.1B-Chat-v1.0.FP16.gguf"

example_text = "Hello"
Expand Down Expand Up @@ -385,6 +389,62 @@ def test_llama3_q4_0(self):
EXPECTED_TEXT = "Hello, I am interested in [The Park]\nThe"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)

def test_bloom_fp16(self):
tokenizer = AutoTokenizer.from_pretrained(self.bloom_model_id, gguf_file=self.fp16_bloom_model_id)
model = AutoModelForCausalLM.from_pretrained(
self.bloom_model_id,
gguf_file=self.fp16_bloom_model_id,
device_map="auto",
torch_dtype=torch.float16,
)

text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)

EXPECTED_TEXT = "Hello, I just want to say that I am very"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)

def test_bloom_q8_0(self):
tokenizer = AutoTokenizer.from_pretrained(self.bloom_model_id, gguf_file=self.q8_bloom_model_id)
model = AutoModelForCausalLM.from_pretrained(
self.bloom_model_id,
gguf_file=self.q8_bloom_model_id,
device_map="auto",
torch_dtype=torch.float16,
)

text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)

EXPECTED_TEXT = "Hello, I just want to say that I am very"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)

def test_bloom_weights_conversion_fp16(self):
quantized_model = AutoModelForCausalLM.from_pretrained(
self.bloom_model_id,
gguf_file=self.fp16_bloom_model_id,
device_map="auto",
torch_dtype=torch.float16,
)
original_model = AutoModelForCausalLM.from_pretrained(
self.original_bloom_model_id,
device_map="auto",
torch_dtype=torch.float16,
)

quantized_state_dict = quantized_model.state_dict()
original_state_dict = original_model.state_dict()

for (quantized_name, quantized_param), (original_name, original_param) in zip(
quantized_state_dict.items(), original_state_dict.items()
):
if (
"self_attention.query_key_value" in quantized_name
and "self_attention.query_key_value" in original_name
):
self.assertTrue(quantized_param.shape == original_param.shape)
torch.testing.assert_close(quantized_param, original_param)

def test_tokenization_xnli(self):
import tqdm
from datasets import load_dataset
Expand Down
Loading