Skip to content

Multi-gpus training with accelerate #1246

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

YushunXiang
Copy link
Contributor

I have seen a lot of issues and PRs link #1176 #558 #317 #956 #876 #778 about data parallel training with multi-GPU. This shows that multi-GPU training is what this community as well as I need. So I write a lerobot/scripts/train_accelerate.py to fix it.

What this does

This pull request introduces a new training script leveraging the accelerate library for distributed and mixed-precision training. It also adds support for gradient accumulation and updates dependencies accordingly. Below are the most significant changes grouped by theme:

New Training Script

  • Added a comprehensive training script in lerobot/scripts/train_accelerate.py that integrates the accelerate library for distributed training, mixed-precision support, and gradient accumulation. The script includes features such as policy updates, checkpointing, evaluation, and integration with Weights & Biases for logging.

Configuration Updates

  • Introduced a new configuration parameter gradient_accumulation_steps in the PreTrainedConfig class to support gradient accumulation during training.

Dependency Updates

  • Added accelerate>=1.7.0 to the pyproject.toml file to include the accelerate library as a dependency for distributed and mixed-precision training.

How to checkout & try? (for the reviewer)

Provide a simple way for the reviewer to try out your changes.

Examples:

python lerobot/scripts/train_accelerate.py \
--policy.type=pi0 \
--dataset.repo_id=danaaubakirova/koch_test

@lucasjinreal
Copy link

Please merge?

@lucasjinreal
Copy link

This PR is not working:

return self._untyped_storage.data_ptr()
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: Attempted to access the data pointer on an invalid python storage.

The make_policy will fail

@YushunXiang
Copy link
Contributor Author

This PR is not working:

return self._untyped_storage.data_ptr() [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: RuntimeError: Attempted to access the data pointer on an invalid python storage.

The make_policy will fail

Could you give me some details about this error?

@lucasjinreal
Copy link

lucasjinreal commented Jun 14, 2025

Hi, the torch2.6+ introduces DTensor feature. Accelerate won't be able to load the model properly or prepare the model for distributed training when DTensor is not disabled. From what I can see in the train_accelerate, there is nowhere to properly handle this error.

On my side, torch 2.7.1 + transformers & accelerate latest hit the error when training on multiple GPUs.

Once you resolve the loading state dict error, there will still be an error:

[rank0]: AssertionError: found no DeviceMesh from dtensor args for c10d.broadcast_.default!

@zhangxp12345678
Copy link

Can you please tell me what modifications need to be made to this saved weights for inference, I find that multi-card weights are very ineffective (compared to single card training)

@YushunXiang
Copy link
Contributor Author

YushunXiang commented Jun 16, 2025

Can you please tell me what modifications need to be made to this saved weights for inference, I find that multi-card weights are very ineffective (compared to single card training)

I used model.safetensors as the inference checkpoint.

@zhangxp12345678
Copy link

您能否告诉我需要对这个保存的权重进行哪些修改以进行推理,我发现多卡权重非常无效(与单卡训练相比)

我用作推理检查点。model.safetensors

My checkpoint file is the same as the single card, but when reasoning (pi0) it has almost no effect, is there a problem with the checkpoint saving logic or do I need to make other changes, please advise!

@YushunXiang
Copy link
Contributor Author

YushunXiang commented Jun 16, 2025

My checkpoint file is the same as the single card, but when reasoning (pi0) it has almost no effect, is there a problem with the checkpoint saving logic or do I need to make other changes, please advise!

It is a interesting question.

In Line 288, lerobot/scripts/train_accelerate.py files:

        if cfg.save_checkpoint and is_saving_step and accelerator.is_main_process:
            logging.info(f"Checkpoint policy after step {step}")
            checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)

            # Wait for all processes before saving
            accelerator.wait_for_everyone()

            # Unwrap model for saving
            unwrapped_policy = accelerator.unwrap_model(policy)
            save_checkpoint(checkpoint_dir, step, cfg, unwrapped_policy, optimizer, lr_scheduler)
            update_last_checkpoint(checkpoint_dir)

You can modify it to:

        if cfg.save_checkpoint and is_saving_step and accelerator.is_main_process:
            logging.info(f"Checkpoint policy after step {step}")
            checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)

            # Wait for all processes before saving
            accelerator.wait_for_everyone()
            accelerator.save_model(model, save_directory)

Then try again?

@lucasjinreal
Copy link

I am training with multiple GPUs now.

@command-z-z
Copy link

