Skip to content

Feature: support baichuan serial models, by now, including Baichuan-7…#3009

Merged
cebtenzzre merged 9 commits into
ggml-org:masterfrom
jameswu2014:master
Sep 14, 2023
Merged

Feature: support baichuan serial models, by now, including Baichuan-7…#3009
cebtenzzre merged 9 commits into
ggml-org:masterfrom
jameswu2014:master

Conversation

@jameswu2014

@jameswu2014 jameswu2014 commented Sep 4, 2023

Copy link
Copy Markdown
Contributor

As more and more people begin to use Baichuan's open-source models, the influence of Baichuan models is growing, especially in China. Many community members are interested in adding support for Baichuan models to llama.cpp. Meanwhile, Baichuan is a very open company, and in the future, it plans to open-source more and more models, taking all this into consideration, we would like to add support for the Baichuan model to llama.cpp. To do this, we need to make some changes, which we hope can be merged into the main branch of llama.cpp. In the future, we would be happy to help maintain support for Baichuan models in llama.cpp. We sincerely hope that our pull request can be accepted. Thank you.

By the way, the changes of this time mainly for supporting Baichuan-7B and Baichuan-13B, and the future version.

…B, Baichuan-13B,in the feature, we will support more Baichuan-models
@jameswu2014 jameswu2014 closed this Sep 4, 2023
@jameswu2014 jameswu2014 reopened this Sep 4, 2023
@ggerganov ggerganov added the need feedback Testing and feedback with results are needed label Sep 4, 2023
@ggerganov

Copy link
Copy Markdown
Member

Cool!

Lets get some feedback if everything runs smoothly and we can merge

@jameswu2014

Copy link
Copy Markdown
Contributor Author

OK, Looking forward to the feedback.

@yinguobing

Copy link
Copy Markdown

I've tried with this PR and had encountered an model conversion issue with model Baichuan-13B-Chat:

# Error message
Can not map tensor 'model.layers.15.self_attn.W_pack.weight'

After a little investigation, it seems this line of code will break the model modification loop on layer 15, which is an early stop as the left parts have similar layers should be modified:

# Original code
    for i in itertools.count():
        if f"model.layers.{i}.self_attn.W_pack.weight" in model_part:
            print(f"Unpacking and permuting layer {i}")
            tmp[f"model.layers.{i}.self_attn.q_proj.weight"]=reverse_hf_permute_part(model_part[f"model.layers.{i}.self_attn.W_pack.weight"],0,head_count,head_count)
            tmp[f"model.layers.{i}.self_attn.k_proj.weight"]=reverse_hf_permute_part(model_part[f"model.layers.{i}.self_attn.W_pack.weight"],1,head_count,head_count_kv)
            tmp[f"model.layers.{i}.self_attn.v_proj.weight"]=reverse_hf_part(model_part[f"model.layers.{i}.self_attn.W_pack.weight"],2)
            del tmp[f"model.layers.{i}.self_attn.W_pack.weight"]
        else:
            continue # <- Breaks on layer 15

A possible fix:

    for i in range(block_count):
        if f"model.layers.{i}.self_attn.W_pack.weight" in model_part:
            print(f"Unpacking and permuting layer {i}")
            tmp[f"model.layers.{i}.self_attn.q_proj.weight"]=reverse_hf_permute_part(model_part[f"model.layers.{i}.self_attn.W_pack.weight"],0,head_count,head_count)
            tmp[f"model.layers.{i}.self_attn.k_proj.weight"]=reverse_hf_permute_part(model_part[f"model.layers.{i}.self_attn.W_pack.weight"],1,head_count,head_count_kv)
            tmp[f"model.layers.{i}.self_attn.v_proj.weight"]=reverse_hf_part(model_part[f"model.layers.{i}.self_attn.W_pack.weight"],2)
            del tmp[f"model.layers.{i}.self_attn.W_pack.weight"]

The final result:
baichuan-13b-chat

@jameswu2014

Copy link
Copy Markdown
Contributor Author

I've tried with this PR and had encountered an model conversion issue with model Baichuan-13B-Chat:

# Error message
Can not map tensor 'model.layers.15.self_attn.W_pack.weight'

After a little investigation, it seems this line of code will break the model modification loop on layer 15, which is an early stop as the left parts have similar layers should be modified:

# Original code
    for i in itertools.count():
        if f"model.layers.{i}.self_attn.W_pack.weight" in model_part:
            print(f"Unpacking and permuting layer {i}")
            tmp[f"model.layers.{i}.self_attn.q_proj.weight"]=reverse_hf_permute_part(model_part[f"model.layers.{i}.self_attn.W_pack.weight"],0,head_count,head_count)
            tmp[f"model.layers.{i}.self_attn.k_proj.weight"]=reverse_hf_permute_part(model_part[f"model.layers.{i}.self_attn.W_pack.weight"],1,head_count,head_count_kv)
            tmp[f"model.layers.{i}.self_attn.v_proj.weight"]=reverse_hf_part(model_part[f"model.layers.{i}.self_attn.W_pack.weight"],2)
            del tmp[f"model.layers.{i}.self_attn.W_pack.weight"]
        else:
            continue # <- Breaks on layer 15

A possible fix:

    for i in range(block_count):
        if f"model.layers.{i}.self_attn.W_pack.weight" in model_part:
            print(f"Unpacking and permuting layer {i}")
            tmp[f"model.layers.{i}.self_attn.q_proj.weight"]=reverse_hf_permute_part(model_part[f"model.layers.{i}.self_attn.W_pack.weight"],0,head_count,head_count)
            tmp[f"model.layers.{i}.self_attn.k_proj.weight"]=reverse_hf_permute_part(model_part[f"model.layers.{i}.self_attn.W_pack.weight"],1,head_count,head_count_kv)
            tmp[f"model.layers.{i}.self_attn.v_proj.weight"]=reverse_hf_part(model_part[f"model.layers.{i}.self_attn.W_pack.weight"],2)
            del tmp[f"model.layers.{i}.self_attn.W_pack.weight"]

The final result: baichuan-13b-chat

Thanks for your feedback, you are right, we fixed it.

@yinguobing

Copy link
Copy Markdown

Besides the above fix, it would be better to provide a sample PROMPT file for Baichuan-13B-Chat.

Prompt file: ./prompts/chat-with-baichuan.txt

以下内容为人类用户与与一位智能助手的对话。

用户:你好!
助手:

Then test the model like this

./main \
  --model /path/to/ggml-model-q4_0.gguf \
  --threads 24 \
  --n_predict 2048 \
  --color \
  --interactive \
  --file prompts/chat-with-baichuan.txt \
  --reverse-prompt "用户:"

@jameswu2014

Copy link
Copy Markdown
Contributor Author

Besides the above fix, it would be better to provide a sample PROMPT file for Baichuan-13B-Chat.

Prompt file: ./prompts/chat-with-baichuan.txt

以下内容为人类用户与与一位智能助手的对话。

用户:你好!
助手:

Then test the model like this

./main \
  --model /path/to/ggml-model-q4_0.gguf \
  --threads 24 \
  --n_predict 2048 \
  --color \
  --interactive \
  --file prompts/chat-with-baichuan.txt \
  --reverse-prompt "用户:"

good advices!

@MarvinLong

MarvinLong commented Sep 8, 2023

Copy link
Copy Markdown
 python convert-baichuan-hf-to-gguf.py public/models/Baichuan-13B-Chat
 

hello print:  BaichuanForCausalLM
gguf: found 3 model parts
num_parts:3

gguf: get model metadata
gguf: get tokenizer metadata
gguf: get sentencepiece tokenizer vocab, scores and token types
gguf: Setting special token type bos to 1
gguf: Setting special token type eos to 2
gguf: Setting special token type pad to 0
gguf: get tensor metadata
gguf: loading model part 'pytorch_model-00001-of-00003.bin'
Unpacking and permuting layer 0
Unpacking and permuting layer 1
Unpacking and permuting layer 2
Unpacking and permuting layer 3
Unpacking and permuting layer 4
Unpacking and permuting layer 5
Unpacking and permuting layer 6
Unpacking and permuting layer 7
Unpacking and permuting layer 8
Unpacking and permuting layer 9
Unpacking and permuting layer 10
Unpacking and permuting layer 11
Unpacking and permuting layer 12
Unpacking and permuting layer 13
Unpacking and permuting layer 14

It got stuck while I was converting the model. Already 24 hours. And I tried again, I got the same problem.
A800, Linux

@ggerganov

Copy link
Copy Markdown
Member

@MarvinLong This looks something might be wrong with your data. Not sure, but it might be better to retry from scratch

@ggerganov ggerganov left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will be merging this soon.

Just a heads up - maintaining this model will primarily rely on contributions from the community. Adding some sort of CI in the future would help guarantee that the implementation is stable. But overall, if breaking changes occur, fixing Baichuan will be secondary priority

With time, we will try to refactor the code to reuse common building blocks when building the graphs of different models and this will probably help to keep everything stable together. I just want to get a few more architectures implemented before abstracting things

@jameswu2014

Copy link
Copy Markdown
Contributor Author
 python convert-baichuan-hf-to-gguf.py public/models/Baichuan-13B-Chat
 

hello print:  BaichuanForCausalLM
gguf: found 3 model parts
num_parts:3

gguf: get model metadata
gguf: get tokenizer metadata
gguf: get sentencepiece tokenizer vocab, scores and token types
gguf: Setting special token type bos to 1
gguf: Setting special token type eos to 2
gguf: Setting special token type pad to 0
gguf: get tensor metadata
gguf: loading model part 'pytorch_model-00001-of-00003.bin'
Unpacking and permuting layer 0
Unpacking and permuting layer 1
Unpacking and permuting layer 2
Unpacking and permuting layer 3
Unpacking and permuting layer 4
Unpacking and permuting layer 5
Unpacking and permuting layer 6
Unpacking and permuting layer 7
Unpacking and permuting layer 8
Unpacking and permuting layer 9
Unpacking and permuting layer 10
Unpacking and permuting layer 11
Unpacking and permuting layer 12
Unpacking and permuting layer 13
Unpacking and permuting layer 14

It got stuck while I was converting the model. Already 24 hours. And I tried again, I got the same problem. A800, Linux

Is it solved? I will try it in my env.

@jameswu2014

This comment was marked as outdated.

@jameswu2014

Copy link
Copy Markdown
Contributor Author

Will be merging this soon.

Just a heads up - maintaining this model will primarily rely on contributions from the community. Adding some sort of CI in the future would help guarantee that the implementation is stable. But overall, if breaking changes occur, fixing Baichuan will be secondary priority

With time, we will try to refactor the code to reuse common building blocks when building the graphs of different models and this will probably help to keep everything stable together. I just want to get a few more architectures implemented before abstracting things

Ok, Maybe we can help maintain baichuan models to be stable. For example, If you refactor the architecture, and that cause baichuan models do not work, we can help fix it. Actually, I think it will be great if llama.cpp is been refactored, We are looking forward it.

@MarvinLong

Copy link
Copy Markdown

from scratch

 python convert-baichuan-hf-to-gguf.py public/models/Baichuan-13B-Chat
 

hello print:  BaichuanForCausalLM
gguf: found 3 model parts
num_parts:3

gguf: get model metadata
gguf: get tokenizer metadata
gguf: get sentencepiece tokenizer vocab, scores and token types
gguf: Setting special token type bos to 1
gguf: Setting special token type eos to 2
gguf: Setting special token type pad to 0
gguf: get tensor metadata
gguf: loading model part 'pytorch_model-00001-of-00003.bin'
Unpacking and permuting layer 0
Unpacking and permuting layer 1
Unpacking and permuting layer 2
Unpacking and permuting layer 3
Unpacking and permuting layer 4
Unpacking and permuting layer 5
Unpacking and permuting layer 6
Unpacking and permuting layer 7
Unpacking and permuting layer 8
Unpacking and permuting layer 9
Unpacking and permuting layer 10
Unpacking and permuting layer 11
Unpacking and permuting layer 12
Unpacking and permuting layer 13
Unpacking and permuting layer 14

It got stuck while I was converting the model. Already 24 hours. And I tried again, I got the same problem. A800, Linux

Is it solved? I will try it in my env.

No, I tried buid from scratch 3 times, it always the same. I tried https://github.com/ouwei2013/baichuan13b.cpp/tree/master and it works on me. Can you give me a process to build from scratch? I can check whether I did something wrong.

Comment thread convert-baichuan-hf-to-gguf.py
@LMX-xin

This comment was marked as outdated.

@cebtenzzre cebtenzzre left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The script currently does not work.

@jameswu2014

Copy link
Copy Markdown
Contributor Author

Could pull it again?

@jameswu2014

Copy link
Copy Markdown
Contributor Author

from scratch

 python convert-baichuan-hf-to-gguf.py public/models/Baichuan-13B-Chat
 

hello print:  BaichuanForCausalLM
gguf: found 3 model parts
num_parts:3

gguf: get model metadata
gguf: get tokenizer metadata
gguf: get sentencepiece tokenizer vocab, scores and token types
gguf: Setting special token type bos to 1
gguf: Setting special token type eos to 2
gguf: Setting special token type pad to 0
gguf: get tensor metadata
gguf: loading model part 'pytorch_model-00001-of-00003.bin'
Unpacking and permuting layer 0
Unpacking and permuting layer 1
Unpacking and permuting layer 2
Unpacking and permuting layer 3
Unpacking and permuting layer 4
Unpacking and permuting layer 5
Unpacking and permuting layer 6
Unpacking and permuting layer 7
Unpacking and permuting layer 8
Unpacking and permuting layer 9
Unpacking and permuting layer 10
Unpacking and permuting layer 11
Unpacking and permuting layer 12
Unpacking and permuting layer 13
Unpacking and permuting layer 14

It got stuck while I was converting the model. Already 24 hours. And I tried again, I got the same problem. A800, Linux

Is it solved? I will try it in my env.

No, I tried buid from scratch 3 times, it always the same. I tried https://github.com/ouwei2013/baichuan13b.cpp/tree/master and it works on me. Can you give me a process to build from scratch? I can check whether I did something wrong.

try again, Its my fault, Sorry.

Comment thread convert-baichuan-hf-to-gguf.py
parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab")
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.bin)")
parser.add_argument("ftype", type=int, choices=[0, 1], help="output format - use 0 for float32, 1 for float16", default = 1)

@cebtenzzre cebtenzzre Sep 11, 2023

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default for 'ftype' does not work unless you also use nargs='?'. Someone should fix this in the other scripts as well...

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, I have noticed this problem, We hope keep the consistency with other model converion script by now.Maybe we will fix it in the future.

@cebtenzzre

cebtenzzre commented Sep 11, 2023

Copy link
Copy Markdown
Collaborator

I can't use this script on a 13B model unless I set TMPDIR to disk. I have 24GB of RAM and 24GB of swap. Is this a general limitation of these simpler convert scripts? I've never had such issues with 33B+ models and the standard convert.py.

Traceback (most recent call last):
  File "/home/cebtenzzre/src/forks/llama.cpp/convert-baichuan-hf-to-gguf.py", line 279, in <module>
    gguf_writer.add_tensor(new_name, data)
  File "/home/cebtenzzre/src/forks/llama.cpp/gguf-py/gguf/gguf.py", line 622, in add_tensor
    tensor.tofile(self.temp_file)
OSError: Not enough free space to write 140247040 bytes

@LMX-xin

LMX-xin commented Sep 11, 2023

Copy link
Copy Markdown

I can't use this script on a 13B model unless I set TMPDIR to disk. I have 24GB of RAM and 24GB of swap. Is this a general limitation of these simpler convert scripts? I've never had such issues with 33B+ models and the standard convert.py.

Traceback (most recent call last):
  File "/home/cebtenzzre/src/forks/llama.cpp/convert-baichuan-hf-to-gguf.py", line 279, in <module>
    gguf_writer.add_tensor(new_name, data)
  File "/home/cebtenzzre/src/forks/llama.cpp/gguf-py/gguf/gguf.py", line 622, in add_tensor
    tensor.tofile(self.temp_file)
OSError: Not enough free space to write 140247040 bytes

I tried and I ran it successfully.I used 19.4GB RAM to run the baichuan convert script.So it may be a general limitation.

@lx0126z

lx0126z commented Sep 14, 2023

Copy link
Copy Markdown

How to inference by python after conversion?

@LMX-xin

LMX-xin commented Sep 14, 2023

Copy link
Copy Markdown

How to inference by python after conversion?

Reference Readme to inference after conversion
such as:

# quantize the model to 4-bits (using q4_0 method)
./quantize /path/to/your_converted ./models/7B/ggml-model-q4_0.gguf q4_0

# run the inference
./main -m ./models/7B/ggml-model-q4_0.gguf -n 128

This ”/path/to/your_converted“ becomes the path after conversion

@ggerganov

Copy link
Copy Markdown
Member

I can't use this script on a 13B model unless I set TMPDIR to disk.

The main convert.py script has some extra logic for lazy loading tensors, which allows to convert without loading the full model all at once. It's a bit over my head, but we probably will try to move this logic into gguf.py so it can be easily reused for other scripts. For now, the simple convert scripts will have this drawback in favor of simplicity

@ggerganov ggerganov left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see some questions are still pending. After resolving any issues - just merge it

@cebtenzzre cebtenzzre merged commit 4c8643d into ggml-org:master Sep 14, 2023
pkrmf pushed a commit to morlockstudios-com/llama.cpp that referenced this pull request Sep 26, 2023
Seunghhon pushed a commit to Seunghhon/llama.cpp that referenced this pull request Apr 26, 2026
phuongncn pushed a commit to phuongncn/llama.cpp-gx10-dgx-sparks-deepseekv4 that referenced this pull request Apr 28, 2026
ljubomirj pushed a commit to ljubomirj/llama.cpp that referenced this pull request May 6, 2026
AlexiAlp pushed a commit to minghaop/llama.cpp that referenced this pull request Jun 2, 2026
AlexiAlp pushed a commit to minghaop/llama.cpp that referenced this pull request Jun 2, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

need feedback Testing and feedback with results are needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants