-
Notifications
You must be signed in to change notification settings - Fork 667
Description
I follow this cookbook in order to finetune LLaMa4 (https://github.com/meta-llama/llama-cookbook/blob/main/getting-started/finetuning/finetune_llama4.md)
Bug Description
Getting RuntimeError: aten.mm.default: got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators! when running distributed LoRA fine-tuning on Llama4-Scout-17B-16E model.
Error Message
RuntimeError: aten.mm.default: got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators!
To Reproduce
Command:
tune run --nproc_per_node 8 lora_finetune_distributed --config llama4/scout_17B_16E_lora
Configuration (scout_17B_16E_lora.yaml)
# Config for multi-device full finetuning in full_finetune_distributed.py
# using a Llama4 17Bx16E MoE model
#
# This config assumes that you've run the following command before launching:
# tune download meta-llama/Llama-4-Scout-17B-16E-Instruct
#
# To launch on 8 devices, run the following command from root:
# tune run --nproc_per_node 8 lora_finetune_distributed --config llama4/scout_17B_16E_lora
#
# You can add specific overrides through the command line. For example, to use a larger bsz:
# tune run --nproc_per_node 8 lora_finetune_distributed --config llama4/scout_17B_16E_lora batch_size=8
#
# This config was only tested on 8xA100 machine.
output_dir: /tmp/torchtune/llama4_17Bx16E/lora
# Modeling Arguments
model:
_component_: torchtune.models.llama4.lora_llama4_scout_17b_16e
decoder_trainable: "lora"
encoder_trainable: "frozen"
fusion_trainable: "lora"
lora_attn_modules: ['q_proj', 'v_proj', 'output_proj']
apply_lora_to_mlp: True
apply_lora_to_output: False
lora_rank: 16 # higher increases accuracy and memory
lora_alpha: 32 # usually alpha=2*rank
lora_dropout: 0.0
tokenizer:
_component_: torchtune.models.llama4.llama4_transform
path: /tmp/Llama-4-Scout-17B-16E-Instruct/tokenizer.model
max_seq_len: null
max_num_tiles: 16
checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Llama-4-Scout-17B-16E-Instruct
checkpoint_files:
filename_format: model-{}-of-{}.safetensors
max_filename: "00050"
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA4
save_adapter_weights_only: True # use this for faster checkpoint save
resume_from_checkpoint: False
# Dataset and Sampler
dataset:
_component_: torchtune.datasets.alpaca_cleaned_dataset
packed: False # True increases speed
split: train[:95%]
seed: null
shuffle: True
# Validation
run_val_every_n_steps: null # Change to an integer to enable validation every N steps
dataset_val:
_component_: torchtune.datasets.alpaca_cleaned_dataset
split: train[95%:]
batch_size_val: ${batch_size}
# Training arguments
epochs: 1
batch_size: 2
max_steps_per_epoch: null
gradient_accumulation_steps: 1 # Use to increase effective batch size
optimizer:
_component_: torch.optim.AdamW
lr: 2e-5
fused: False
optimizer_in_bwd: False
lr_scheduler:
_component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
num_warmup_steps: 100
loss:
_component_: torchtune.modules.loss.LinearCrossEntropyLoss
clip_grad_norm: null
# cuda, cpu, rocm, xpu...
device: cuda
# Memory management / performance
enable_activation_checkpointing: True
enable_activation_offloading: False
custom_sharded_layers: ['tok_embeddings']
fsdp_cpu_offload: False
compile: False # torch.compile, set to true for perf/memory improvement
# Reduced precision
dtype: bf16
# Log metrics during training
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
log_level: INFO # DEBUG, WARN, etc.
# Useful for understanding how to optimize memory and performance
profiler:
_component_: torchtune.training.setup_torch_profiler
enabled: False
Environment Information
pytorch
PyTorch: 2.8.0.dev20250625+cu126
cdua version
CUDA: 12.6
CUDA devices: 8
python --version
Python 3.10.18
OS system
Linux ip-10-207-79-120 6.8.0-1030-aws #32-Ubuntu SMP Wed May 28 19:48:56 UTC 2025 x86_64 x86_64 x86_64 GNU/Linux
nvidia-smi
py310) ubuntu@ip-10-207-79-120:~/torchtune/recipes/configs/llama4$ nvidia-smi
Sun Jun 29 10:56:45 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.133.20 Driver Version: 570.133.20 CUDA Version: 12.8 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA A100-SXM4-80GB On | 00000000:10:1C.0 Off | 0 |
| N/A 49C P0 69W / 400W | 0MiB / 81920MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 1 NVIDIA A100-SXM4-80GB On | 00000000:10:1D.0 Off | 0 |
| N/A 46C P0 68W / 400W | 0MiB / 81920MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 2 NVIDIA A100-SXM4-80GB On | 00000000:20:1C.0 Off | 0 |
| N/A 49C P0 72W / 400W | 0MiB / 81920MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 3 NVIDIA A100-SXM4-80GB On | 00000000:20:1D.0 Off | 0 |
| N/A 43C P0 65W / 400W | 0MiB / 81920MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 4 NVIDIA A100-SXM4-80GB On | 00000000:90:1C.0 Off | 0 |
| N/A 47C P0 66W / 400W | 0MiB / 81920MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 5 NVIDIA A100-SXM4-80GB On | 00000000:90:1D.0 Off | 0 |
| N/A 44C P0 70W / 400W | 0MiB / 81920MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 6 NVIDIA A100-SXM4-80GB On | 00000000:A0:1C.0 Off | 0 |
| N/A 45C P0 69W / 400W | 0MiB / 81920MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 7 NVIDIA A100-SXM4-80GB On | 00000000:A0:1D.0 Off | 0 |
| N/A 42C P0 68W / 400W | 0MiB / 81920MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| No running processes found |
+-----------------------------------------------------------------------------------------+