Skip to content

Train llama 3.1 with GRIT #60

@ThisisXXZ

Description

@ThisisXXZ

I'm now trying to train llama3.1 with GRIT pipeline.

At first I directly change --model_name_or_path and run the training code (the training script I used is as follows)

#!/bin/bash
#SBATCH --time=6:00:00
#SBATCH --job-name=grit_train
#SBATCH --gres=gpu:h100-96:2
#SBATCH --mem=60G
#SBATCH --output=/home/e/e1347696/unified_encoder_decoder/logs/grit_train_out.log
#SBATCH --error=/home/e/e1347696/unified_encoder_decoder/logs/grit_train_err.log

source ~/.bashrc
conda activate grit_eval

export CUDA_HOME='/usr/local/cuda-12.1'
# CUDA_VISIBLE_DEVICES=$(python train/gritlm/mig_uuid_setup.py)
export CUDA_VISIBLE_DEVICES=0,1

cd /home/e/e1347696/unified_encoder_decoder

# nvidia-smi 

deepspeed \
    --num_gpus=2 \
    --module train.gritlm.training.run \
    --output_dir results/GritLM-7B-training \
    --model_name_or_path model/Llama-3.1-8B \
    --train_data data/grit_training_data \
    --max_example_num_per_dataset 1000 \
    --learning_rate 2e-5 \
    --lr_scheduler_type linear \
    --warmup_ratio 0.03 \
    --max_steps 1253 \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 256 \
    --per_device_generative_bs 32 \
    --dataloader_drop_last \
    --normalized \
    --temperature 0.02 \
    --train_group_size 2 \
    --negatives_cross_device \
    --query_max_len 256 \
    --passage_max_len 1024 \
    --mode unified \
    --logging_steps 1 \
    --bf16 \
    --pooling_method mean \
    --use_unique_indices \
    --loss_gen_type mixed \
    --attn bbcc \
    --attn_implementation sdpa \
    --no_gen_gas \
    --gradient_checkpointing \
    --save_steps 1000 \
    --split_emb \
    --deepspeed scripts/configs/config_8gpusds_m7.json

But there is an error TypeError: LlamaModel.forward() got an unexpected keyword argument 'is_causal'. I looked into it and found several issues regarding this #34, #32 and #19.
Just to confirm, if I want to train llama 3.1 model with GRIT, can I just

  • reuse the provided modeling file directly by putting modeling_gritlm7b.py into llama3.1 model folder
    or do I need to
  • change the modeling file for llama3.1 so that it could accept is_causal arg and thus influence attention behavior?

Thank you so much!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions