-
Notifications
You must be signed in to change notification settings - Fork 109
Description
Describe the bug
When I train with the precision recipe "bf16-with-fp8-delayed-scaling-mixed", I get the following error that I do not see when training with the recipe "bf16-mixed". Interestingly, I do not see this error even if I initially train with the bf16-with-fp8-delayed-scaling-mixed but restart with bf16-mixed. This suggests that the issue is in checkpoint resumption with the "bf16-with-fp8-delayed-scaling-mixed" precision recipe. Note that this is with the dp_reshardable optimizer type, not the fully reshardable one (I have other issues with the fully reshardable optimizer that prevent me from testing that).
[rank7]: Traceback (most recent call last):
[rank7]: File "/workspace/.venv/bin/train_evo2", line 10, in <module>
[rank7]: sys.exit(main())
[rank7]: ^^^^^^
[rank7]: File "/workspace/bionemo/src/bionemo/evo2/run/train.py", line 676, in main
[rank7]: train(args=args)
[rank7]: File "/workspace/bionemo/src/bionemo/evo2/run/train.py", line 945, in train
[rank7]: pretrain(cfg, hyena_forward_step)
[rank7]: File "/workspace/.venv/lib/python3.12/site-packages/megatron/bridge/utils/decorators.py", line 39, in wrapper
[rank7]: return func(*args, **kwargs)
[rank7]: ^^^^^^^^^^^^^^^^^^^^^
[rank7]: File "/workspace/.venv/lib/python3.12/site-packages/megatron/bridge/training/pretrain.py", line 92, in pretrain
[rank7]: _pretrain(state=state, forward_step_func=forward_step_func)
[rank7]: File "/workspace/.venv/lib/python3.12/site-packages/megatron/bridge/training/pretrain.py", line 120, in _pretrain
[rank7]: setup_output = setup(state, dataset_provider, restart_store=store)
[rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]: File "/workspace/.venv/lib/python3.12/site-packages/megatron/bridge/training/setup.py", line 251, in setup
[rank7]: load_checkpoint(
[rank7]: File "/workspace/.venv/lib/python3.12/site-packages/megatron/bridge/training/checkpointing.py", line 1251, in load_checkpoint
[rank7]: return _load_checkpoint_from_path(
[rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]: File "/workspace/.venv/lib/python3.12/site-packages/megatron/bridge/training/checkpointing.py", line 1577, in _load_checkpoint_from_path
[rank7]: optimizer.load_state_dict(state_dict["optimizer"])
[rank7]: File "/workspace/.venv/lib/python3.12/site-packages/megatron/core/optimizer/optimizer.py", line 1215, in load_state_dict
[rank7]: self.chained_optimizers[0].load_state_dict(state_dict)
[rank7]: File "/workspace/.venv/lib/python3.12/site-packages/megatron/core/optimizer/distrib_optimizer.py", line 866, in load_state_dict
[rank7]: self.load_parameter_state_from_dp_reshardable(param_state)
[rank7]: File "/workspace/.venv/lib/python3.12/site-packages/megatron/core/optimizer/distrib_optimizer.py", line 1784, in load_parameter_state_from_dp_reshardable
[rank7]: self._set_main_param_and_optimizer_states(model_param, src_tensors)
[rank7]: File "/workspace/.venv/lib/python3.12/site-packages/megatron/core/optimizer/distrib_optimizer.py", line 923, in _set_main_param_and_optimizer_states
[rank7]: self.optimizer.set_scaled_state(sharded_model_param, k, v)
[rank7]: File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/optimizers/fused_adam.py", line 349, in set_scaled_state
[rank7]: assert unscaled_state.dtype == torch.float32
[rank7]: ^^^^^^^^^^^^^^^^^^^^
[rank7]: AttributeError: 'bool' object has no attribute 'dtype'
[rank3]: Traceback (most recent call last):
[rank3]: File "/workspace/.venv/bin/train_evo2", line 10, in <module>
[rank3]: sys.exit(main())
[rank3]: ^^^^^^
[rank3]: File "/workspace/bionemo/src/bionemo/evo2/run/train.py", line 676, in main
[rank3]: train(args=args)
[rank3]: File "/workspace/bionemo/src/bionemo/evo2/run/train.py", line 945, in train
[rank3]: pretrain(cfg, hyena_forward_step)
[rank3]: File "/workspace/.venv/lib/python3.12/site-packages/megatron/bridge/utils/decorators.py", line 39, in wrapper
[rank3]: return func(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/workspace/.venv/lib/python3.12/site-packages/megatron/bridge/training/pretrain.py", line 92, in pretrain
[rank3]: _pretrain(state=state, forward_step_func=forward_step_func)
[rank3]: File "/workspace/.venv/lib/python3.12/site-packages/megatron/bridge/training/pretrain.py", line 120, in _pretrain
[rank3]: setup_output = setup(state, dataset_provider, restart_store=store)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/workspace/.venv/lib/python3.12/site-packages/megatron/bridge/training/setup.py", line 251, in setup
[rank3]: load_checkpoint(
[rank3]: File "/workspace/.venv/lib/python3.12/site-packages/megatron/bridge/training/checkpointing.py", line 1251, in load_checkpoint
[rank3]: return _load_checkpoint_from_path(
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/workspace/.venv/lib/python3.12/site-packages/megatron/bridge/training/checkpointing.py", line 1577, in _load_checkpoint_from_path
[rank3]: optimizer.load_state_dict(state_dict["optimizer"])
[rank3]: File "/workspace/.venv/lib/python3.12/site-packages/megatron/core/optimizer/optimizer.py", line 1215, in load_state_dict
[rank3]: self.chained_optimizers[0].load_state_dict(state_dict)
[rank3]: File "/workspace/.venv/lib/python3.12/site-packages/megatron/core/optimizer/distrib_optimizer.py", line 866, in load_state_dict
[rank3]: self.load_parameter_state_from_dp_reshardable(param_state)
[rank3]: File "/workspace/.venv/lib/python3.12/site-packages/megatron/core/optimizer/distrib_optimizer.py", line 1784, in load_parameter_state_from_dp_reshardable
[rank3]: self._set_main_param_and_optimizer_states(model_param, src_tensors)
[rank3]: File "/workspace/.venv/lib/python3.12/site-packages/megatron/core/optimizer/distrib_optimizer.py", line 923, in _set_main_param_and_optimizer_states
[rank3]: self.optimizer.set_scaled_state(sharded_model_param, k, v)
[rank3]: File "/usr/local/lib/python3.12/dist-packages/transformer_engine/pytorch/optimizers/fused_adam.py", line 349, in set_scaled_state
[rank3]: assert unscaled_state.dtype == torch.float32
[rank3]: ^^^^^^^^^^^^^^^^^^^^
[rank3]: AttributeError: 'bool' object has no attribute 'dtype'
[rank5]:[W1215 11:48:18.932344841 ProcessGroupNCCL.cpp:1566] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank7]:[W1215 11:48:18.133450842 ProcessGroupNCCL.cpp:1566] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank0]:[W1215 11:48:18.205772003 ProcessGroupNCCL.cpp:1566] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank1]:[W1215 11:48:18.325030003 ProcessGroupNCCL.cpp:1566] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank3]:[W1215 11:48:18.344597499 ProcessGroupNCCL.cpp:1566] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank6]:[W1215 11:48:18.575980747 ProcessGroupNCCL.cpp:1566] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank4]:[W1215 11:48:19.626290445 ProcessGroupNCCL.cpp:1566] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank2]:[W1215 11:48:19.626507003 ProcessGroupNCCL.cpp:1566] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
W1215 11:48:20.002000 3933610 torch/distributed/elastic/multiprocessing/api.py:940] Sending process 3933707 closing signal SIGTERM
W1215 11:48:20.004000 3933610 torch/distributed/elastic/multiprocessing/api.py:940] Sending process 3933708 closing signal SIGTERM
W1215 11:48:20.008000 3933610 torch/distributed/elastic/multiprocessing/api.py:940] Sending process 3933709 closing signal SIGTERM
W1215 11:48:20.010000 3933610 torch/distributed/elastic/multiprocessing/api.py:940] Sending process 3933710 closing signal SIGTERM
W1215 11:48:20.013000 3933610 torch/distributed/elastic/multiprocessing/api.py:940] Sending process 3933711 closing signal SIGTERM
W1215 11:48:20.015000 3933610 torch/distributed/elastic/multiprocessing/api.py:940] Sending process 3933713 closing signal SIGTERM
W1215 11:48:20.017000 3933610 torch/distributed/elastic/multiprocessing/api.py:940] Sending process 3933714 closing signal SIGTERM
In PDB we see that something is calling adam's set_scaled_state with a bool for unscaled state rather than an argument.
(Pdb) ll
327 def set_scaled_state(self, param, state_name, unscaled_state):
328 """Set the optimizer state.
329
330 If the dtype of the corresponding optimizer state is not FP32,
331 it will do scaling automatically.
332
333 Arguments:
334 param (torch.nn.Parameter): One of parameters in this optimizer.
335 state_name (string): Name of optimizer states, can be one of 'exp_avg', 'exp_avg_sq',
336 and 'master_param`.
337 unscaled_state (torch.Tensor): The original high-precision(FP32) state.
338 """
339
340 store_param_remainders = (
341 self.store_param_remainders
342 and state_name == "master_param"
343 and param.dtype == torch.bfloat16
344 )
345
346 if store_param_remainders:
347 assert unscaled_state.dtype == torch.int16
348 else:
349 -> assert unscaled_state.dtype == torch.float32
350 state = self.state[param]
351 if state_name not in state:
352 self._initialize_state(param, state_name, False, store_param_remainders)
353
354 dtype = self.name_to_dtype_map[state_name]
355 if dtype != torch.float32:
356 scale = self._scales[param]
357 self._apply_scale(state_name, unscaled_state, state[state_name], scale[state_name])
358 else:
359 state[state_name].copy_(unscaled_state)
(Pdb) unscaled_state
False
(Pdb) state_name
'padding'
(Pdb) param
tensor([ 0.0024, 0.0138, 0.0028, ..., -0.0189, 0.0129, -0.0217],
device='cuda:0', dtype=torch.bfloat16)
Looking at the caller we see this comes from megatron's distributed optimizer:
> /workspace/.venv/lib/python3.12/site-packages/megatron/core/optimizer/distrib_optimizer.py(923)_set_main_param_and_optimizer_states()
-> self.optimizer.set_scaled_state(sharded_model_param, k, v)
(Pdb) v
False
(Pdb) k
'padding'
(Pdb) sharded_model_param
tensor([ 0.0024, 0.0138, 0.0028, ..., -0.0189, 0.0129, -0.0217],
device='cuda:0', dtype=torch.bfloat16)
(Pdb) ll
900 def _set_main_param_and_optimizer_states(self, model_param, tensors):
901 """Set the main param and optimizer states corresponding to the input model_param.
902
903 The structure of the input `tensors`:
904 tensors = {
905 "param": torch.Tensor
906 "exp_avg": torch.Tensor
907 "exp_avg_sq": torch.Tensor
908 }
909 """
910 group_index, group_order = self.model_param_group_index_map[model_param]
911 if self.config.use_precision_aware_optimizer_no_fp8_or_ds_fp8:
912 sharded_model_param = self.optimizer.param_groups[group_index]["params"][group_order]
913 for k, v in tensors.items():
914 if isinstance(self.optimizer, HybridDeviceOptimizer):
915 if k == "param":
916 k = "master_param"
917 self.optimizer.state[sharded_model_param][k] = v
918 continue
919
920 if k == "param":
921 self.optimizer.set_scaled_state(sharded_model_param, "master_param", v)
922 else:
923 -> self.optimizer.set_scaled_state(sharded_model_param, k, v)
924 else:
925 main_param = self.optimizer.param_groups[group_index]["params"][group_order]
926 optim_state = self.optimizer.state[main_param]
927 dst_tensors = {"param": main_param, **optim_state}
928 for key in dst_tensors:
929 dst_tensors[key].copy_(tensors[key])
Which comes from load_parameter_state_from_dp_reshardable:
> /workspace/.venv/lib/python3.12/site-packages/megatron/core/optimizer/distrib_optimizer.py(1784)load_parameter_state_from_dp_reshardable()
-> self._set_main_param_and_optimizer_states(model_param, src_tensors)
(Pdb) ll
1752 def load_parameter_state_from_dp_reshardable(self, state_dict):
1753 """Loads the parameter state from an internal representation.
1754
1755 Inverse of the `get_parameter_state_dp_reshardable` method.
1756 """
1757 if state_dict is not None and "per_bucket_numel_unpadded" in state_dict:
1758 per_bucket_numel_unpadded_in_checkpoint = state_dict["per_bucket_numel_unpadded"]
1759 assert self.per_bucket_numel_unpadded == per_bucket_numel_unpadded_in_checkpoint, (
1760 f"Number of unpadded elements in each bucket need to be the same in current run "
1761 f"({self.per_bucket_numel_unpadded}) and checkpoint "
1762 f"({per_bucket_numel_unpadded_in_checkpoint})"
1763 )
1764
1765 for gbuf_idx, gbuf_range_maps in enumerate(self.gbuf_ranges):
1766 assert len(gbuf_range_maps) == 1, "single dtype supported, for now."
1767 for dtype, gbuf_range_map_for_all_buckets in gbuf_range_maps.items():
1768 for bucket_idx, gbuf_range_map in enumerate(gbuf_range_map_for_all_buckets):
1769 bucket_state = state_dict[gbuf_idx][dtype][bucket_idx]
1770 bucket_state = [
1771 bucket_state_elem
1772 for bucket_state_elem in bucket_state
1773 if not bucket_state_elem['padding']
1774 ]
1775
1776 assert len(bucket_state) == len(gbuf_range_map["param_map"]), (
1777 len(bucket_state),
1778 len(gbuf_range_map["param_map"]),
1779 )
1780 for src_tensors, (model_param, param_range_map) in zip(
1781 bucket_state, gbuf_range_map["param_map"].items()
1782 ):
1783 # Main param & optimizer states.
1784 -> self._set_main_param_and_optimizer_states(model_param, src_tensors)
Which comes from:
> /workspace/.venv/lib/python3.12/site-packages/megatron/core/optimizer/distrib_optimizer.py(1784)load_parameter_state_from_dp_reshardable()
-> self._set_main_param_and_optimizer_states(model_param, src_tensors)
(Pdb) ll
1752 def load_parameter_state_from_dp_reshardable(self, state_dict):
1753 """Loads the parameter state from an internal representation.
1754
1755 Inverse of the `get_parameter_state_dp_reshardable` method.
1756 """
1757 if state_dict is not None and "per_bucket_numel_unpadded" in state_dict:
1758 per_bucket_numel_unpadded_in_checkpoint = state_dict["per_bucket_numel_unpadded"]
1759 assert self.per_bucket_numel_unpadded == per_bucket_numel_unpadded_in_checkpoint, (
1760 f"Number of unpadded elements in each bucket need to be the same in current run "
1761 f"({self.per_bucket_numel_unpadded}) and checkpoint "
1762 f"({per_bucket_numel_unpadded_in_checkpoint})"
1763 )
1764
1765 for gbuf_idx, gbuf_range_maps in enumerate(self.gbuf_ranges):
1766 assert len(gbuf_range_maps) == 1, "single dtype supported, for now."
1767 for dtype, gbuf_range_map_for_all_buckets in gbuf_range_maps.items():
1768 for bucket_idx, gbuf_range_map in enumerate(gbuf_range_map_for_all_buckets):
1769 bucket_state = state_dict[gbuf_idx][dtype][bucket_idx]
1770 bucket_state = [
1771 bucket_state_elem
1772 for bucket_state_elem in bucket_state
1773 if not bucket_state_elem['padding']
1774 ]
1775
1776 assert len(bucket_state) == len(gbuf_range_map["param_map"]), (
1777 len(bucket_state),
1778 len(gbuf_range_map["param_map"]),
1779 )
1780 for src_tensors, (model_param, param_range_map) in zip(
1781 bucket_state, gbuf_range_map["param_map"].items()
1782 ):
1783 # Main param & optimizer states.
1784 -> self._set_main_param_and_optimizer_states(model_param, src_tensors)
It seems like one possible issue is that bucket_state is trying to filter "padding" out, but we see that the problem above happens on a "padding" object.
We also see that the level one above the last above has the problem land in the if self.config.use_precision_aware_optimizer_no_fp8_or_ds_fp8: block, so we know that outside of that block things seem to be working (eg in bf16).
Steps/Code to reproduce bug
- Train with ``"bf16-with-fp8-delayed-scaling-mixed"`
torchrun --nproc-per-node 8 --no-python \
train_evo2 \
--hf-tokenizer-model-path tokenizers/nucleotide_fast_tokenizer_256 \
--model-size striped_hyena_1b_nv_parallel --max-steps 12 --eval-interval 10 \
--eval-iters 3 --mock-data \
--micro-batch-size 32 --global-batch-size 256 --seq-length 1024 \
--tensor-model-parallel 1 \
--use-precision-aware-optimizer --dataset-seed 33 \
--seed 41 --ckpt-async-save --spike-no-more-embedding-init \
--no-weight-decay-embeddings --cross-entropy-loss-fusion \
--align-param-gather --overlap-param-gather --grad-reduce-in-fp32 \
--decay-steps 100 --warmup-steps 10 \
--mixed-precision-recipe bf16-with-fp8-delayed-scaling-mixed \
--no-fp32-residual-connection --activation-checkpoint-recompute-num-layers 1 \
--attention-dropout 0.001 --hidden-dropout 0.001 \
--eod-pad-in-loss-mask --enable-preemption \
--log-interval 5 --debug-ddp-parity-freq 10 \
--wandb-project evo2-recipes-verification-tmp \
--wandb-run-name tmp_workstation_run_mock_data \
--result-dir tmpfp8 --no-renormalize-loss
- Resume training (eg increase max steps and run again) keeping the
"bf16-with-fp8-delayed-scaling-mixed"recipe and observe failure
torchrun --nproc-per-node 8 --no-python \
train_evo2 \
--hf-tokenizer-model-path tokenizers/nucleotide_fast_tokenizer_256 \
--model-size striped_hyena_1b_nv_parallel --max-steps 22 --eval-interval 10 \
--eval-iters 3 --mock-data \
--micro-batch-size 32 --global-batch-size 256 --seq-length 1024 \
--tensor-model-parallel 1 \
--use-precision-aware-optimizer --dataset-seed 33 \
--seed 41 --ckpt-async-save --spike-no-more-embedding-init \
--no-weight-decay-embeddings --cross-entropy-loss-fusion \
--align-param-gather --overlap-param-gather --grad-reduce-in-fp32 \
--decay-steps 100 --warmup-steps 10 \
--mixed-precision-recipe bf16-with-fp8-delayed-scaling-mixed \
--no-fp32-residual-connection --activation-checkpoint-recompute-num-layers 1 \
--attention-dropout 0.001 --hidden-dropout 0.001 \
--eod-pad-in-loss-mask --enable-preemption \
--log-interval 5 --debug-ddp-parity-freq 10 \
--wandb-project evo2-recipes-verification-tmp \
--wandb-run-name tmp_workstation_run_mock_data \
--result-dir tmpfp8 --no-renormalize-loss
Repeat 1,2 with "bf16-mixed" and do not observe failure. Also replace step 2 above with "bf16-mixed" and do not observe failure (eg you can train with fp8 and resume with bf16).
Expected behavior
FP8 and BF16 training both work.
Additional context
Reach out on slack (nvidia internal) for example images/run scripts on eos.