Skip to content

Commit 8e8acff

Browse files
misc
1 parent f97896b commit 8e8acff

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

torchtune/training/_distributed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -378,11 +378,11 @@ def gather_cpu_state_dict(
378378
if isinstance(param, NF4Tensor):
379379
# upcasting NF4 to original dtype
380380
param = param.to(param.dtype)
381-
if adapter_weights_only:
382-
cpu_state_dict = get_adapter_state_dict(cpu_state_dict, device=None)
383381
if is_rank_zero:
384382
cpu_state_dict[param_name] = param.cpu()
385383
torch.distributed.barrier()
384+
if adapter_weights_only:
385+
cpu_state_dict = get_adapter_state_dict(cpu_state_dict, device=None)
386386
return cpu_state_dict
387387

388388

0 commit comments

Comments
 (0)