command-z-z commented Jun 17, 2025

My checkpoint file is the same as the single card, but when reasoning (pi0) it has almost no effect, is there a problem with the checkpoint saving logic or do I need to make other changes, please advise!

It is a interesting question.

In Line 288, lerobot/scripts/train_accelerate.py files:

        if cfg.save_checkpoint and is_saving_step and accelerator.is_main_process:
            logging.info(f"Checkpoint policy after step {step}")
            checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)

            # Wait for all processes before saving
            accelerator.wait_for_everyone()

            # Unwrap model for saving
            unwrapped_policy = accelerator.unwrap_model(policy)
            save_checkpoint(checkpoint_dir, step, cfg, unwrapped_policy, optimizer, lr_scheduler)
            update_last_checkpoint(checkpoint_dir)

You can modify it to:

        if cfg.save_checkpoint and is_saving_step and accelerator.is_main_process:
            logging.info(f"Checkpoint policy after step {step}")
            checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)

            # Wait for all processes before saving
            accelerator.wait_for_everyone()
            accelerator.save_model(model, save_directory)

Then try again?

I'm curious why you don't run into a deadlock problem due to incorrect usage accelerator.wait_for_everyone(). Only rank 0(i.e., main process) will come under this if condition, which will wait for other ranks due to accelerator.wait_for_everyone(). However, others will never run this code, which will result in a deadlock.
It may also be that I know too little about this accelerate package. I look forward to your reply to my confusion.
Is the following better or correct?

       if cfg.save_checkpoint and is_saving_step:
            logging.info(f"Process {accelerator.process_index} waiting at barrier before saving.")
            accelerator.wait_for_everyone()
            logging.info(f"Process {accelerator.process_index} passed the barrier.")

            if accelerator.is_main_process:
                logging.info(f"Checkpoint policy after step {step}")
                checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)

                logging.info(colored("This is save model in ", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")

                # Unwrap model for saving
                unwrapped_policy = accelerator.unwrap_model(policy)
                save_checkpoint(checkpoint_dir, step, cfg, unwrapped_policy, optimizer, lr_scheduler)
                update_last_checkpoint(checkpoint_dir)
                # if wandb_logger:
                #     wandb_logger.log_policy(checkpoint_dir)

@zhangxp12345678
Copy link

我的 checkpoint 文件和单卡一样,但是在推理 (pi0) 的时候几乎没有效果,是 checkpoint 保存逻辑有问题还是需要做其他改动,请指教!

这是一个有趣的问题。
在第 288 行中,文件:lerobot/scripts/train_accelerate.py

        if cfg.save_checkpoint and is_saving_step and accelerator.is_main_process:
            logging.info(f"Checkpoint policy after step {step}")
            checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)

            # Wait for all processes before saving
            accelerator.wait_for_everyone()

            # Unwrap model for saving
            unwrapped_policy = accelerator.unwrap_model(policy)
            save_checkpoint(checkpoint_dir, step, cfg, unwrapped_policy, optimizer, lr_scheduler)
            update_last_checkpoint(checkpoint_dir)

您可以将其修改为:

        if cfg.save_checkpoint and is_saving_step and accelerator.is_main_process:
            logging.info(f"Checkpoint policy after step {step}")
            checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)

            # Wait for all processes before saving
            accelerator.wait_for_everyone()
            accelerator.save_model(model, save_directory)

然后重试?

我很好奇为什么你不会因为使用不正确而遇到死锁问题。只有 rank 0(即主进程)才会处于此条件,它将等待其他 rank,因为 。但是,其他人永远不会运行此代码,这将导致死锁。也可能是我对这个包了解得太少。我期待着你对我的困惑的回答。accelerator.wait_for_everyone()``if``accelerator.wait_for_everyone()``accelerate

I've actually encountered this problem and look forward to the author's answer

@xliu0105
Copy link

I am training with multiple GPUs now.

You mentioned above that you encountered the make_policy problem and another problem. How did you solve them?

@YushunXiang
Copy link
Contributor Author

