Skip to content

Commit e7e2a8f

Browse files
committed
add hybrid model and update readme
1 parent 5623afa commit e7e2a8f

File tree

3 files changed

+71
-36
lines changed

3 files changed

+71
-36
lines changed

.github/workflows/quamba-ci.yml

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -112,11 +112,11 @@ jobs:
112112
python generate.py state-spaces/mamba2-130m --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition_penalty 1.2 --cache_graph --pretrained_dir pretrained_models/ --quantize --quantize_embedding --quantize_lm_head --w_bits 4 --a_bits 8 --apply_gptq --group_heads
113113
python generate.py state-spaces/mamba2-130m --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition_penalty 1.2 --cache_graph --pretrained_dir pretrained_models/ --quantize --quantize_embedding --quantize_lm_head --w_bits 4 --a_bits 16 --apply_gptq
114114
# test generate.py with w4ax hybrid model and store w4ax hybrid models
115-
# we hack and apply the mamba2-8B hybrid config (searched_1400_v3.json) to state-spaces/mamba2-130m
116-
# - name: Test w4ax hybrid generate.py
117-
# run: |
118-
# export CUDA_VISIBLE_DEVICES=7
119-
# python generate.py state-spaces/mamba2-130m --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition_penalty 1.2 --cache_graph --pretrained_dir pretrained_models/ --quantize --quantize_embedding --quantize_lm_head --w_bits 4 --apply_gptq --group_heads --hybrid_blocks --hybrid_blocks_config configs/hybrid/mamba2-8b/searched_1400_v3.json
115+
# we hack and apply the mamba2-8B hybrid config (hybrid_blocks_config.json) to state-spaces/mamba2-130m
116+
- name: Test w4ax hybrid generate.py
117+
run: |
118+
export CUDA_VISIBLE_DEVICES=7
119+
python generate.py state-spaces/mamba2-130m --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition_penalty 1.2 --cache_graph --pretrained_dir pretrained_models/ --quantize --quantize_embedding --quantize_lm_head --w_bits 4 --apply_gptq --group_heads --hybrid_blocks --hybrid_blocks_config configs/hybrid/mamba2-8b/hybrid_blocks_config.json
120120
121121
# test loading the stored quantized models with generate.py
122122
- name: Test loading quantized models
@@ -129,11 +129,11 @@ jobs:
129129
python generate.py ut-enyac/quamba2-130m-w4a8 --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition_penalty 1.2 --cache_graph --pretrained_dir pretrained_models/
130130
python generate.py ut-enyac/quamba2-130m-w4a16 --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition_penalty 1.2 --cache_graph --pretrained_dir pretrained_models/
131131
# test loading the stored w4ax hybrid model with generate.py
132-
# we hack and apply the mamba2-8B hybrid config (searched_1400_v3.json) to state-spaces/mamba2-130m
133-
# - name: Test loading w4ax hybrid generate.py
134-
# run: |
135-
# export CUDA_VISIBLE_DEVICES=7
136-
# python generate.py ut-enyac/quamba2-130m-w4aX-searched_1400_v3 --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition_penalty 1.2 --cache_graph --pretrained_dir pretrained_models/
132+
# we hack and apply the mamba2-8B hybrid config (hybrid_blocks_config.json) to state-spaces/mamba2-130m
133+
- name: Test loading w4ax hybrid generate.py
134+
run: |
135+
export CUDA_VISIBLE_DEVICES=7
136+
python generate.py ut-enyac/quamba2-130m-w4aX-hybrid_blocks_config --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition_penalty 1.2 --cache_graph --pretrained_dir pretrained_models/
137137
- name: Clean up pretrained models
138138
run: |
139139
rm -rf pretrained_models/ut-enyac/*

README.md

Lines changed: 60 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
### Clone Quamba
3535
- Clone the repository with all submodules:
36-
```
36+
```bash
3737
git clone --recurse-submodules [email protected]:enyac-group/Quamba.git
3838
# or
3939
cd Quamba
@@ -43,19 +43,19 @@ git submodule update --init --recursive
4343
- Run in docker (optional)
4444

4545
To build the docker image with customized kernels, run the following commands:
46-
```
46+
```bash
4747
cd docker
4848
./build_docker.sh
4949
./run.sh # launch the container
5050
```
5151

5252
Or Pull the pre-built docker image by
53-
```
53+
```bash
5454
docker image pull hychiang/quamba-cuda-12.1:latest
5555
```
5656

5757
- Create Quamba conda environment
58-
```
58+
```bash
5959
cd Quamba
6060
conda create -n quamba python=3.10
6161
conda activate quamba
@@ -65,102 +65,102 @@ pip install -r requirements.txt
6565
### Build 3rd-party Libraries
6666

6767
- Install `fast-hadamard-transform`:
68-
```
68+
```bash
6969
# set force build to include 12N, 40N from the newer commit
7070
export FAST_HADAMARD_TRANSFORM_FORCE_BUILD=TRUE
7171
pip install 3rdparty/fast-hadamard-transform
7272
```
7373

7474
- Install `lm-evaluation-harness`:
75-
```
75+
```bash
7676
# lm_eval-0.4.2 word2number-1.1
7777
pip install 3rdparty/lm-evaluation-harness
7878
``````
7979

8080
- Install mamba
81-
```
81+
```bash
8282
# set force build to use the commit for Quamba
8383
export MAMBA_FORCE_BUILD=TRUE
8484
pip install 3rdparty/mamba
8585
```
8686

8787
- Install CUTLASS
88-
```
88+
```bash
8989
# cmake version >= 3.22.1
9090
bash build_cutlass.sh
9191
```
9292

9393
- Install Megatron-LM
94-
```
94+
```bash
9595
pip install -e 3rdparty/Megatron-LM
9696
# Not sure why Megatron-LM will force to install pytorch 2.6.0+cu124
9797
# , so run `pip install -r requirements.txt` again if necessary
9898
```
9999

100100
### Build Quamba
101-
```
101+
```bash
102102
pip install .
103103
```
104104

105105
## Model Zoo
106106
| Models | W8A8 | W4A8 | W4A16 | W4AX |
107107
| --------- | ---------|-------------|--------------|------|
108108
| [Mamba1](https://huggingface.co/collections/ut-enyac/quamba-67edf67881154f4a12e41cb3) |||| - |
109-
| [Mamba2](https://huggingface.co/collections/ut-enyac/quamba2-67edf74a0880f7fba8438cc3) | ✅ | ✅ | ✅ | TBD |
109+
| [Mamba2](https://huggingface.co/collections/ut-enyac/quamba2-67edf74a0880f7fba8438cc3) |||| 8B |
110110

111111
: support all sizes, *e.g*, Mamba2 130m/370m/780m/1.3b/2.7b/8b
112112

113113
## Download Models
114-
```
114+
```bash
115115
# huggingface-cli download ut-enyac/quamba2-{size}-{precision} --local-dir pretrained_models/ut-enyac/quamba2-{size}-{precision}
116116
huggingface-cli download ut-enyac/quamba2-2.7b-w4a8 --local-dir pretrained_models/ut-enyac/quamba2-2.7b-w4a8
117117
```
118118

119119
## Generate
120120

121-
```
121+
```bash
122122
python generate.py ut-enyac/quamba2-2.7b-w4a8 --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition_penalty 1.2 --quantize --cache_graph --pretrained_dir pretrained_models
123123
```
124124

125125
## Evaluate
126-
```
126+
```bash
127127
bash eval.sh ut-enyac/quamba2-2.7b-w4a8
128128
```
129129

130130

131131
## Profile latency and memory
132132

133133
- To profile model size, use `--size`:
134-
```
134+
```bash
135135
python profile_mamba.py ut-enyac/quamba2-2.7b-w4a8 --prompt_len 512 --size --pretrained_dir pretrained_models
136136
```
137137

138138
- To profile time-to-first-token (prefilling stage), use `--ttft`:
139-
```
139+
```bash
140140
python profile_mamba.py ut-enyac/quamba2-2.7b-w4a8 --prompt_len 512 --ttft --pretrained_dir pretrained_models
141141
```
142142

143143
- To profile time-per-output-token (generation stage), use `--tpot --cache_graph`:
144-
```
144+
```bash
145145
python profile_mamba.py ut-enyac/quamba2-2.7b-w4a8 --tpot --cache_graph --pretrained_dir pretrained_models
146146
```
147147

148148
- To profile time-to-last-token (prefilling + generation stage), use `--ttlt --cache_graph`:
149-
```
149+
```bash
150150
python profile_mamba.py ut-enyac/quamba2-2.7b-w4a8 --prompt_len 512 --gen_len 512 --ttlt --cache_graph --pretrained_dir pretrained_models
151151
```
152152

153153
## Chat (Mamba1 Only)
154154

155-
```
155+
```bash
156156
huggingface-cli download ut-enyac/quamba-chat-w4a8 --local-dir pretrained_models/ut-enyac/quamba-chat-w4a8
157157
python chat.py ut-enyac/quamba-chat-w4a8 --cache_graph --pretrained_dir ./pretrained_models
158158
```
159159

160160
## Mamba2-8B
161161

162162
**[TL;DR]** We provide the 8B model in all precision formats on Hugging Face. To use it, run:
163-
```
163+
```bash
164164
huggingface-cli download ut-enyac/quamba2-8b-converted-w4a8 --local-dir pretrained_models/ut-enyac/quamba2-8b-converted-w4a8
165165
python main.py ut-enyac/quamba2-8b-converted-w4a8 \
166166
--batch_size 16 \
@@ -173,11 +173,11 @@ python main.py ut-enyac/quamba2-8b-converted-w4a8 \
173173
### Convert Nvidia Mamba2-8B to HuggingFace
174174

175175
Download the checkpoint using `huggingface-cli`
176-
```
176+
```bash
177177
huggingface-cli download nvidia/mamba2-8b-3t-4k --local-dir ./pretrained_models/mamba2-8b-3t-4k
178178
```
179179
After downloading, you will have the directory `./pretrained_models/mamba2-8b-3t-4k` having a structure like this
180-
```
180+
```bash
181181
├── latest_checkpointed_iteration.txt
182182
├── mt_nlg_plus_multilingual_ja_zh_the_stack_frac_015_256k.model (This is tokenizer)
183183
├── README.md
@@ -186,7 +186,7 @@ After downloading, you will have the directory `./pretrained_models/mamba2-8b-3t
186186
└── model_optim_rng.pt (This is weights)
187187
```
188188
+ Run the conversion scripts to get the model directory
189-
```
189+
```bash
190190
python convert_mamba2_8b_to_hf.py \
191191
./pretrained_models/mamba2-8b-3t-4k/release/mp_rank_00/model_optim_rng.pt \
192192
./pretrained_models/mamba2-8b-3t-4k/mt_nlg_plus_multilingual_ja_zh_the_stack_frac_015_256k.model \
@@ -198,7 +198,8 @@ python convert_mamba2_8b_to_hf.py \
198198
After running, you will see a directory called `mamba2-8b-converted` has been created. Then you can run it with evaluation, profiling as the instructions above. However, it requires at least *24GB* memory on the GPU to quantize the Mamba2-8b model.
199199

200200
For example:
201-
```
201+
```bash
202+
# use the `--pretrained_dir` flag to store the quantized model
202203
python main.py pretrained_models/mamba2-8b-converted \
203204
--batch_size 16 \
204205
--eval_zero_shot \
@@ -214,13 +215,46 @@ python main.py pretrained_models/mamba2-8b-converted \
214215
--log_dir logs
215216
```
216217

218+
# Run Mixed-precision Quamba2-8B-W4AX
219+
**[TL;DR]** We provide the W4AX 8B model on Hugging Face. To use it, run:
220+
```bash
221+
huggingface-cli download ut-enyac/quamba2-8b-converted-w4aX --local-dir pretrained_models/ut-enyac/quamba2-8b-converted-w4aX
222+
python main.py ut-enyac/quamba2-8b-converted-w4aX \
223+
--batch_size 16 \
224+
--eval_zero_shot \
225+
--task_list lambada_openai \
226+
--pretrained_dir ./pretrained_models \
227+
--log_dir logs
228+
```
229+
230+
### Quantize and Evaluate Qamba2-8B-W4AX
231+
Follow the previous steps to convert the Mamba2-8B first, and then run
232+
```bash
233+
# use the `--pretrained_dir` flag to store the quantized model
234+
# it will store the mixed-precision model with the name
235+
# ut-enyac/mamba2-8b-converted-w4aX-hybrid_blocks_config
236+
python main.py pretrained_models/mamba2-8b-converted \
237+
--batch_size 16 \
238+
--eval_zero_shot \
239+
--task_list lambada_openai \
240+
--quantize \
241+
--group_heads \
242+
--apply_gptq \
243+
--quantize_embedding \
244+
--quantize_lm_head \
245+
--w_bits 4 \
246+
--hybrid_blocks \
247+
--hybrid_blocks_config configs/hybrid/mamba2-8b/hybrid_blocks_config.json \
248+
--pretrained_dir ./pretrained_models \
249+
--log_dir logs
250+
```
217251

218252
## Citation
219253
```
220254
@article{chiang2025quamba2,
221255
title={Quamba2: A Robust and Scalable Post-training Quantization Framework for Selective State Space Models},
222256
author={Chiang, Hung-Yueh and Chang, Chi-Chih and Frumkin, Natalia and Wu, Kai-Chiang, Abdelfattah, Mohamed S. and Marculescu, Diana},
223-
journal={arXiv preprint arXiv:2503.22879},
257+
journal={International Conference on Machine Learning (ICML)},
224258
year={2025}
225259
}
226260
@inproceedings{chiang2025quamba,
@@ -229,4 +263,4 @@ python main.py pretrained_models/mamba2-8b-converted \
229263
booktitle = {International Conference on Learning Representations (ICLR)},
230264
year = {2025},
231265
}
232-
````
266+
````
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
["W4A16", "W4A8", "W4A8", "W4A8", "W4A16", "W4A16", "W4A8", "W4A8", "W4A8", "W4A16", "W4A16", "W4A8", "W4A8", "W4A8", "W4A8", "W4A8", "W4A8", "W4A16", "W4A8", "W4A8", "W4A8", "W4A16", "W4A16", "W4A8", "W4A8", "W4A8", "W4A8", "W4A8", "W4A16", "W4A8", "W4A8", "W4A8", "W4A8", "W4A16", "W4A8", "W4A16", "W4A8", "W4A8", "W4A8", "W4A8", "W4A8", "W4A8", "W4A8", "W4A8", "W4A8", "W4A8", "W4A8", "W4A8", "W4A16", "W4A16", "W4A8", "W4A8", "W4A8", "W4A8", "W4A16", "W4A8"]

0 commit comments

Comments
 (0)