diff --git a/torchtune/modules/loss/cross_entropy_loss.py b/torchtune/modules/loss/cross_entropy_loss.py index db33cf17ea..cefcdc70ca 100644 --- a/torchtune/modules/loss/cross_entropy_loss.py +++ b/torchtune/modules/loss/cross_entropy_loss.py @@ -72,8 +72,8 @@ def set_model_output(self, model: nn.Module) -> None: def patch_tp_plan(self, tp_plan) -> dict: if "output" not in tp_plan and "decoder.output" not in tp_plan: raise KeyError("`tp_plan` requires `output` key") - - tp_plan["output"] = ColwiseParallel( + key = "output" if "output" in tp_plan else "decoder.output" + tp_plan[key] = ColwiseParallel( input_layouts=Shard(1), output_layouts=Shard(-1), use_local_output=False,