I'm curious why you don't run into a deadlock problem due to incorrect usage accelerator.wait_for_everyone(). Only rank 0(i.e., main process) will come under this if condition, which will wait for other ranks due to accelerator.wait_for_everyone(). However, others will never run this code, which will result in a deadlock. It may also be that I know too little about this accelerate package. I look forward to your reply to my confusion. Is the following better or correct?

       if cfg.save_checkpoint and is_saving_step:
            logging.info(f"Process {accelerator.process_index} waiting at barrier before saving.")
            accelerator.wait_for_everyone()
            logging.info(f"Process {accelerator.process_index} passed the barrier.")

            if accelerator.is_main_process:
                logging.info(f"Checkpoint policy after step {step}")
                checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)

                logging.info(colored("This is save model in ", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")

                # Unwrap model for saving
                unwrapped_policy = accelerator.unwrap_model(policy)
                save_checkpoint(checkpoint_dir, step, cfg, unwrapped_policy, optimizer, lr_scheduler)
                update_last_checkpoint(checkpoint_dir)
                # if wandb_logger:
                #     wandb_logger.log_policy(checkpoint_dir)

You are right. I think the correct code is:

if cfg.save_checkpoint and is_saving_step:
    accelerator.wait_for_everyone()
    if accelerator.is_main_process:
        logging.info(f"Checkpoint policy after step {step}")
        checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
        logging.info(colored("This is save model in ", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
        accelerator.save_model(model, save_directory)
    accelerator.wait_for_everyone() 

I will test it and commit it again. Thanks!
@command-z-z @zhangxp12345678

@xliu0105
Copy link

My checkpoint file is the same as the single card, but when reasoning (pi0) it has almost no effect, is there a problem with the checkpoint saving logic or do I need to make other changes, please advise!

It is a interesting question.

In Line 288, lerobot/scripts/train_accelerate.py files:

        if cfg.save_checkpoint and is_saving_step and accelerator.is_main_process:
            logging.info(f"Checkpoint policy after step {step}")
            checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)

            # Wait for all processes before saving
            accelerator.wait_for_everyone()

            # Unwrap model for saving
            unwrapped_policy = accelerator.unwrap_model(policy)
            save_checkpoint(checkpoint_dir, step, cfg, unwrapped_policy, optimizer, lr_scheduler)
            update_last_checkpoint(checkpoint_dir)

You can modify it to:

        if cfg.save_checkpoint and is_saving_step and accelerator.is_main_process:
            logging.info(f"Checkpoint policy after step {step}")
            checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)

            # Wait for all processes before saving
            accelerator.wait_for_everyone()
            accelerator.save_model(model, save_directory)

Then try again?

When I saved the model using the save_checkpoint function provided by lerobot, I encountered the following problem (I have saved it using the policy after unwrap_model):

[rank0]: Traceback (most recent call last):
[rank0]: File "/nas13/lx_folder/lx_vla/lerobot/lerobot/scripts/train_dis.py", line 389, in
[rank0]: train()
[rank0]: File "/nas13/lx_folder/lx_vla/lerobot/lerobot/configs/parser.py", line 226, in wrapper_inner
[rank0]: response = fn(cfg, *args, **kwargs)
[rank0]: File "/nas13/lx_folder/lx_vla/lerobot/lerobot/scripts/train_dis.py", line 333, in train
[rank0]: save_checkpoint(checkpoint_dir/"lerobot", step, cfg, unwarpped_policy, optimizer, lr_scheduler)
[rank0]: File "/nas13/lx_folder/lx_vla/lerobot/lerobot/common/utils/train_utils.py", line 100, in save_checkpoint
[rank0]: policy.save_pretrained(pretrained_dir)
[rank0]: File "/nas13/lx_folder/lx_vla/lerobot/lerobot/common/utils/hub.py", line 66, in save_pretrained
[rank0]: self._save_pretrained(save_directory)
[rank0]: File "/nas13/lx_folder/lx_vla/lerobot/lerobot/common/policies/pretrained.py", line 74, in _save_pretrained
[rank0]: save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))
[rank0]: File "/home/liuxu/miniforge3/envs/lerobot/lib/python3.10/site-packages/safetensors/torch.py", line 155, in save_model
[rank0]: to_removes = _remove_duplicate_names(state_dict)
[rank0]: File "/home/liuxu/miniforge3/envs/lerobot/lib/python3.10/site-packages/safetensors/torch.py", line 102, in _remove_duplicate_names
[rank0]: raise RuntimeError(
[rank0]: RuntimeError: Error while trying to find names to remove to save state dict, but found no suitable name to keep for saving amongst: {'model.paligemma_with_expert.paligemma.model.multi_modal_projector.linear.weight'}. None is covering the entire storage.Refusing to save/load the model since you could be storing much more memory than needed. Please refer to https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an issue.

Do you know how to solve it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants