Skip to content

Failure with checkpoint resumption using the bf16-with-fp8-delayed-scaling-mixed precision recipe. #1730

@jstjohn

Description

@jstjohn

Describe the bug

When I train with the precision recipe "bf16-with-fp8-delayed-scaling-mixed", I get the following error that I do not see when training with the recipe "bf16-mixed". Interestingly, I do not see this error even if I initially train with the bf16-with-fp8-delayed-scaling-mixed but restart with bf16-mixed. This suggests that the issue is in checkpoint resumption with the "bf16-with-fp8-delayed-scaling-mixed" precision recipe. Note that this is with the dp_reshardable optimizer type, not the fully reshardable one (I have other issues with the fully reshardable optimizer that prevent me from testing that).

[rank7]: Traceback (most recent call last):                                                                                                                                                                                          
[rank7]:   File "/workspace/.venv/bin/train_evo2", line 10, in <module>                                                                                                                                                              
[rank7]:     sys.exit(main())                                                                                                                                                                                                        
[rank7]:              ^^^^^^                                                                                                                                                                                                         
[rank7]:   File "/workspace/bionemo/src/bionemo/evo2/run/train.py", line 676, in main                                                                                                                                                
[rank7]:     train(args=args)                                                                                                                                                                                                        
[rank7]:   File "/workspace/bionemo/src/bionemo/evo2/run/train.py", line 945, in train                                                                                                                                               
[rank7]:     pretrain(cfg, hyena_forward_step)                                                                                                                                                                                       
[rank7]:   File "/workspace/.venv/lib/python3.12/site-packages/megatron/bridge/utils/decorators.py", line 39, in wrapper                                                                                                             
[rank7]:     return func(*args, **kwargs)                                                                                                                                                                                            
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                            
[rank7]:   File "/workspace/.venv/lib/python3.12/site-packages/megatron/bridge/training/pretrain.py", line 92, in pretrain                                                                                                           
[rank7]:     _pretrain(state=state, forward_step_func=forward_step_func)                                                                                                                                                             
[rank7]:   File "/workspace/.venv/lib/python3.12/site-packages/megatron/bridge/training/pretrain.py", line 120, in _pretrain                                                                                                         
[rank7]:     setup_output = setup(state, dataset_provider, restart_store=store)                                                                                                                                                      
[rank7]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                      
[rank7]:   File "/workspace/.venv/lib/python3.12/site-packages/megatron/bridge/training/setup.py", line 251, in setup                                                                                                                
[rank7]:     load_checkpoint(                                                                                                                                                                                                        
[rank7]:   File "/workspace/.venv/lib/python3.12/site-packages/megatron/bridge/training/checkpointing.py", line 1251, in load_checkpoint                                                                                             
[rank7]:     return _load_checkpoint_from_path(                                                                                                                                                                                      
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                      
[rank7]:   File "/workspace/.venv/lib/python3.12/site-packages/megatron/bridge/training/checkpointing.py", line 1577, in _load_checkpoint_from_path                                                                                  
[rank7]:     optimizer.load_state_dict(state_dict["optimizer"])                                                                                                                                                                      
[rank7]:   File "/workspace/.venv/lib/python3.12/site-packages/megatron/core/optimizer/optimizer.py", line 1215, in load_state_dict                                                                                                  
[rank7]:     self.chained_optimizers[0].load_state_dict(state_dict)                                                                                                                                                                  
[rank7]:   File "/workspace/.venv/lib/python3.12/site-packages/megatron/core/optimizer/distrib_optimizer.py", line 866, in load_state_dict                                                                                           
[rank7]:     self.load_parameter_state_from_dp_reshardable(param_state)                                                                                                                                                              
[rank7]:   File "/workspace/.venv/lib/python3.12/site-packages/megatron/core/optimizer/distrib_optimizer.py", line 1784, in load_parameter_state_from_dp_reshardable                                                                 
[rank7]:     self._set_main_param_and_optimizer_states(model_param, src_tensors)                                                                                                                                                     
[rank7]:   File "/workspace/.venv/lib/python3.12/site-packages/megatron/core/optimizer/distrib_optimizer.py", line 923, in _set_main_param_and_optimizer_states                                                                      
[rank7]:     self.optimizer.set_scaled_state(sharded_model_param, k, v)                                                                                                                                                              
[rank7]:   File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/optimizers/fused_adam.py", line 349, in set_scaled_state                                                                                         
[rank7]:     assert unscaled_state.dtype == torch.float32                                                                                                                                                                            
[rank7]:            ^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                             
[rank7]: AttributeError: 'bool' object has no attribute 'dtype'                                                                                                                                                                      
[rank3]: Traceback (most recent call last):                                                                                                                                                                                          
[rank3]:   File "/workspace/.venv/bin/train_evo2", line 10, in <module>                                                                                                                                                              
[rank3]:     sys.exit(main())                                                                                                                                                                                                        
[rank3]:              ^^^^^^                                                                                                                                                                                                         
[rank3]:   File "/workspace/bionemo/src/bionemo/evo2/run/train.py", line 676, in main                                                                                                                                                
[rank3]:     train(args=args)                                                                                                                                                                                                        
[rank3]:   File "/workspace/bionemo/src/bionemo/evo2/run/train.py", line 945, in train                                                                                                                                               
[rank3]:     pretrain(cfg, hyena_forward_step)                                                                                                                                                                                       
[rank3]:   File "/workspace/.venv/lib/python3.12/site-packages/megatron/bridge/utils/decorators.py", line 39, in wrapper                                                                                                             
[rank3]:     return func(*args, **kwargs)                                                                                                                                                                                            
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                            
[rank3]:   File "/workspace/.venv/lib/python3.12/site-packages/megatron/bridge/training/pretrain.py", line 92, in pretrain                                                                                                           
[rank3]:     _pretrain(state=state, forward_step_func=forward_step_func)                                                                                                                                                             
[rank3]:   File "/workspace/.venv/lib/python3.12/site-packages/megatron/bridge/training/pretrain.py", line 120, in _pretrain                                                                                                         
[rank3]:     setup_output = setup(state, dataset_provider, restart_store=store)                                                                                                                                                      
[rank3]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                      
[rank3]:   File "/workspace/.venv/lib/python3.12/site-packages/megatron/bridge/training/setup.py", line 251, in setup                                                                                                                
[rank3]:     load_checkpoint(                                                                                                                                                                                                        
[rank3]:   File "/workspace/.venv/lib/python3.12/site-packages/megatron/bridge/training/checkpointing.py", line 1251, in load_checkpoint                                                                                             
[rank3]:     return _load_checkpoint_from_path(                                                                                                                                                                                      
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                                                                                      
[rank3]:   File "/workspace/.venv/lib/python3.12/site-packages/megatron/bridge/training/checkpointing.py", line 1577, in _load_checkpoint_from_path                                                                                  
[rank3]:     optimizer.load_state_dict(state_dict["optimizer"])                                                                                                                                                                      
[rank3]:   File "/workspace/.venv/lib/python3.12/site-packages/megatron/core/optimizer/optimizer.py", line 1215, in load_state_dict                                                                                                  
[rank3]:     self.chained_optimizers[0].load_state_dict(state_dict)                                                                                                                                                                  
[rank3]:   File "/workspace/.venv/lib/python3.12/site-packages/megatron/core/optimizer/distrib_optimizer.py", line 866, in load_state_dict                                                                                           
[rank3]:     self.load_parameter_state_from_dp_reshardable(param_state)                                                                                                                                                              
[rank3]:   File "/workspace/.venv/lib/python3.12/site-packages/megatron/core/optimizer/distrib_optimizer.py", line 1784, in load_parameter_state_from_dp_reshardable
[rank3]:     self._set_main_param_and_optimizer_states(model_param, src_tensors)
[rank3]:   File "/workspace/.venv/lib/python3.12/site-packages/megatron/core/optimizer/distrib_optimizer.py", line 923, in _set_main_param_and_optimizer_states
[rank3]:     self.optimizer.set_scaled_state(sharded_model_param, k, v)
[rank3]:   File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/optimizers/fused_adam.py", line 349, in set_scaled_state
[rank3]:     assert unscaled_state.dtype == torch.float32 
[rank3]:            ^^^^^^^^^^^^^^^^^^^^
[rank3]: AttributeError: 'bool' object has no attribute 'dtype'
[rank5]:[W1215 11:48:18.932344841 ProcessGroupNCCL.cpp:1566] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank7]:[W1215 11:48:18.133450842 ProcessGroupNCCL.cpp:1566] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank0]:[W1215 11:48:18.205772003 ProcessGroupNCCL.cpp:1566] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank1]:[W1215 11:48:18.325030003 ProcessGroupNCCL.cpp:1566] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank3]:[W1215 11:48:18.344597499 ProcessGroupNCCL.cpp:1566] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank6]:[W1215 11:48:18.575980747 ProcessGroupNCCL.cpp:1566] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank4]:[W1215 11:48:19.626290445 ProcessGroupNCCL.cpp:1566] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank2]:[W1215 11:48:19.626507003 ProcessGroupNCCL.cpp:1566] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
W1215 11:48:20.002000 3933610 torch/distributed/elastic/multiprocessing/api.py:940] Sending process 3933707 closing signal SIGTERM
W1215 11:48:20.004000 3933610 torch/distributed/elastic/multiprocessing/api.py:940] Sending process 3933708 closing signal SIGTERM
W1215 11:48:20.008000 3933610 torch/distributed/elastic/multiprocessing/api.py:940] Sending process 3933709 closing signal SIGTERM
W1215 11:48:20.010000 3933610 torch/distributed/elastic/multiprocessing/api.py:940] Sending process 3933710 closing signal SIGTERM
W1215 11:48:20.013000 3933610 torch/distributed/elastic/multiprocessing/api.py:940] Sending process 3933711 closing signal SIGTERM
W1215 11:48:20.015000 3933610 torch/distributed/elastic/multiprocessing/api.py:940] Sending process 3933713 closing signal SIGTERM
W1215 11:48:20.017000 3933610 torch/distributed/elastic/multiprocessing/api.py:940] Sending process 3933714 closing signal SIGTERM

