-
Notifications
You must be signed in to change notification settings - Fork 48
Open
Description
I'm now trying to train llama3.1 with GRIT pipeline.
At first I directly change --model_name_or_path
and run the training code (the training script I used is as follows)
#!/bin/bash
#SBATCH --time=6:00:00
#SBATCH --job-name=grit_train
#SBATCH --gres=gpu:h100-96:2
#SBATCH --mem=60G
#SBATCH --output=/home/e/e1347696/unified_encoder_decoder/logs/grit_train_out.log
#SBATCH --error=/home/e/e1347696/unified_encoder_decoder/logs/grit_train_err.log
source ~/.bashrc
conda activate grit_eval
export CUDA_HOME='/usr/local/cuda-12.1'
# CUDA_VISIBLE_DEVICES=$(python train/gritlm/mig_uuid_setup.py)
export CUDA_VISIBLE_DEVICES=0,1
cd /home/e/e1347696/unified_encoder_decoder
# nvidia-smi
deepspeed \
--num_gpus=2 \
--module train.gritlm.training.run \
--output_dir results/GritLM-7B-training \
--model_name_or_path model/Llama-3.1-8B \
--train_data data/grit_training_data \
--max_example_num_per_dataset 1000 \
--learning_rate 2e-5 \
--lr_scheduler_type linear \
--warmup_ratio 0.03 \
--max_steps 1253 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 256 \
--per_device_generative_bs 32 \
--dataloader_drop_last \
--normalized \
--temperature 0.02 \
--train_group_size 2 \
--negatives_cross_device \
--query_max_len 256 \
--passage_max_len 1024 \
--mode unified \
--logging_steps 1 \
--bf16 \
--pooling_method mean \
--use_unique_indices \
--loss_gen_type mixed \
--attn bbcc \
--attn_implementation sdpa \
--no_gen_gas \
--gradient_checkpointing \
--save_steps 1000 \
--split_emb \
--deepspeed scripts/configs/config_8gpusds_m7.json
But there is an error TypeError: LlamaModel.forward() got an unexpected keyword argument 'is_causal'
. I looked into it and found several issues regarding this #34, #32 and #19.
Just to confirm, if I want to train llama 3.1 model with GRIT, can I just
- reuse the provided modeling file directly by putting
modeling_gritlm7b.py
into llama3.1 model folder
or do I need to - change the modeling file for llama3.1 so that it could accept
is_causal
arg and thus influence attention behavior?
Thank you so much!
Metadata
Metadata
Assignees
Labels
No labels