We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent f97896b commit 8e8acffCopy full SHA for 8e8acff
torchtune/training/_distributed.py
@@ -378,11 +378,11 @@ def gather_cpu_state_dict(
378
if isinstance(param, NF4Tensor):
379
# upcasting NF4 to original dtype
380
param = param.to(param.dtype)
381
- if adapter_weights_only:
382
- cpu_state_dict = get_adapter_state_dict(cpu_state_dict, device=None)
383
if is_rank_zero:
384
cpu_state_dict[param_name] = param.cpu()
385
torch.distributed.barrier()
+ if adapter_weights_only:
+ cpu_state_dict = get_adapter_state_dict(cpu_state_dict, device=None)
386
return cpu_state_dict
387
388
0 commit comments