In PDB we see that something is calling adam's set_scaled_state with a bool for unscaled state rather than an argument.

(Pdb) ll
327         def set_scaled_state(self, param, state_name, unscaled_state):
328             """Set the optimizer state.
329  
330             If the dtype of the corresponding optimizer state is not FP32,
331             it will do scaling automatically.
332  
333             Arguments:
334                 param (torch.nn.Parameter): One of parameters in this optimizer.
335                 state_name (string): Name of optimizer states, can be one of 'exp_avg', 'exp_avg_sq',
336                     and 'master_param`.
337                 unscaled_state (torch.Tensor): The original high-precision(FP32) state.
338             """
339  
340             store_param_remainders = (
341                 self.store_param_remainders
342                 and state_name == "master_param"
343                 and param.dtype == torch.bfloat16
344             )
345  
346             if store_param_remainders:
347                 assert unscaled_state.dtype == torch.int16
348             else:
349  ->             assert unscaled_state.dtype == torch.float32
350             state = self.state[param]
351             if state_name not in state:
352                 self._initialize_state(param, state_name, False, store_param_remainders)
353  
354             dtype = self.name_to_dtype_map[state_name]
355             if dtype != torch.float32:
356                 scale = self._scales[param]
357                 self._apply_scale(state_name, unscaled_state, state[state_name], scale[state_name])
358             else:
359                 state[state_name].copy_(unscaled_state)
(Pdb) unscaled_state
False
(Pdb) state_name
'padding'
(Pdb) param
tensor([ 0.0024,  0.0138,  0.0028,  ..., -0.0189,  0.0129, -0.0217],
       device='cuda:0', dtype=torch.bfloat16)

Looking at the caller we see this comes from megatron's distributed optimizer:

> /workspace/.venv/lib/python3.12/site-packages/megatron/core/optimizer/distrib_optimizer.py(923)_set_main_param_and_optimizer_states()
-> self.optimizer.set_scaled_state(sharded_model_param, k, v)
(Pdb) v
False
(Pdb) k
'padding'
(Pdb) sharded_model_param
tensor([ 0.0024,  0.0138,  0.0028,  ..., -0.0189,  0.0129, -0.0217],
       device='cuda:0', dtype=torch.bfloat16)
(Pdb) ll
900         def _set_main_param_and_optimizer_states(self, model_param, tensors):
901             """Set the main param and optimizer states corresponding to the input model_param.
902  
903             The structure of the input `tensors`:
904             tensors = {
905                 "param": torch.Tensor
906                 "exp_avg": torch.Tensor
907                 "exp_avg_sq": torch.Tensor
908             }
909             """
910             group_index, group_order = self.model_param_group_index_map[model_param]
911             if self.config.use_precision_aware_optimizer_no_fp8_or_ds_fp8:
912                 sharded_model_param = self.optimizer.param_groups[group_index]["params"][group_order]
913                 for k, v in tensors.items():
914                     if isinstance(self.optimizer, HybridDeviceOptimizer):
915                         if k == "param":
916                             k = "master_param"
917                         self.optimizer.state[sharded_model_param][k] = v
918                         continue
919  
920                     if k == "param":
921                         self.optimizer.set_scaled_state(sharded_model_param, "master_param", v)
922                     else:
923  ->                     self.optimizer.set_scaled_state(sharded_model_param, k, v)
924             else:
925                 main_param = self.optimizer.param_groups[group_index]["params"][group_order]
926                 optim_state = self.optimizer.state[main_param]
927                 dst_tensors = {"param": main_param, **optim_state}
928                 for key in dst_tensors:
929                     dst_tensors[key].copy_(tensors[key])

Which comes from load_parameter_state_from_dp_reshardable:

> /workspace/.venv/lib/python3.12/site-packages/megatron/core/optimizer/distrib_optimizer.py(1784)load_parameter_state_from_dp_reshardable()
-> self._set_main_param_and_optimizer_states(model_param, src_tensors)
(Pdb) ll
1752        def load_parameter_state_from_dp_reshardable(self, state_dict):
1753            """Loads the parameter state from an internal representation.
1754 
1755            Inverse of the `get_parameter_state_dp_reshardable` method.
1756            """
1757            if state_dict is not None and "per_bucket_numel_unpadded" in state_dict:
1758                per_bucket_numel_unpadded_in_checkpoint = state_dict["per_bucket_numel_unpadded"]
1759                assert self.per_bucket_numel_unpadded == per_bucket_numel_unpadded_in_checkpoint, (
1760                    f"Number of unpadded elements in each bucket need to be the same in current run "
1761                    f"({self.per_bucket_numel_unpadded}) and checkpoint "
1762                    f"({per_bucket_numel_unpadded_in_checkpoint})"
1763                )
1764 
1765            for gbuf_idx, gbuf_range_maps in enumerate(self.gbuf_ranges):
1766                assert len(gbuf_range_maps) == 1, "single dtype supported, for now."
1767                for dtype, gbuf_range_map_for_all_buckets in gbuf_range_maps.items():
1768                    for bucket_idx, gbuf_range_map in enumerate(gbuf_range_map_for_all_buckets):
1769                        bucket_state = state_dict[gbuf_idx][dtype][bucket_idx]
1770                        bucket_state = [
1771                            bucket_state_elem
1772                            for bucket_state_elem in bucket_state
1773                            if not bucket_state_elem['padding']
1774                        ]
1775 
1776                        assert len(bucket_state) == len(gbuf_range_map["param_map"]), (
1777                            len(bucket_state),
1778                            len(gbuf_range_map["param_map"]),
1779                        )
1780                        for src_tensors, (model_param, param_range_map) in zip(
1781                            bucket_state, gbuf_range_map["param_map"].items()
1782                        ):
1783                            # Main param & optimizer states.
1784 ->                         self._set_main_param_and_optimizer_states(model_param, src_tensors)

Which comes from:

> /workspace/.venv/lib/python3.12/site-packages/megatron/core/optimizer/distrib_optimizer.py(1784)load_parameter_state_from_dp_reshardable()
-> self._set_main_param_and_optimizer_states(model_param, src_tensors)
(Pdb) ll
1752        def load_parameter_state_from_dp_reshardable(self, state_dict):
1753            """Loads the parameter state from an internal representation.
1754 
1755            Inverse of the `get_parameter_state_dp_reshardable` method.
1756            """
1757            if state_dict is not None and "per_bucket_numel_unpadded" in state_dict:
1758                per_bucket_numel_unpadded_in_checkpoint = state_dict["per_bucket_numel_unpadded"]
1759                assert self.per_bucket_numel_unpadded == per_bucket_numel_unpadded_in_checkpoint, (
1760                    f"Number of unpadded elements in each bucket need to be the same in current run "
1761                    f"({self.per_bucket_numel_unpadded}) and checkpoint "
1762                    f"({per_bucket_numel_unpadded_in_checkpoint})"
1763                )
1764 
1765            for gbuf_idx, gbuf_range_maps in enumerate(self.gbuf_ranges):
1766                assert len(gbuf_range_maps) == 1, "single dtype supported, for now."
1767                for dtype, gbuf_range_map_for_all_buckets in gbuf_range_maps.items():
1768                    for bucket_idx, gbuf_range_map in enumerate(gbuf_range_map_for_all_buckets):
1769                        bucket_state = state_dict[gbuf_idx][dtype][bucket_idx]
1770                        bucket_state = [
1771                            bucket_state_elem
1772                            for bucket_state_elem in bucket_state
1773                            if not bucket_state_elem['padding']
1774                        ]
1775 
1776                        assert len(bucket_state) == len(gbuf_range_map["param_map"]), (
1777                            len(bucket_state),
1778                            len(gbuf_range_map["param_map"]),
1779                        )
1780                        for src_tensors, (model_param, param_range_map) in zip(
1781                            bucket_state, gbuf_range_map["param_map"].items()
1782                        ):
1783                            # Main param & optimizer states.
1784 ->                         self._set_main_param_and_optimizer_states(model_param, src_tensors)

It seems like one possible issue is that bucket_state is trying to filter "padding" out, but we see that the problem above happens on a "padding" object.

We also see that the level one above the last above has the problem land in the if self.config.use_precision_aware_optimizer_no_fp8_or_ds_fp8: block, so we know that outside of that block things seem to be working (eg in bf16).

Steps/Code to reproduce bug

  1. Train with ``"bf16-with-fp8-delayed-scaling-mixed"`
torchrun --nproc-per-node 8 --no-python \
  train_evo2 \
  --hf-tokenizer-model-path tokenizers/nucleotide_fast_tokenizer_256 \
  --model-size striped_hyena_1b_nv_parallel --max-steps 12 --eval-interval 10 \
  --eval-iters 3 --mock-data \
  --micro-batch-size 32 --global-batch-size 256 --seq-length 1024 \
  --tensor-model-parallel 1 \
  --use-precision-aware-optimizer --dataset-seed 33 \
  --seed 41 --ckpt-async-save  --spike-no-more-embedding-init \
  --no-weight-decay-embeddings --cross-entropy-loss-fusion \
  --align-param-gather --overlap-param-gather  --grad-reduce-in-fp32 \
  --decay-steps 100 --warmup-steps 10 \
  --mixed-precision-recipe bf16-with-fp8-delayed-scaling-mixed \
  --no-fp32-residual-connection --activation-checkpoint-recompute-num-layers 1 \
  --attention-dropout 0.001 --hidden-dropout 0.001 \
  --eod-pad-in-loss-mask --enable-preemption \
  --log-interval 5 --debug-ddp-parity-freq 10 \
  --wandb-project evo2-recipes-verification-tmp \
  --wandb-run-name tmp_workstation_run_mock_data \
  --result-dir tmpfp8 --no-renormalize-loss
  1. Resume training (eg increase max steps and run again) keeping the "bf16-with-fp8-delayed-scaling-mixed" recipe and observe failure
torchrun --nproc-per-node 8 --no-python \
  train_evo2 \
  --hf-tokenizer-model-path tokenizers/nucleotide_fast_tokenizer_256 \
  --model-size striped_hyena_1b_nv_parallel --max-steps 22 --eval-interval 10 \
  --eval-iters 3 --mock-data \
  --micro-batch-size 32 --global-batch-size 256 --seq-length 1024 \
  --tensor-model-parallel 1 \
  --use-precision-aware-optimizer --dataset-seed 33 \
  --seed 41 --ckpt-async-save  --spike-no-more-embedding-init \
  --no-weight-decay-embeddings --cross-entropy-loss-fusion \
  --align-param-gather --overlap-param-gather  --grad-reduce-in-fp32 \
  --decay-steps 100 --warmup-steps 10 \
  --mixed-precision-recipe bf16-with-fp8-delayed-scaling-mixed \
  --no-fp32-residual-connection --activation-checkpoint-recompute-num-layers 1 \
  --attention-dropout 0.001 --hidden-dropout 0.001 \
  --eod-pad-in-loss-mask --enable-preemption \
  --log-interval 5 --debug-ddp-parity-freq 10 \
  --wandb-project evo2-recipes-verification-tmp \
  --wandb-run-name tmp_workstation_run_mock_data \
  --result-dir tmpfp8 --no-renormalize-loss

Repeat 1,2 with "bf16-mixed" and do not observe failure. Also replace step 2 above with "bf16-mixed" and do not observe failure (eg you can train with fp8 and resume with bf16).

Expected behavior

FP8 and BF16 training both work.

Additional context

Reach out on slack (nvidia internal) for example images/run scripts on eos.

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions